Add a few clip and detection transforms for outputs

This commit is contained in:
mbsantiago 2026-03-18 11:24:22 +00:00
parent 45ae15eed5
commit 7b1cb402b4
5 changed files with 633 additions and 5 deletions

View File

@ -1,13 +1,16 @@
from typing import Annotated, Literal
from pydantic import Field
from soundevent.geometry import compute_bounds
from batdetect2.core.registries import (
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.outputs.types import ClipDetectionsTransform
from batdetect2.postprocess.types import ClipDetections, Detection
__all__ = [
"ClipDetectionsTransformConfig",
@ -25,7 +28,139 @@ class ClipDetectionsTransformImportConfig(ImportConfig):
name: Literal["import"] = "import"
class RemoveAboveNyquistConfig(BaseConfig):
"""Configuration for `RemoveAboveNyquist`.
Defines parameters for removing detections above the Nyquist frequency.
Attributes
----------
name : Literal["remove_above_nyquist"]
The unique identifier for this transform type.
min_freq : float
The minimum frequency (in Hz) for detections to be kept.
"""
name: Literal["remove_above_nyquist"] = "remove_above_nyquist"
mode: Literal["low_freq", "high_freq"] = "high_freq"
buffer: float = 0
class RemoveAboveNyquist:
def __init__(self, mode: Literal["low_freq", "high_freq"], buffer: float):
self.mode = mode
self.buffer = buffer
def __call__(self, detections: ClipDetections) -> ClipDetections:
recording = detections.clip.recording
nyquist = recording.samplerate / 2
threshold = nyquist - self.buffer
return ClipDetections(
clip=detections.clip,
detections=[
detection
for detection in detections.detections
if self._is_below_threshold(detection, threshold)
],
)
def _is_below_threshold(
self,
detection: Detection,
threshold: float,
) -> bool:
_, low_freq, _, high_freq = compute_bounds(detection.geometry)
if self.mode == "low_freq":
return low_freq < threshold
return high_freq < threshold
@clip_transforms.register(RemoveAboveNyquistConfig)
@staticmethod
def from_config(config: RemoveAboveNyquistConfig):
return RemoveAboveNyquist(
mode=config.mode,
buffer=config.buffer,
)
class RemoveAtEdgesConfig(BaseConfig):
"""Configuration for `RemoveAtEdges`.
Defines parameters for removing detections at the edges of the clip.
Attributes
----------
name : Literal["remove_at_edges"]
The unique identifier for this transform type.
buffer : float
The amount of time (in seconds) to remove detections from the edge.
mode : Literal["start_time", "end_time", "both"]
Criteria for removing detections at the edges of the clip.
If "start_time", remove detections with a start time within the
buffer. If "end_time", remove detections with an end time within
the buffer. If "both", remove detections with a start time within
the buffer or an end time within the buffer.
"""
name: Literal["remove_at_edges"] = "remove_at_edges"
buffer: float = 0.1
mode: Literal["start_time", "end_time", "both"] = "both"
class RemoveAtEdges:
def __init__(
self,
buffer: float,
mode: Literal["start_time", "end_time", "both"],
):
self.buffer = buffer
self.mode = mode
def __call__(self, detections: ClipDetections) -> ClipDetections:
clip = detections.clip
start = clip.start_time + self.buffer
end = clip.end_time - self.buffer
return ClipDetections(
clip=detections.clip,
detections=[
detection
for detection in detections.detections
if self._is_within_buffer(detection, start, end)
],
)
def _is_within_buffer(
self,
detection: Detection,
start: float,
end: float,
) -> bool:
start_time, _, end_time, _ = compute_bounds(detection.geometry)
if self.mode == "start_time":
return (start_time >= start) and (start_time <= end)
if self.mode == "end_time":
return (end_time >= start) and (end_time <= end)
return (start_time >= start) and (end_time <= end)
@clip_transforms.register(RemoveAtEdgesConfig)
@staticmethod
def from_config(config: RemoveAtEdgesConfig):
return RemoveAtEdges(
buffer=config.buffer,
mode=config.mode,
)
ClipDetectionsTransformConfig = Annotated[
ClipDetectionsTransformImportConfig,
ClipDetectionsTransformImportConfig
| RemoveAboveNyquistConfig
| RemoveAtEdgesConfig,
Field(discriminator="name"),
]

View File

@ -3,9 +3,10 @@ from dataclasses import replace
from typing import Annotated, Literal
from pydantic import Field
from soundevent.geometry import shift_geometry
from soundevent.geometry import compute_bounds, shift_geometry
from batdetect2.core.registries import (
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
@ -31,8 +32,125 @@ class DetectionTransformImportConfig(ImportConfig):
name: Literal["import"] = "import"
class FilterByFrequencyConfig(BaseConfig):
"""Configuration for `FilterByFrequency`.
Defines parameters for filtering detections by frequency.
Attributes
----------
name : Literal["filter_by_frequency"]
The unique identifier for this transform type.
min_freq : float
The minimum frequency (in Hz) for detections to be kept.
max_freq : float
The maximum frequency (in Hz) for detections to be kept.
mode : Literal["low_freq", "high_freq", "both"]
Criteria for filtering detections by frequency.
If "low_freq", keep detections with a low frequency within the
specified range. If "high_freq", keep detections with a high
frequency within the specified range. If "both", keep detections
with a low frequency within the specified range or a high frequency
within the specified range.
"""
name: Literal["filter_by_frequency"] = "filter_by_frequency"
min_freq: float = 0
max_freq: float = float("inf")
mode: Literal["low_freq", "high_freq", "both"] = "both"
class FilterByFrequency:
def __init__(
self,
min_freq: float = 0,
max_freq: float = float("inf"),
mode: Literal["low_freq", "high_freq", "both"] = "both",
):
self.min_freq = min_freq
self.max_freq = max_freq
self.mode = mode
def __call__(self, detection: Detection) -> Detection | None:
if self._is_within_frequency_range(detection):
return detection
def _is_within_frequency_range(self, detection: Detection) -> bool:
_, low_freq, _, high_freq = compute_bounds(detection.geometry)
if self.mode == "low_freq":
return (low_freq >= self.min_freq) and (low_freq <= self.max_freq)
if self.mode == "high_freq":
return (high_freq >= self.min_freq) and (
high_freq <= self.max_freq
)
return (low_freq >= self.min_freq) or (high_freq <= self.max_freq)
@detection_transforms.register(FilterByFrequencyConfig)
@staticmethod
def from_config(config: FilterByFrequencyConfig):
return FilterByFrequency(
min_freq=config.min_freq,
max_freq=config.max_freq,
mode=config.mode,
)
class FilterByDurationConfig(BaseConfig):
"""Configuration for `FilterByDuration`.
Defines parameters for filtering detections by duration.
Attributes
----------
name : Literal["filter_by_duration"]
The unique identifier for this transform type.
min_duration : float
The minimum duration (in seconds) for detections to be kept.
max_duration : float
The maximum duration (in seconds) for detections to be kept.
"""
name: Literal["filter_by_duration"] = "filter_by_duration"
min_duration: float = 0
max_duration: float = float("inf")
class FilterByDuration:
def __init__(
self,
min_duration: float = 0,
max_duration: float = float("inf"),
):
self.min_duration = min_duration
self.max_duration = max_duration
def __call__(self, detection: Detection) -> Detection | None:
if self._is_within_duration_range(detection):
return detection
def _is_within_duration_range(self, detection: Detection) -> bool:
start_time, _, end_time, _ = compute_bounds(detection.geometry)
duration = end_time - start_time
return (duration >= self.min_duration) and (
duration <= self.max_duration
)
@detection_transforms.register(FilterByDurationConfig)
@staticmethod
def from_config(config: FilterByDurationConfig):
return FilterByDuration(
min_duration=config.min_duration,
max_duration=config.max_duration,
)
DetectionTransformConfig = Annotated[
DetectionTransformImportConfig,
DetectionTransformImportConfig
| FilterByFrequencyConfig
| FilterByDurationConfig,
Field(discriminator="name"),
]

View File

@ -0,0 +1,121 @@
import numpy as np
from soundevent import data
from batdetect2.outputs.transforms.clip_transforms import (
RemoveAboveNyquist,
RemoveAtEdges,
)
from batdetect2.postprocess.types import ClipDetections, Detection
def _detection(
start_time: float,
low_freq: float,
end_time: float,
high_freq: float,
) -> Detection:
return Detection(
geometry=data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
),
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
def test_remove_above_nyquist_high_freq_mode(clip: data.Clip) -> None:
# Nyquist should be at 128kHz
assert clip.recording.samplerate == 256_000
transform = RemoveAboveNyquist(mode="high_freq", buffer=0)
prediction = ClipDetections(
clip=clip,
detections=[
_detection(0.1, 10_000, 0.2, 120_000),
_detection(0.1, 10_000, 0.2, 130_000),
],
)
out = transform(prediction)
assert len(out.detections) == 1
def test_remove_above_nyquist_low_freq_mode(clip: data.Clip) -> None:
transform = RemoveAboveNyquist(mode="low_freq", buffer=0)
prediction = ClipDetections(
clip=clip,
detections=[
_detection(0.1, 120_000, 0.2, 140_000),
_detection(0.1, 130_000, 0.2, 140_000),
],
)
out = transform(prediction)
assert len(out.detections) == 1
def test_remove_above_nyquist_respects_buffer(clip: data.Clip) -> None:
transform = RemoveAboveNyquist(mode="high_freq", buffer=5_000)
prediction = ClipDetections(
clip=clip,
detections=[
_detection(0.1, 10_000, 0.2, 122_000),
_detection(0.1, 10_000, 0.2, 124_000),
],
)
out = transform(prediction)
assert len(out.detections) == 1
def test_remove_at_edges_start_mode(clip: data.Clip) -> None:
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0})
transform = RemoveAtEdges(buffer=1.0, mode="start_time")
prediction = ClipDetections(
clip=clip,
detections=[
_detection(10.2, 20_000, 10.4, 30_000),
_detection(11.2, 20_000, 11.4, 30_000),
],
)
out = transform(prediction)
assert len(out.detections) == 1
def test_remove_at_edges_end_mode(clip: data.Clip) -> None:
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0})
transform = RemoveAtEdges(buffer=1.0, mode="end_time")
prediction = ClipDetections(
clip=clip,
detections=[
_detection(11.2, 20_000, 19.8, 30_000),
_detection(11.2, 20_000, 18.6, 30_000),
],
)
out = transform(prediction)
assert len(out.detections) == 1
def test_remove_at_edges_both_mode(clip: data.Clip) -> None:
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0})
transform = RemoveAtEdges(buffer=1.0, mode="both")
prediction = ClipDetections(
clip=clip,
detections=[
_detection(10.2, 20_000, 18.5, 30_000),
_detection(11.2, 20_000, 18.5, 30_000),
_detection(11.2, 20_000, 19.8, 30_000),
],
)
out = transform(prediction)
assert len(out.detections) == 1

