From 7b1cb402b44296a0b8b7635a254b8d008e86c730 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 18 Mar 2026 11:24:22 +0000 Subject: [PATCH] Add a few clip and detection transforms for outputs --- .../outputs/transforms/clip_transforms.py | 139 +++++++++++++++++- .../transforms/detection_transforms.py | 124 +++++++++++++++- .../test_transform/test_clip_transforms.py | 121 +++++++++++++++ .../test_detection_transforms.py | 117 +++++++++++++++ .../test_output_transform_pipeline.py | 137 +++++++++++++++++ 5 files changed, 633 insertions(+), 5 deletions(-) create mode 100644 tests/test_outputs/test_transform/test_clip_transforms.py create mode 100644 tests/test_outputs/test_transform/test_detection_transforms.py create mode 100644 tests/test_outputs/test_transform/test_output_transform_pipeline.py diff --git a/src/batdetect2/outputs/transforms/clip_transforms.py b/src/batdetect2/outputs/transforms/clip_transforms.py index fddfa9a..93a8016 100644 --- a/src/batdetect2/outputs/transforms/clip_transforms.py +++ b/src/batdetect2/outputs/transforms/clip_transforms.py @@ -1,13 +1,16 @@ from typing import Annotated, Literal from pydantic import Field +from soundevent.geometry import compute_bounds -from batdetect2.core.registries import ( +from batdetect2.core import ( + BaseConfig, ImportConfig, Registry, add_import_config, ) from batdetect2.outputs.types import ClipDetectionsTransform +from batdetect2.postprocess.types import ClipDetections, Detection __all__ = [ "ClipDetectionsTransformConfig", @@ -25,7 +28,139 @@ class ClipDetectionsTransformImportConfig(ImportConfig): name: Literal["import"] = "import" +class RemoveAboveNyquistConfig(BaseConfig): + """Configuration for `RemoveAboveNyquist`. + + Defines parameters for removing detections above the Nyquist frequency. + + Attributes + ---------- + name : Literal["remove_above_nyquist"] + The unique identifier for this transform type. + min_freq : float + The minimum frequency (in Hz) for detections to be kept. + """ + + name: Literal["remove_above_nyquist"] = "remove_above_nyquist" + mode: Literal["low_freq", "high_freq"] = "high_freq" + buffer: float = 0 + + +class RemoveAboveNyquist: + def __init__(self, mode: Literal["low_freq", "high_freq"], buffer: float): + self.mode = mode + self.buffer = buffer + + def __call__(self, detections: ClipDetections) -> ClipDetections: + recording = detections.clip.recording + nyquist = recording.samplerate / 2 + threshold = nyquist - self.buffer + + return ClipDetections( + clip=detections.clip, + detections=[ + detection + for detection in detections.detections + if self._is_below_threshold(detection, threshold) + ], + ) + + def _is_below_threshold( + self, + detection: Detection, + threshold: float, + ) -> bool: + _, low_freq, _, high_freq = compute_bounds(detection.geometry) + + if self.mode == "low_freq": + return low_freq < threshold + + return high_freq < threshold + + @clip_transforms.register(RemoveAboveNyquistConfig) + @staticmethod + def from_config(config: RemoveAboveNyquistConfig): + return RemoveAboveNyquist( + mode=config.mode, + buffer=config.buffer, + ) + + +class RemoveAtEdgesConfig(BaseConfig): + """Configuration for `RemoveAtEdges`. + + Defines parameters for removing detections at the edges of the clip. + + Attributes + ---------- + name : Literal["remove_at_edges"] + The unique identifier for this transform type. + buffer : float + The amount of time (in seconds) to remove detections from the edge. + mode : Literal["start_time", "end_time", "both"] + Criteria for removing detections at the edges of the clip. + If "start_time", remove detections with a start time within the + buffer. If "end_time", remove detections with an end time within + the buffer. If "both", remove detections with a start time within + the buffer or an end time within the buffer. + """ + + name: Literal["remove_at_edges"] = "remove_at_edges" + buffer: float = 0.1 + mode: Literal["start_time", "end_time", "both"] = "both" + + +class RemoveAtEdges: + def __init__( + self, + buffer: float, + mode: Literal["start_time", "end_time", "both"], + ): + self.buffer = buffer + self.mode = mode + + def __call__(self, detections: ClipDetections) -> ClipDetections: + clip = detections.clip + start = clip.start_time + self.buffer + end = clip.end_time - self.buffer + + return ClipDetections( + clip=detections.clip, + detections=[ + detection + for detection in detections.detections + if self._is_within_buffer(detection, start, end) + ], + ) + + def _is_within_buffer( + self, + detection: Detection, + start: float, + end: float, + ) -> bool: + start_time, _, end_time, _ = compute_bounds(detection.geometry) + + if self.mode == "start_time": + return (start_time >= start) and (start_time <= end) + + if self.mode == "end_time": + return (end_time >= start) and (end_time <= end) + + return (start_time >= start) and (end_time <= end) + + @clip_transforms.register(RemoveAtEdgesConfig) + @staticmethod + def from_config(config: RemoveAtEdgesConfig): + return RemoveAtEdges( + buffer=config.buffer, + mode=config.mode, + ) + + ClipDetectionsTransformConfig = Annotated[ - ClipDetectionsTransformImportConfig, + ClipDetectionsTransformImportConfig + | RemoveAboveNyquistConfig + | RemoveAtEdgesConfig, Field(discriminator="name"), ] diff --git a/src/batdetect2/outputs/transforms/detection_transforms.py b/src/batdetect2/outputs/transforms/detection_transforms.py index f5a9fcb..b5fcc1b 100644 --- a/src/batdetect2/outputs/transforms/detection_transforms.py +++ b/src/batdetect2/outputs/transforms/detection_transforms.py @@ -3,9 +3,10 @@ from dataclasses import replace from typing import Annotated, Literal from pydantic import Field -from soundevent.geometry import shift_geometry +from soundevent.geometry import compute_bounds, shift_geometry -from batdetect2.core.registries import ( +from batdetect2.core import ( + BaseConfig, ImportConfig, Registry, add_import_config, @@ -31,8 +32,125 @@ class DetectionTransformImportConfig(ImportConfig): name: Literal["import"] = "import" +class FilterByFrequencyConfig(BaseConfig): + """Configuration for `FilterByFrequency`. + + Defines parameters for filtering detections by frequency. + + Attributes + ---------- + name : Literal["filter_by_frequency"] + The unique identifier for this transform type. + min_freq : float + The minimum frequency (in Hz) for detections to be kept. + max_freq : float + The maximum frequency (in Hz) for detections to be kept. + mode : Literal["low_freq", "high_freq", "both"] + Criteria for filtering detections by frequency. + If "low_freq", keep detections with a low frequency within the + specified range. If "high_freq", keep detections with a high + frequency within the specified range. If "both", keep detections + with a low frequency within the specified range or a high frequency + within the specified range. + """ + + name: Literal["filter_by_frequency"] = "filter_by_frequency" + min_freq: float = 0 + max_freq: float = float("inf") + mode: Literal["low_freq", "high_freq", "both"] = "both" + + +class FilterByFrequency: + def __init__( + self, + min_freq: float = 0, + max_freq: float = float("inf"), + mode: Literal["low_freq", "high_freq", "both"] = "both", + ): + self.min_freq = min_freq + self.max_freq = max_freq + self.mode = mode + + def __call__(self, detection: Detection) -> Detection | None: + if self._is_within_frequency_range(detection): + return detection + + def _is_within_frequency_range(self, detection: Detection) -> bool: + _, low_freq, _, high_freq = compute_bounds(detection.geometry) + + if self.mode == "low_freq": + return (low_freq >= self.min_freq) and (low_freq <= self.max_freq) + + if self.mode == "high_freq": + return (high_freq >= self.min_freq) and ( + high_freq <= self.max_freq + ) + + return (low_freq >= self.min_freq) or (high_freq <= self.max_freq) + + @detection_transforms.register(FilterByFrequencyConfig) + @staticmethod + def from_config(config: FilterByFrequencyConfig): + return FilterByFrequency( + min_freq=config.min_freq, + max_freq=config.max_freq, + mode=config.mode, + ) + + +class FilterByDurationConfig(BaseConfig): + """Configuration for `FilterByDuration`. + + Defines parameters for filtering detections by duration. + + Attributes + ---------- + name : Literal["filter_by_duration"] + The unique identifier for this transform type. + min_duration : float + The minimum duration (in seconds) for detections to be kept. + max_duration : float + The maximum duration (in seconds) for detections to be kept. + """ + + name: Literal["filter_by_duration"] = "filter_by_duration" + min_duration: float = 0 + max_duration: float = float("inf") + + +class FilterByDuration: + def __init__( + self, + min_duration: float = 0, + max_duration: float = float("inf"), + ): + self.min_duration = min_duration + self.max_duration = max_duration + + def __call__(self, detection: Detection) -> Detection | None: + if self._is_within_duration_range(detection): + return detection + + def _is_within_duration_range(self, detection: Detection) -> bool: + start_time, _, end_time, _ = compute_bounds(detection.geometry) + duration = end_time - start_time + return (duration >= self.min_duration) and ( + duration <= self.max_duration + ) + + @detection_transforms.register(FilterByDurationConfig) + @staticmethod + def from_config(config: FilterByDurationConfig): + return FilterByDuration( + min_duration=config.min_duration, + max_duration=config.max_duration, + ) + + DetectionTransformConfig = Annotated[ - DetectionTransformImportConfig, + DetectionTransformImportConfig + | FilterByFrequencyConfig + | FilterByDurationConfig, Field(discriminator="name"), ] diff --git a/tests/test_outputs/test_transform/test_clip_transforms.py b/tests/test_outputs/test_transform/test_clip_transforms.py new file mode 100644 index 0000000..445502b --- /dev/null +++ b/tests/test_outputs/test_transform/test_clip_transforms.py @@ -0,0 +1,121 @@ +import numpy as np +from soundevent import data + +from batdetect2.outputs.transforms.clip_transforms import ( + RemoveAboveNyquist, + RemoveAtEdges, +) +from batdetect2.postprocess.types import ClipDetections, Detection + + +def _detection( + start_time: float, + low_freq: float, + end_time: float, + high_freq: float, +) -> Detection: + return Detection( + geometry=data.BoundingBox( + coordinates=[start_time, low_freq, end_time, high_freq] + ), + detection_score=0.9, + class_scores=np.array([0.9]), + features=np.array([1.0, 2.0]), + ) + + +def test_remove_above_nyquist_high_freq_mode(clip: data.Clip) -> None: + # Nyquist should be at 128kHz + assert clip.recording.samplerate == 256_000 + + transform = RemoveAboveNyquist(mode="high_freq", buffer=0) + prediction = ClipDetections( + clip=clip, + detections=[ + _detection(0.1, 10_000, 0.2, 120_000), + _detection(0.1, 10_000, 0.2, 130_000), + ], + ) + + out = transform(prediction) + + assert len(out.detections) == 1 + + +def test_remove_above_nyquist_low_freq_mode(clip: data.Clip) -> None: + transform = RemoveAboveNyquist(mode="low_freq", buffer=0) + prediction = ClipDetections( + clip=clip, + detections=[ + _detection(0.1, 120_000, 0.2, 140_000), + _detection(0.1, 130_000, 0.2, 140_000), + ], + ) + + out = transform(prediction) + + assert len(out.detections) == 1 + + +def test_remove_above_nyquist_respects_buffer(clip: data.Clip) -> None: + transform = RemoveAboveNyquist(mode="high_freq", buffer=5_000) + prediction = ClipDetections( + clip=clip, + detections=[ + _detection(0.1, 10_000, 0.2, 122_000), + _detection(0.1, 10_000, 0.2, 124_000), + ], + ) + + out = transform(prediction) + + assert len(out.detections) == 1 + + +def test_remove_at_edges_start_mode(clip: data.Clip) -> None: + clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0}) + transform = RemoveAtEdges(buffer=1.0, mode="start_time") + prediction = ClipDetections( + clip=clip, + detections=[ + _detection(10.2, 20_000, 10.4, 30_000), + _detection(11.2, 20_000, 11.4, 30_000), + ], + ) + + out = transform(prediction) + + assert len(out.detections) == 1 + + +def test_remove_at_edges_end_mode(clip: data.Clip) -> None: + clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0}) + transform = RemoveAtEdges(buffer=1.0, mode="end_time") + prediction = ClipDetections( + clip=clip, + detections=[ + _detection(11.2, 20_000, 19.8, 30_000), + _detection(11.2, 20_000, 18.6, 30_000), + ], + ) + + out = transform(prediction) + + assert len(out.detections) == 1 + + +def test_remove_at_edges_both_mode(clip: data.Clip) -> None: + clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0}) + transform = RemoveAtEdges(buffer=1.0, mode="both") + prediction = ClipDetections( + clip=clip, + detections=[ + _detection(10.2, 20_000, 18.5, 30_000), + _detection(11.2, 20_000, 18.5, 30_000), + _detection(11.2, 20_000, 19.8, 30_000), + ], + ) + + out = transform(prediction) + + assert len(out.detections) == 1 diff --git a/tests/test_outputs/test_transform/test_detection_transforms.py b/tests/test_outputs/test_transform/test_detection_transforms.py new file mode 100644 index 0000000..a6cfe66 --- /dev/null +++ b/tests/test_outputs/test_transform/test_detection_transforms.py @@ -0,0 +1,117 @@ +import numpy as np +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.outputs.transforms.detection_transforms import ( + FilterByDuration, + FilterByDurationConfig, + FilterByFrequency, + FilterByFrequencyConfig, + detection_transforms, + shift_detection_time, + shift_detections_to_start_time, +) +from batdetect2.postprocess.types import Detection + + +def _detection( + start_time: float, + low_freq: float, + end_time: float, + high_freq: float, +) -> Detection: + return Detection( + geometry=data.BoundingBox( + coordinates=[start_time, low_freq, end_time, high_freq] + ), + detection_score=0.9, + class_scores=np.array([0.9]), + features=np.array([1.0, 2.0]), + ) + + +def test_shift_detection_time_moves_geometry_by_offset() -> None: + detection = _detection(0.1, 20_000, 0.2, 30_000) + + shifted = shift_detection_time(detection, time=2.5) + start, low, end, high = compute_bounds(shifted.geometry) + + assert np.isclose(start, 2.6) + assert np.isclose(end, 2.7) + assert np.isclose(low, 20_000) + assert np.isclose(high, 30_000) + + +def test_shift_detections_to_start_time_zero_is_identity() -> None: + detections = [_detection(0.1, 20_000, 0.2, 30_000)] + + shifted = shift_detections_to_start_time(detections, start_time=0) + + assert len(shifted) == 1 + assert shifted[0] is detections[0] + + +def test_filter_by_frequency_low_freq_mode() -> None: + transform = FilterByFrequency( + min_freq=20_000, + max_freq=40_000, + mode="low_freq", + ) + + assert transform(_detection(0.1, 25_000, 0.2, 60_000)) is not None + assert transform(_detection(0.1, 10_000, 0.2, 60_000)) is None + + +def test_filter_by_frequency_high_freq_mode() -> None: + transform = FilterByFrequency( + min_freq=20_000, + max_freq=40_000, + mode="high_freq", + ) + + assert transform(_detection(0.1, 10_000, 0.2, 35_000)) is not None + assert transform(_detection(0.1, 10_000, 0.2, 60_000)) is None + + +def test_filter_by_frequency_both_mode_current_semantics() -> None: + transform = FilterByFrequency( + min_freq=20_000, + max_freq=40_000, + mode="both", + ) + + # low >= min passes + assert transform(_detection(0.1, 25_000, 0.2, 80_000)) is not None + # high <= max passes + assert transform(_detection(0.1, 10_000, 0.2, 35_000)) is not None + # neither condition passes + assert transform(_detection(0.1, 10_000, 0.2, 80_000)) is None + + +def test_filter_by_duration_keeps_within_range() -> None: + transform = FilterByDuration(min_duration=0.04, max_duration=0.06) + + kept = transform(_detection(0.1, 20_000, 0.15, 30_000)) + removed = transform(_detection(0.1, 20_000, 0.2, 30_000)) + + assert kept is not None + assert removed is None + + +def test_detection_transform_registry_builds_builtin_transforms() -> None: + frequency_transform = detection_transforms.build( + FilterByFrequencyConfig( + min_freq=20_000, + max_freq=40_000, + mode="high_freq", + ) + ) + duration_transform = detection_transforms.build( + FilterByDurationConfig( + min_duration=0.01, + max_duration=0.2, + ) + ) + + assert callable(frequency_transform) + assert callable(duration_transform) diff --git a/tests/test_outputs/test_transform/test_output_transform_pipeline.py b/tests/test_outputs/test_transform/test_output_transform_pipeline.py new file mode 100644 index 0000000..737d2f9 --- /dev/null +++ b/tests/test_outputs/test_transform/test_output_transform_pipeline.py @@ -0,0 +1,137 @@ +import numpy as np +import torch +from soundevent import data +from soundevent.geometry import compute_bounds + +from batdetect2.outputs import build_output_transform +from batdetect2.postprocess.types import ClipDetectionsTensor +from batdetect2.targets.types import TargetProtocol + + +def _mock_clip_detections_tensor( + *, + time: float, + duration: float, + frequency: float, + bandwidth: float, +) -> ClipDetectionsTensor: + # NOTE: size time is represented in milliseconds. + return ClipDetectionsTensor( + scores=torch.tensor([0.9], dtype=torch.float32), + sizes=torch.tensor( + [[duration * 1_000, bandwidth]], dtype=torch.float32 + ), + class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32), + times=torch.tensor([time], dtype=torch.float32), + frequencies=torch.tensor([frequency], dtype=torch.float32), + features=torch.tensor([[1.0, 2.0]], dtype=torch.float32), + ) + + +def test_pipeline_from_config_applies_detection_and_clip_transforms( + clip: data.Clip, + sample_targets: TargetProtocol, +) -> None: + clip = clip.model_copy(update={"start_time": 10.0, "end_time": 11.0}) + transform = build_output_transform( + targets=sample_targets, + config={ + "detection_transforms": [ + { + "name": "filter_by_duration", + "min_duration": 0.08, + "max_duration": 0.12, + } + ], + "clip_transforms": [ + { + "name": "remove_at_edges", + "buffer": 0.1, + "mode": "both", + } + ], + }, + ) + + raw = _mock_clip_detections_tensor( + time=0.03, + duration=0.1, + frequency=60_000, + bandwidth=1_000, + ) + + prediction = transform.to_clip_detections(raw, clip=clip) + + # duration filter keeps it, edge filter removes it. + assert len(prediction.detections) == 0 + + +def test_pipeline_keeps_detection_when_all_filters_pass( + clip: data.Clip, + sample_targets: TargetProtocol, +) -> None: + clip = clip.model_copy(update={"start_time": 10.0, "end_time": 11.0}) + transform = build_output_transform( + targets=sample_targets, + config={ + "detection_transforms": [ + { + "name": "filter_by_duration", + "min_duration": 0.08, + "max_duration": 0.12, + }, + ], + "clip_transforms": [ + { + "name": "remove_at_edges", + "buffer": 0.05, + "mode": "both", + } + ], + }, + ) + + raw = _mock_clip_detections_tensor( + time=0.3, + duration=0.1, + frequency=60_000, + bandwidth=1_000, + ) + + prediction = transform.to_clip_detections(raw, clip=clip) + + assert len(prediction.detections) == 1 + start_time, _, _, _ = compute_bounds(prediction.detections[0].geometry) + assert np.isclose(start_time, 10.3) + + +def test_remove_above_nyquist_uses_clip_recording_metadata( + clip: data.Clip, + sample_targets: TargetProtocol, +) -> None: + clip = clip.model_copy(update={"start_time": 0.0, "end_time": 1.0}) + transform = build_output_transform( + targets=sample_targets, + config={ + "clip_transforms": [ + { + "name": "remove_above_nyquist", + "mode": "high_freq", + "buffer": 0, + } + ] + }, + ) + + raw = _mock_clip_detections_tensor( + time=0.5, + duration=0.05, + frequency=127_500, + bandwidth=2_000, + ) + + prediction = transform.to_clip_detections(raw, clip=clip) + + # clip fixture samplerate is 256_000, nyquist is 128_000, high bound + # becomes 128_500 and must be removed. + assert len(prediction.detections) == 0