diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 30dc20e..e7fffc6 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -26,7 +26,7 @@ from batdetect2.outputs import ( get_output_formatter, ) from batdetect2.outputs.types import OutputFormatterProtocol -from batdetect2.postprocess import build_postprocessor, to_raw_predictions +from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess.types import ( ClipDetections, Detection, @@ -215,13 +215,8 @@ class BatDetect2API: detections = self.model.postprocessor( outputs, )[0] - raw_predictions = to_raw_predictions( - detections.numpy(), - targets=self.targets, - ) - - return self.output_transform.transform_detections( - raw_predictions, + return self.output_transform.to_detections( + detections=detections, start_time=start_time, ) @@ -321,7 +316,8 @@ class BatDetect2API: config=config.outputs.format, ) output_transform = build_output_transform( - config=config.outputs.transform + config=config.outputs.transform, + targets=targets, ) return cls( @@ -375,7 +371,8 @@ class BatDetect2API: config=config.outputs.format, ) output_transform = build_output_transform( - config=config.outputs.transform + config=config.outputs.transform, + targets=targets, ) return cls( diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 92cea1e..48b8293 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -61,7 +61,10 @@ def run_evaluate( experiment_name=experiment_name, run_name=run_name, ) - output_transform = build_output_transform(config=output_config.transform) + output_transform = build_output_transform( + config=output_config.transform, + targets=targets, + ) module = EvaluationModule( model, evaluator, diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index c721e44..84ee763 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import get_image_logger from batdetect2.models import Model from batdetect2.outputs import OutputTransformProtocol, build_output_transform -from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess.types import ClipDetections @@ -24,7 +23,9 @@ class EvaluationModule(LightningModule): self.model = model self.evaluator = evaluator - self.output_transform = output_transform or build_output_transform() + self.output_transform = output_transform or build_output_transform( + targets=evaluator.targets + ) self.clip_annotations: List[data.ClipAnnotation] = [] self.predictions: List[ClipDetections] = [] @@ -39,18 +40,14 @@ class EvaluationModule(LightningModule): outputs = self.model.detector(batch.spec) clip_detections = self.model.postprocessor(outputs) predictions = [ - ClipDetections( + self.output_transform.to_clip_detections( + detections=clip_dets, clip=clip_annotation.clip, - detections=to_raw_predictions( - clip_dets.numpy(), - targets=self.evaluator.targets, - ), ) for clip_annotation, clip_dets in zip( clip_annotations, clip_detections, strict=False ) ] - predictions = self.output_transform(predictions) self.clip_annotations.extend(clip_annotations) self.predictions.extend(predictions) diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index da56712..b96fc12 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -44,6 +44,7 @@ def run_batch_inference( targets = targets or build_targets() output_transform = output_transform or build_output_transform( config=config.outputs.transform, + targets=targets, ) loader = build_inference_loader( diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index c2689d7..2c853a5 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -6,7 +6,6 @@ from torch.utils.data import DataLoader from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.models import Model from batdetect2.outputs import OutputTransformProtocol, build_output_transform -from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess.types import ClipDetections @@ -18,7 +17,9 @@ class InferenceModule(LightningModule): ): super().__init__() self.model = model - self.output_transform = output_transform or build_output_transform() + self.output_transform = output_transform or build_output_transform( + targets=model.targets + ) def predict_step( self, @@ -34,19 +35,14 @@ class InferenceModule(LightningModule): clip_detections = self.model.postprocessor(outputs) - predictions = [ - ClipDetections( + return [ + self.output_transform.to_clip_detections( + detections=clip_dets, clip=clip, - detections=to_raw_predictions( - clip_dets.numpy(), - targets=self.model.targets, - ), ) - for clip, clip_dets in zip(clips, clip_detections, strict=False) + for clip, clip_dets in zip(clips, clip_detections, strict=True) ] - return self.output_transform(predictions) - def get_dataset(self) -> InferenceDataset: dataloaders = self.trainer.predict_dataloaders assert isinstance(dataloaders, DataLoader) diff --git a/src/batdetect2/outputs/__init__.py b/src/batdetect2/outputs/__init__.py index 28e31c8..c6528c1 100644 --- a/src/batdetect2/outputs/__init__.py +++ b/src/batdetect2/outputs/__init__.py @@ -11,9 +11,9 @@ from batdetect2.outputs.formats import ( ) from batdetect2.outputs.transforms import ( OutputTransformConfig, - OutputTransformProtocol, build_output_transform, ) +from batdetect2.outputs.types import OutputTransformProtocol __all__ = [ "BatDetect2OutputConfig", diff --git a/src/batdetect2/outputs/transforms.py b/src/batdetect2/outputs/transforms.py deleted file mode 100644 index de7dafe..0000000 --- a/src/batdetect2/outputs/transforms.py +++ /dev/null @@ -1,89 +0,0 @@ -from collections.abc import Sequence -from dataclasses import replace -from typing import Protocol - -from soundevent.geometry import shift_geometry - -from batdetect2.core.configs import BaseConfig -from batdetect2.postprocess.types import ClipDetections, Detection - -__all__ = [ - "OutputTransform", - "OutputTransformConfig", - "OutputTransformProtocol", - "build_output_transform", -] - - -class OutputTransformConfig(BaseConfig): - shift_time_to_clip_start: bool = True - - -class OutputTransformProtocol(Protocol): - def __call__( - self, - predictions: Sequence[ClipDetections], - ) -> list[ClipDetections]: ... - - def transform_detections( - self, - detections: Sequence[Detection], - start_time: float = 0, - ) -> list[Detection]: ... - - -def shift_detection_time(detection: Detection, time: float) -> Detection: - geometry = shift_geometry(detection.geometry, time=time) - return replace(detection, geometry=geometry) - - -class OutputTransform(OutputTransformProtocol): - def __init__(self, shift_time_to_clip_start: bool = True): - self.shift_time_to_clip_start = shift_time_to_clip_start - - def __call__( - self, - predictions: Sequence[ClipDetections], - ) -> list[ClipDetections]: - return [ - self.transform_prediction(prediction) for prediction in predictions - ] - - def transform_prediction( - self, prediction: ClipDetections - ) -> ClipDetections: - if not self.shift_time_to_clip_start: - return prediction - - detections = self.transform_detections( - prediction.detections, - start_time=prediction.clip.start_time, - ) - return ClipDetections(clip=prediction.clip, detections=detections) - - def transform_detections( - self, - detections: Sequence[Detection], - start_time: float = 0, - ) -> list[Detection]: - if not self.shift_time_to_clip_start or start_time == 0: - return list(detections) - - return [ - shift_detection_time(detection, time=start_time) - for detection in detections - ] - - -def build_output_transform( - config: OutputTransformConfig | dict | None = None, -) -> OutputTransformProtocol: - if config is None: - config = OutputTransformConfig() - - if not isinstance(config, OutputTransformConfig): - config = OutputTransformConfig.model_validate(config) - - return OutputTransform( - shift_time_to_clip_start=config.shift_time_to_clip_start, - ) diff --git a/src/batdetect2/outputs/transforms/__init__.py b/src/batdetect2/outputs/transforms/__init__.py new file mode 100644 index 0000000..dfcf5b3 --- /dev/null +++ b/src/batdetect2/outputs/transforms/__init__.py @@ -0,0 +1,173 @@ +from collections.abc import Sequence + +from pydantic import Field +from soundevent import data + +from batdetect2.core.configs import BaseConfig +from batdetect2.outputs.transforms.clip_transforms import ( + ClipDetectionsTransformConfig, +) +from batdetect2.outputs.transforms.clip_transforms import ( + clip_transforms as clip_transform_registry, +) +from batdetect2.outputs.transforms.decoding import to_detections +from batdetect2.outputs.transforms.detection_transforms import ( + DetectionTransformConfig, + shift_detections_to_start_time, +) +from batdetect2.outputs.transforms.detection_transforms import ( + detection_transforms as detection_transform_registry, +) +from batdetect2.outputs.types import ( + ClipDetectionsTransform, + DetectionTransform, + OutputTransformProtocol, +) +from batdetect2.postprocess.types import ( + ClipDetections, + ClipDetectionsTensor, + Detection, +) +from batdetect2.targets.types import TargetProtocol + +__all__ = [ + "ClipDetectionsTransformConfig", + "DetectionTransformConfig", + "OutputTransform", + "OutputTransformConfig", + "build_output_transform", +] + + +class OutputTransformConfig(BaseConfig): + detection_transforms: list[DetectionTransformConfig] = Field( + default_factory=list + ) + clip_transforms: list[ClipDetectionsTransformConfig] = Field( + default_factory=list + ) + + +class OutputTransform(OutputTransformProtocol): + detection_transform_steps: list[DetectionTransform] + clip_transform_steps: list[ClipDetectionsTransform] + + def __init__( + self, + targets: TargetProtocol, + detection_transform_steps: Sequence[DetectionTransform] = (), + clip_transform_steps: Sequence[ClipDetectionsTransform] = (), + ): + self.targets = targets + self.detection_transform_steps = list(detection_transform_steps) + self.clip_transform_steps = list(clip_transform_steps) + + def __call__( + self, + predictions: Sequence[ClipDetections], + ) -> list[ClipDetections]: + return [ + self._transform_prediction(prediction) + for prediction in predictions + ] + + def _transform_prediction( + self, + prediction: ClipDetections, + ) -> ClipDetections: + detections = shift_detections_to_start_time( + prediction.detections, + start_time=prediction.clip.start_time, + ) + detections = self.transform_detections(detections) + return self.transform_clip_detections( + ClipDetections(clip=prediction.clip, detections=detections) + ) + + def to_detections( + self, + detections: ClipDetectionsTensor, + start_time: float = 0, + ) -> list[Detection]: + decoded = to_detections(detections.numpy(), targets=self.targets) + shifted = shift_detections_to_start_time( + decoded, + start_time=start_time, + ) + return self.transform_detections(shifted) + + def to_clip_detections( + self, + detections: ClipDetectionsTensor, + clip: data.Clip, + ) -> ClipDetections: + prediction = ClipDetections( + clip=clip, + detections=self.to_detections( + detections, + start_time=clip.start_time, + ), + ) + return self.transform_clip_detections(prediction) + + def transform_detections( + self, + detections: Sequence[Detection], + ) -> list[Detection]: + out: list[Detection] = [] + for detection in detections: + transformed = self.transform_detection(detection) + + if transformed is None: + continue + + out.append(transformed) + + return [] + + def transform_detection( + self, + detection: Detection, + ) -> Detection | None: + for transform in self.detection_transform_steps: + detection = transform(detection) # type: ignore + + if detection is None: + return None + + return detection + + def transform_clip_detections( + self, + prediction: ClipDetections, + ) -> ClipDetections: + for transform in self.clip_transform_steps: + prediction = transform(prediction) + return prediction + + +def build_output_transform( + config: OutputTransformConfig | dict | None = None, + targets: TargetProtocol | None = None, +) -> OutputTransformProtocol: + from batdetect2.targets import build_targets + + if config is None: + config = OutputTransformConfig() + + if not isinstance(config, OutputTransformConfig): + config = OutputTransformConfig.model_validate(config) + + targets = targets or build_targets() + + return OutputTransform( + targets=targets, + detection_transform_steps=[ + detection_transform_registry.build(transform_config) + for transform_config in config.detection_transforms + ], + clip_transform_steps=[ + clip_transform_registry.build(transform_config) + for transform_config in config.clip_transforms + ], + ) diff --git a/src/batdetect2/outputs/transforms/clip_transforms.py b/src/batdetect2/outputs/transforms/clip_transforms.py new file mode 100644 index 0000000..fddfa9a --- /dev/null +++ b/src/batdetect2/outputs/transforms/clip_transforms.py @@ -0,0 +1,31 @@ +from typing import Annotated, Literal + +from pydantic import Field + +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) +from batdetect2.outputs.types import ClipDetectionsTransform + +__all__ = [ + "ClipDetectionsTransformConfig", + "clip_transforms", +] + + +clip_transforms: Registry[ClipDetectionsTransform, []] = Registry( + "clip_detection_transform" +) + + +@add_import_config(clip_transforms) +class ClipDetectionsTransformImportConfig(ImportConfig): + name: Literal["import"] = "import" + + +ClipDetectionsTransformConfig = Annotated[ + ClipDetectionsTransformImportConfig, + Field(discriminator="name"), +] diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/outputs/transforms/decoding.py similarity index 68% rename from src/batdetect2/postprocess/decoding.py rename to src/batdetect2/outputs/transforms/decoding.py index 517522d..f04d3c4 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/outputs/transforms/decoding.py @@ -1,33 +1,28 @@ -"""Decodes extracted detection data into standard soundevent predictions.""" +"""Decode extracted tensors into output-friendly detection objects.""" from typing import List import numpy as np from soundevent import data -from batdetect2.postprocess.types import ( - ClipDetectionsArray, - Detection, -) +from batdetect2.postprocess.types import ClipDetectionsArray, Detection from batdetect2.targets.types import TargetProtocol __all__ = [ - "to_raw_predictions", - "convert_raw_predictions_to_clip_prediction", - "convert_raw_prediction_to_sound_event_prediction", "DEFAULT_CLASSIFICATION_THRESHOLD", + "convert_raw_prediction_to_sound_event_prediction", + "convert_raw_predictions_to_clip_prediction", + "get_class_tags", + "get_generic_tags", + "get_prediction_features", + "to_detections", ] DEFAULT_CLASSIFICATION_THRESHOLD = 0.1 -"""Default threshold applied to classification scores. - -Class predictions with scores below this value are typically ignored during -decoding. -""" -def to_raw_predictions( +def to_detections( detections: ClipDetectionsArray, targets: TargetProtocol, ) -> List[Detection]: @@ -69,7 +64,6 @@ def convert_raw_predictions_to_clip_prediction( classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only: bool = False, ) -> data.ClipPrediction: - """Convert a list of RawPredictions into a soundevent ClipPrediction.""" return data.ClipPrediction( clip=clip, sound_events=[ @@ -92,7 +86,6 @@ def convert_raw_prediction_to_sound_event_prediction( classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, top_class_only: bool = False, ): - """Convert a single RawPrediction into a soundevent SoundEventPrediction.""" sound_event = data.SoundEvent( recording=recording, geometry=raw_prediction.geometry, @@ -123,7 +116,6 @@ def get_generic_tags( detection_score: float, generic_class_tags: List[data.Tag], ) -> List[data.PredictedTag]: - """Create PredictedTag objects for the generic category.""" return [ data.PredictedTag(tag=tag, score=detection_score) for tag in generic_class_tags @@ -131,7 +123,6 @@ def get_generic_tags( def get_prediction_features(features: np.ndarray) -> List[data.Feature]: - """Convert an extracted feature vector DataArray into soundevent Features.""" return [ data.Feature( term=data.Term( @@ -151,39 +142,11 @@ def get_class_tags( top_class_only: bool = False, threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, ) -> List[data.PredictedTag]: - """Generate specific PredictedTags based on class scores and decoder. - - Filters class scores by the threshold, sorts remaining scores descending, - decodes the class name(s) into base tags using the `sound_event_decoder`, - and creates `PredictedTag` objects associating the class score. Stops after - the first (top) class if `top_class_only` is True. - - Parameters - ---------- - class_scores : xr.DataArray - A 1D xarray DataArray containing class probabilities/scores, indexed - by a 'category' coordinate holding the class names. - sound_event_decoder : SoundEventDecoder - Function to map a class name string to a list of base `data.Tag` - objects. - top_class_only : bool, default=False - If True, only generate tags for the single highest-scoring class above - the threshold. - threshold : float, optional - Minimum score for a class to be considered. If None, all classes are - processed (or top-1 if `top_class_only` is True). Defaults to - `DEFAULT_CLASSIFICATION_THRESHOLD`. - - Returns - ------- - List[data.PredictedTag] - A list of `PredictedTag` objects for the class(es) that passed the - threshold, ordered by score if `top_class_only` is False. - """ tags = [] for class_name, score in _iterate_sorted( - class_scores, targets.class_names + class_scores, + targets.class_names, ): if threshold is not None and score < threshold: continue diff --git a/src/batdetect2/outputs/transforms/detection_transforms.py b/src/batdetect2/outputs/transforms/detection_transforms.py new file mode 100644 index 0000000..f5a9fcb --- /dev/null +++ b/src/batdetect2/outputs/transforms/detection_transforms.py @@ -0,0 +1,55 @@ +from collections.abc import Sequence +from dataclasses import replace +from typing import Annotated, Literal + +from pydantic import Field +from soundevent.geometry import shift_geometry + +from batdetect2.core.registries import ( + ImportConfig, + Registry, + add_import_config, +) +from batdetect2.outputs.types import DetectionTransform +from batdetect2.postprocess.types import Detection + +__all__ = [ + "DetectionTransformConfig", + "detection_transforms", + "shift_detection_time", + "shift_detections_to_start_time", +] + + +detection_transforms: Registry[DetectionTransform, []] = Registry( + "detection_transform" +) + + +@add_import_config(detection_transforms) +class DetectionTransformImportConfig(ImportConfig): + name: Literal["import"] = "import" + + +DetectionTransformConfig = Annotated[ + DetectionTransformImportConfig, + Field(discriminator="name"), +] + + +def shift_detection_time(detection: Detection, time: float) -> Detection: + geometry = shift_geometry(detection.geometry, time=time) + return replace(detection, geometry=geometry) + + +def shift_detections_to_start_time( + detections: Sequence[Detection], + start_time: float = 0, +) -> list[Detection]: + if start_time == 0: + return list(detections) + + return [ + shift_detection_time(detection, time=start_time) + for detection in detections + ] diff --git a/src/batdetect2/outputs/types.py b/src/batdetect2/outputs/types.py index 6e67fe6..97d3e14 100644 --- a/src/batdetect2/outputs/types.py +++ b/src/batdetect2/outputs/types.py @@ -1,12 +1,20 @@ -from collections.abc import Sequence +from collections.abc import Callable, Sequence from typing import Generic, Protocol, TypeVar +from soundevent import data from soundevent.data import PathLike -from batdetect2.postprocess.types import ClipDetections +from batdetect2.postprocess.types import ( + ClipDetections, + ClipDetectionsTensor, + Detection, +) __all__ = [ + "ClipDetectionsTransform", + "DetectionTransform", "OutputFormatterProtocol", + "OutputTransformProtocol", ] T = TypeVar("T") @@ -23,3 +31,31 @@ class OutputFormatterProtocol(Protocol, Generic[T]): ) -> None: ... def load(self, path: PathLike) -> list[T]: ... + + +DetectionTransform = Callable[[Detection], Detection | None] +ClipDetectionsTransform = Callable[[ClipDetections], ClipDetections] + + +class OutputTransformProtocol(Protocol): + def to_detections( + self, + detections: ClipDetectionsTensor, + start_time: float = 0, + ) -> list[Detection]: ... + + def to_clip_detections( + self, + detections: ClipDetectionsTensor, + clip: data.Clip, + ) -> ClipDetections: ... + + def transform_detections( + self, + detections: Sequence[Detection], + ) -> list[Detection]: ... + + def transform_clip_detections( + self, + prediction: ClipDetections, + ) -> ClipDetections: ... diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index 8497c00..20f2744 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -4,10 +4,6 @@ from batdetect2.postprocess.config import ( PostprocessConfig, load_postprocess_config, ) -from batdetect2.postprocess.decoding import ( - convert_raw_predictions_to_clip_prediction, - to_raw_predictions, -) from batdetect2.postprocess.nms import non_max_suppression from batdetect2.postprocess.postprocessor import ( Postprocessor, @@ -18,8 +14,6 @@ __all__ = [ "PostprocessConfig", "Postprocessor", "build_postprocessor", - "convert_raw_predictions_to_clip_prediction", - "to_raw_predictions", "load_postprocess_config", "non_max_suppression", ] diff --git a/src/batdetect2/postprocess/config.py b/src/batdetect2/postprocess/config.py index 7f7e297..69d25d9 100644 --- a/src/batdetect2/postprocess/config.py +++ b/src/batdetect2/postprocess/config.py @@ -2,7 +2,6 @@ from pydantic import Field from soundevent import data from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD from batdetect2.postprocess.nms import NMS_KERNEL_SIZE __all__ = [ @@ -11,6 +10,7 @@ __all__ = [ ] DEFAULT_DETECTION_THRESHOLD = 0.01 +DEFAULT_CLASSIFICATION_THRESHOLD = 0.1 TOP_K_PER_SEC = 100 diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 43fb000..6d17c3b 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import get_image_logger from batdetect2.models.types import ModelOutput from batdetect2.outputs import OutputTransformProtocol, build_output_transform -from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess.types import ClipDetections from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule @@ -25,7 +24,7 @@ class ValidationMetrics(Callback): super().__init__() self.evaluator = evaluator - self.output_transform = output_transform or build_output_transform() + self.output_transform = output_transform self._clip_annotations: List[data.ClipAnnotation] = [] self._predictions: List[ClipDetections] = [] @@ -92,6 +91,14 @@ class ValidationMetrics(Callback): dataloader_idx: int = 0, ) -> None: model = pl_module.model + if self.output_transform is None: + self.output_transform = build_output_transform( + targets=model.targets + ) + + output_transform = self.output_transform + assert output_transform is not None + dataset = self.get_dataset(trainer) clip_annotations = [ @@ -101,17 +108,14 @@ class ValidationMetrics(Callback): clip_detections = model.postprocessor(outputs) predictions = [ - ClipDetections( + output_transform.to_clip_detections( + detections=clip_dets, clip=clip_annotation.clip, - detections=to_raw_predictions( - clip_dets.numpy(), targets=model.targets - ), ) for clip_annotation, clip_dets in zip( clip_annotations, clip_detections, strict=False ) ] - predictions = self.output_transform(predictions) self._clip_annotations.extend(clip_annotations) self._predictions.extend(predictions) diff --git a/tests/test_outputs/test_transform/test_transform.py b/tests/test_outputs/test_transform/test_transform.py index bf98ce5..840473f 100644 --- a/tests/test_outputs/test_transform/test_transform.py +++ b/tests/test_outputs/test_transform/test_transform.py @@ -1,12 +1,35 @@ +from dataclasses import replace + import numpy as np +import torch from soundevent import data from soundevent.geometry import compute_bounds from batdetect2.outputs import build_output_transform -from batdetect2.postprocess.types import ClipDetections, Detection +from batdetect2.outputs.transforms import OutputTransform +from batdetect2.postprocess.types import ( + ClipDetections, + ClipDetectionsTensor, + Detection, +) +from batdetect2.targets.types import TargetProtocol -def test_shift_time_to_clip_start(clip: data.Clip): +def _mock_clip_detections_tensor() -> ClipDetectionsTensor: + return ClipDetectionsTensor( + scores=torch.tensor([0.9], dtype=torch.float32), + sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32), + class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32), + times=torch.tensor([0.2], dtype=torch.float32), + frequencies=torch.tensor([60_000.0], dtype=torch.float32), + features=torch.tensor([[1.0, 2.0]], dtype=torch.float32), + ) + + +def test_shift_time_to_clip_start( + clip: data.Clip, + sample_targets: TargetProtocol, +): clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) detection = Detection( @@ -16,7 +39,7 @@ def test_shift_time_to_clip_start(clip: data.Clip): features=np.array([1.0, 2.0]), ) - transformed = build_output_transform()( + transformed = OutputTransform(targets=sample_targets)( [ClipDetections(clip=clip, detections=[detection])] )[0] @@ -28,21 +51,71 @@ def test_shift_time_to_clip_start(clip: data.Clip): assert np.isclose(end_time, 2.7) -def test_transform_identity_when_disabled(clip: data.Clip): +def test_to_clip_detections_shifts_by_clip_start( + clip: data.Clip, + sample_targets: TargetProtocol, +): + clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) + transform = build_output_transform(targets=sample_targets) + raw = _mock_clip_detections_tensor() + + shifted = transform.to_clip_detections(detections=raw, clip=clip) + unshifted = transform.to_detections(detections=raw, start_time=0) + + shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry) + unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry) + + assert np.isclose(shifted_start - unshifted_start, clip.start_time) + + +def test_detection_and_clip_transforms_applied_in_order( + clip: data.Clip, + sample_targets: TargetProtocol, +): clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) - detection = Detection( + detection_1 = Detection( geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]), - detection_score=0.9, + detection_score=0.5, + class_scores=np.array([0.9]), + features=np.array([1.0, 2.0]), + ) + detection_2 = Detection( + geometry=data.BoundingBox(coordinates=[0.2, 10_000, 0.3, 12_000]), + detection_score=0.7, class_scores=np.array([0.9]), features=np.array([1.0, 2.0]), ) - transform = build_output_transform( - config={"shift_time_to_clip_start": False} + def boost_score(detection: Detection) -> Detection: + return replace( + detection, + detection_score=detection.detection_score + 0.2, + ) + + def keep_high_score(detection: Detection) -> Detection | None: + if detection.detection_score < 0.8: + return None + return detection + + def tag_clip_transform(prediction: ClipDetections) -> ClipDetections: + detections = [ + replace(detection, detection_score=1.0) + for detection in prediction.detections + ] + return replace(prediction, detections=detections) + + transform = OutputTransform( + targets=sample_targets, + detection_transform_steps=[boost_score, keep_high_score], + clip_transform_steps=[tag_clip_transform], ) transformed = transform( - [ClipDetections(clip=clip, detections=[detection])] + [ClipDetections(clip=clip, detections=[detection_1, detection_2])] )[0] - assert transformed.detections[0].geometry == detection.geometry + assert len(transformed.detections) == 1 + assert transformed.detections[0].detection_score == 1.0 + + start_time, _, _, _ = compute_bounds(transformed.detections[0].geometry) + assert np.isclose(start_time, 2.7) diff --git a/tests/test_postprocessing/test_decoding.py b/tests/test_postprocessing/test_decoding.py index 212e10b..3152596 100644 --- a/tests/test_postprocessing/test_decoding.py +++ b/tests/test_postprocessing/test_decoding.py @@ -6,7 +6,7 @@ import pytest import xarray as xr from soundevent import data -from batdetect2.postprocess.decoding import ( +from batdetect2.outputs.transforms.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, convert_raw_prediction_to_sound_event_prediction, convert_raw_predictions_to_clip_prediction,