Implement changes needed to make roi encode/decode class dependent

This commit is contained in:
mbsantiago 2025-06-23 18:52:36 +01:00
parent 3407e1b5f0
commit c7ea361cf4
17 changed files with 528 additions and 85 deletions

View File

@ -157,4 +157,4 @@ def load_config(
if field: if field:
config = get_object_field(config, field) config = get_object_field(config, field)
return schema.model_validate(config) return schema.model_validate(config or {})

View File

@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
The pipeline involves several configurable steps, implemented in submodules: The pipeline involves several configurable steps, implemented in submodules:
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks. 1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency 2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
coordinates to raw model output arrays. model output arrays.
3. Detection Extraction (`.detection`): Identifies candidate detection points 3. Detection Extraction (`.detection`): Identifies candidate detection points
(location and score) based on thresholds and score ranking (top-k). (location and score) based on thresholds and score ranking (top-k).
4. Data Extraction (`.extraction`): Gathers associated model outputs (size, 4. Data Extraction (`.extraction`): Gathers associated model outputs (size,

View File

@ -4,8 +4,7 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
It takes the structured detection data extracted by the `extraction` module It takes the structured detection data extracted by the `extraction` module
(typically an `xarray.Dataset` containing scores, positions, predicted sizes, (typically an `xarray.Dataset` containing scores, positions, predicted sizes,
class probabilities, and features for each detection point) and converts it class probabilities, and features for each detection point) and converts it
into meaningful, standardized prediction objects based on the `soundevent` data into standardized prediction objects based on the `soundevent` data model.
model.
The process involves: The process involves:
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction` 1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
@ -33,7 +32,7 @@ import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
from batdetect2.targets.classes import SoundEventDecoder from batdetect2.targets.classes import SoundEventDecoder
from batdetect2.utils.arrays import iterate_over_array from batdetect2.utils.arrays import iterate_over_array
@ -55,7 +54,7 @@ decoding.
def convert_xr_dataset_to_raw_prediction( def convert_xr_dataset_to_raw_prediction(
detection_dataset: xr.Dataset, detection_dataset: xr.Dataset,
geometry_builder: GeometryBuilder, geometry_decoder: GeometryDecoder,
) -> List[RawPrediction]: ) -> List[RawPrediction]:
"""Convert an xarray.Dataset of detections to RawPrediction objects. """Convert an xarray.Dataset of detections to RawPrediction objects.
@ -72,7 +71,7 @@ def convert_xr_dataset_to_raw_prediction(
output by `extract_detection_xr_dataset`. Expected variables include output by `extract_detection_xr_dataset`. Expected variables include
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'. 'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
Must have a 'detection' dimension. Must have a 'detection' dimension.
geometry_builder : GeometryBuilder geometry_decoder : GeometryDecoder
A function that takes a position tuple `(time, freq)` and a NumPy array A function that takes a position tuple `(time, freq)` and a NumPy array
of dimensions, and returns the corresponding reconstructed of dimensions, and returns the corresponding reconstructed
`soundevent.data.Geometry`. `soundevent.data.Geometry`.
@ -96,14 +95,20 @@ def convert_xr_dataset_to_raw_prediction(
for det_num in range(detection_dataset.sizes["detection"]): for det_num in range(detection_dataset.sizes["detection"]):
det_info = detection_dataset.sel(detection=det_num) det_info = detection_dataset.sel(detection=det_num)
geom = geometry_builder( # TODO: Maybe clean this up
highest_scoring_class = det_info.coords["category"][
det_info["classes"].argmax()
].item()
geom = geometry_decoder(
(det_info.time, det_info.frequency), (det_info.time, det_info.frequency),
det_info.dimensions, det_info.dimensions,
class_name=highest_scoring_class,
) )
detections.append( detections.append(
RawPrediction( RawPrediction(
detection_score=det_info.score, detection_score=det_info.scores,
geometry=geom, geometry=geom,
class_scores=det_info.classes, class_scores=det_info.classes,
features=det_info.features, features=det_info.features,

View File

@ -1,6 +1,6 @@
"""Extracts candidate detection points from a model output heatmap. """Extracts candidate detection points from a model output heatmap.
This module implements a specific step within the BatDetect2 postprocessing This module implements Step 3 within the BatDetect2 postprocessing
pipeline. Its primary function is to identify potential sound event locations pipeline. Its primary function is to identify potential sound event locations
by finding peaks (local maxima or high-scoring points) in the detection heatmap by finding peaks (local maxima or high-scoring points) in the detection heatmap
produced by the neural network (usually after Non-Maximum Suppression and produced by the neural network (usually after Non-Maximum Suppression and

View File

@ -1,9 +1,9 @@
"""Extracts associated data for detected points from model output arrays. """Extracts associated data for detected points from model output arrays.
This module implements a key step (Step 4) in the BatDetect2 postprocessing This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
pipeline. After candidate detection points (time, frequency, score) have been After candidate detection points (time, frequency, score) have been identified,
identified, this module extracts the corresponding values from other raw model this module extracts the corresponding values from other raw model output
output arrays, such as: arrays, such as:
- Predicted bounding box sizes (width, height). - Predicted bounding box sizes (width, height).
- Class probability scores for each defined target class. - Class probability scores for each defined target class.

View File

@ -11,30 +11,37 @@ modularity and consistent interaction between different parts of the BatDetect2
system that deal with model predictions. system that deal with model predictions.
""" """
from typing import Callable, List, NamedTuple, Protocol from typing import List, NamedTuple, Optional, Protocol
import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.targets.types import Position, Size
__all__ = [ __all__ = [
"RawPrediction", "RawPrediction",
"PostprocessorProtocol", "PostprocessorProtocol",
"GeometryBuilder", "GeometryDecoder",
] ]
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry] # TODO: update the docstring
"""Type alias for a function that recovers geometry from position and size. class GeometryDecoder(Protocol):
"""Type alias for a function that recovers geometry from position and size.
This callable takes: This callable takes:
1. A position tuple `(time, frequency)`. 1. A position tuple `(time, frequency)`.
2. A NumPy array of size dimensions (e.g., `[width, height]`). 2. A NumPy array of size dimensions (e.g., `[width, height]`).
It should return the reconstructed `soundevent.data.Geometry` (typically a 3. Optionally a class name of the highest scoring class. This is to accomodate
`BoundingBox`). different ways of decoding geometry that depend on the predicted class.
""" It should return the reconstructed `soundevent.data.Geometry` (typically a
`BoundingBox`).
"""
def __call__(
self, position: Position, size: Size, class_name: Optional[str] = None
) -> data.Geometry: ...
class RawPrediction(NamedTuple): class RawPrediction(NamedTuple):

View File

@ -23,6 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional from typing import List, Optional
from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -49,7 +50,8 @@ from batdetect2.targets.filtering import (
load_filter_from_config, load_filter_from_config,
) )
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
BBoxAnchorMapperConfig, AnchorBBoxMapperConfig,
ROIMapperConfig,
ROITargetMapper, ROITargetMapper,
build_roi_mapper, build_roi_mapper,
) )
@ -58,11 +60,11 @@ from batdetect2.targets.terms import (
TermInfo, TermInfo,
TermRegistry, TermRegistry,
call_type, call_type,
default_term_registry,
get_tag_from_info, get_tag_from_info,
get_term_from_key, get_term_from_key,
individual, individual,
register_term, register_term,
default_term_registry,
) )
from batdetect2.targets.transform import ( from batdetect2.targets.transform import (
DerivationRegistry, DerivationRegistry,
@ -87,7 +89,7 @@ __all__ = [
"FilterConfig", "FilterConfig",
"FilterRule", "FilterRule",
"MapValueRule", "MapValueRule",
"BBoxAnchorMapperConfig", "AnchorBBoxMapperConfig",
"ROITargetMapper", "ROITargetMapper",
"ReplaceRule", "ReplaceRule",
"SoundEventDecoder", "SoundEventDecoder",
@ -160,7 +162,7 @@ class TargetConfig(BaseConfig):
classes: ClassesConfig = Field( classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG default_factory=lambda: DEFAULT_CLASSES_CONFIG
) )
roi: Optional[BBoxAnchorMapperConfig] = None roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
def load_target_config( def load_target_config(
@ -239,6 +241,7 @@ class Targets(TargetProtocol):
generic_class_tags: List[data.Tag], generic_class_tags: List[data.Tag],
filter_fn: Optional[SoundEventFilter] = None, filter_fn: Optional[SoundEventFilter] = None,
transform_fn: Optional[SoundEventTransformation] = None, transform_fn: Optional[SoundEventTransformation] = None,
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
): ):
"""Initialize the Targets object. """Initialize the Targets object.
@ -271,6 +274,16 @@ class Targets(TargetProtocol):
self._encode_fn = encode_fn self._encode_fn = encode_fn
self._decode_fn = decode_fn self._decode_fn = decode_fn
self._transform_fn = transform_fn self._transform_fn = transform_fn
self._roi_mapper_overrides = roi_mapper_overrides or {}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation. """Apply the configured filter to a sound event annotation.
@ -375,9 +388,21 @@ class Targets(TargetProtocol):
ValueError ValueError
If the annotation lacks geometry. If the annotation lacks geometry.
""" """
class_name = self.encode_class(sound_event)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
)
return self._roi_mapper.encode(sound_event.sound_event) return self._roi_mapper.encode(sound_event.sound_event)
def decode_roi(self, position: Position, size: Size) -> data.Geometry: def decode_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions. """Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which Delegates to the internal ROI mapper's `recover_roi` method, which
@ -397,6 +422,13 @@ class Targets(TargetProtocol):
data.Geometry data.Geometry
The reconstructed geometry (typically `BoundingBox`). The reconstructed geometry (typically `BoundingBox`).
""" """
print(class_name)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size) return self._roi_mapper.decode(position, size)
@ -452,10 +484,12 @@ DEFAULT_CLASSES = [
TargetClass( TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")], tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei", name="nyclei",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
), ),
TargetClass( TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")], tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer", name="rhifer",
roi=AnchorBBoxMapperConfig(anchor="top-left"),
), ),
TargetClass( TargetClass(
tags=[TagInfo(value="Plecotus auritus")], tags=[TagInfo(value="Plecotus auritus")],
@ -496,6 +530,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
] ]
), ),
classes=DEFAULT_CLASSES_CONFIG, classes=DEFAULT_CLASSES_CONFIG,
roi=AnchorBBoxMapperConfig(),
) )
@ -565,12 +600,17 @@ def build_targets(
if config.transforms if config.transforms
else None else None
) )
roi_mapper = build_roi_mapper(config.roi or BBoxAnchorMapperConfig()) roi_mapper = build_roi_mapper(config.roi)
class_names = get_class_names_from_config(config.classes) class_names = get_class_names_from_config(config.classes)
generic_class_tags = build_generic_class_tags( generic_class_tags = build_generic_class_tags(
config.classes, config.classes,
term_registry=term_registry, term_registry=term_registry,
) )
roi_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classes.classes
if class_config.roi is not None
}
return Targets( return Targets(
filter_fn=filter_fn, filter_fn=filter_fn,
@ -580,6 +620,7 @@ def build_targets(
roi_mapper=roi_mapper, roi_mapper=roi_mapper,
generic_class_tags=generic_class_tags, generic_class_tags=generic_class_tags,
transform_fn=transform_fn, transform_fn=transform_fn,
roi_mapper_overrides=roi_overrides,
) )

View File

@ -6,49 +6,26 @@ from pydantic import Field, field_validator
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import ( from batdetect2.targets.terms import (
GENERIC_CLASS_KEY, GENERIC_CLASS_KEY,
TagInfo, TagInfo,
TermRegistry, TermRegistry,
get_tag_from_info,
default_term_registry, default_term_registry,
get_tag_from_info,
) )
__all__ = [ __all__ = [
"SoundEventEncoder",
"SoundEventDecoder",
"TargetClass",
"ClassesConfig",
"load_classes_config",
"load_encoder_from_config",
"load_decoder_from_config",
"build_sound_event_encoder",
"build_sound_event_decoder",
"build_generic_class_tags",
"get_class_names_from_config",
"DEFAULT_SPECIES_LIST", "DEFAULT_SPECIES_LIST",
"PositionMethod", "build_generic_class_tags",
"CornerPosition", "build_sound_event_decoder",
"SizeMethod", "build_sound_event_encoder",
"BoundingBoxSize", "get_class_names_from_config",
"load_classes_config",
"load_decoder_from_config",
"load_encoder_from_config",
] ]
class PositionMethod(BaseConfig):
"""Base class for defining how to select a position from a geometry."""
method_type: str
class CornerPosition(PositionMethod):
"""Selects a position based on a corner or center of the bounding box."""
method_type: Literal["corner"] = "corner"
corner: Literal["upper_left", "lower_left", "center"] = "lower_left"
class SizeMethod(BaseConfig):
"""Base class for defining how to select a size from a geometry."""
method_type: str
class BoundingBoxSize(SizeMethod):
"""Uses the width and height of the bounding box as the size."""
method_type: Literal["bounding_box"] = "bounding_box"
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]] SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Type alias for a sound event class encoder function. """Type alias for a sound event class encoder function.
@ -134,8 +111,7 @@ class TargetClass(BaseConfig):
tags: List[TagInfo] = Field(min_length=1) tags: List[TagInfo] = Field(min_length=1)
match_type: Literal["all", "any"] = Field(default="all") match_type: Literal["all", "any"] = Field(default="all")
output_tags: Optional[List[TagInfo]] = None output_tags: Optional[List[TagInfo]] = None
position_method: PositionMethod = Field(default_factory=lambda: CornerPosition(corner="lower_left")) roi: Optional[ROIMapperConfig] = None
size_method: SizeMethod = Field(default_factory=BoundingBoxSize)
def _get_default_classes() -> List[TargetClass]: def _get_default_classes() -> List[TargetClass]:
@ -258,7 +234,7 @@ class ClassesConfig(BaseConfig):
return v return v
def _is_target_class( def is_target_class(
sound_event_annotation: data.SoundEventAnnotation, sound_event_annotation: data.SoundEventAnnotation,
tags: Set[data.Tag], tags: Set[data.Tag],
match_all: bool = True, match_all: bool = True,
@ -373,7 +349,7 @@ def build_sound_event_encoder(
( (
class_info.name, class_info.name,
partial( partial(
_is_target_class, is_target_class,
tags={ tags={
get_tag_from_info(tag_info, term_registry=term_registry) get_tag_from_info(tag_info, term_registry=term_registry)
for tag_info in class_info.tags for tag_info in class_info.tags

View File

@ -34,7 +34,7 @@ from batdetect2.targets.types import Position, Size
__all__ = [ __all__ = [
"Anchor", "Anchor",
"AnchorBBoxMapper", "AnchorBBoxMapper",
"BBoxAnchorMapperConfig", "AnchorBBoxMapperConfig",
"DEFAULT_ANCHOR", "DEFAULT_ANCHOR",
"DEFAULT_FREQUENCY_SCALE", "DEFAULT_FREQUENCY_SCALE",
"DEFAULT_TIME_SCALE", "DEFAULT_TIME_SCALE",
@ -148,7 +148,7 @@ class ROITargetMapper(Protocol):
... ...
class BBoxAnchorMapperConfig(BaseConfig): class AnchorBBoxMapperConfig(BaseConfig):
"""Configuration for `AnchorBBoxMapper`. """Configuration for `AnchorBBoxMapper`.
Defines parameters for converting ROIs into targets using a fixed anchor Defines parameters for converting ROIs into targets using a fixed anchor
@ -470,7 +470,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
ROIMapperConfig = Annotated[ ROIMapperConfig = Annotated[
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig], Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
Field(discriminator="name"), Field(discriminator="name"),
] ]
"""A discriminated union of all supported ROI mapper configurations. """A discriminated union of all supported ROI mapper configurations.
@ -480,7 +480,9 @@ implementations by using the `name` field as a discriminator.
""" """
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper: def build_roi_mapper(
config: Optional[ROIMapperConfig] = None,
) -> ROITargetMapper:
"""Factory function to create an ROITargetMapper from a config object. """Factory function to create an ROITargetMapper from a config object.
Parameters Parameters
@ -498,6 +500,8 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
NotImplementedError NotImplementedError
If the `name` in the config does not correspond to a known mapper. If the `name` in the config does not correspond to a known mapper.
""" """
config = config or AnchorBBoxMapperConfig()
if config.name == "anchor_bbox": if config.name == "anchor_bbox":
return AnchorBBoxMapper( return AnchorBBoxMapper(
anchor=config.anchor, anchor=config.anchor,

View File

@ -235,6 +235,7 @@ default_term_registry = TermRegistry(
[ [
*getmembers(terms, lambda x: isinstance(x, data.Term)), *getmembers(terms, lambda x: isinstance(x, data.Term)),
("event", call_type), ("event", call_type),
("species", terms.scientific_name),
("individual", individual), ("individual", individual),
("data_source", data_source), ("data_source", data_source),
(GENERIC_CLASS_KEY, generic_class), (GENERIC_CLASS_KEY, generic_class),

View File

@ -181,7 +181,13 @@ class TargetProtocol(Protocol):
""" """
... ...
def decode_roi(self, position: Position, size: Size) -> data.Geometry: # TODO: Update docstrings
def decode_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
) -> data.Geometry:
"""Recover the ROI geometry from a position and dimensions. """Recover the ROI geometry from a position and dimensions.
Performs the inverse mapping of `get_position` and `get_size`. It takes Performs the inverse mapping of `get_position` and `get_size`. It takes
@ -195,6 +201,8 @@ class TargetProtocol(Protocol):
dims : np.ndarray dims : np.ndarray
The NumPy array containing the dimensions (e.g., predicted The NumPy array containing the dimensions (e.g., predicted
by the model), corresponding to the order in `dimension_names`. by the model), corresponding to the order in `dimension_names`.
class_name: str
class
Returns Returns
------- -------

View File

@ -1,6 +1,7 @@
import uuid import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional from typing import Callable, List, Optional
from uuid import uuid4
import numpy as np import numpy as np
import pytest import pytest
@ -447,3 +448,15 @@ def example_annotations(
annotations = load_dataset(example_dataset) annotations = load_dataset(example_dataset)
assert len(annotations) == 3 assert len(annotations) == 3
return annotations return annotations
@pytest.fixture
def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
"""Create a temporary YAML file with the given content."""
def factory(content: str) -> Path:
temp_file = tmp_path / f"{uuid4()}.yaml"
temp_file.write_text(content)
return temp_file
return factory

View File

@ -12,11 +12,11 @@ from batdetect2.targets.classes import (
TargetClass, TargetClass,
_get_default_class_name, _get_default_class_name,
_get_default_classes, _get_default_classes,
_is_target_class,
build_generic_class_tags, build_generic_class_tags,
build_sound_event_decoder, build_sound_event_decoder,
build_sound_event_encoder, build_sound_event_encoder,
get_class_names_from_config, get_class_names_from_config,
is_target_class,
load_classes_config, load_classes_config,
load_decoder_from_config, load_decoder_from_config,
load_encoder_from_config, load_encoder_from_config,
@ -145,7 +145,7 @@ def test_is_target_class_match_all(
), ),
data.Tag(term=sample_term_registry["quality"], value="Good"), data.Tag(term=sample_term_registry["quality"], value="Good"),
} }
assert _is_target_class(sample_annotation, tags, match_all=True) is True assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = { tags = {
data.Tag( data.Tag(
@ -153,14 +153,14 @@ def test_is_target_class_match_all(
value="Pipistrellus pipistrellus", value="Pipistrellus pipistrellus",
) )
} }
assert _is_target_class(sample_annotation, tags, match_all=True) is True assert is_target_class(sample_annotation, tags, match_all=True) is True
tags = { tags = {
data.Tag( data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii" term=sample_term_registry["species"], value="Myotis daubentonii"
) )
} }
assert _is_target_class(sample_annotation, tags, match_all=True) is False assert is_target_class(sample_annotation, tags, match_all=True) is False
def test_is_target_class_match_any( def test_is_target_class_match_any(
@ -174,7 +174,7 @@ def test_is_target_class_match_any(
), ),
data.Tag(term=sample_term_registry["quality"], value="Good"), data.Tag(term=sample_term_registry["quality"], value="Good"),
} }
assert _is_target_class(sample_annotation, tags, match_all=False) is True assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = { tags = {
data.Tag( data.Tag(
@ -182,14 +182,14 @@ def test_is_target_class_match_any(
value="Pipistrellus pipistrellus", value="Pipistrellus pipistrellus",
) )
} }
assert _is_target_class(sample_annotation, tags, match_all=False) is True assert is_target_class(sample_annotation, tags, match_all=False) is True
tags = { tags = {
data.Tag( data.Tag(
term=sample_term_registry["species"], value="Myotis daubentonii" term=sample_term_registry["species"], value="Myotis daubentonii"
) )
} }
assert _is_target_class(sample_annotation, tags, match_all=False) is False assert is_target_class(sample_annotation, tags, match_all=False) is False
def test_get_class_names_from_config(): def test_get_class_names_from_config():

View File

@ -11,7 +11,7 @@ from batdetect2.targets.rois import (
SIZE_HEIGHT, SIZE_HEIGHT,
SIZE_WIDTH, SIZE_WIDTH,
AnchorBBoxMapper, AnchorBBoxMapper,
BBoxAnchorMapperConfig, AnchorBBoxMapperConfig,
PeakEnergyBBoxMapper, PeakEnergyBBoxMapper,
PeakEnergyBBoxMapperConfig, PeakEnergyBBoxMapperConfig,
_build_bounding_box, _build_bounding_box,
@ -243,7 +243,7 @@ def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper):
def test_build_roi_mapper(): def test_build_roi_mapper():
"""Test build_roi_mapper creates a configured BBoxEncoder.""" """Test build_roi_mapper creates a configured BBoxEncoder."""
config = BBoxAnchorMapperConfig( config = AnchorBBoxMapperConfig(
anchor="top-right", time_scale=2.0, frequency_scale=20.0 anchor="top-right", time_scale=2.0, frequency_scale=20.0
) )
mapper = build_roi_mapper(config) mapper = build_roi_mapper(config)
@ -571,7 +571,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
def test_build_roi_mapper_for_anchor_bbox(): def test_build_roi_mapper_for_anchor_bbox():
# Given # Given
config = BBoxAnchorMapperConfig( config = AnchorBBoxMapperConfig(
anchor="center", anchor="center",
time_scale=123.0, time_scale=123.0,
frequency_scale=456.0, frequency_scale=456.0,

View File

@ -0,0 +1,117 @@
from collections.abc import Callable
from pathlib import Path
from soundevent import data
from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
def test_can_override_default_roi_mapper_per_class(
create_temp_yaml: Callable[..., Path],
recording: data.Recording,
sample_term_registry,
):
yaml_content = """
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
"""
config_path = create_temp_yaml(yaml_content)
config = load_target_config(config_path)
targets = build_targets(config, term_registry=sample_term_registry)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
)
se2 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Myotis myotis")],
)
(time1, freq1), _ = targets.encode_roi(se1)
(time2, freq2), _ = targets.encode_roi(se2)
assert time1 == time2 == 0.1
assert freq1 == 12_000
assert freq2 == 18_000
# TODO: rename this test function
def test_roi_is_recovered_roundtrip_even_with_overriders(
create_temp_yaml,
sample_term_registry,
recording,
):
yaml_content = """
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
"""
config_path = create_temp_yaml(yaml_content)
config = load_target_config(config_path)
targets = build_targets(config, term_registry=sample_term_registry)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
)
se2 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Myotis myotis")],
)
position1, size1 = targets.encode_roi(se1)
position2, size2 = targets.encode_roi(se2)
class_name1 = targets.encode_class(se1)
class_name2 = targets.encode_class(se2)
recovered1 = targets.decode_roi(position1, size1, class_name=class_name1)
recovered2 = targets.decode_roi(position2, size2, class_name=class_name2)
assert recovered1 == geometry
assert recovered2 == geometry

View File

@ -5,7 +5,7 @@ import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.targets.rois import BBoxAnchorMapperConfig from batdetect2.targets.rois import AnchorBBoxMapperConfig
from batdetect2.targets.terms import TagInfo, TermRegistry from batdetect2.targets.terms import TagInfo, TermRegistry
from batdetect2.train.labels import generate_heatmaps from batdetect2.train.labels import generate_heatmaps
@ -85,7 +85,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
): ):
config = sample_target_config.model_copy( config = sample_target_config.model_copy(
update=dict( update=dict(
roi=BBoxAnchorMapperConfig( roi=AnchorBBoxMapperConfig(
time_scale=1, time_scale=1,
frequency_scale=1, frequency_scale=1,
) )

View File

@ -0,0 +1,271 @@
import pytest
import torch
import xarray as xr
from soundevent import data
from batdetect2.models.types import ModelOutput
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.targets.terms import get_term_from_key
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.preprocess import generate_train_example
@pytest.fixture
def build_from_config(
create_temp_yaml,
sample_term_registry,
):
def build(yaml_content):
config_path = create_temp_yaml(yaml_content)
targets_config = load_target_config(config_path, field="targets")
preprocessing_config = load_preprocessing_config(
config_path,
field="preprocessing",
)
labels_config = load_label_config(config_path, field="labels")
postprocessing_config = load_postprocess_config(
config_path,
field="postprocessing",
)
targets = build_targets(
targets_config, term_registry=sample_term_registry
)
preprocessor = build_preprocessor(preprocessing_config)
labeller = build_clip_labeler(
targets=targets,
config=labels_config,
)
postprocessor = build_postprocessor(
targets,
config=postprocessing_config,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
return targets, preprocessor, labeller, postprocessor
return build
# TODO: better name
def test_generated_train_example_has_expected_outputs(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
generic_class:
- key: order
value: Chiroptera
preprocessing:
postprocessing:
"""
_, preprocessor, labeller, _ = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
)
clip_annotation = data.ClipAnnotation(
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
sound_events=[se1],
)
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
assert isinstance(encoded, xr.Dataset)
assert "audio" in encoded
assert "spectrogram" in encoded
assert "detection" in encoded
assert "class" in encoded
assert "size" in encoded
spec_shape = encoded["spectrogram"].shape
assert len(spec_shape) == 2
height, width = spec_shape
assert encoded["detection"].shape == (height, width)
assert encoded["class"].shape == (1, height, width)
assert encoded["size"].shape == (2, height, width)
def test_encoding_decoding_roundtrip_recovers_object(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=torch.tensor([[encoded["detection"].data]]),
size_preds=torch.tensor([encoded["size"].data]),
class_probs=torch.tensor([encoded["class"].data]),
features=torch.tensor([[encoded["spectrogram"].data]]),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == species), None
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
build_from_config,
sample_term_registry,
recording,
):
yaml_content = """
labels:
targets:
roi:
name: anchor_bbox
anchor: bottom-left
classes:
classes:
- name: pippip
tags:
- key: species
value: Pipistrellus pipistrellus
- name: myomyo
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
generic_class:
- key: order
value: Chiroptera
preprocessing:
"""
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
species = get_term_from_key("species", term_registry=sample_term_registry)
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
tags=[data.Tag(term=species, value="Myotis myotis")],
)
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
predictions = postprocessor.get_predictions(
ModelOutput(
detection_probs=torch.tensor([[encoded["detection"].data]]),
size_preds=torch.tensor([encoded["size"].data]),
class_probs=torch.tensor([encoded["class"].data]),
features=torch.tensor([[encoded["spectrogram"].data]]),
),
[clip],
)[0]
assert isinstance(predictions, data.ClipPrediction)
assert len(predictions.sound_events) == 1
recovered = predictions.sound_events[0]
assert recovered.sound_event.geometry is not None
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
recovered.sound_event.geometry.coordinates
)
start_time_or, low_freq_or, end_time_or, high_freq_or = (
geometry.coordinates
)
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
assert len(recovered.tags) == 2
predicted_species_tag = next(
iter(t for t in recovered.tags if t.tag.term == species), None
)
assert predicted_species_tag is not None
assert predicted_species_tag.score == 1
assert predicted_species_tag.tag.value == "Myotis myotis"
predicted_order_tag = next(
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
)
assert predicted_order_tag is not None
assert predicted_order_tag.score == 1
assert predicted_order_tag.tag.value == "Chiroptera"