From 2b235e28bb03798459404e259f14e7d7cc18f188 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 29 Mar 2026 13:11:42 +0100 Subject: [PATCH] Separate targets and ROI --- example_data/targets.yaml | 5 +- justfile | 2 +- pyproject.toml | 1 + src/batdetect2/api_v2.py | 28 ++++- src/batdetect2/evaluate/evaluate.py | 5 +- src/batdetect2/evaluate/evaluator.py | 12 +- src/batdetect2/inference/batch.py | 7 +- src/batdetect2/inference/lightning.py | 3 +- src/batdetect2/models/__init__.py | 24 +++- src/batdetect2/models/detectors.py | 2 + src/batdetect2/models/heads.py | 5 +- src/batdetect2/outputs/transforms/__init__.py | 15 ++- src/batdetect2/outputs/transforms/decoding.py | 5 +- src/batdetect2/plotting/clip_annotations.py | 9 +- src/batdetect2/targets/__init__.py | 7 ++ src/batdetect2/targets/classes.py | 3 - src/batdetect2/targets/config.py | 4 +- src/batdetect2/targets/rois.py | 93 ++++++++++++++- src/batdetect2/targets/targets.py | 112 +++--------------- src/batdetect2/targets/types.py | 19 ++- src/batdetect2/train/callbacks.py | 3 +- src/batdetect2/train/checkpoints.py | 2 +- src/batdetect2/train/labels.py | 18 ++- src/batdetect2/train/train.py | 29 +++-- tests/test_audio_utils.py | 4 +- tests/test_cli/test_detect.py | 9 -- tests/test_models/test_detectors.py | 16 +++ .../test_transform/test_roundtrip.py | 10 +- .../test_transform/test_transform.py | 22 +++- tests/test_targets/test_rois.py | 44 +++++++ tests/test_targets/test_targets.py | 50 +++++--- tests/test_train/test_checkpoints.py | 22 ---- tests/test_train/test_labels.py | 14 ++- 33 files changed, 403 insertions(+), 201 deletions(-) diff --git a/example_data/targets.yaml b/example_data/targets.yaml index 887d492..f630dc6 100644 --- a/example_data/targets.yaml +++ b/example_data/targets.yaml @@ -32,5 +32,6 @@ classification_targets: value: Rhinolophus ferrumequinum roi: - name: anchor_bbox - anchor: top-left + default: + name: anchor_bbox + anchor: top-left diff --git a/justfile b/justfile index d0ba419..39fa8b7 100644 --- a/justfile +++ b/justfile @@ -20,7 +20,7 @@ install: # Testing & Coverage # Run tests using pytest. test: - uv run pytest {{TESTS_DIR}} + uv run pytest -n auto {{TESTS_DIR}} # Run tests and generate coverage data. coverage: diff --git a/pyproject.toml b/pyproject.toml index 84e10c9..5c4a9ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,6 +88,7 @@ dev = [ "pandas-stubs>=2.2.2.240807", "python-lsp-server>=1.13.0", "deepdiff>=8.6.1", + "pytest-xdist[psutil]>=3.8.0", ] dvclive = ["dvclive>=3.48.2"] mlflow = ["mlflow>=3.1.1"] diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index d7a8733..176d28e 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -50,7 +50,13 @@ from batdetect2.postprocess import ( build_postprocessor, ) from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor -from batdetect2.targets import TargetConfig, TargetProtocol, build_targets +from batdetect2.targets import ( + ROIMapperProtocol, + TargetConfig, + TargetProtocol, + build_roi_mapping, + build_targets, +) from batdetect2.train import ( DEFAULT_CHECKPOINT_DIR, TrainingConfig, @@ -70,6 +76,7 @@ class BatDetect2API: outputs_config: OutputsConfig, logging_config: AppLoggingConfig, targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, @@ -86,6 +93,7 @@ class BatDetect2API: self.outputs_config = outputs_config self.logging_config = logging_config self.targets = targets + self.roi_mapper = roi_mapper self.audio_loader = audio_loader self.preprocessor = preprocessor self.postprocessor = postprocessor @@ -125,6 +133,7 @@ class BatDetect2API: val_annotations=val_annotations, model=self.model, targets=self.targets, + roi_mapper=self.roi_mapper, model_config=model_config or self.model_config, audio_loader=self.audio_loader, preprocessor=self.preprocessor, @@ -171,6 +180,7 @@ class BatDetect2API: val_annotations=val_annotations, model=self.model, targets=self.targets, + roi_mapper=self.roi_mapper, model_config=model_config or self.model_config, preprocessor=self.preprocessor, audio_loader=self.audio_loader, @@ -205,6 +215,7 @@ class BatDetect2API: self.model, test_annotations, targets=self.targets, + roi_mapper=self.roi_mapper, audio_loader=self.audio_loader, preprocessor=self.preprocessor, audio_config=audio_config or self.audio_config, @@ -391,6 +402,7 @@ class BatDetect2API: self.model, audio_files, targets=self.targets, + roi_mapper=self.roi_mapper, audio_loader=self.audio_loader, preprocessor=self.preprocessor, output_transform=self.output_transform, @@ -416,6 +428,7 @@ class BatDetect2API: self.model, clips, targets=self.targets, + roi_mapper=self.roi_mapper, audio_loader=self.audio_loader, preprocessor=self.preprocessor, output_transform=self.output_transform, @@ -472,6 +485,7 @@ class BatDetect2API: config: BatDetect2Config, ) -> "BatDetect2API": targets = build_targets(config=config.model.targets) + roi_mapper = build_roi_mapping(config=config.model.targets.roi) audio_loader = build_audio_loader(config=config.audio) @@ -492,11 +506,13 @@ class BatDetect2API: output_transform = build_output_transform( config=config.outputs.transform, targets=targets, + roi_mapper=roi_mapper, ) evaluator = build_evaluator( config=config.evaluation, targets=targets, + roi_mapper=roi_mapper, transform=output_transform, ) @@ -504,7 +520,8 @@ class BatDetect2API: # to avoid device mismatch errors model = build_model( config=config.model, - targets=build_targets(config=config.model.targets), + targets=targets, + roi_mapper=roi_mapper, preprocessor=build_preprocessor( input_samplerate=audio_loader.samplerate, config=config.model.preprocess, @@ -524,6 +541,7 @@ class BatDetect2API: outputs_config=config.outputs, logging_config=config.logging, targets=targets, + roi_mapper=roi_mapper, audio_loader=audio_loader, preprocessor=preprocessor, postprocessor=postprocessor, @@ -561,15 +579,18 @@ class BatDetect2API: and targets_config != model_config.targets ): targets = build_targets(config=targets_config) + roi_mapper = build_roi_mapping(config=targets_config.roi) model = build_model_with_new_targets( model=model, targets=targets, + roi_mapper=roi_mapper, ) model_config = model_config.model_copy( update={"targets": targets_config} ) targets = build_targets(config=model_config.targets) + roi_mapper = build_roi_mapping(config=model_config.targets.roi) audio_loader = build_audio_loader(config=audio_config) @@ -591,11 +612,13 @@ class BatDetect2API: output_transform = build_output_transform( config=outputs_config.transform, targets=targets, + roi_mapper=roi_mapper, ) evaluator = build_evaluator( config=evaluation_config, targets=targets, + roi_mapper=roi_mapper, transform=output_transform, ) @@ -608,6 +631,7 @@ class BatDetect2API: outputs_config=outputs_config, logging_config=logging_config, targets=targets, + roi_mapper=roi_mapper, audio_loader=audio_loader, preprocessor=preprocessor, postprocessor=postprocessor, diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index c488e93..0a0ac5b 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -16,7 +16,7 @@ from batdetect2.outputs import OutputsConfig, build_output_transform from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.postprocess.types import ClipDetections from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" @@ -25,6 +25,7 @@ def run_evaluate( model: Model, test_annotations: Sequence[data.ClipAnnotation], targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, audio_config: AudioConfig | None = None, @@ -46,6 +47,7 @@ def run_evaluate( preprocessor = preprocessor or model.preprocessor targets = targets or model.targets + roi_mapper = roi_mapper or model.roi_mapper loader = build_test_loader( test_annotations, @@ -57,6 +59,7 @@ def run_evaluate( output_transform = build_output_transform( config=output_config.transform, targets=targets, + roi_mapper=roi_mapper, ) evaluator = build_evaluator( config=evaluation_config, diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py index 685079f..6481fa9 100644 --- a/src/batdetect2/evaluate/evaluator.py +++ b/src/batdetect2/evaluate/evaluator.py @@ -8,8 +8,8 @@ from batdetect2.evaluate.tasks import build_task from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor -from batdetect2.targets import build_targets -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets import build_roi_mapping, build_targets +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol __all__ = [ "Evaluator", @@ -67,17 +67,23 @@ class Evaluator: def build_evaluator( config: EvaluationConfig | dict | None = None, targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, transform: OutputTransformProtocol | None = None, ) -> EvaluatorProtocol: targets = targets or build_targets() + roi_mapper = roi_mapper or build_roi_mapping() + if config is None: config = EvaluationConfig() if not isinstance(config, EvaluationConfig): config = EvaluationConfig.model_validate(config) - transform = transform or build_output_transform(targets=targets) + transform = transform or build_output_transform( + targets=targets, + roi_mapper=roi_mapper, + ) return Evaluator( targets=targets, diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 7ccf71d..2c43bde 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -18,13 +18,14 @@ from batdetect2.outputs import ( ) from batdetect2.postprocess.types import ClipDetections from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol def run_batch_inference( model: Model, clips: Sequence[data.Clip], targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, audio_config: AudioConfig | None = None, @@ -45,10 +46,12 @@ def run_batch_inference( preprocessor = preprocessor or model.preprocessor targets = targets or model.targets + roi_mapper = roi_mapper or model.roi_mapper output_transform = output_transform or build_output_transform( config=output_config.transform, targets=targets, + roi_mapper=roi_mapper, ) loader = build_inference_loader( @@ -78,6 +81,7 @@ def process_file_list( model: Model, paths: Sequence[data.PathLike], targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, audio_loader: AudioLoader | None = None, audio_config: AudioConfig | None = None, preprocessor: PreprocessorProtocol | None = None, @@ -101,6 +105,7 @@ def process_file_list( model, clips, targets=targets, + roi_mapper=roi_mapper, audio_loader=audio_loader, preprocessor=preprocessor, batch_size=batch_size, diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index 7e4b058..7ae010e 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -20,7 +20,8 @@ class InferenceModule(LightningModule): self.model = model self.detection_threshold = detection_threshold self.output_transform = output_transform or build_output_transform( - targets=model.targets + targets=model.targets, + roi_mapper=model.roi_mapper, ) def predict_step( diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index cd1dc1f..9b53004 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -74,7 +74,7 @@ from batdetect2.postprocess.types import ( from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets.config import TargetConfig -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol __all__ = [ "BBoxHead", @@ -186,12 +186,15 @@ class Model(torch.nn.Module): targets : TargetProtocol Describes the set of target classes; used when building heads and during training target construction. + roi_mapper : ROIMapperProtocol + Maps geometries to target-size channels and back. """ detector: DetectionModel preprocessor: PreprocessorProtocol postprocessor: PostprocessorProtocol targets: TargetProtocol + roi_mapper: ROIMapperProtocol def __init__( self, @@ -199,12 +202,14 @@ class Model(torch.nn.Module): preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, ): super().__init__() self.detector = detector self.preprocessor = preprocessor self.postprocessor = postprocessor self.targets = targets + self.roi_mapper = roi_mapper def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]: """Run the full detection pipeline on a waveform tensor. @@ -234,6 +239,7 @@ class Model(torch.nn.Module): def build_model( config: ModelConfig | None = None, targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, preprocessor: PreprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None, ) -> Model: @@ -272,10 +278,19 @@ def build_model( """ from batdetect2.postprocess import build_postprocessor from batdetect2.preprocess import build_preprocessor - from batdetect2.targets import build_targets + from batdetect2.targets import build_roi_mapping, build_targets config = config or ModelConfig() targets = targets or build_targets(config=config.targets) + + targets_config = getattr(targets, "config", None) + roi_config = ( + targets_config.roi + if isinstance(targets_config, TargetConfig) + else config.targets.roi + ) + + roi_mapper = roi_mapper or build_roi_mapping(config=roi_config) preprocessor = preprocessor or build_preprocessor( config=config.preprocess, input_samplerate=config.samplerate, @@ -286,6 +301,7 @@ def build_model( ) detector = build_detector( num_classes=len(targets.class_names), + num_sizes=len(roi_mapper.dimension_names), config=config.architecture, ) return Model( @@ -293,16 +309,19 @@ def build_model( postprocessor=postprocessor, preprocessor=preprocessor, targets=targets, + roi_mapper=roi_mapper, ) def build_model_with_new_targets( model: Model, targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, ) -> Model: """Build a new model with a different target set.""" detector = build_detector( num_classes=len(targets.class_names), + num_sizes=len(roi_mapper.dimension_names), backbone=model.detector.backbone, ) @@ -311,4 +330,5 @@ def build_model_with_new_targets( postprocessor=model.postprocessor, preprocessor=model.preprocessor, targets=targets, + roi_mapper=roi_mapper, ) diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index 12a6a77..a3894ce 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -136,6 +136,7 @@ class Detector(DetectionModel): def build_detector( num_classes: int, + num_sizes: int = 2, config: BackboneConfig | None = None, backbone: BackboneModel | None = None, ) -> DetectionModel: @@ -181,6 +182,7 @@ def build_detector( ) bbox_head = BBoxHead( in_channels=backbone.out_channels, + num_sizes=num_sizes, ) return Detector( backbone=backbone, diff --git a/src/batdetect2/models/heads.py b/src/batdetect2/models/heads.py index 65a2a40..ba7b437 100644 --- a/src/batdetect2/models/heads.py +++ b/src/batdetect2/models/heads.py @@ -165,14 +165,15 @@ class BBoxHead(nn.Module): 1×1 convolution with 2 output channels (duration, bandwidth). """ - def __init__(self, in_channels: int): + def __init__(self, in_channels: int, num_sizes: int = 2): """Initialise the BBoxHead.""" super().__init__() self.in_channels = in_channels + self.num_sizes = num_sizes self.bbox = nn.Conv2d( in_channels=self.in_channels, - out_channels=2, + out_channels=self.num_sizes, kernel_size=1, padding=0, ) diff --git a/src/batdetect2/outputs/transforms/__init__.py b/src/batdetect2/outputs/transforms/__init__.py index e82f84e..5214149 100644 --- a/src/batdetect2/outputs/transforms/__init__.py +++ b/src/batdetect2/outputs/transforms/__init__.py @@ -28,7 +28,7 @@ from batdetect2.postprocess.types import ( ClipDetectionsTensor, Detection, ) -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol __all__ = [ "ClipDetectionsTransformConfig", @@ -55,10 +55,12 @@ class OutputTransform(OutputTransformProtocol): def __init__( self, targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, detection_transform_steps: Sequence[DetectionTransform] = (), clip_transform_steps: Sequence[ClipDetectionsTransform] = (), ): self.targets = targets + self.roi_mapper = roi_mapper self.detection_transform_steps = list(detection_transform_steps) self.clip_transform_steps = list(clip_transform_steps) @@ -89,7 +91,11 @@ class OutputTransform(OutputTransformProtocol): detections: ClipDetectionsTensor, start_time: float = 0, ) -> list[Detection]: - decoded = to_detections(detections.numpy(), targets=self.targets) + decoded = to_detections( + detections.numpy(), + targets=self.targets, + roi_mapper=self.roi_mapper, + ) shifted = shift_detections_to_start_time( decoded, start_time=start_time, @@ -151,8 +157,9 @@ class OutputTransform(OutputTransformProtocol): def build_output_transform( config: OutputTransformConfig | dict | None = None, targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, ) -> OutputTransformProtocol: - from batdetect2.targets import build_targets + from batdetect2.targets import build_roi_mapping, build_targets if config is None: config = OutputTransformConfig() @@ -161,9 +168,11 @@ def build_output_transform( config = OutputTransformConfig.model_validate(config) targets = targets or build_targets() + roi_mapper = roi_mapper or build_roi_mapping() return OutputTransform( targets=targets, + roi_mapper=roi_mapper, detection_transform_steps=[ detection_transform_registry.build(transform_config) for transform_config in config.detection_transforms diff --git a/src/batdetect2/outputs/transforms/decoding.py b/src/batdetect2/outputs/transforms/decoding.py index f04d3c4..298226f 100644 --- a/src/batdetect2/outputs/transforms/decoding.py +++ b/src/batdetect2/outputs/transforms/decoding.py @@ -6,7 +6,7 @@ import numpy as np from soundevent import data from batdetect2.postprocess.types import ClipDetectionsArray, Detection -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol __all__ = [ "DEFAULT_CLASSIFICATION_THRESHOLD", @@ -25,6 +25,7 @@ DEFAULT_CLASSIFICATION_THRESHOLD = 0.1 def to_detections( detections: ClipDetectionsArray, targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, ) -> List[Detection]: predictions = [] @@ -39,7 +40,7 @@ def to_detections( ): highest_scoring_class = targets.class_names[class_scores.argmax()] - geom = targets.decode_roi( + geom = roi_mapper.decode( (time, freq), dims, class_name=highest_scoring_class, diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py index a866360..67aaf94 100644 --- a/src/batdetect2/plotting/clip_annotations.py +++ b/src/batdetect2/plotting/clip_annotations.py @@ -4,7 +4,7 @@ from soundevent import data, plot from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.common import create_ax from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol __all__ = [ "plot_clip_annotation", @@ -48,6 +48,7 @@ def plot_clip_annotation( def plot_anchor_points( clip_annotation: data.ClipAnnotation, targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, figsize: tuple[int, int] | None = None, ax: Axes | None = None, size: int = 1, @@ -63,7 +64,11 @@ def plot_anchor_points( if not targets.filter(sound_event): continue - position, _ = targets.encode_roi(sound_event) + class_name = targets.encode_class(sound_event) + position, _ = roi_mapper.encode( + sound_event.sound_event, + class_name=class_name, + ) positions.append(position) X, Y = zip(*positions, strict=False) diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index c6f556e..5bba3d1 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -10,7 +10,10 @@ from batdetect2.targets.config import TargetConfig from batdetect2.targets.rois import ( AnchorBBoxMapperConfig, ROIMapperConfig, + ROIMapperProtocol, + ROIMappingConfig, build_roi_mapper, + build_roi_mapping, ) from batdetect2.targets.targets import ( Targets, @@ -30,12 +33,15 @@ from batdetect2.targets.types import ( Size, SoundEventDecoder, SoundEventEncoder, + SoundEventFilter, TargetProtocol, ) __all__ = [ "AnchorBBoxMapperConfig", "Position", + "ROIMappingConfig", + "ROIMapperProtocol", "ROIMapperConfig", "ROITargetMapper", "Size", @@ -46,6 +52,7 @@ __all__ = [ "TargetConfig", "TargetProtocol", "Targets", + "build_roi_mapping", "build_roi_mapper", "build_sound_event_decoder", "build_sound_event_encoder", diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index e7b3604..7639cf0 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -14,7 +14,6 @@ from batdetect2.data.conditions import ( SoundEventConditionConfig, build_sound_event_condition, ) -from batdetect2.targets.rois import ROIMapperConfig from batdetect2.targets.terms import call_type, generic_class from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder @@ -39,8 +38,6 @@ class TargetClassConfig(BaseConfig): assign_tags: List[data.Tag] = Field(default_factory=list) - roi: ROIMapperConfig | None = None - _match_if: SoundEventConditionConfig = PrivateAttr() @model_validator(mode="after") diff --git a/src/batdetect2/targets/config.py b/src/batdetect2/targets/config.py index 9d3f0c5..aa7cca9 100644 --- a/src/batdetect2/targets/config.py +++ b/src/batdetect2/targets/config.py @@ -9,7 +9,7 @@ from batdetect2.targets.classes import ( DEFAULT_DETECTION_CLASS, TargetClassConfig, ) -from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig +from batdetect2.targets.rois import ROIMappingConfig __all__ = [ "TargetConfig", @@ -25,7 +25,7 @@ class TargetConfig(BaseConfig): default_factory=lambda: DEFAULT_CLASSES ) - roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig) + roi: ROIMappingConfig = Field(default_factory=ROIMappingConfig) @field_validator("classification_targets") def check_unique_class_names(cls, v: List[TargetClassConfig]): diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index e0e8504..4114e1a 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -29,7 +29,12 @@ from batdetect2.core.arrays import spec_to_xarray from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets.types import Position, ROITargetMapper, Size +from batdetect2.targets.types import ( + Position, + ROIMapperProtocol, + ROITargetMapper, + Size, +) __all__ = [ "Anchor", @@ -40,12 +45,15 @@ __all__ = [ "DEFAULT_TIME_SCALE", "PeakEnergyBBoxMapper", "PeakEnergyBBoxMapperConfig", + "ROIMappingConfig", + "ROIMapperProtocol", "ROIMapperConfig", "ROIMapperImportConfig", "ROITargetMapper", "SIZE_HEIGHT", "SIZE_ORDER", "SIZE_WIDTH", + "build_roi_mapping", "build_roi_mapper", ] @@ -456,6 +464,59 @@ implementations by using the `name` field as a discriminator. """ +class ROIMappingConfig(BaseConfig): + """Configuration for class-aware ROI mapping. + + Attributes + ---------- + default : ROIMapperConfig + Default mapper used when no class-specific override exists. + overrides : dict[str, ROIMapperConfig] + Optional class-specific mapper overrides by class name. + """ + + default: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig) + overrides: dict[str, ROIMapperConfig] = Field(default_factory=dict) + + +class ClassAwareROIMapper(ROIMapperProtocol): + """Apply a default ROI mapper with optional per-class overrides.""" + + dimension_names: list[str] + + def __init__( + self, + default_mapper: ROITargetMapper, + overrides: dict[str, ROITargetMapper] | None = None, + ): + self.default_mapper = default_mapper + self.overrides = overrides or {} + self.dimension_names = list(default_mapper.dimension_names) + + def encode( + self, + sound_event: data.SoundEvent, + class_name: str | None = None, + ) -> tuple[Position, Size]: + mapper = self._select_mapper(class_name) + return mapper.encode(sound_event) + + def decode( + self, + position: Position, + size: Size, + class_name: str | None = None, + ) -> data.Geometry: + mapper = self._select_mapper(class_name) + return mapper.decode(position, size) + + def _select_mapper(self, class_name: str | None = None) -> ROITargetMapper: + if class_name is not None and class_name in self.overrides: + return self.overrides[class_name] + + return self.default_mapper + + def build_roi_mapper( config: ROIMapperConfig | None = None, ) -> ROITargetMapper: @@ -480,6 +541,36 @@ def build_roi_mapper( return roi_mapper_registry.build(config) +def build_roi_mapping( + config: ROIMappingConfig | None = None, +) -> ROIMapperProtocol: + """Build a class-aware ROI mapper and validate consistency.""" + config = config or ROIMappingConfig() + + default_mapper = build_roi_mapper(config.default) + overrides = { + class_name: build_roi_mapper(mapper_config) + for class_name, mapper_config in config.overrides.items() + } + + expected = list(default_mapper.dimension_names) + + for class_name, mapper in overrides.items(): + actual = list(mapper.dimension_names) + + if actual != expected: + raise ValueError( + "All ROI mappers must share the same dimension order. " + f"Default dimensions: {expected}, " + f"class '{class_name}' dimensions: {actual}." + ) + + return ClassAwareROIMapper( + default_mapper=default_mapper, + overrides=overrides, + ) + + VALID_ANCHORS = [ "bottom-left", "bottom-right", diff --git a/src/batdetect2/targets/targets.py b/src/batdetect2/targets/targets.py index 895a487..72e0262 100644 --- a/src/batdetect2/targets/targets.py +++ b/src/batdetect2/targets/targets.py @@ -12,21 +12,21 @@ from batdetect2.targets.classes import ( get_class_names_from_config, ) from batdetect2.targets.config import TargetConfig -from batdetect2.targets.rois import ( - AnchorBBoxMapperConfig, - build_roi_mapper, +from batdetect2.targets.types import ( + Position, + ROIMapperProtocol, + Size, + TargetProtocol, ) -from batdetect2.targets.types import Position, Size, TargetProtocol class Targets(TargetProtocol): - """Encapsulates the complete configured target definition pipeline. + """Encapsulates the configured target class definition pipeline. This class implements the `TargetProtocol`, holding the configured - functions for filtering, transforming, encoding (tags to class name), - decoding (class name to tags), and mapping ROIs (geometry to position/size - and back). It provides a high-level interface to apply these steps and - access relevant metadata like class names and dimension names. + functions for filtering, encoding (tags to class name), and decoding + (class name to tags). Geometry ROI mapping is handled separately by + ``ROIMapperProtocol``. Instances are typically created using the `build_targets` factory function or the `load_targets` convenience loader. @@ -39,14 +39,10 @@ class Targets(TargetProtocol): generic_class_tags A list of `soundevent.data.Tag` objects representing the configured generic class category (used when no specific class matches). - dimension_names - The names of the size dimensions handled by the ROI mapper - (e.g., ['width', 'height']). """ class_names: list[str] detection_class_tags: list[data.Tag] - dimension_names: list[str] detection_class_name: str def __init__(self, config: TargetConfig): @@ -63,10 +59,6 @@ class Targets(TargetProtocol): config.classification_targets ) - self._roi_mapper = build_roi_mapper(config.roi) - - self.dimension_names = self._roi_mapper.dimension_names - self.class_names = get_class_names_from_config( config.classification_targets ) @@ -74,21 +66,6 @@ class Targets(TargetProtocol): self.detection_class_name = config.detection_target.name self.detection_class_tags = config.detection_target.assign_tags - self._roi_mapper_overrides = { - class_config.name: build_roi_mapper(class_config.roi) - for class_config in config.classification_targets - if class_config.roi is not None - } - - for class_name in self._roi_mapper_overrides: - if class_name not in self.class_names: - # TODO: improve this warning - logger.warning( - "The ROI mapper overrides contains a class ({class_name}) " - "not present in the class names.", - class_name=class_name, - ) - def filter(self, sound_event: data.SoundEventAnnotation) -> bool: """Apply the configured filter to a sound event annotation. @@ -147,75 +124,10 @@ class Targets(TargetProtocol): """ return self._decode_fn(class_label) - def encode_roi( - self, sound_event: data.SoundEventAnnotation - ) -> tuple[Position, Size]: - """Extract the target reference position from the annotation's roi. - - Delegates to the internal ROI mapper's `get_roi_position` method. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation containing the geometry (ROI). - - Returns - ------- - tuple[float, float] - The reference position `(time, frequency)`. - - Raises - ------ - ValueError - If the annotation lacks geometry. - """ - class_name = self.encode_class(sound_event) - - if class_name in self._roi_mapper_overrides: - return self._roi_mapper_overrides[class_name].encode( - sound_event.sound_event - ) - - return self._roi_mapper.encode(sound_event.sound_event) - - def decode_roi( - self, - position: Position, - size: Size, - class_name: str | None = None, - ) -> data.Geometry: - """Recover an approximate geometric ROI from a position and dimensions. - - Delegates to the internal ROI mapper's `recover_roi` method, which - un-scales the dimensions and reconstructs the geometry (typically a - `BoundingBox`). - - Parameters - ---------- - pos - The reference position `(time, frequency)`. - dims - NumPy array with size dimensions (e.g., from model prediction), - matching the order in `self.dimension_names`. - - Returns - ------- - data.Geometry - The reconstructed geometry (typically `BoundingBox`). - """ - if class_name in self._roi_mapper_overrides: - return self._roi_mapper_overrides[class_name].decode( - position, - size, - ) - - return self._roi_mapper.decode(position, size) - DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( classification_targets=DEFAULT_CLASSES, detection_target=DEFAULT_DETECTION_CLASS, - roi=AnchorBBoxMapperConfig(), ) @@ -292,6 +204,7 @@ def load_targets( def iterate_encoded_sound_events( sound_events: Iterable[data.SoundEventAnnotation], targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, ) -> Iterable[tuple[str | None, Position, Size]]: for sound_event in sound_events: if not targets.filter(sound_event): @@ -303,6 +216,9 @@ def iterate_encoded_sound_events( continue class_name = targets.encode_class(sound_event) - position, size = targets.encode_roi(sound_event) + position, size = roi_mapper.encode( + sound_event.sound_event, + class_name=class_name, + ) yield class_name, position, size diff --git a/src/batdetect2/targets/types.py b/src/batdetect2/targets/types.py index af5ab44..4f435ba 100644 --- a/src/batdetect2/targets/types.py +++ b/src/batdetect2/targets/types.py @@ -6,6 +6,7 @@ from soundevent import data __all__ = [ "Position", + "ROIMapperProtocol", "ROITargetMapper", "Size", "SoundEventDecoder", @@ -26,7 +27,6 @@ class TargetProtocol(Protocol): class_names: list[str] detection_class_tags: list[data.Tag] detection_class_name: str - dimension_names: list[str] def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ... @@ -37,6 +37,23 @@ class TargetProtocol(Protocol): def decode_class(self, class_label: str) -> list[data.Tag]: ... + +class ROIMapperProtocol(Protocol): + dimension_names: list[str] + + def encode( + self, + sound_event: data.SoundEvent, + class_name: str | None = None, + ) -> tuple[Position, Size]: ... + + def decode( + self, + position: Position, + size: Size, + class_name: str | None = None, + ) -> data.Geometry: ... + def encode_roi( self, sound_event: data.SoundEventAnnotation, diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 6d17c3b..e4a9881 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -93,7 +93,8 @@ class ValidationMetrics(Callback): model = pl_module.model if self.output_transform is None: self.output_transform = build_output_transform( - targets=model.targets + targets=model.targets, + roi_mapper=model.roi_mapper, ) output_transform = self.output_transform diff --git a/src/batdetect2/train/checkpoints.py b/src/batdetect2/train/checkpoints.py index e93c7ca..ef69ca6 100644 --- a/src/batdetect2/train/checkpoints.py +++ b/src/batdetect2/train/checkpoints.py @@ -40,7 +40,7 @@ def build_checkpoint_callback( if run_name is not None: checkpoint_dir = checkpoint_dir / run_name - checkpoint_dir.mkdir(parents=True, exist_ok=True) + Path(checkpoint_dir).mkdir(parents=True, exist_ok=True) return ModelCheckpoint( dirpath=str(checkpoint_dir), diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index ce055d3..790abdd 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -14,8 +14,12 @@ from soundevent import data from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.targets import build_targets, iterate_encoded_sound_events -from batdetect2.targets.types import TargetProtocol +from batdetect2.targets import ( + build_roi_mapping, + build_targets, + iterate_encoded_sound_events, +) +from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol from batdetect2.train.types import ClipLabeller, Heatmaps __all__ = [ @@ -42,6 +46,7 @@ class LabelConfig(BaseConfig): def build_clip_labeler( targets: TargetProtocol | None = None, + roi_mapper: ROIMapperProtocol | None = None, min_freq: float = MIN_FREQ, max_freq: float = MAX_FREQ, config: LabelConfig | None = None, @@ -53,12 +58,13 @@ def build_clip_labeler( lambda: config.to_yaml_string(), ) - if targets is None: - targets = build_targets() + targets = targets or build_targets() + roi_mapper = roi_mapper or build_roi_mapping() return partial( generate_heatmaps, targets=targets, + roi_mapper=roi_mapper, min_freq=min_freq, max_freq=max_freq, target_sigma=config.sigma, @@ -73,6 +79,7 @@ def generate_heatmaps( clip_annotation: data.ClipAnnotation, spec: torch.Tensor, targets: TargetProtocol, + roi_mapper: ROIMapperProtocol, min_freq: float, max_freq: float, target_sigma: float = 3.0, @@ -89,7 +96,7 @@ def generate_heatmaps( height = spec.shape[-2] width = spec.shape[-1] num_classes = len(targets.class_names) - num_dims = len(targets.dimension_names) + num_dims = len(roi_mapper.dimension_names) clip = clip_annotation.clip # Initialize heatmaps @@ -109,6 +116,7 @@ def generate_heatmaps( for class_name, (time, frequency), size in iterate_encoded_sound_events( clip_annotation.sound_events, targets, + roi_mapper, ): time_index = map_to_pixels(time, width, clip.start_time, clip.end_time) freq_index = map_to_pixels(frequency, height, min_freq, max_freq) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 2842738..a9d91d9 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -6,23 +6,24 @@ from lightning import Trainer, seed_everything from loguru import logger from soundevent import data -from batdetect2.audio import AudioConfig, build_audio_loader -from batdetect2.audio.types import AudioLoader -from batdetect2.evaluate import build_evaluator -from batdetect2.evaluate.types import EvaluatorProtocol +from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader +from batdetect2.evaluate import EvaluatorProtocol, build_evaluator from batdetect2.logging import ( LoggerConfig, TensorBoardLoggerConfig, build_logger, ) from batdetect2.models import Model, ModelConfig, build_model -from batdetect2.preprocess import build_preprocessor -from batdetect2.preprocess.types import PreprocessorProtocol -from batdetect2.targets import build_targets -from batdetect2.targets.types import TargetProtocol -from batdetect2.train import TrainingConfig +from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor +from batdetect2.targets import ( + ROIMapperProtocol, + TargetProtocol, + build_roi_mapping, + build_targets, +) from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.checkpoints import build_checkpoint_callback +from batdetect2.train.config import TrainingConfig from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module @@ -39,6 +40,7 @@ def run_train( val_annotations: Sequence[data.ClipAnnotation] | None = None, model: Model | None = None, targets: Optional["TargetProtocol"] = None, + roi_mapper: Optional["ROIMapperProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, audio_loader: Optional["AudioLoader"] = None, labeller: Optional["ClipLabeller"] = None, @@ -69,8 +71,15 @@ def run_train( if model is not None: targets = targets or model.targets + if roi_mapper is None and targets is model.targets: + roi_mapper = model.roi_mapper + targets = targets or build_targets(config=model_config.targets) + roi_mapper = roi_mapper or build_roi_mapping( + config=model_config.targets.roi + ) + audio_loader = audio_loader or build_audio_loader(config=audio_config) preprocessor = preprocessor or build_preprocessor( @@ -80,6 +89,7 @@ def run_train( labeller = labeller or build_clip_labeler( targets, + roi_mapper, min_freq=preprocessor.min_freq, max_freq=preprocessor.max_freq, config=train_config.labels, @@ -119,6 +129,7 @@ def run_train( evaluator=build_evaluator( train_config.validation, targets=targets, + roi_mapper=roi_mapper, ), checkpoint_dir=checkpoint_dir, num_epochs=num_epochs, diff --git a/tests/test_audio_utils.py b/tests/test_audio_utils.py index 6e635e2..1414401 100644 --- a/tests/test_audio_utils.py +++ b/tests/test_audio_utils.py @@ -1,7 +1,7 @@ import numpy as np import torch import torch.nn.functional as F -from hypothesis import given +from hypothesis import given, settings from hypothesis import strategies as st from batdetect2.detector import parameters @@ -9,6 +9,7 @@ from batdetect2.utils import audio_utils, detector_utils @given(duration=st.floats(min_value=0.1, max_value=1)) +@settings(deadline=None) def test_can_compute_correct_spectrogram_width(duration: float): samplerate = parameters.TARGET_SAMPLERATE_HZ params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS @@ -87,6 +88,7 @@ def test_pad_audio_without_fixed_size(duration: float): @given(duration=st.floats(min_value=0.1, max_value=2)) +@settings(deadline=None) def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor( duration: float, ): diff --git a/tests/test_cli/test_detect.py b/tests/test_cli/test_detect.py index ed26b95..e10cca1 100644 --- a/tests/test_cli/test_detect.py +++ b/tests/test_cli/test_detect.py @@ -8,15 +8,6 @@ from click.testing import CliRunner from batdetect2.cli import cli -def test_cli_detect_help() -> None: - """User story: get usage help for legacy detect command.""" - - result = CliRunner().invoke(cli, ["detect", "--help"]) - - assert result.exit_code == 0 - assert "Detect bat calls in files in AUDIO_DIR" in result.output - - def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None: """User story: run legacy detect on example audio directory.""" diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index 5cee836..f5ce769 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -41,6 +41,22 @@ def test_build_detector_custom_config(): assert model.backbone.encoder.in_channels == 2 +def test_build_detector_custom_size_channels(): + num_classes = 3 + num_sizes = 4 + config = UNetBackboneConfig(in_channels=1, input_height=128) + + model = build_detector( + num_classes=num_classes, + num_sizes=num_sizes, + config=config, + ) + + dummy = torch.randn(1, 1, 128, 64) + output = model(dummy) + assert output.size_preds.shape[1] == num_sizes + + def test_detector_forward_pass_shapes(dummy_spectrogram): """Test that the forward pass produces correctly shaped outputs.""" num_classes = 4 diff --git a/tests/test_outputs/test_transform/test_roundtrip.py b/tests/test_outputs/test_transform/test_roundtrip.py index 6b70e55..ffc6d57 100644 --- a/tests/test_outputs/test_transform/test_roundtrip.py +++ b/tests/test_outputs/test_transform/test_roundtrip.py @@ -6,6 +6,7 @@ from soundevent.geometry import compute_bounds from batdetect2.models.types import ModelOutput from batdetect2.outputs import build_output_transform from batdetect2.postprocess import build_postprocessor +from batdetect2.targets import build_roi_mapping from batdetect2.targets.types import TargetProtocol from batdetect2.train.labels import build_clip_labeler @@ -37,7 +38,9 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform( width = int(duration * sample_preprocessor.output_samplerate) spec = torch.zeros((1, height, width), dtype=torch.float32) - labeler = build_clip_labeler(targets=sample_targets) + roi_mapper = build_roi_mapping() + + labeler = build_clip_labeler(targets=sample_targets, roi_mapper=roi_mapper) heatmaps = labeler(clip_annotation, spec) output = ModelOutput( @@ -51,7 +54,10 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform( clip_detection_tensors = postprocessor(output) assert len(clip_detection_tensors) == 1 - transform = build_output_transform(targets=sample_targets) + transform = build_output_transform( + targets=sample_targets, + roi_mapper=roi_mapper, + ) clip_detections = transform.to_clip_detections( detections=clip_detection_tensors[0], clip=clip, diff --git a/tests/test_outputs/test_transform/test_transform.py b/tests/test_outputs/test_transform/test_transform.py index 4fe3de1..1239991 100644 --- a/tests/test_outputs/test_transform/test_transform.py +++ b/tests/test_outputs/test_transform/test_transform.py @@ -12,6 +12,7 @@ from batdetect2.postprocess.types import ( ClipDetectionsTensor, Detection, ) +from batdetect2.targets import TargetConfig, build_roi_mapping from batdetect2.targets.types import TargetProtocol @@ -27,9 +28,22 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor: ) +def _build_roi_mapper(targets: TargetProtocol): + config_obj = getattr(targets, "config", None) + target_config = ( + config_obj if isinstance(config_obj, TargetConfig) else None + ) + return build_roi_mapping( + config=(target_config.roi if target_config is not None else None), + ) + + def test_shift_time_to_clip_start(sample_targets: TargetProtocol): raw = _mock_clip_detections_tensor() - transform = build_output_transform(targets=sample_targets) + transform = build_output_transform( + targets=sample_targets, + roi_mapper=_build_roi_mapper(sample_targets), + ) transformed = transform.to_detections(raw, start_time=2.5) start_time, _, end_time, _ = compute_bounds(transformed[0].geometry) @@ -43,7 +57,10 @@ def test_to_clip_detections_shifts_by_clip_start( sample_targets: TargetProtocol, ): clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) - transform = build_output_transform(targets=sample_targets) + transform = build_output_transform( + targets=sample_targets, + roi_mapper=_build_roi_mapper(sample_targets), + ) raw = _mock_clip_detections_tensor() shifted = transform.to_clip_detections(detections=raw, clip=clip) start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry) @@ -90,6 +107,7 @@ def test_detection_and_clip_transforms_applied_in_order( transform = OutputTransform( targets=sample_targets, + roi_mapper=_build_roi_mapper(sample_targets), detection_transform_steps=[boost_score, keep_high_score], clip_transform_steps=[tag_clip_transform], ) diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index 2d91d5e..475df00 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -1,3 +1,5 @@ +from pathlib import Path + import numpy as np import pytest import soundfile as sf @@ -22,8 +24,10 @@ from batdetect2.targets.rois import ( AnchorBBoxMapperConfig, PeakEnergyBBoxMapper, PeakEnergyBBoxMapperConfig, + ROIMappingConfig, _build_bounding_box, build_roi_mapper, + build_roi_mapping, get_peak_energy_coordinates, ) @@ -630,3 +634,43 @@ def test_build_roi_mapper_raises_error_for_unknown_name(): # Then with pytest.raises(NotImplementedError): build_roi_mapper(DummyConfig()) # type: ignore + + +def test_build_roi_mapping_applies_class_override(): + config = ROIMappingConfig( + default=AnchorBBoxMapperConfig(anchor="bottom-left"), + overrides={ + "myomyo": AnchorBBoxMapperConfig(anchor="top-left"), + }, + ) + + mapper = build_roi_mapping(config=config) + + geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) + sound_event = data.SoundEvent( + recording=data.Recording( + path=Path("x.wav"), + samplerate=256_000, + channels=1, + duration=1.0, + ), + geometry=geometry, + ) + + default_position, _ = mapper.encode(sound_event, class_name="pippip") + override_position, _ = mapper.encode(sound_event, class_name="myomyo") + + assert default_position == pytest.approx((0.1, 12_000)) + assert override_position == pytest.approx((0.1, 18_000)) + + +def test_build_roi_mapping_rejects_dimension_mismatch(): + config = ROIMappingConfig( + default=AnchorBBoxMapperConfig(), + overrides={ + "myomyo": PeakEnergyBBoxMapperConfig(), + }, + ) + + with pytest.raises(ValueError, match="same dimension order"): + build_roi_mapping(config=config) diff --git a/tests/test_targets/test_targets.py b/tests/test_targets/test_targets.py index c823cc5..aa8433f 100644 --- a/tests/test_targets/test_targets.py +++ b/tests/test_targets/test_targets.py @@ -1,9 +1,10 @@ from collections.abc import Callable from pathlib import Path +import pytest from soundevent import data, terms -from batdetect2.targets import TargetConfig, build_targets +from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets def test_can_override_default_roi_mapper_per_class( @@ -32,18 +33,21 @@ def test_can_override_default_roi_mapper_per_class( tags: - key: species value: Myotis myotis - roi: - name: anchor_bbox - anchor: top-left roi: - name: anchor_bbox - anchor: bottom-left + default: + name: anchor_bbox + anchor: bottom-left + overrides: + myomyo: + name: anchor_bbox + anchor: top-left """ config_path = create_temp_yaml(yaml_content) config = TargetConfig.load(config_path) targets = build_targets(config) + roi_mapper = build_roi_mapping(config=config.roi) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) @@ -60,8 +64,17 @@ def test_can_override_default_roi_mapper_per_class( tags=[data.Tag(term=species, value="Myotis myotis")], ) - (time1, freq1), _ = targets.encode_roi(se1) - (time2, freq2), _ = targets.encode_roi(se2) + class_name1 = targets.encode_class(se1) + class_name2 = targets.encode_class(se2) + + (time1, freq1), _ = roi_mapper.encode( + se1.sound_event, + class_name=class_name1, + ) + (time2, freq2), _ = roi_mapper.encode( + se2.sound_event, + class_name=class_name2, + ) assert time1 == time2 == 0.1 assert freq1 == 12_000 @@ -95,18 +108,21 @@ def test_roi_is_recovered_roundtrip_even_with_overriders( tags: - key: species value: Myotis myotis - roi: - name: anchor_bbox - anchor: top-left roi: - name: anchor_bbox - anchor: bottom-left + default: + name: anchor_bbox + anchor: bottom-left + overrides: + myomyo: + name: anchor_bbox + anchor: top-left """ config_path = create_temp_yaml(yaml_content) config = TargetConfig.load(config_path) targets = build_targets(config) + roi_mapper = build_roi_mapping(config=config.roi) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) @@ -122,14 +138,14 @@ def test_roi_is_recovered_roundtrip_even_with_overriders( tags=[data.Tag(term=species, value="Myotis myotis")], ) - position1, size1 = targets.encode_roi(se1) - position2, size2 = targets.encode_roi(se2) + position1, size1 = roi_mapper.encode(se1.sound_event, class_name="pippip") + position2, size2 = roi_mapper.encode(se2.sound_event, class_name="myomyo") class_name1 = targets.encode_class(se1) class_name2 = targets.encode_class(se2) - recovered1 = targets.decode_roi(position1, size1, class_name=class_name1) - recovered2 = targets.decode_roi(position2, size2, class_name=class_name2) + recovered1 = roi_mapper.decode(position1, size1, class_name=class_name1) + recovered2 = roi_mapper.decode(position2, size2, class_name=class_name2) assert recovered1 == geometry assert recovered2 == geometry diff --git a/tests/test_train/test_checkpoints.py b/tests/test_train/test_checkpoints.py index a1b9031..7ff97bb 100644 --- a/tests/test_train/test_checkpoints.py +++ b/tests/test_train/test_checkpoints.py @@ -42,28 +42,6 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir( assert checkpoints -def test_train_without_validation_does_not_save_default_monitored_checkpoint( - tmp_path: Path, - example_annotations: list[data.ClipAnnotation], -) -> None: - config = _build_fast_train_config() - - run_train( - train_annotations=example_annotations[:1], - val_annotations=None, - train_config=config.train, - model_config=config.model, - audio_config=config.audio, - num_epochs=1, - train_workers=0, - val_workers=0, - checkpoint_dir=tmp_path, - seed=0, - ) - - assert not list(tmp_path.rglob("*.ckpt")) - - def test_train_without_validation_can_still_save_last_checkpoint( tmp_path: Path, example_annotations: list[data.ClipAnnotation], diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 19c39ab..e1de1d6 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -3,8 +3,8 @@ from pathlib import Path import torch from soundevent import data -from batdetect2.targets import TargetConfig, build_targets -from batdetect2.targets.rois import AnchorBBoxMapperConfig +from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets +from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMappingConfig from batdetect2.train.labels import generate_heatmaps recording = data.Recording( @@ -30,14 +30,17 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( ): config = sample_target_config.model_copy( update=dict( - roi=AnchorBBoxMapperConfig( - time_scale=1, - frequency_scale=1, + roi=ROIMappingConfig( + default=AnchorBBoxMapperConfig( + time_scale=1, + frequency_scale=1, + ) ) ) ) targets = build_targets(config) + roi_mapper = build_roi_mapping(config=config.roi) clip_annotation = data.ClipAnnotation( clip=clip, @@ -60,6 +63,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( min_freq=0, max_freq=100, targets=targets, + roi_mapper=roi_mapper, ) pippip_index = targets.class_names.index("pippip") myomyo_index = targets.class_names.index("myomyo")