View File

@ -0,0 +1,117 @@
import numpy as np
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.outputs.transforms.detection_transforms import (
FilterByDuration,
FilterByDurationConfig,
FilterByFrequency,
FilterByFrequencyConfig,
detection_transforms,
shift_detection_time,
shift_detections_to_start_time,
)
from batdetect2.postprocess.types import Detection
def _detection(
start_time: float,
low_freq: float,
end_time: float,
high_freq: float,
) -> Detection:
return Detection(
geometry=data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
),
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
def test_shift_detection_time_moves_geometry_by_offset() -> None:
detection = _detection(0.1, 20_000, 0.2, 30_000)
shifted = shift_detection_time(detection, time=2.5)
start, low, end, high = compute_bounds(shifted.geometry)
assert np.isclose(start, 2.6)
assert np.isclose(end, 2.7)
assert np.isclose(low, 20_000)
assert np.isclose(high, 30_000)
def test_shift_detections_to_start_time_zero_is_identity() -> None:
detections = [_detection(0.1, 20_000, 0.2, 30_000)]
shifted = shift_detections_to_start_time(detections, start_time=0)
assert len(shifted) == 1
assert shifted[0] is detections[0]
def test_filter_by_frequency_low_freq_mode() -> None:
transform = FilterByFrequency(
min_freq=20_000,
max_freq=40_000,
mode="low_freq",
)
assert transform(_detection(0.1, 25_000, 0.2, 60_000)) is not None
assert transform(_detection(0.1, 10_000, 0.2, 60_000)) is None
def test_filter_by_frequency_high_freq_mode() -> None:
transform = FilterByFrequency(
min_freq=20_000,
max_freq=40_000,
mode="high_freq",
)
assert transform(_detection(0.1, 10_000, 0.2, 35_000)) is not None
assert transform(_detection(0.1, 10_000, 0.2, 60_000)) is None
def test_filter_by_frequency_both_mode_current_semantics() -> None:
transform = FilterByFrequency(
min_freq=20_000,
max_freq=40_000,
mode="both",
)
# low >= min passes
assert transform(_detection(0.1, 25_000, 0.2, 80_000)) is not None
# high <= max passes
assert transform(_detection(0.1, 10_000, 0.2, 35_000)) is not None
# neither condition passes
assert transform(_detection(0.1, 10_000, 0.2, 80_000)) is None
def test_filter_by_duration_keeps_within_range() -> None:
transform = FilterByDuration(min_duration=0.04, max_duration=0.06)
kept = transform(_detection(0.1, 20_000, 0.15, 30_000))
removed = transform(_detection(0.1, 20_000, 0.2, 30_000))
assert kept is not None
assert removed is None
def test_detection_transform_registry_builds_builtin_transforms() -> None:
frequency_transform = detection_transforms.build(
FilterByFrequencyConfig(
min_freq=20_000,
max_freq=40_000,
mode="high_freq",
)
)
duration_transform = detection_transforms.build(
FilterByDurationConfig(
min_duration=0.01,
max_duration=0.2,
)
)
assert callable(frequency_transform)
assert callable(duration_transform)

