From c7ea361cf4e0f5d127df306a4ce79f9001e3ddc6 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 23 Jun 2025 18:52:36 +0100 Subject: [PATCH] Implement changes needed to make roi encode/decode class dependent --- src/batdetect2/configs.py | 2 +- src/batdetect2/postprocess/__init__.py | 4 +- src/batdetect2/postprocess/decoding.py | 19 +- src/batdetect2/postprocess/detection.py | 2 +- src/batdetect2/postprocess/extraction.py | 8 +- src/batdetect2/postprocess/types.py | 29 ++- src/batdetect2/targets/__init__.py | 53 ++++- src/batdetect2/targets/classes.py | 48 +--- src/batdetect2/targets/rois.py | 12 +- src/batdetect2/targets/terms.py | 1 + src/batdetect2/targets/types.py | 10 +- tests/conftest.py | 13 ++ tests/test_targets/test_classes.py | 14 +- tests/test_targets/test_rois.py | 6 +- tests/test_targets/test_targets.py | 117 ++++++++++ tests/test_train/test_labels.py | 4 +- tests/test_train/test_preprocessing.py | 271 +++++++++++++++++++++++ 17 files changed, 528 insertions(+), 85 deletions(-) create mode 100644 tests/test_targets/test_targets.py create mode 100644 tests/test_train/test_preprocessing.py diff --git a/src/batdetect2/configs.py b/src/batdetect2/configs.py index e5549cd..f252dd2 100644 --- a/src/batdetect2/configs.py +++ b/src/batdetect2/configs.py @@ -157,4 +157,4 @@ def load_config( if field: config = get_object_field(config, field) - return schema.model_validate(config) + return schema.model_validate(config or {}) diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index 7a79289..f6f84ac 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry. The pipeline involves several configurable steps, implemented in submodules: 1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks. -2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency - coordinates to raw model output arrays. +2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw + model output arrays. 3. Detection Extraction (`.detection`): Identifies candidate detection points (location and score) based on thresholds and score ranking (top-k). 4. Data Extraction (`.extraction`): Gathers associated model outputs (size, diff --git a/src/batdetect2/postprocess/decoding.py b/src/batdetect2/postprocess/decoding.py index 3f4611c..6a5846c 100644 --- a/src/batdetect2/postprocess/decoding.py +++ b/src/batdetect2/postprocess/decoding.py @@ -4,8 +4,7 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline. It takes the structured detection data extracted by the `extraction` module (typically an `xarray.Dataset` containing scores, positions, predicted sizes, class probabilities, and features for each detection point) and converts it -into meaningful, standardized prediction objects based on the `soundevent` data -model. +into standardized prediction objects based on the `soundevent` data model. The process involves: 1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction` @@ -33,7 +32,7 @@ import numpy as np import xarray as xr from soundevent import data -from batdetect2.postprocess.types import GeometryBuilder, RawPrediction +from batdetect2.postprocess.types import GeometryDecoder, RawPrediction from batdetect2.targets.classes import SoundEventDecoder from batdetect2.utils.arrays import iterate_over_array @@ -55,7 +54,7 @@ decoding. def convert_xr_dataset_to_raw_prediction( detection_dataset: xr.Dataset, - geometry_builder: GeometryBuilder, + geometry_decoder: GeometryDecoder, ) -> List[RawPrediction]: """Convert an xarray.Dataset of detections to RawPrediction objects. @@ -72,7 +71,7 @@ def convert_xr_dataset_to_raw_prediction( output by `extract_detection_xr_dataset`. Expected variables include 'scores' (with time/freq coords), 'dimensions', 'classes', 'features'. Must have a 'detection' dimension. - geometry_builder : GeometryBuilder + geometry_decoder : GeometryDecoder A function that takes a position tuple `(time, freq)` and a NumPy array of dimensions, and returns the corresponding reconstructed `soundevent.data.Geometry`. @@ -96,14 +95,20 @@ def convert_xr_dataset_to_raw_prediction( for det_num in range(detection_dataset.sizes["detection"]): det_info = detection_dataset.sel(detection=det_num) - geom = geometry_builder( + # TODO: Maybe clean this up + highest_scoring_class = det_info.coords["category"][ + det_info["classes"].argmax() + ].item() + + geom = geometry_decoder( (det_info.time, det_info.frequency), det_info.dimensions, + class_name=highest_scoring_class, ) detections.append( RawPrediction( - detection_score=det_info.score, + detection_score=det_info.scores, geometry=geom, class_scores=det_info.classes, features=det_info.features, diff --git a/src/batdetect2/postprocess/detection.py b/src/batdetect2/postprocess/detection.py index 78e7003..9b2a185 100644 --- a/src/batdetect2/postprocess/detection.py +++ b/src/batdetect2/postprocess/detection.py @@ -1,6 +1,6 @@ """Extracts candidate detection points from a model output heatmap. -This module implements a specific step within the BatDetect2 postprocessing +This module implements Step 3 within the BatDetect2 postprocessing pipeline. Its primary function is to identify potential sound event locations by finding peaks (local maxima or high-scoring points) in the detection heatmap produced by the neural network (usually after Non-Maximum Suppression and diff --git a/src/batdetect2/postprocess/extraction.py b/src/batdetect2/postprocess/extraction.py index 84019a2..2809ab7 100644 --- a/src/batdetect2/postprocess/extraction.py +++ b/src/batdetect2/postprocess/extraction.py @@ -1,9 +1,9 @@ """Extracts associated data for detected points from model output arrays. -This module implements a key step (Step 4) in the BatDetect2 postprocessing -pipeline. After candidate detection points (time, frequency, score) have been -identified, this module extracts the corresponding values from other raw model -output arrays, such as: +This module implements a Step 4 in the BatDetect2 postprocessing pipeline. +After candidate detection points (time, frequency, score) have been identified, +this module extracts the corresponding values from other raw model output +arrays, such as: - Predicted bounding box sizes (width, height). - Class probability scores for each defined target class. diff --git a/src/batdetect2/postprocess/types.py b/src/batdetect2/postprocess/types.py index 70f1f39..c34b57f 100644 --- a/src/batdetect2/postprocess/types.py +++ b/src/batdetect2/postprocess/types.py @@ -11,30 +11,37 @@ modularity and consistent interaction between different parts of the BatDetect2 system that deal with model predictions. """ -from typing import Callable, List, NamedTuple, Protocol +from typing import List, NamedTuple, Optional, Protocol -import numpy as np import xarray as xr from soundevent import data from batdetect2.models.types import ModelOutput +from batdetect2.targets.types import Position, Size __all__ = [ "RawPrediction", "PostprocessorProtocol", - "GeometryBuilder", + "GeometryDecoder", ] -GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry] -"""Type alias for a function that recovers geometry from position and size. +# TODO: update the docstring +class GeometryDecoder(Protocol): + """Type alias for a function that recovers geometry from position and size. -This callable takes: -1. A position tuple `(time, frequency)`. -2. A NumPy array of size dimensions (e.g., `[width, height]`). -It should return the reconstructed `soundevent.data.Geometry` (typically a -`BoundingBox`). -""" + This callable takes: + 1. A position tuple `(time, frequency)`. + 2. A NumPy array of size dimensions (e.g., `[width, height]`). + 3. Optionally a class name of the highest scoring class. This is to accomodate + different ways of decoding geometry that depend on the predicted class. + It should return the reconstructed `soundevent.data.Geometry` (typically a + `BoundingBox`). + """ + + def __call__( + self, position: Position, size: Size, class_name: Optional[str] = None + ) -> data.Geometry: ... class RawPrediction(NamedTuple): diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 68a8095..34fa84e 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -23,6 +23,7 @@ object is via the `build_targets` or `load_targets` functions. from typing import List, Optional +from loguru import logger from pydantic import Field from soundevent import data @@ -49,7 +50,8 @@ from batdetect2.targets.filtering import ( load_filter_from_config, ) from batdetect2.targets.rois import ( - BBoxAnchorMapperConfig, + AnchorBBoxMapperConfig, + ROIMapperConfig, ROITargetMapper, build_roi_mapper, ) @@ -58,11 +60,11 @@ from batdetect2.targets.terms import ( TermInfo, TermRegistry, call_type, + default_term_registry, get_tag_from_info, get_term_from_key, individual, register_term, - default_term_registry, ) from batdetect2.targets.transform import ( DerivationRegistry, @@ -87,7 +89,7 @@ __all__ = [ "FilterConfig", "FilterRule", "MapValueRule", - "BBoxAnchorMapperConfig", + "AnchorBBoxMapperConfig", "ROITargetMapper", "ReplaceRule", "SoundEventDecoder", @@ -160,7 +162,7 @@ class TargetConfig(BaseConfig): classes: ClassesConfig = Field( default_factory=lambda: DEFAULT_CLASSES_CONFIG ) - roi: Optional[BBoxAnchorMapperConfig] = None + roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig) def load_target_config( @@ -239,6 +241,7 @@ class Targets(TargetProtocol): generic_class_tags: List[data.Tag], filter_fn: Optional[SoundEventFilter] = None, transform_fn: Optional[SoundEventTransformation] = None, + roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None, ): """Initialize the Targets object. @@ -271,6 +274,16 @@ class Targets(TargetProtocol): self._encode_fn = encode_fn self._decode_fn = decode_fn self._transform_fn = transform_fn + self._roi_mapper_overrides = roi_mapper_overrides or {} + + for class_name in self._roi_mapper_overrides: + if class_name not in self.class_names: + # TODO: improve this warning + logger.warning( + "The ROI mapper overrides contains a class ({class_name}) " + "not present in the class names.", + class_name=class_name, + ) def filter(self, sound_event: data.SoundEventAnnotation) -> bool: """Apply the configured filter to a sound event annotation. @@ -375,9 +388,21 @@ class Targets(TargetProtocol): ValueError If the annotation lacks geometry. """ + class_name = self.encode_class(sound_event) + + if class_name in self._roi_mapper_overrides: + return self._roi_mapper_overrides[class_name].encode( + sound_event.sound_event + ) + return self._roi_mapper.encode(sound_event.sound_event) - def decode_roi(self, position: Position, size: Size) -> data.Geometry: + def decode_roi( + self, + position: Position, + size: Size, + class_name: Optional[str] = None, + ) -> data.Geometry: """Recover an approximate geometric ROI from a position and dimensions. Delegates to the internal ROI mapper's `recover_roi` method, which @@ -397,6 +422,13 @@ class Targets(TargetProtocol): data.Geometry The reconstructed geometry (typically `BoundingBox`). """ + print(class_name) + if class_name in self._roi_mapper_overrides: + return self._roi_mapper_overrides[class_name].decode( + position, + size, + ) + return self._roi_mapper.decode(position, size) @@ -452,10 +484,12 @@ DEFAULT_CLASSES = [ TargetClass( tags=[TagInfo(value="Nyctalus leisleri")], name="nyclei", + roi=AnchorBBoxMapperConfig(anchor="top-left"), ), TargetClass( tags=[TagInfo(value="Rhinolophus ferrumequinum")], name="rhifer", + roi=AnchorBBoxMapperConfig(anchor="top-left"), ), TargetClass( tags=[TagInfo(value="Plecotus auritus")], @@ -496,6 +530,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( ] ), classes=DEFAULT_CLASSES_CONFIG, + roi=AnchorBBoxMapperConfig(), ) @@ -565,12 +600,17 @@ def build_targets( if config.transforms else None ) - roi_mapper = build_roi_mapper(config.roi or BBoxAnchorMapperConfig()) + roi_mapper = build_roi_mapper(config.roi) class_names = get_class_names_from_config(config.classes) generic_class_tags = build_generic_class_tags( config.classes, term_registry=term_registry, ) + roi_overrides = { + class_config.name: build_roi_mapper(class_config.roi) + for class_config in config.classes.classes + if class_config.roi is not None + } return Targets( filter_fn=filter_fn, @@ -580,6 +620,7 @@ def build_targets( roi_mapper=roi_mapper, generic_class_tags=generic_class_tags, transform_fn=transform_fn, + roi_mapper_overrides=roi_overrides, ) diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index 7b947bd..d9083a3 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -6,49 +6,26 @@ from pydantic import Field, field_validator from soundevent import data from batdetect2.configs import BaseConfig, load_config +from batdetect2.targets.rois import ROIMapperConfig from batdetect2.targets.terms import ( GENERIC_CLASS_KEY, TagInfo, TermRegistry, - get_tag_from_info, default_term_registry, + get_tag_from_info, ) __all__ = [ - "SoundEventEncoder", - "SoundEventDecoder", - "TargetClass", - "ClassesConfig", - "load_classes_config", - "load_encoder_from_config", - "load_decoder_from_config", - "build_sound_event_encoder", - "build_sound_event_decoder", - "build_generic_class_tags", - "get_class_names_from_config", "DEFAULT_SPECIES_LIST", - "PositionMethod", - "CornerPosition", - "SizeMethod", - "BoundingBoxSize", + "build_generic_class_tags", + "build_sound_event_decoder", + "build_sound_event_encoder", + "get_class_names_from_config", + "load_classes_config", + "load_decoder_from_config", + "load_encoder_from_config", ] -class PositionMethod(BaseConfig): - """Base class for defining how to select a position from a geometry.""" - method_type: str - -class CornerPosition(PositionMethod): - """Selects a position based on a corner or center of the bounding box.""" - method_type: Literal["corner"] = "corner" - corner: Literal["upper_left", "lower_left", "center"] = "lower_left" - -class SizeMethod(BaseConfig): - """Base class for defining how to select a size from a geometry.""" - method_type: str - -class BoundingBoxSize(SizeMethod): - """Uses the width and height of the bounding box as the size.""" - method_type: Literal["bounding_box"] = "bounding_box" SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] """Type alias for a sound event class encoder function. @@ -134,8 +111,7 @@ class TargetClass(BaseConfig): tags: List[TagInfo] = Field(min_length=1) match_type: Literal["all", "any"] = Field(default="all") output_tags: Optional[List[TagInfo]] = None - position_method: PositionMethod = Field(default_factory=lambda: CornerPosition(corner="lower_left")) - size_method: SizeMethod = Field(default_factory=BoundingBoxSize) + roi: Optional[ROIMapperConfig] = None def _get_default_classes() -> List[TargetClass]: @@ -258,7 +234,7 @@ class ClassesConfig(BaseConfig): return v -def _is_target_class( +def is_target_class( sound_event_annotation: data.SoundEventAnnotation, tags: Set[data.Tag], match_all: bool = True, @@ -373,7 +349,7 @@ def build_sound_event_encoder( ( class_info.name, partial( - _is_target_class, + is_target_class, tags={ get_tag_from_info(tag_info, term_registry=term_registry) for tag_info in class_info.tags diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index c1787fc..330f6a0 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -34,7 +34,7 @@ from batdetect2.targets.types import Position, Size __all__ = [ "Anchor", "AnchorBBoxMapper", - "BBoxAnchorMapperConfig", + "AnchorBBoxMapperConfig", "DEFAULT_ANCHOR", "DEFAULT_FREQUENCY_SCALE", "DEFAULT_TIME_SCALE", @@ -148,7 +148,7 @@ class ROITargetMapper(Protocol): ... -class BBoxAnchorMapperConfig(BaseConfig): +class AnchorBBoxMapperConfig(BaseConfig): """Configuration for `AnchorBBoxMapper`. Defines parameters for converting ROIs into targets using a fixed anchor @@ -470,7 +470,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper): ROIMapperConfig = Annotated[ - Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig], + Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig], Field(discriminator="name"), ] """A discriminated union of all supported ROI mapper configurations. @@ -480,7 +480,9 @@ implementations by using the `name` field as a discriminator. """ -def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: +def build_roi_mapper( + config: Optional[ROIMapperConfig] = None, +) -> ROITargetMapper: """Factory function to create an ROITargetMapper from a config object. Parameters @@ -498,6 +500,8 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: NotImplementedError If the `name` in the config does not correspond to a known mapper. """ + config = config or AnchorBBoxMapperConfig() + if config.name == "anchor_bbox": return AnchorBBoxMapper( anchor=config.anchor, diff --git a/src/batdetect2/targets/terms.py b/src/batdetect2/targets/terms.py index d6a3814..9f82b84 100644 --- a/src/batdetect2/targets/terms.py +++ b/src/batdetect2/targets/terms.py @@ -235,6 +235,7 @@ default_term_registry = TermRegistry( [ *getmembers(terms, lambda x: isinstance(x, data.Term)), ("event", call_type), + ("species", terms.scientific_name), ("individual", individual), ("data_source", data_source), (GENERIC_CLASS_KEY, generic_class), diff --git a/src/batdetect2/targets/types.py b/src/batdetect2/targets/types.py index 19a0ea6..221897a 100644 --- a/src/batdetect2/targets/types.py +++ b/src/batdetect2/targets/types.py @@ -181,7 +181,13 @@ class TargetProtocol(Protocol): """ ... - def decode_roi(self, position: Position, size: Size) -> data.Geometry: + # TODO: Update docstrings + def decode_roi( + self, + position: Position, + size: Size, + class_name: Optional[str] = None, + ) -> data.Geometry: """Recover the ROI geometry from a position and dimensions. Performs the inverse mapping of `get_position` and `get_size`. It takes @@ -195,6 +201,8 @@ class TargetProtocol(Protocol): dims : np.ndarray The NumPy array containing the dimensions (e.g., predicted by the model), corresponding to the order in `dimension_names`. + class_name: str + class Returns ------- diff --git a/tests/conftest.py b/tests/conftest.py index 1bf855e..1c8b065 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import uuid from pathlib import Path from typing import Callable, List, Optional +from uuid import uuid4 import numpy as np import pytest @@ -447,3 +448,15 @@ def example_annotations( annotations = load_dataset(example_dataset) assert len(annotations) == 3 return annotations + + +@pytest.fixture +def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]: + """Create a temporary YAML file with the given content.""" + + def factory(content: str) -> Path: + temp_file = tmp_path / f"{uuid4()}.yaml" + temp_file.write_text(content) + return temp_file + + return factory diff --git a/tests/test_targets/test_classes.py b/tests/test_targets/test_classes.py index 4143c3c..fc4c155 100644 --- a/tests/test_targets/test_classes.py +++ b/tests/test_targets/test_classes.py @@ -12,11 +12,11 @@ from batdetect2.targets.classes import ( TargetClass, _get_default_class_name, _get_default_classes, - _is_target_class, build_generic_class_tags, build_sound_event_decoder, build_sound_event_encoder, get_class_names_from_config, + is_target_class, load_classes_config, load_decoder_from_config, load_encoder_from_config, @@ -145,7 +145,7 @@ def test_is_target_class_match_all( ), data.Tag(term=sample_term_registry["quality"], value="Good"), } - assert _is_target_class(sample_annotation, tags, match_all=True) is True + assert is_target_class(sample_annotation, tags, match_all=True) is True tags = { data.Tag( @@ -153,14 +153,14 @@ def test_is_target_class_match_all( value="Pipistrellus pipistrellus", ) } - assert _is_target_class(sample_annotation, tags, match_all=True) is True + assert is_target_class(sample_annotation, tags, match_all=True) is True tags = { data.Tag( term=sample_term_registry["species"], value="Myotis daubentonii" ) } - assert _is_target_class(sample_annotation, tags, match_all=True) is False + assert is_target_class(sample_annotation, tags, match_all=True) is False def test_is_target_class_match_any( @@ -174,7 +174,7 @@ def test_is_target_class_match_any( ), data.Tag(term=sample_term_registry["quality"], value="Good"), } - assert _is_target_class(sample_annotation, tags, match_all=False) is True + assert is_target_class(sample_annotation, tags, match_all=False) is True tags = { data.Tag( @@ -182,14 +182,14 @@ def test_is_target_class_match_any( value="Pipistrellus pipistrellus", ) } - assert _is_target_class(sample_annotation, tags, match_all=False) is True + assert is_target_class(sample_annotation, tags, match_all=False) is True tags = { data.Tag( term=sample_term_registry["species"], value="Myotis daubentonii" ) } - assert _is_target_class(sample_annotation, tags, match_all=False) is False + assert is_target_class(sample_annotation, tags, match_all=False) is False def test_get_class_names_from_config(): diff --git a/tests/test_targets/test_rois.py b/tests/test_targets/test_rois.py index c6afc24..8b92045 100644 --- a/tests/test_targets/test_rois.py +++ b/tests/test_targets/test_rois.py @@ -11,7 +11,7 @@ from batdetect2.targets.rois import ( SIZE_HEIGHT, SIZE_WIDTH, AnchorBBoxMapper, - BBoxAnchorMapperConfig, + AnchorBBoxMapperConfig, PeakEnergyBBoxMapper, PeakEnergyBBoxMapperConfig, _build_bounding_box, @@ -243,7 +243,7 @@ def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper): def test_build_roi_mapper(): """Test build_roi_mapper creates a configured BBoxEncoder.""" - config = BBoxAnchorMapperConfig( + config = AnchorBBoxMapperConfig( anchor="top-right", time_scale=2.0, frequency_scale=20.0 ) mapper = build_roi_mapper(config) @@ -571,7 +571,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle): def test_build_roi_mapper_for_anchor_bbox(): # Given - config = BBoxAnchorMapperConfig( + config = AnchorBBoxMapperConfig( anchor="center", time_scale=123.0, frequency_scale=456.0, diff --git a/tests/test_targets/test_targets.py b/tests/test_targets/test_targets.py new file mode 100644 index 0000000..bb4d00f --- /dev/null +++ b/tests/test_targets/test_targets.py @@ -0,0 +1,117 @@ +from collections.abc import Callable +from pathlib import Path + +from soundevent import data + +from batdetect2.targets import build_targets, load_target_config +from batdetect2.targets.terms import get_term_from_key + + +def test_can_override_default_roi_mapper_per_class( + create_temp_yaml: Callable[..., Path], + recording: data.Recording, + sample_term_registry, +): + yaml_content = """ + roi: + name: anchor_bbox + anchor: bottom-left + classes: + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + - name: myomyo + tags: + - key: species + value: Myotis myotis + roi: + name: anchor_bbox + anchor: top-left + generic_class: + - key: order + value: Chiroptera + """ + config_path = create_temp_yaml(yaml_content) + + config = load_target_config(config_path) + targets = build_targets(config, term_registry=sample_term_registry) + + geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) + + species = get_term_from_key("species", term_registry=sample_term_registry) + se1 = data.SoundEventAnnotation( + sound_event=data.SoundEvent(recording=recording, geometry=geometry), + tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], + ) + + se2 = data.SoundEventAnnotation( + sound_event=data.SoundEvent(recording=recording, geometry=geometry), + tags=[data.Tag(term=species, value="Myotis myotis")], + ) + + (time1, freq1), _ = targets.encode_roi(se1) + (time2, freq2), _ = targets.encode_roi(se2) + + assert time1 == time2 == 0.1 + assert freq1 == 12_000 + assert freq2 == 18_000 + + +# TODO: rename this test function +def test_roi_is_recovered_roundtrip_even_with_overriders( + create_temp_yaml, + sample_term_registry, + recording, +): + yaml_content = """ + roi: + name: anchor_bbox + anchor: bottom-left + classes: + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + - name: myomyo + tags: + - key: species + value: Myotis myotis + roi: + name: anchor_bbox + anchor: top-left + generic_class: + - key: order + value: Chiroptera + """ + config_path = create_temp_yaml(yaml_content) + + config = load_target_config(config_path) + targets = build_targets(config, term_registry=sample_term_registry) + + geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) + + species = get_term_from_key("species", term_registry=sample_term_registry) + se1 = data.SoundEventAnnotation( + sound_event=data.SoundEvent(recording=recording, geometry=geometry), + tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], + ) + + se2 = data.SoundEventAnnotation( + sound_event=data.SoundEvent(recording=recording, geometry=geometry), + tags=[data.Tag(term=species, value="Myotis myotis")], + ) + + position1, size1 = targets.encode_roi(se1) + position2, size2 = targets.encode_roi(se2) + + class_name1 = targets.encode_class(se1) + class_name2 = targets.encode_class(se2) + + recovered1 = targets.decode_roi(position1, size1, class_name=class_name1) + recovered2 = targets.decode_roi(position2, size2, class_name=class_name2) + + assert recovered1 == geometry + assert recovered2 == geometry diff --git a/tests/test_train/test_labels.py b/tests/test_train/test_labels.py index 27f7a1d..6e4e23c 100644 --- a/tests/test_train/test_labels.py +++ b/tests/test_train/test_labels.py @@ -5,7 +5,7 @@ import xarray as xr from soundevent import data from batdetect2.targets import TargetConfig, TargetProtocol, build_targets -from batdetect2.targets.rois import BBoxAnchorMapperConfig +from batdetect2.targets.rois import AnchorBBoxMapperConfig from batdetect2.targets.terms import TagInfo, TermRegistry from batdetect2.train.labels import generate_heatmaps @@ -85,7 +85,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions( ): config = sample_target_config.model_copy( update=dict( - roi=BBoxAnchorMapperConfig( + roi=AnchorBBoxMapperConfig( time_scale=1, frequency_scale=1, ) diff --git a/tests/test_train/test_preprocessing.py b/tests/test_train/test_preprocessing.py new file mode 100644 index 0000000..8727ee9 --- /dev/null +++ b/tests/test_train/test_preprocessing.py @@ -0,0 +1,271 @@ +import pytest +import torch +import xarray as xr +from soundevent import data + +from batdetect2.models.types import ModelOutput +from batdetect2.postprocess import build_postprocessor, load_postprocess_config +from batdetect2.preprocess import build_preprocessor, load_preprocessing_config +from batdetect2.targets import build_targets, load_target_config +from batdetect2.targets.terms import get_term_from_key +from batdetect2.train.labels import build_clip_labeler, load_label_config +from batdetect2.train.preprocess import generate_train_example + + +@pytest.fixture +def build_from_config( + create_temp_yaml, + sample_term_registry, +): + def build(yaml_content): + config_path = create_temp_yaml(yaml_content) + + targets_config = load_target_config(config_path, field="targets") + preprocessing_config = load_preprocessing_config( + config_path, + field="preprocessing", + ) + labels_config = load_label_config(config_path, field="labels") + postprocessing_config = load_postprocess_config( + config_path, + field="postprocessing", + ) + + targets = build_targets( + targets_config, term_registry=sample_term_registry + ) + preprocessor = build_preprocessor(preprocessing_config) + labeller = build_clip_labeler( + targets=targets, + config=labels_config, + ) + postprocessor = build_postprocessor( + targets, + config=postprocessing_config, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) + + return targets, preprocessor, labeller, postprocessor + + return build + + +# TODO: better name +def test_generated_train_example_has_expected_outputs( + build_from_config, + sample_term_registry, + recording, +): + yaml_content = """ + labels: + targets: + roi: + name: anchor_bbox + anchor: bottom-left + classes: + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + generic_class: + - key: order + value: Chiroptera + preprocessing: + postprocessing: + """ + _, preprocessor, labeller, _ = build_from_config(yaml_content) + + geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) + species = get_term_from_key("species", term_registry=sample_term_registry) + se1 = data.SoundEventAnnotation( + sound_event=data.SoundEvent(recording=recording, geometry=geometry), + tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], + ) + clip_annotation = data.ClipAnnotation( + clip=data.Clip(start_time=0, end_time=0.5, recording=recording), + sound_events=[se1], + ) + + encoded = generate_train_example(clip_annotation, preprocessor, labeller) + + assert isinstance(encoded, xr.Dataset) + assert "audio" in encoded + assert "spectrogram" in encoded + assert "detection" in encoded + assert "class" in encoded + assert "size" in encoded + + spec_shape = encoded["spectrogram"].shape + assert len(spec_shape) == 2 + + height, width = spec_shape + assert encoded["detection"].shape == (height, width) + assert encoded["class"].shape == (1, height, width) + assert encoded["size"].shape == (2, height, width) + + +def test_encoding_decoding_roundtrip_recovers_object( + build_from_config, + sample_term_registry, + recording, +): + yaml_content = """ + labels: + targets: + roi: + name: anchor_bbox + anchor: bottom-left + classes: + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + generic_class: + - key: order + value: Chiroptera + preprocessing: + """ + _, preprocessor, labeller, postprocessor = build_from_config(yaml_content) + + geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) + species = get_term_from_key("species", term_registry=sample_term_registry) + se1 = data.SoundEventAnnotation( + sound_event=data.SoundEvent(recording=recording, geometry=geometry), + tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")], + ) + clip = data.Clip(start_time=0, end_time=0.5, recording=recording) + clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) + + encoded = generate_train_example(clip_annotation, preprocessor, labeller) + predictions = postprocessor.get_predictions( + ModelOutput( + detection_probs=torch.tensor([[encoded["detection"].data]]), + size_preds=torch.tensor([encoded["size"].data]), + class_probs=torch.tensor([encoded["class"].data]), + features=torch.tensor([[encoded["spectrogram"].data]]), + ), + [clip], + )[0] + + assert isinstance(predictions, data.ClipPrediction) + assert len(predictions.sound_events) == 1 + + recovered = predictions.sound_events[0] + assert recovered.sound_event.geometry is not None + assert isinstance(recovered.sound_event.geometry, data.BoundingBox) + start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = ( + recovered.sound_event.geometry.coordinates + ) + start_time_or, low_freq_or, end_time_or, high_freq_or = ( + geometry.coordinates + ) + + assert start_time_rec == pytest.approx(start_time_or, abs=0.01) + assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000) + assert end_time_rec == pytest.approx(end_time_or, abs=0.01) + assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000) + + assert len(recovered.tags) == 2 + + predicted_species_tag = next( + iter(t for t in recovered.tags if t.tag.term == species), None + ) + assert predicted_species_tag is not None + assert predicted_species_tag.score == 1 + assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus" + + predicted_order_tag = next( + iter(t for t in recovered.tags if t.tag.term.label == "order"), None + ) + assert predicted_order_tag is not None + assert predicted_order_tag.score == 1 + assert predicted_order_tag.tag.value == "Chiroptera" + + +def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( + build_from_config, + sample_term_registry, + recording, +): + yaml_content = """ + labels: + targets: + roi: + name: anchor_bbox + anchor: bottom-left + classes: + classes: + - name: pippip + tags: + - key: species + value: Pipistrellus pipistrellus + - name: myomyo + tags: + - key: species + value: Myotis myotis + roi: + name: anchor_bbox + anchor: top-left + generic_class: + - key: order + value: Chiroptera + preprocessing: + """ + _, preprocessor, labeller, postprocessor = build_from_config(yaml_content) + + geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000]) + species = get_term_from_key("species", term_registry=sample_term_registry) + se1 = data.SoundEventAnnotation( + sound_event=data.SoundEvent(recording=recording, geometry=geometry), + tags=[data.Tag(term=species, value="Myotis myotis")], + ) + clip = data.Clip(start_time=0, end_time=0.5, recording=recording) + clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) + + encoded = generate_train_example(clip_annotation, preprocessor, labeller) + predictions = postprocessor.get_predictions( + ModelOutput( + detection_probs=torch.tensor([[encoded["detection"].data]]), + size_preds=torch.tensor([encoded["size"].data]), + class_probs=torch.tensor([encoded["class"].data]), + features=torch.tensor([[encoded["spectrogram"].data]]), + ), + [clip], + )[0] + + assert isinstance(predictions, data.ClipPrediction) + assert len(predictions.sound_events) == 1 + + recovered = predictions.sound_events[0] + assert recovered.sound_event.geometry is not None + assert isinstance(recovered.sound_event.geometry, data.BoundingBox) + start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = ( + recovered.sound_event.geometry.coordinates + ) + start_time_or, low_freq_or, end_time_or, high_freq_or = ( + geometry.coordinates + ) + + assert start_time_rec == pytest.approx(start_time_or, abs=0.01) + assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000) + assert end_time_rec == pytest.approx(end_time_or, abs=0.01) + assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000) + + assert len(recovered.tags) == 2 + + predicted_species_tag = next( + iter(t for t in recovered.tags if t.tag.term == species), None + ) + assert predicted_species_tag is not None + assert predicted_species_tag.score == 1 + assert predicted_species_tag.tag.value == "Myotis myotis" + + predicted_order_tag = next( + iter(t for t in recovered.tags if t.tag.term.label == "order"), None + ) + assert predicted_order_tag is not None + assert predicted_order_tag.score == 1 + assert predicted_order_tag.tag.value == "Chiroptera"