mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add a few clip and detection transforms for outputs
This commit is contained in:
parent
45ae15eed5
commit
7b1cb402b4
@ -1,13 +1,16 @@
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core import (
|
||||||
|
BaseConfig,
|
||||||
ImportConfig,
|
ImportConfig,
|
||||||
Registry,
|
Registry,
|
||||||
add_import_config,
|
add_import_config,
|
||||||
)
|
)
|
||||||
from batdetect2.outputs.types import ClipDetectionsTransform
|
from batdetect2.outputs.types import ClipDetectionsTransform
|
||||||
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClipDetectionsTransformConfig",
|
"ClipDetectionsTransformConfig",
|
||||||
@ -25,7 +28,139 @@ class ClipDetectionsTransformImportConfig(ImportConfig):
|
|||||||
name: Literal["import"] = "import"
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveAboveNyquistConfig(BaseConfig):
|
||||||
|
"""Configuration for `RemoveAboveNyquist`.
|
||||||
|
|
||||||
|
Defines parameters for removing detections above the Nyquist frequency.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["remove_above_nyquist"]
|
||||||
|
The unique identifier for this transform type.
|
||||||
|
min_freq : float
|
||||||
|
The minimum frequency (in Hz) for detections to be kept.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["remove_above_nyquist"] = "remove_above_nyquist"
|
||||||
|
mode: Literal["low_freq", "high_freq"] = "high_freq"
|
||||||
|
buffer: float = 0
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveAboveNyquist:
|
||||||
|
def __init__(self, mode: Literal["low_freq", "high_freq"], buffer: float):
|
||||||
|
self.mode = mode
|
||||||
|
self.buffer = buffer
|
||||||
|
|
||||||
|
def __call__(self, detections: ClipDetections) -> ClipDetections:
|
||||||
|
recording = detections.clip.recording
|
||||||
|
nyquist = recording.samplerate / 2
|
||||||
|
threshold = nyquist - self.buffer
|
||||||
|
|
||||||
|
return ClipDetections(
|
||||||
|
clip=detections.clip,
|
||||||
|
detections=[
|
||||||
|
detection
|
||||||
|
for detection in detections.detections
|
||||||
|
if self._is_below_threshold(detection, threshold)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_below_threshold(
|
||||||
|
self,
|
||||||
|
detection: Detection,
|
||||||
|
threshold: float,
|
||||||
|
) -> bool:
|
||||||
|
_, low_freq, _, high_freq = compute_bounds(detection.geometry)
|
||||||
|
|
||||||
|
if self.mode == "low_freq":
|
||||||
|
return low_freq < threshold
|
||||||
|
|
||||||
|
return high_freq < threshold
|
||||||
|
|
||||||
|
@clip_transforms.register(RemoveAboveNyquistConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: RemoveAboveNyquistConfig):
|
||||||
|
return RemoveAboveNyquist(
|
||||||
|
mode=config.mode,
|
||||||
|
buffer=config.buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveAtEdgesConfig(BaseConfig):
|
||||||
|
"""Configuration for `RemoveAtEdges`.
|
||||||
|
|
||||||
|
Defines parameters for removing detections at the edges of the clip.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["remove_at_edges"]
|
||||||
|
The unique identifier for this transform type.
|
||||||
|
buffer : float
|
||||||
|
The amount of time (in seconds) to remove detections from the edge.
|
||||||
|
mode : Literal["start_time", "end_time", "both"]
|
||||||
|
Criteria for removing detections at the edges of the clip.
|
||||||
|
If "start_time", remove detections with a start time within the
|
||||||
|
buffer. If "end_time", remove detections with an end time within
|
||||||
|
the buffer. If "both", remove detections with a start time within
|
||||||
|
the buffer or an end time within the buffer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["remove_at_edges"] = "remove_at_edges"
|
||||||
|
buffer: float = 0.1
|
||||||
|
mode: Literal["start_time", "end_time", "both"] = "both"
|
||||||
|
|
||||||
|
|
||||||
|
class RemoveAtEdges:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
buffer: float,
|
||||||
|
mode: Literal["start_time", "end_time", "both"],
|
||||||
|
):
|
||||||
|
self.buffer = buffer
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def __call__(self, detections: ClipDetections) -> ClipDetections:
|
||||||
|
clip = detections.clip
|
||||||
|
start = clip.start_time + self.buffer
|
||||||
|
end = clip.end_time - self.buffer
|
||||||
|
|
||||||
|
return ClipDetections(
|
||||||
|
clip=detections.clip,
|
||||||
|
detections=[
|
||||||
|
detection
|
||||||
|
for detection in detections.detections
|
||||||
|
if self._is_within_buffer(detection, start, end)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_within_buffer(
|
||||||
|
self,
|
||||||
|
detection: Detection,
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
) -> bool:
|
||||||
|
start_time, _, end_time, _ = compute_bounds(detection.geometry)
|
||||||
|
|
||||||
|
if self.mode == "start_time":
|
||||||
|
return (start_time >= start) and (start_time <= end)
|
||||||
|
|
||||||
|
if self.mode == "end_time":
|
||||||
|
return (end_time >= start) and (end_time <= end)
|
||||||
|
|
||||||
|
return (start_time >= start) and (end_time <= end)
|
||||||
|
|
||||||
|
@clip_transforms.register(RemoveAtEdgesConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: RemoveAtEdgesConfig):
|
||||||
|
return RemoveAtEdges(
|
||||||
|
buffer=config.buffer,
|
||||||
|
mode=config.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
ClipDetectionsTransformConfig = Annotated[
|
ClipDetectionsTransformConfig = Annotated[
|
||||||
ClipDetectionsTransformImportConfig,
|
ClipDetectionsTransformImportConfig
|
||||||
|
| RemoveAboveNyquistConfig
|
||||||
|
| RemoveAtEdgesConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -3,9 +3,10 @@ from dataclasses import replace
|
|||||||
from typing import Annotated, Literal
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.geometry import shift_geometry
|
from soundevent.geometry import compute_bounds, shift_geometry
|
||||||
|
|
||||||
from batdetect2.core.registries import (
|
from batdetect2.core import (
|
||||||
|
BaseConfig,
|
||||||
ImportConfig,
|
ImportConfig,
|
||||||
Registry,
|
Registry,
|
||||||
add_import_config,
|
add_import_config,
|
||||||
@ -31,8 +32,125 @@ class DetectionTransformImportConfig(ImportConfig):
|
|||||||
name: Literal["import"] = "import"
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
class FilterByFrequencyConfig(BaseConfig):
|
||||||
|
"""Configuration for `FilterByFrequency`.
|
||||||
|
|
||||||
|
Defines parameters for filtering detections by frequency.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["filter_by_frequency"]
|
||||||
|
The unique identifier for this transform type.
|
||||||
|
min_freq : float
|
||||||
|
The minimum frequency (in Hz) for detections to be kept.
|
||||||
|
max_freq : float
|
||||||
|
The maximum frequency (in Hz) for detections to be kept.
|
||||||
|
mode : Literal["low_freq", "high_freq", "both"]
|
||||||
|
Criteria for filtering detections by frequency.
|
||||||
|
If "low_freq", keep detections with a low frequency within the
|
||||||
|
specified range. If "high_freq", keep detections with a high
|
||||||
|
frequency within the specified range. If "both", keep detections
|
||||||
|
with a low frequency within the specified range or a high frequency
|
||||||
|
within the specified range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["filter_by_frequency"] = "filter_by_frequency"
|
||||||
|
min_freq: float = 0
|
||||||
|
max_freq: float = float("inf")
|
||||||
|
mode: Literal["low_freq", "high_freq", "both"] = "both"
|
||||||
|
|
||||||
|
|
||||||
|
class FilterByFrequency:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_freq: float = 0,
|
||||||
|
max_freq: float = float("inf"),
|
||||||
|
mode: Literal["low_freq", "high_freq", "both"] = "both",
|
||||||
|
):
|
||||||
|
self.min_freq = min_freq
|
||||||
|
self.max_freq = max_freq
|
||||||
|
self.mode = mode
|
||||||
|
|
||||||
|
def __call__(self, detection: Detection) -> Detection | None:
|
||||||
|
if self._is_within_frequency_range(detection):
|
||||||
|
return detection
|
||||||
|
|
||||||
|
def _is_within_frequency_range(self, detection: Detection) -> bool:
|
||||||
|
_, low_freq, _, high_freq = compute_bounds(detection.geometry)
|
||||||
|
|
||||||
|
if self.mode == "low_freq":
|
||||||
|
return (low_freq >= self.min_freq) and (low_freq <= self.max_freq)
|
||||||
|
|
||||||
|
if self.mode == "high_freq":
|
||||||
|
return (high_freq >= self.min_freq) and (
|
||||||
|
high_freq <= self.max_freq
|
||||||
|
)
|
||||||
|
|
||||||
|
return (low_freq >= self.min_freq) or (high_freq <= self.max_freq)
|
||||||
|
|
||||||
|
@detection_transforms.register(FilterByFrequencyConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: FilterByFrequencyConfig):
|
||||||
|
return FilterByFrequency(
|
||||||
|
min_freq=config.min_freq,
|
||||||
|
max_freq=config.max_freq,
|
||||||
|
mode=config.mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FilterByDurationConfig(BaseConfig):
|
||||||
|
"""Configuration for `FilterByDuration`.
|
||||||
|
|
||||||
|
Defines parameters for filtering detections by duration.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
name : Literal["filter_by_duration"]
|
||||||
|
The unique identifier for this transform type.
|
||||||
|
min_duration : float
|
||||||
|
The minimum duration (in seconds) for detections to be kept.
|
||||||
|
max_duration : float
|
||||||
|
The maximum duration (in seconds) for detections to be kept.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["filter_by_duration"] = "filter_by_duration"
|
||||||
|
min_duration: float = 0
|
||||||
|
max_duration: float = float("inf")
|
||||||
|
|
||||||
|
|
||||||
|
class FilterByDuration:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_duration: float = 0,
|
||||||
|
max_duration: float = float("inf"),
|
||||||
|
):
|
||||||
|
self.min_duration = min_duration
|
||||||
|
self.max_duration = max_duration
|
||||||
|
|
||||||
|
def __call__(self, detection: Detection) -> Detection | None:
|
||||||
|
if self._is_within_duration_range(detection):
|
||||||
|
return detection
|
||||||
|
|
||||||
|
def _is_within_duration_range(self, detection: Detection) -> bool:
|
||||||
|
start_time, _, end_time, _ = compute_bounds(detection.geometry)
|
||||||
|
duration = end_time - start_time
|
||||||
|
return (duration >= self.min_duration) and (
|
||||||
|
duration <= self.max_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
@detection_transforms.register(FilterByDurationConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: FilterByDurationConfig):
|
||||||
|
return FilterByDuration(
|
||||||
|
min_duration=config.min_duration,
|
||||||
|
max_duration=config.max_duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DetectionTransformConfig = Annotated[
|
DetectionTransformConfig = Annotated[
|
||||||
DetectionTransformImportConfig,
|
DetectionTransformImportConfig
|
||||||
|
| FilterByFrequencyConfig
|
||||||
|
| FilterByDurationConfig,
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
121
tests/test_outputs/test_transform/test_clip_transforms.py
Normal file
121
tests/test_outputs/test_transform/test_clip_transforms.py
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.outputs.transforms.clip_transforms import (
|
||||||
|
RemoveAboveNyquist,
|
||||||
|
RemoveAtEdges,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||||
|
|
||||||
|
|
||||||
|
def _detection(
|
||||||
|
start_time: float,
|
||||||
|
low_freq: float,
|
||||||
|
end_time: float,
|
||||||
|
high_freq: float,
|
||||||
|
) -> Detection:
|
||||||
|
return Detection(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
|
),
|
||||||
|
detection_score=0.9,
|
||||||
|
class_scores=np.array([0.9]),
|
||||||
|
features=np.array([1.0, 2.0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_above_nyquist_high_freq_mode(clip: data.Clip) -> None:
|
||||||
|
# Nyquist should be at 128kHz
|
||||||
|
assert clip.recording.samplerate == 256_000
|
||||||
|
|
||||||
|
transform = RemoveAboveNyquist(mode="high_freq", buffer=0)
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
_detection(0.1, 10_000, 0.2, 120_000),
|
||||||
|
_detection(0.1, 10_000, 0.2, 130_000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
out = transform(prediction)
|
||||||
|
|
||||||
|
assert len(out.detections) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_above_nyquist_low_freq_mode(clip: data.Clip) -> None:
|
||||||
|
transform = RemoveAboveNyquist(mode="low_freq", buffer=0)
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
_detection(0.1, 120_000, 0.2, 140_000),
|
||||||
|
_detection(0.1, 130_000, 0.2, 140_000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
out = transform(prediction)
|
||||||
|
|
||||||
|
assert len(out.detections) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_above_nyquist_respects_buffer(clip: data.Clip) -> None:
|
||||||
|
transform = RemoveAboveNyquist(mode="high_freq", buffer=5_000)
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
_detection(0.1, 10_000, 0.2, 122_000),
|
||||||
|
_detection(0.1, 10_000, 0.2, 124_000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
out = transform(prediction)
|
||||||
|
|
||||||
|
assert len(out.detections) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_at_edges_start_mode(clip: data.Clip) -> None:
|
||||||
|
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0})
|
||||||
|
transform = RemoveAtEdges(buffer=1.0, mode="start_time")
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
_detection(10.2, 20_000, 10.4, 30_000),
|
||||||
|
_detection(11.2, 20_000, 11.4, 30_000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
out = transform(prediction)
|
||||||
|
|
||||||
|
assert len(out.detections) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_at_edges_end_mode(clip: data.Clip) -> None:
|
||||||
|
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0})
|
||||||
|
transform = RemoveAtEdges(buffer=1.0, mode="end_time")
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
_detection(11.2, 20_000, 19.8, 30_000),
|
||||||
|
_detection(11.2, 20_000, 18.6, 30_000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
out = transform(prediction)
|
||||||
|
|
||||||
|
assert len(out.detections) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_at_edges_both_mode(clip: data.Clip) -> None:
|
||||||
|
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 20.0})
|
||||||
|
transform = RemoveAtEdges(buffer=1.0, mode="both")
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=[
|
||||||
|
_detection(10.2, 20_000, 18.5, 30_000),
|
||||||
|
_detection(11.2, 20_000, 18.5, 30_000),
|
||||||
|
_detection(11.2, 20_000, 19.8, 30_000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
out = transform(prediction)
|
||||||
|
|
||||||
|
assert len(out.detections) == 1
|
||||||
117
tests/test_outputs/test_transform/test_detection_transforms.py
Normal file
117
tests/test_outputs/test_transform/test_detection_transforms.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.outputs.transforms.detection_transforms import (
|
||||||
|
FilterByDuration,
|
||||||
|
FilterByDurationConfig,
|
||||||
|
FilterByFrequency,
|
||||||
|
FilterByFrequencyConfig,
|
||||||
|
detection_transforms,
|
||||||
|
shift_detection_time,
|
||||||
|
shift_detections_to_start_time,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.types import Detection
|
||||||
|
|
||||||
|
|
||||||
|
def _detection(
|
||||||
|
start_time: float,
|
||||||
|
low_freq: float,
|
||||||
|
end_time: float,
|
||||||
|
high_freq: float,
|
||||||
|
) -> Detection:
|
||||||
|
return Detection(
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
|
),
|
||||||
|
detection_score=0.9,
|
||||||
|
class_scores=np.array([0.9]),
|
||||||
|
features=np.array([1.0, 2.0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shift_detection_time_moves_geometry_by_offset() -> None:
|
||||||
|
detection = _detection(0.1, 20_000, 0.2, 30_000)
|
||||||
|
|
||||||
|
shifted = shift_detection_time(detection, time=2.5)
|
||||||
|
start, low, end, high = compute_bounds(shifted.geometry)
|
||||||
|
|
||||||
|
assert np.isclose(start, 2.6)
|
||||||
|
assert np.isclose(end, 2.7)
|
||||||
|
assert np.isclose(low, 20_000)
|
||||||
|
assert np.isclose(high, 30_000)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shift_detections_to_start_time_zero_is_identity() -> None:
|
||||||
|
detections = [_detection(0.1, 20_000, 0.2, 30_000)]
|
||||||
|
|
||||||
|
shifted = shift_detections_to_start_time(detections, start_time=0)
|
||||||
|
|
||||||
|
assert len(shifted) == 1
|
||||||
|
assert shifted[0] is detections[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_by_frequency_low_freq_mode() -> None:
|
||||||
|
transform = FilterByFrequency(
|
||||||
|
min_freq=20_000,
|
||||||
|
max_freq=40_000,
|
||||||
|
mode="low_freq",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transform(_detection(0.1, 25_000, 0.2, 60_000)) is not None
|
||||||
|
assert transform(_detection(0.1, 10_000, 0.2, 60_000)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_by_frequency_high_freq_mode() -> None:
|
||||||
|
transform = FilterByFrequency(
|
||||||
|
min_freq=20_000,
|
||||||
|
max_freq=40_000,
|
||||||
|
mode="high_freq",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert transform(_detection(0.1, 10_000, 0.2, 35_000)) is not None
|
||||||
|
assert transform(_detection(0.1, 10_000, 0.2, 60_000)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_by_frequency_both_mode_current_semantics() -> None:
|
||||||
|
transform = FilterByFrequency(
|
||||||
|
min_freq=20_000,
|
||||||
|
max_freq=40_000,
|
||||||
|
mode="both",
|
||||||
|
)
|
||||||
|
|
||||||
|
# low >= min passes
|
||||||
|
assert transform(_detection(0.1, 25_000, 0.2, 80_000)) is not None
|
||||||
|
# high <= max passes
|
||||||
|
assert transform(_detection(0.1, 10_000, 0.2, 35_000)) is not None
|
||||||
|
# neither condition passes
|
||||||
|
assert transform(_detection(0.1, 10_000, 0.2, 80_000)) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_by_duration_keeps_within_range() -> None:
|
||||||
|
transform = FilterByDuration(min_duration=0.04, max_duration=0.06)
|
||||||
|
|
||||||
|
kept = transform(_detection(0.1, 20_000, 0.15, 30_000))
|
||||||
|
removed = transform(_detection(0.1, 20_000, 0.2, 30_000))
|
||||||
|
|
||||||
|
assert kept is not None
|
||||||
|
assert removed is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_detection_transform_registry_builds_builtin_transforms() -> None:
|
||||||
|
frequency_transform = detection_transforms.build(
|
||||||
|
FilterByFrequencyConfig(
|
||||||
|
min_freq=20_000,
|
||||||
|
max_freq=40_000,
|
||||||
|
mode="high_freq",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
duration_transform = detection_transforms.build(
|
||||||
|
FilterByDurationConfig(
|
||||||
|
min_duration=0.01,
|
||||||
|
max_duration=0.2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert callable(frequency_transform)
|
||||||
|
assert callable(duration_transform)
|
||||||
@ -0,0 +1,137 @@
|
|||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.outputs import build_output_transform
|
||||||
|
from batdetect2.postprocess.types import ClipDetectionsTensor
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_clip_detections_tensor(
|
||||||
|
*,
|
||||||
|
time: float,
|
||||||
|
duration: float,
|
||||||
|
frequency: float,
|
||||||
|
bandwidth: float,
|
||||||
|
) -> ClipDetectionsTensor:
|
||||||
|
# NOTE: size time is represented in milliseconds.
|
||||||
|
return ClipDetectionsTensor(
|
||||||
|
scores=torch.tensor([0.9], dtype=torch.float32),
|
||||||
|
sizes=torch.tensor(
|
||||||
|
[[duration * 1_000, bandwidth]], dtype=torch.float32
|
||||||
|
),
|
||||||
|
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
|
||||||
|
times=torch.tensor([time], dtype=torch.float32),
|
||||||
|
frequencies=torch.tensor([frequency], dtype=torch.float32),
|
||||||
|
features=torch.tensor([[1.0, 2.0]], dtype=torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_from_config_applies_detection_and_clip_transforms(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
) -> None:
|
||||||
|
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 11.0})
|
||||||
|
transform = build_output_transform(
|
||||||
|
targets=sample_targets,
|
||||||
|
config={
|
||||||
|
"detection_transforms": [
|
||||||
|
{
|
||||||
|
"name": "filter_by_duration",
|
||||||
|
"min_duration": 0.08,
|
||||||
|
"max_duration": 0.12,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"clip_transforms": [
|
||||||
|
{
|
||||||
|
"name": "remove_at_edges",
|
||||||
|
"buffer": 0.1,
|
||||||
|
"mode": "both",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = _mock_clip_detections_tensor(
|
||||||
|
time=0.03,
|
||||||
|
duration=0.1,
|
||||||
|
frequency=60_000,
|
||||||
|
bandwidth=1_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction = transform.to_clip_detections(raw, clip=clip)
|
||||||
|
|
||||||
|
# duration filter keeps it, edge filter removes it.
|
||||||
|
assert len(prediction.detections) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_keeps_detection_when_all_filters_pass(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
) -> None:
|
||||||
|
clip = clip.model_copy(update={"start_time": 10.0, "end_time": 11.0})
|
||||||
|
transform = build_output_transform(
|
||||||
|
targets=sample_targets,
|
||||||
|
config={
|
||||||
|
"detection_transforms": [
|
||||||
|
{
|
||||||
|
"name": "filter_by_duration",
|
||||||
|
"min_duration": 0.08,
|
||||||
|
"max_duration": 0.12,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"clip_transforms": [
|
||||||
|
{
|
||||||
|
"name": "remove_at_edges",
|
||||||
|
"buffer": 0.05,
|
||||||
|
"mode": "both",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = _mock_clip_detections_tensor(
|
||||||
|
time=0.3,
|
||||||
|
duration=0.1,
|
||||||
|
frequency=60_000,
|
||||||
|
bandwidth=1_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction = transform.to_clip_detections(raw, clip=clip)
|
||||||
|
|
||||||
|
assert len(prediction.detections) == 1
|
||||||
|
start_time, _, _, _ = compute_bounds(prediction.detections[0].geometry)
|
||||||
|
assert np.isclose(start_time, 10.3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_remove_above_nyquist_uses_clip_recording_metadata(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
) -> None:
|
||||||
|
clip = clip.model_copy(update={"start_time": 0.0, "end_time": 1.0})
|
||||||
|
transform = build_output_transform(
|
||||||
|
targets=sample_targets,
|
||||||
|
config={
|
||||||
|
"clip_transforms": [
|
||||||
|
{
|
||||||
|
"name": "remove_above_nyquist",
|
||||||
|
"mode": "high_freq",
|
||||||
|
"buffer": 0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
raw = _mock_clip_detections_tensor(
|
||||||
|
time=0.5,
|
||||||
|
duration=0.05,
|
||||||
|
frequency=127_500,
|
||||||
|
bandwidth=2_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
prediction = transform.to_clip_detections(raw, clip=clip)
|
||||||
|
|
||||||
|
# clip fixture samplerate is 256_000, nyquist is 128_000, high bound
|
||||||
|
# becomes 128_500 and must be removed.
|
||||||
|
assert len(prediction.detections) == 0
|
||||||
Loading…
Reference in New Issue
Block a user