View File

@ -0,0 +1,137 @@
import numpy as np
import torch
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.outputs import build_output_transform
from batdetect2.postprocess.types import ClipDetectionsTensor
from batdetect2.targets.types import TargetProtocol
def _mock_clip_detections_tensor(
*,
time: float,
duration: float,
frequency: float,
bandwidth: float,
) -> ClipDetectionsTensor:
# NOTE: size time is represented in milliseconds.
return ClipDetectionsTensor(
scores=torch.tensor([0.9], dtype=torch.float32),
sizes=torch.tensor(
[[duration * 1_000, bandwidth]], dtype=torch.float32
),
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
times=torch.tensor([time], dtype=torch.float32),
frequencies=torch.tensor([frequency], dtype=torch.float32),
features=torch.tensor([[1.0, 2.0]], dtype=torch.float32),
)
def test_pipeline_from_config_applies_detection_and_clip_transforms(
clip: data.Clip,
sample_targets: TargetProtocol,
) -> None:
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 11.0})
transform = build_output_transform(
targets=sample_targets,
config={
"detection_transforms": [
{
"name": "filter_by_duration",
"min_duration": 0.08,
"max_duration": 0.12,
}
],
"clip_transforms": [
{
"name": "remove_at_edges",
"buffer": 0.1,
"mode": "both",
}
],
},
)
raw = _mock_clip_detections_tensor(
time=0.03,
duration=0.1,
frequency=60_000,
bandwidth=1_000,
)
prediction = transform.to_clip_detections(raw, clip=clip)
# duration filter keeps it, edge filter removes it.
assert len(prediction.detections) == 0
def test_pipeline_keeps_detection_when_all_filters_pass(
clip: data.Clip,
sample_targets: TargetProtocol,
) -> None:
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 11.0})
transform = build_output_transform(
targets=sample_targets,
config={
"detection_transforms": [
{
"name": "filter_by_duration",
"min_duration": 0.08,
"max_duration": 0.12,
},
],
"clip_transforms": [
{
"name": "remove_at_edges",
"buffer": 0.05,
"mode": "both",
}
],
},
)
raw = _mock_clip_detections_tensor(
time=0.3,
duration=0.1,
frequency=60_000,
bandwidth=1_000,
)
prediction = transform.to_clip_detections(raw, clip=clip)
assert len(prediction.detections) == 1
start_time, _, _, _ = compute_bounds(prediction.detections[0].geometry)
assert np.isclose(start_time, 10.3)
def test_remove_above_nyquist_uses_clip_recording_metadata(
clip: data.Clip,
sample_targets: TargetProtocol,
) -> None:
clip = clip.model_copy(update={"start_time": 0.0, "end_time": 1.0})
transform = build_output_transform(
targets=sample_targets,
config={
"clip_transforms": [
{
"name": "remove_above_nyquist",
"mode": "high_freq",
"buffer": 0,
}
]
},
)
raw = _mock_clip_detections_tensor(
time=0.5,
duration=0.05,
frequency=127_500,
bandwidth=2_000,
)
prediction = transform.to_clip_detections(raw, clip=clip)
# clip fixture samplerate is 256_000, nyquist is 128_000, high bound
# becomes 128_500 and must be removed.
assert len(prediction.detections) == 0