diff --git a/pyproject.toml b/pyproject.toml index 60b3ae4..08066ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ exclude = [ "src/batdetect2/plotting/legacy", "src/batdetect2/evaluate/legacy", "src/batdetect2/finetune", + "src/batdetect2/utils", ] [tool.ruff.format] @@ -121,4 +122,5 @@ exclude = [ "src/batdetect2/plotting/legacy", "src/batdetect2/evaluate/legacy", "src/batdetect2/finetune", + "src/batdetect2/utils", ] diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index b9c3b8c..90ed5e7 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -11,17 +11,20 @@ from batdetect2.audio import build_audio_loader from batdetect2.config import BatDetect2Config from batdetect2.core import merge_configs from batdetect2.data import ( - OutputFormatConfig, - build_output_formatter, - get_output_formatter, load_dataset_from_config, ) from batdetect2.data.datasets import Dataset -from batdetect2.data.predictions.base import OutputFormatterProtocol from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.models import Model, build_model +from batdetect2.outputs import ( + OutputFormatConfig, + OutputTransformProtocol, + build_output_formatter, + build_output_transform, + get_output_formatter, +) from batdetect2.postprocess import build_postprocessor, to_raw_predictions from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets @@ -35,6 +38,7 @@ from batdetect2.typing import ( ClipDetections, Detection, EvaluatorProtocol, + OutputFormatterProtocol, PostprocessorProtocol, PreprocessorProtocol, TargetProtocol, @@ -51,6 +55,7 @@ class BatDetect2API: postprocessor: PostprocessorProtocol, evaluator: EvaluatorProtocol, formatter: OutputFormatterProtocol, + output_transform: OutputTransformProtocol, model: Model, ): self.config = config @@ -61,6 +66,7 @@ class BatDetect2API: self.evaluator = evaluator self.model = model self.formatter = formatter + self.output_transform = output_transform self.model.eval() @@ -208,10 +214,16 @@ class BatDetect2API: detections = self.model.postprocessor( outputs, - start_times=[start_time], )[0] + raw_predictions = to_raw_predictions( + detections.numpy(), + targets=self.targets, + ) - return to_raw_predictions(detections.numpy(), targets=self.targets) + return self.output_transform.transform_detections( + raw_predictions, + start_time=start_time, + ) def process_directory( self, @@ -304,7 +316,13 @@ class BatDetect2API: # postprocessor as these may be moved to another device. model = build_model(config=config.model) - formatter = build_output_formatter(targets, config=config.output) + formatter = build_output_formatter( + targets, + config=config.outputs.format, + ) + output_transform = build_output_transform( + config=config.outputs.transform + ) return cls( config=config, @@ -315,6 +333,7 @@ class BatDetect2API: evaluator=evaluator, model=model, formatter=formatter, + output_transform=output_transform, ) @classmethod @@ -351,7 +370,13 @@ class BatDetect2API: evaluator = build_evaluator(config=config.evaluation, targets=targets) - formatter = build_output_formatter(targets, config=config.output) + formatter = build_output_formatter( + targets, + config=config.outputs.format, + ) + output_transform = build_output_transform( + config=config.outputs.transform + ) return cls( config=config, @@ -362,4 +387,5 @@ class BatDetect2API: evaluator=evaluator, model=model, formatter=formatter, + output_transform=output_transform, ) diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index 6382b30..c167d09 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -5,14 +5,13 @@ from soundevent.data import PathLike from batdetect2.audio import AudioConfig from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.data.predictions import OutputFormatConfig -from batdetect2.data.predictions.raw import RawOutputConfig from batdetect2.evaluate.config import ( EvaluationConfig, get_default_eval_config, ) from batdetect2.inference.config import InferenceConfig from batdetect2.models import ModelConfig +from batdetect2.outputs import OutputsConfig from batdetect2.train.config import TrainingConfig __all__ = [ @@ -32,7 +31,7 @@ class BatDetect2Config(BaseConfig): model: ModelConfig = Field(default_factory=ModelConfig) audio: AudioConfig = Field(default_factory=AudioConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig) - output: OutputFormatConfig = Field(default_factory=RawOutputConfig) + outputs: OutputsConfig = Field(default_factory=OutputsConfig) def validate_config(config: dict | None) -> BatDetect2Config: diff --git a/src/batdetect2/data/__init__.py b/src/batdetect2/data/__init__.py index e405a0c..f12f8b0 100644 --- a/src/batdetect2/data/__init__.py +++ b/src/batdetect2/data/__init__.py @@ -12,15 +12,6 @@ from batdetect2.data.datasets import ( load_dataset_config, load_dataset_from_config, ) -from batdetect2.data.predictions import ( - BatDetect2OutputConfig, - OutputFormatConfig, - RawOutputConfig, - SoundEventOutputConfig, - build_output_formatter, - get_output_formatter, - load_predictions, -) from batdetect2.data.summary import ( compute_class_summary, extract_recordings_df, @@ -36,6 +27,7 @@ __all__ = [ "BatDetect2OutputConfig", "DatasetConfig", "OutputFormatConfig", + "ParquetOutputConfig", "RawOutputConfig", "SoundEventOutputConfig", "build_output_formatter", diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index b700d7d..c533e8b 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -10,6 +10,7 @@ from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.logging import build_logger from batdetect2.models import Model +from batdetect2.outputs import build_output_transform from batdetect2.typing import Detection if TYPE_CHECKING: @@ -61,7 +62,12 @@ def evaluate( experiment_name=experiment_name, run_name=run_name, ) - module = EvaluationModule(model, evaluator) + output_transform = build_output_transform(config=config.outputs.transform) + module = EvaluationModule( + model, + evaluator, + output_transform=output_transform, + ) trainer = Trainer(logger=logger, enable_checkpointing=False) metrics = trainer.test(module, loader) diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index f8d51af..0abd367 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -7,6 +7,7 @@ from torch.utils.data import DataLoader from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.logging import get_image_logger from batdetect2.models import Model +from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess import to_raw_predictions from batdetect2.typing import EvaluatorProtocol from batdetect2.typing.postprocess import ClipDetections @@ -17,11 +18,13 @@ class EvaluationModule(LightningModule): self, model: Model, evaluator: EvaluatorProtocol, + output_transform: OutputTransformProtocol | None = None, ): super().__init__() self.model = model self.evaluator = evaluator + self.output_transform = output_transform or build_output_transform() self.clip_annotations: List[data.ClipAnnotation] = [] self.predictions: List[ClipDetections] = [] @@ -34,10 +37,7 @@ class EvaluationModule(LightningModule): ] outputs = self.model.detector(batch.spec) - clip_detections = self.model.postprocessor( - outputs, - start_times=[ca.clip.start_time for ca in clip_annotations], - ) + clip_detections = self.model.postprocessor(outputs) predictions = [ ClipDetections( clip=clip_annotation.clip, @@ -50,6 +50,7 @@ class EvaluationModule(LightningModule): clip_annotations, clip_detections, strict=False ) ] + predictions = self.output_transform(predictions) self.clip_annotations.extend(clip_annotations) self.predictions.extend(predictions) diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index a32549f..23681b7 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -8,6 +8,7 @@ from batdetect2.inference.clips import get_clips_from_files from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.lightning import InferenceModule from batdetect2.models import Model +from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.preprocess.preprocessor import build_preprocessor from batdetect2.targets.targets import build_targets from batdetect2.typing.postprocess import ClipDetections @@ -28,6 +29,7 @@ def run_batch_inference( audio_loader: Optional["AudioLoader"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, config: Optional["BatDetect2Config"] = None, + output_transform: Optional[OutputTransformProtocol] = None, num_workers: int | None = None, batch_size: int | None = None, ) -> List[ClipDetections]: @@ -42,6 +44,9 @@ def run_batch_inference( ) targets = targets or build_targets() + output_transform = output_transform or build_output_transform( + config=config.outputs.transform, + ) loader = build_inference_loader( clips, @@ -52,7 +57,10 @@ def run_batch_inference( batch_size=batch_size, ) - module = InferenceModule(model) + module = InferenceModule( + model, + output_transform=output_transform, + ) trainer = Trainer(enable_checkpointing=False, logger=False) outputs = trainer.predict(module, loader) return [ diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index 1cd07bd..2db58e6 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -5,14 +5,20 @@ from torch.utils.data import DataLoader from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.models import Model +from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess import to_raw_predictions from batdetect2.typing.postprocess import ClipDetections class InferenceModule(LightningModule): - def __init__(self, model: Model): + def __init__( + self, + model: Model, + output_transform: OutputTransformProtocol | None = None, + ): super().__init__() self.model = model + self.output_transform = output_transform or build_output_transform() def predict_step( self, @@ -26,10 +32,7 @@ class InferenceModule(LightningModule): outputs = self.model.detector(batch.spec) - clip_detections = self.model.postprocessor( - outputs, - start_times=[clip.start_time for clip in clips], - ) + clip_detections = self.model.postprocessor(outputs) predictions = [ ClipDetections( @@ -42,7 +45,7 @@ class InferenceModule(LightningModule): for clip, clip_dets in zip(clips, clip_detections, strict=False) ] - return predictions + return self.output_transform(predictions) def get_dataset(self) -> InferenceDataset: dataloaders = self.trainer.predict_dataloaders diff --git a/src/batdetect2/outputs/__init__.py b/src/batdetect2/outputs/__init__.py new file mode 100644 index 0000000..28e31c8 --- /dev/null +++ b/src/batdetect2/outputs/__init__.py @@ -0,0 +1,31 @@ +from batdetect2.outputs.config import OutputsConfig +from batdetect2.outputs.formats import ( + BatDetect2OutputConfig, + OutputFormatConfig, + ParquetOutputConfig, + RawOutputConfig, + SoundEventOutputConfig, + build_output_formatter, + get_output_formatter, + load_predictions, +) +from batdetect2.outputs.transforms import ( + OutputTransformConfig, + OutputTransformProtocol, + build_output_transform, +) + +__all__ = [ + "BatDetect2OutputConfig", + "OutputFormatConfig", + "OutputTransformConfig", + "OutputTransformProtocol", + "OutputsConfig", + "ParquetOutputConfig", + "RawOutputConfig", + "SoundEventOutputConfig", + "build_output_formatter", + "build_output_transform", + "get_output_formatter", + "load_predictions", +] diff --git a/src/batdetect2/outputs/config.py b/src/batdetect2/outputs/config.py new file mode 100644 index 0000000..be3a78b --- /dev/null +++ b/src/batdetect2/outputs/config.py @@ -0,0 +1,15 @@ +from pydantic import Field + +from batdetect2.core.configs import BaseConfig +from batdetect2.outputs.formats import OutputFormatConfig +from batdetect2.outputs.formats.raw import RawOutputConfig +from batdetect2.outputs.transforms import OutputTransformConfig + +__all__ = ["OutputsConfig"] + + +class OutputsConfig(BaseConfig): + format: OutputFormatConfig = Field(default_factory=RawOutputConfig) + transform: OutputTransformConfig = Field( + default_factory=OutputTransformConfig + ) diff --git a/src/batdetect2/data/predictions/__init__.py b/src/batdetect2/outputs/formats/__init__.py similarity index 78% rename from src/batdetect2/data/predictions/__init__.py rename to src/batdetect2/outputs/formats/__init__.py index 638d9a0..d103a37 100644 --- a/src/batdetect2/data/predictions/__init__.py +++ b/src/batdetect2/outputs/formats/__init__.py @@ -3,23 +3,25 @@ from typing import Annotated from pydantic import Field from soundevent.data import PathLike -from batdetect2.data.predictions.base import ( +from batdetect2.outputs.formats.base import ( OutputFormatterProtocol, - prediction_formatters, + output_formatters, ) -from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig -from batdetect2.data.predictions.parquet import ParquetOutputConfig -from batdetect2.data.predictions.raw import RawOutputConfig -from batdetect2.data.predictions.soundevent import SoundEventOutputConfig +from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig +from batdetect2.outputs.formats.parquet import ParquetOutputConfig +from batdetect2.outputs.formats.raw import RawOutputConfig +from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig from batdetect2.typing import TargetProtocol __all__ = [ - "build_output_formatter", - "get_output_formatter", "BatDetect2OutputConfig", + "OutputFormatConfig", "ParquetOutputConfig", "RawOutputConfig", "SoundEventOutputConfig", + "build_output_formatter", + "get_output_formatter", + "load_predictions", ] @@ -42,7 +44,7 @@ def build_output_formatter( config = config or RawOutputConfig() targets = targets or build_targets() - return prediction_formatters.build(config, targets) + return output_formatters.build(config, targets) def get_output_formatter( @@ -56,7 +58,7 @@ def get_output_formatter( if name is None: raise ValueError("Either config or name must be provided.") - config_class = prediction_formatters.get_config_type(name) + config_class = output_formatters.get_config_type(name) config = config_class() # type: ignore if config.name != name: # type: ignore diff --git a/src/batdetect2/data/predictions/base.py b/src/batdetect2/outputs/formats/base.py similarity index 78% rename from src/batdetect2/data/predictions/base.py rename to src/batdetect2/outputs/formats/base.py index 70e3bf5..5fd5f24 100644 --- a/src/batdetect2/data/predictions/base.py +++ b/src/batdetect2/outputs/formats/base.py @@ -9,6 +9,13 @@ from batdetect2.typing import ( TargetProtocol, ) +__all__ = [ + "OutputFormatterProtocol", + "PredictionFormatterImportConfig", + "make_path_relative", + "output_formatters", +] + def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path: path = Path(path) @@ -25,12 +32,12 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path: return path -prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = ( +output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = ( Registry(name="output_formatter") ) -@add_import_config(prediction_formatters) +@add_import_config(output_formatters) class PredictionFormatterImportConfig(ImportConfig): """Use any callable as a prediction formatter. diff --git a/src/batdetect2/data/predictions/batdetect2.py b/src/batdetect2/outputs/formats/batdetect2.py similarity index 83% rename from src/batdetect2/data/predictions/batdetect2.py rename to src/batdetect2/outputs/formats/batdetect2.py index 4a76541..386d214 100644 --- a/src/batdetect2/data/predictions/batdetect2.py +++ b/src/batdetect2/outputs/formats/batdetect2.py @@ -7,9 +7,9 @@ from soundevent import data from soundevent.geometry import compute_bounds from batdetect2.core import BaseConfig -from batdetect2.data.predictions.base import ( +from batdetect2.outputs.formats.base import ( make_path_relative, - prediction_formatters, + output_formatters, ) from batdetect2.targets import terms from batdetect2.typing import ( @@ -28,76 +28,32 @@ DictWithClass = TypedDict("DictWithClass", {"class": str}) class Annotation(DictWithClass): - """Format of annotations. - - This is the format of a single annotation as expected by the - annotation tool. - """ - start_time: float - """Start time in seconds.""" - end_time: float - """End time in seconds.""" - low_freq: float - """Low frequency in Hz.""" - high_freq: float - """High frequency in Hz.""" - class_prob: float - """Probability of class assignment.""" - det_prob: float - """Probability of detection.""" - individual: str - """Individual ID.""" - event: str - """Type of detected event.""" class FileAnnotation(TypedDict): - """Format of results. - - This is the format of the results expected by the annotation tool. - """ - id: str - """File ID.""" - annotated: bool - """Whether file has been annotated.""" - duration: float - """Duration of audio file.""" - issues: bool - """Whether file has issues.""" - time_exp: float - """Time expansion factor.""" - class_name: str - """Class predicted at file level.""" - notes: str - """Notes of file.""" - annotation: List[Annotation] - """List of annotations.""" - file_path: NotRequired[str] # ty: ignore[invalid-type-form] - """Path to file.""" class BatDetect2OutputConfig(BaseConfig): name: Literal["batdetect2"] = "batdetect2" event_name: str = "Echolocation" - annotation_note: str = "Automatically generated." @@ -156,8 +112,6 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): ] def get_recording_class(self, annotations: List[Annotation]) -> str: - """Get class of recording from annotations.""" - if not annotations: return "" @@ -215,7 +169,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]): **{"class": top_class}, ) - @prediction_formatters.register(BatDetect2OutputConfig) + @output_formatters.register(BatDetect2OutputConfig) @staticmethod def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol): return BatDetect2Formatter( diff --git a/src/batdetect2/data/predictions/parquet.py b/src/batdetect2/outputs/formats/parquet.py similarity index 88% rename from src/batdetect2/data/predictions/parquet.py rename to src/batdetect2/outputs/formats/parquet.py index f024a3b..c6d0034 100644 --- a/src/batdetect2/data/predictions/parquet.py +++ b/src/batdetect2/outputs/formats/parquet.py @@ -9,9 +9,9 @@ from soundevent import data from soundevent.geometry import compute_bounds from batdetect2.core import BaseConfig -from batdetect2.data.predictions.base import ( +from batdetect2.outputs.formats.base import ( make_path_relative, - prediction_formatters, + output_formatters, ) from batdetect2.typing import ( ClipDetections, @@ -59,10 +59,7 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]): if not path.parent.exists(): path.parent.mkdir(parents=True) - # Ensure the file has .parquet extension if it's a file path if path.suffix != ".parquet": - # If it's a directory, we might want to save as a partitioned dataset or a single file inside - # For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet' if path.is_dir() or not path.suffix: path = path / "predictions.parquet" @@ -90,7 +87,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]): } if self.include_geometry: - # Store geometry as [start_time, low_freq, end_time, high_freq] start_time, low_freq, end_time, high_freq = compute_bounds( pred.geometry ) @@ -98,8 +94,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]): row["low_freq"] = low_freq row["end_time"] = end_time row["high_freq"] = high_freq - - # Store full geometry as JSON row["geometry"] = pred.geometry.model_dump_json() if self.include_class_scores: @@ -121,11 +115,9 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]): def load(self, path: data.PathLike) -> List[ClipDetections]: path = Path(path) if path.is_dir(): - # Try to find parquet files files = list(path.glob("*.parquet")) if not files: return [] - # Read all and concatenate dfs = [pd.read_parquet(f) for f in files] df = pd.concat(dfs, ignore_index=True) else: @@ -148,7 +140,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]): ) predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []} - # Reconstruct geometry if "geometry" in row and row["geometry"]: geometry = data.geometry_validate(row["geometry"]) else: @@ -182,13 +173,14 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]): for clip_data in predictions_by_clip.values(): results.append( ClipDetections( - clip=clip_data["clip"], detections=clip_data["preds"] + clip=clip_data["clip"], + detections=clip_data["preds"], ) ) return results - @prediction_formatters.register(ParquetOutputConfig) + @output_formatters.register(ParquetOutputConfig) @staticmethod def from_config(config: ParquetOutputConfig, targets: TargetProtocol): return ParquetFormatter( diff --git a/src/batdetect2/data/predictions/raw.py b/src/batdetect2/outputs/formats/raw.py similarity index 81% rename from src/batdetect2/data/predictions/raw.py rename to src/batdetect2/outputs/formats/raw.py index 6149f11..dc58cd8 100644 --- a/src/batdetect2/data/predictions/raw.py +++ b/src/batdetect2/outputs/formats/raw.py @@ -10,9 +10,9 @@ from soundevent import data from soundevent.geometry import compute_bounds from batdetect2.core import BaseConfig -from batdetect2.data.predictions.base import ( +from batdetect2.outputs.formats.base import ( make_path_relative, - prediction_formatters, + output_formatters, ) from batdetect2.typing import ( ClipDetections, @@ -95,56 +95,56 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]): update=dict(path=make_path_relative(recording.path, audio_dir)) ) - data = defaultdict(list) + values = defaultdict(list) for pred in prediction.detections: detection_id = str(uuid4()) - data["detection_id"].append(detection_id) - data["detection_score"].append(pred.detection_score) + values["detection_id"].append(detection_id) + values["detection_score"].append(pred.detection_score) start_time, low_freq, end_time, high_freq = compute_bounds( pred.geometry ) - data["start_time"].append(start_time) - data["end_time"].append(end_time) - data["low_freq"].append(low_freq) - data["high_freq"].append(high_freq) + values["start_time"].append(start_time) + values["end_time"].append(end_time) + values["low_freq"].append(low_freq) + values["high_freq"].append(high_freq) - data["geometry"].append(pred.geometry.model_dump_json()) + values["geometry"].append(pred.geometry.model_dump_json()) top_class_index = int(np.argmax(pred.class_scores)) top_class_score = float(pred.class_scores[top_class_index]) top_class = self.targets.class_names[top_class_index] - data["top_class"].append(top_class) - data["top_class_score"].append(top_class_score) + values["top_class"].append(top_class) + values["top_class_score"].append(top_class_score) - data["class_scores"].append(pred.class_scores) - data["features"].append(pred.features) + values["class_scores"].append(pred.class_scores) + values["features"].append(pred.features) num_features = len(pred.features) data_vars = { - "score": (["detection"], data["detection_score"]), - "start_time": (["detection"], data["start_time"]), - "end_time": (["detection"], data["end_time"]), - "low_freq": (["detection"], data["low_freq"]), - "high_freq": (["detection"], data["high_freq"]), - "top_class": (["detection"], data["top_class"]), - "top_class_score": (["detection"], data["top_class_score"]), + "score": (["detection"], values["detection_score"]), + "start_time": (["detection"], values["start_time"]), + "end_time": (["detection"], values["end_time"]), + "low_freq": (["detection"], values["low_freq"]), + "high_freq": (["detection"], values["high_freq"]), + "top_class": (["detection"], values["top_class"]), + "top_class_score": (["detection"], values["top_class_score"]), } coords = { - "detection": ("detection", data["detection_id"]), + "detection": ("detection", values["detection_id"]), "clip_start": clip.start_time, "clip_end": clip.end_time, "clip_id": str(clip.uuid), } if self.include_class_scores: - class_scores = np.stack(data["class_scores"], axis=0) + class_scores = np.stack(values["class_scores"], axis=0) data_vars["class_scores"] = ( ["detection", "classes"], class_scores, @@ -152,12 +152,12 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]): coords["classes"] = ("classes", self.targets.class_names) if self.include_features: - features = np.stack(data["features"], axis=0) + features = np.stack(values["features"], axis=0) data_vars["features"] = (["detection", "feature"], features) coords["feature"] = ("feature", np.arange(num_features)) if self.include_geometry: - data_vars["geometry"] = (["detection"], data["geometry"]) + data_vars["geometry"] = (["detection"], values["geometry"]) return xr.Dataset( data_vars=data_vars, @@ -169,7 +169,6 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]): def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections: clip_data = dataset - clip_id = clip_data.clip_id.item() recording = data.Recording.model_validate_json( clip_data.attrs["recording"] @@ -232,7 +231,7 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]): detections=sound_events, ) - @prediction_formatters.register(RawOutputConfig) + @output_formatters.register(RawOutputConfig) @staticmethod def from_config(config: RawOutputConfig, targets: TargetProtocol): return RawFormatter( diff --git a/src/batdetect2/data/predictions/soundevent.py b/src/batdetect2/outputs/formats/soundevent.py similarity index 96% rename from src/batdetect2/data/predictions/soundevent.py rename to src/batdetect2/outputs/formats/soundevent.py index 6dae1bc..1be9616 100644 --- a/src/batdetect2/data/predictions/soundevent.py +++ b/src/batdetect2/outputs/formats/soundevent.py @@ -5,8 +5,8 @@ import numpy as np from soundevent import data, io from batdetect2.core import BaseConfig -from batdetect2.data.predictions.base import ( - prediction_formatters, +from batdetect2.outputs.formats.base import ( + output_formatters, ) from batdetect2.typing import ( ClipDetections, @@ -121,7 +121,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]): return tags - @prediction_formatters.register(SoundEventOutputConfig) + @output_formatters.register(SoundEventOutputConfig) @staticmethod def from_config(config: SoundEventOutputConfig, targets: TargetProtocol): return SoundEventOutputFormatter( diff --git a/src/batdetect2/outputs/transforms.py b/src/batdetect2/outputs/transforms.py new file mode 100644 index 0000000..89ef5dc --- /dev/null +++ b/src/batdetect2/outputs/transforms.py @@ -0,0 +1,89 @@ +from collections.abc import Sequence +from dataclasses import replace +from typing import Protocol + +from soundevent.geometry import shift_geometry + +from batdetect2.core.configs import BaseConfig +from batdetect2.typing import ClipDetections, Detection + +__all__ = [ + "OutputTransform", + "OutputTransformConfig", + "OutputTransformProtocol", + "build_output_transform", +] + + +class OutputTransformConfig(BaseConfig): + shift_time_to_clip_start: bool = True + + +class OutputTransformProtocol(Protocol): + def __call__( + self, + predictions: Sequence[ClipDetections], + ) -> list[ClipDetections]: ... + + def transform_detections( + self, + detections: Sequence[Detection], + start_time: float = 0, + ) -> list[Detection]: ... + + +def shift_detection_time(detection: Detection, time: float) -> Detection: + geometry = shift_geometry(detection.geometry, time=time) + return replace(detection, geometry=geometry) + + +class OutputTransform(OutputTransformProtocol): + def __init__(self, shift_time_to_clip_start: bool = True): + self.shift_time_to_clip_start = shift_time_to_clip_start + + def __call__( + self, + predictions: Sequence[ClipDetections], + ) -> list[ClipDetections]: + return [ + self.transform_prediction(prediction) for prediction in predictions + ] + + def transform_prediction( + self, prediction: ClipDetections + ) -> ClipDetections: + if not self.shift_time_to_clip_start: + return prediction + + detections = self.transform_detections( + prediction.detections, + start_time=prediction.clip.start_time, + ) + return ClipDetections(clip=prediction.clip, detections=detections) + + def transform_detections( + self, + detections: Sequence[Detection], + start_time: float = 0, + ) -> list[Detection]: + if not self.shift_time_to_clip_start or start_time == 0: + return list(detections) + + return [ + shift_detection_time(detection, time=start_time) + for detection in detections + ] + + +def build_output_transform( + config: OutputTransformConfig | dict | None = None, +) -> OutputTransformProtocol: + if config is None: + config = OutputTransformConfig() + + if not isinstance(config, OutputTransformConfig): + config = OutputTransformConfig.model_validate(config) + + return OutputTransform( + shift_time_to_clip_start=config.shift_time_to_clip_start, + ) diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index 2a27560..104a03c 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -64,7 +64,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): def forward( self, output: ModelOutput, - start_times: list[float] | None = None, ) -> list[ClipDetectionsTensor]: detection_heatmap = non_max_suppression( output.detection_probs.detach(), @@ -83,9 +82,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): threshold=self.detection_threshold, ) - if start_times is None: - start_times = [0 for _ in range(len(detections))] - return [ map_detection_to_clip( detection, diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 51c1c59..2383955 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -6,6 +6,7 @@ from soundevent import data from torch.utils.data import DataLoader from batdetect2.logging import get_image_logger +from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess import to_raw_predictions from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule @@ -18,10 +19,15 @@ from batdetect2.typing import ( class ValidationMetrics(Callback): - def __init__(self, evaluator: EvaluatorProtocol): + def __init__( + self, + evaluator: EvaluatorProtocol, + output_transform: OutputTransformProtocol | None = None, + ): super().__init__() self.evaluator = evaluator + self.output_transform = output_transform or build_output_transform() self._clip_annotations: List[data.ClipAnnotation] = [] self._predictions: List[ClipDetections] = [] @@ -95,10 +101,7 @@ class ValidationMetrics(Callback): for example_idx in batch.idx ] - clip_detections = model.postprocessor( - outputs, - start_times=[ca.clip.start_time for ca in clip_annotations], - ) + clip_detections = model.postprocessor(outputs) predictions = [ ClipDetections( clip=clip_annotation.clip, @@ -110,6 +113,7 @@ class ValidationMetrics(Callback): clip_annotations, clip_detections, strict=False ) ] + predictions = self.output_transform(predictions) self._clip_annotations.extend(clip_annotations) self._predictions.extend(predictions) diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index 5c95308..08e642d 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -12,7 +12,7 @@ system that deal with model predictions. """ from dataclasses import dataclass -from typing import List, NamedTuple, Protocol, Sequence +from typing import List, NamedTuple, Protocol import numpy as np import torch @@ -101,5 +101,4 @@ class PostprocessorProtocol(Protocol): def __call__( self, output: ModelOutput, - start_times: Sequence[float] | None = None, ) -> List[ClipDetectionsTensor]: ... diff --git a/tests/test_data/test_predictions/test_parquet.py b/tests/test_data/test_predictions/test_parquet.py index d4e38e0..b17e8bb 100644 --- a/tests/test_data/test_predictions/test_parquet.py +++ b/tests/test_data/test_predictions/test_parquet.py @@ -5,7 +5,7 @@ import numpy as np import pytest from soundevent import data -from batdetect2.data.predictions import ( +from batdetect2.outputs.formats import ( ParquetOutputConfig, build_output_formatter, ) diff --git a/tests/test_data/test_predictions/test_raw.py b/tests/test_data/test_predictions/test_raw.py index 8e2f88d..c51b41f 100644 --- a/tests/test_data/test_predictions/test_raw.py +++ b/tests/test_data/test_predictions/test_raw.py @@ -4,7 +4,7 @@ import numpy as np import pytest from soundevent import data -from batdetect2.data.predictions import RawOutputConfig, build_output_formatter +from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter from batdetect2.typing import ( ClipDetections, Detection, diff --git a/tests/test_outputs/test_transform/test_transform.py b/tests/test_outputs/test_transform/test_transform.py new file mode 100644 index 0000000..811da12 --- /dev/null +++ b/tests/test_outputs/test_transform/test_transform.py @@ -0,0 +1,48 @@ +import numpy as np +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.outputs import build_output_transform +from batdetect2.typing import ClipDetections, Detection + + +def test_shift_time_to_clip_start(clip: data.Clip): + clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) + + detection = Detection( + geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]), + detection_score=0.9, + class_scores=np.array([0.9]), + features=np.array([1.0, 2.0]), + ) + + transformed = build_output_transform()( + [ClipDetections(clip=clip, detections=[detection])] + )[0] + + start_time, _, end_time, _ = compute_bounds( + transformed.detections[0].geometry + ) + + assert np.isclose(start_time, 2.6) + assert np.isclose(end_time, 2.7) + + +def test_transform_identity_when_disabled(clip: data.Clip): + clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) + + detection = Detection( + geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]), + detection_score=0.9, + class_scores=np.array([0.9]), + features=np.array([1.0, 2.0]), + ) + + transform = build_output_transform( + config={"shift_time_to_clip_start": False} + ) + transformed = transform( + [ClipDetections(clip=clip, detections=[detection])] + )[0] + + assert transformed.detections[0].geometry == detection.geometry