mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Implement changes needed to make roi encode/decode class dependent
This commit is contained in:
parent
3407e1b5f0
commit
c7ea361cf4
@ -157,4 +157,4 @@ def load_config(
|
||||
if field:
|
||||
config = get_object_field(config, field)
|
||||
|
||||
return schema.model_validate(config)
|
||||
return schema.model_validate(config or {})
|
||||
|
@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
|
||||
|
||||
The pipeline involves several configurable steps, implemented in submodules:
|
||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
||||
coordinates to raw model output arrays.
|
||||
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
||||
model output arrays.
|
||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||
(location and score) based on thresholds and score ranking (top-k).
|
||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||
|
@ -4,8 +4,7 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
||||
It takes the structured detection data extracted by the `extraction` module
|
||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||
class probabilities, and features for each detection point) and converts it
|
||||
into meaningful, standardized prediction objects based on the `soundevent` data
|
||||
model.
|
||||
into standardized prediction objects based on the `soundevent` data model.
|
||||
|
||||
The process involves:
|
||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||
@ -33,7 +32,7 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
||||
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||
from batdetect2.targets.classes import SoundEventDecoder
|
||||
from batdetect2.utils.arrays import iterate_over_array
|
||||
|
||||
@ -55,7 +54,7 @@ decoding.
|
||||
|
||||
def convert_xr_dataset_to_raw_prediction(
|
||||
detection_dataset: xr.Dataset,
|
||||
geometry_builder: GeometryBuilder,
|
||||
geometry_decoder: GeometryDecoder,
|
||||
) -> List[RawPrediction]:
|
||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||
|
||||
@ -72,7 +71,7 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
output by `extract_detection_xr_dataset`. Expected variables include
|
||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||
Must have a 'detection' dimension.
|
||||
geometry_builder : GeometryBuilder
|
||||
geometry_decoder : GeometryDecoder
|
||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||
of dimensions, and returns the corresponding reconstructed
|
||||
`soundevent.data.Geometry`.
|
||||
@ -96,14 +95,20 @@ def convert_xr_dataset_to_raw_prediction(
|
||||
for det_num in range(detection_dataset.sizes["detection"]):
|
||||
det_info = detection_dataset.sel(detection=det_num)
|
||||
|
||||
geom = geometry_builder(
|
||||
# TODO: Maybe clean this up
|
||||
highest_scoring_class = det_info.coords["category"][
|
||||
det_info["classes"].argmax()
|
||||
].item()
|
||||
|
||||
geom = geometry_decoder(
|
||||
(det_info.time, det_info.frequency),
|
||||
det_info.dimensions,
|
||||
class_name=highest_scoring_class,
|
||||
)
|
||||
|
||||
detections.append(
|
||||
RawPrediction(
|
||||
detection_score=det_info.score,
|
||||
detection_score=det_info.scores,
|
||||
geometry=geom,
|
||||
class_scores=det_info.classes,
|
||||
features=det_info.features,
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Extracts candidate detection points from a model output heatmap.
|
||||
|
||||
This module implements a specific step within the BatDetect2 postprocessing
|
||||
This module implements Step 3 within the BatDetect2 postprocessing
|
||||
pipeline. Its primary function is to identify potential sound event locations
|
||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||
produced by the neural network (usually after Non-Maximum Suppression and
|
||||
|
@ -1,9 +1,9 @@
|
||||
"""Extracts associated data for detected points from model output arrays.
|
||||
|
||||
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
||||
pipeline. After candidate detection points (time, frequency, score) have been
|
||||
identified, this module extracts the corresponding values from other raw model
|
||||
output arrays, such as:
|
||||
This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
|
||||
After candidate detection points (time, frequency, score) have been identified,
|
||||
this module extracts the corresponding values from other raw model output
|
||||
arrays, such as:
|
||||
|
||||
- Predicted bounding box sizes (width, height).
|
||||
- Class probability scores for each defined target class.
|
||||
|
@ -11,31 +11,38 @@ modularity and consistent interaction between different parts of the BatDetect2
|
||||
system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from typing import Callable, List, NamedTuple, Protocol
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.targets.types import Position, Size
|
||||
|
||||
__all__ = [
|
||||
"RawPrediction",
|
||||
"PostprocessorProtocol",
|
||||
"GeometryBuilder",
|
||||
"GeometryDecoder",
|
||||
]
|
||||
|
||||
|
||||
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
||||
# TODO: update the docstring
|
||||
class GeometryDecoder(Protocol):
|
||||
"""Type alias for a function that recovers geometry from position and size.
|
||||
|
||||
This callable takes:
|
||||
1. A position tuple `(time, frequency)`.
|
||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||
3. Optionally a class name of the highest scoring class. This is to accomodate
|
||||
different ways of decoding geometry that depend on the predicted class.
|
||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||
`BoundingBox`).
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self, position: Position, size: Size, class_name: Optional[str] = None
|
||||
) -> data.Geometry: ...
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
"""Intermediate representation of a single detected sound event.
|
||||
|
@ -23,6 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
@ -49,7 +50,8 @@ from batdetect2.targets.filtering import (
|
||||
load_filter_from_config,
|
||||
)
|
||||
from batdetect2.targets.rois import (
|
||||
BBoxAnchorMapperConfig,
|
||||
AnchorBBoxMapperConfig,
|
||||
ROIMapperConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
@ -58,11 +60,11 @@ from batdetect2.targets.terms import (
|
||||
TermInfo,
|
||||
TermRegistry,
|
||||
call_type,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
get_term_from_key,
|
||||
individual,
|
||||
register_term,
|
||||
default_term_registry,
|
||||
)
|
||||
from batdetect2.targets.transform import (
|
||||
DerivationRegistry,
|
||||
@ -87,7 +89,7 @@ __all__ = [
|
||||
"FilterConfig",
|
||||
"FilterRule",
|
||||
"MapValueRule",
|
||||
"BBoxAnchorMapperConfig",
|
||||
"AnchorBBoxMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"ReplaceRule",
|
||||
"SoundEventDecoder",
|
||||
@ -160,7 +162,7 @@ class TargetConfig(BaseConfig):
|
||||
classes: ClassesConfig = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||
)
|
||||
roi: Optional[BBoxAnchorMapperConfig] = None
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
|
||||
|
||||
def load_target_config(
|
||||
@ -239,6 +241,7 @@ class Targets(TargetProtocol):
|
||||
generic_class_tags: List[data.Tag],
|
||||
filter_fn: Optional[SoundEventFilter] = None,
|
||||
transform_fn: Optional[SoundEventTransformation] = None,
|
||||
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||
):
|
||||
"""Initialize the Targets object.
|
||||
|
||||
@ -271,6 +274,16 @@ class Targets(TargetProtocol):
|
||||
self._encode_fn = encode_fn
|
||||
self._decode_fn = decode_fn
|
||||
self._transform_fn = transform_fn
|
||||
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
if class_name not in self.class_names:
|
||||
# TODO: improve this warning
|
||||
logger.warning(
|
||||
"The ROI mapper overrides contains a class ({class_name}) "
|
||||
"not present in the class names.",
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
@ -375,9 +388,21 @@ class Targets(TargetProtocol):
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
class_name = self.encode_class(sound_event)
|
||||
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].encode(
|
||||
sound_event.sound_event
|
||||
)
|
||||
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
@ -397,6 +422,13 @@ class Targets(TargetProtocol):
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
print(class_name)
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].decode(
|
||||
position,
|
||||
size,
|
||||
)
|
||||
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
@ -452,10 +484,12 @@ DEFAULT_CLASSES = [
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||
name="nyclei",
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||
name="rhifer",
|
||||
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
),
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Plecotus auritus")],
|
||||
@ -496,6 +530,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
]
|
||||
),
|
||||
classes=DEFAULT_CLASSES_CONFIG,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
@ -565,12 +600,17 @@ def build_targets(
|
||||
if config.transforms
|
||||
else None
|
||||
)
|
||||
roi_mapper = build_roi_mapper(config.roi or BBoxAnchorMapperConfig())
|
||||
roi_mapper = build_roi_mapper(config.roi)
|
||||
class_names = get_class_names_from_config(config.classes)
|
||||
generic_class_tags = build_generic_class_tags(
|
||||
config.classes,
|
||||
term_registry=term_registry,
|
||||
)
|
||||
roi_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classes.classes
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
return Targets(
|
||||
filter_fn=filter_fn,
|
||||
@ -580,6 +620,7 @@ def build_targets(
|
||||
roi_mapper=roi_mapper,
|
||||
generic_class_tags=generic_class_tags,
|
||||
transform_fn=transform_fn,
|
||||
roi_mapper_overrides=roi_overrides,
|
||||
)
|
||||
|
||||
|
||||
|
@ -6,49 +6,26 @@ from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.rois import ROIMapperConfig
|
||||
from batdetect2.targets.terms import (
|
||||
GENERIC_CLASS_KEY,
|
||||
TagInfo,
|
||||
TermRegistry,
|
||||
get_tag_from_info,
|
||||
default_term_registry,
|
||||
get_tag_from_info,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SoundEventEncoder",
|
||||
"SoundEventDecoder",
|
||||
"TargetClass",
|
||||
"ClassesConfig",
|
||||
"load_classes_config",
|
||||
"load_encoder_from_config",
|
||||
"load_decoder_from_config",
|
||||
"build_sound_event_encoder",
|
||||
"build_sound_event_decoder",
|
||||
"build_generic_class_tags",
|
||||
"get_class_names_from_config",
|
||||
"DEFAULT_SPECIES_LIST",
|
||||
"PositionMethod",
|
||||
"CornerPosition",
|
||||
"SizeMethod",
|
||||
"BoundingBoxSize",
|
||||
"build_generic_class_tags",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"get_class_names_from_config",
|
||||
"load_classes_config",
|
||||
"load_decoder_from_config",
|
||||
"load_encoder_from_config",
|
||||
]
|
||||
|
||||
class PositionMethod(BaseConfig):
|
||||
"""Base class for defining how to select a position from a geometry."""
|
||||
method_type: str
|
||||
|
||||
class CornerPosition(PositionMethod):
|
||||
"""Selects a position based on a corner or center of the bounding box."""
|
||||
method_type: Literal["corner"] = "corner"
|
||||
corner: Literal["upper_left", "lower_left", "center"] = "lower_left"
|
||||
|
||||
class SizeMethod(BaseConfig):
|
||||
"""Base class for defining how to select a size from a geometry."""
|
||||
method_type: str
|
||||
|
||||
class BoundingBoxSize(SizeMethod):
|
||||
"""Uses the width and height of the bounding box as the size."""
|
||||
method_type: Literal["bounding_box"] = "bounding_box"
|
||||
|
||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Type alias for a sound event class encoder function.
|
||||
@ -134,8 +111,7 @@ class TargetClass(BaseConfig):
|
||||
tags: List[TagInfo] = Field(min_length=1)
|
||||
match_type: Literal["all", "any"] = Field(default="all")
|
||||
output_tags: Optional[List[TagInfo]] = None
|
||||
position_method: PositionMethod = Field(default_factory=lambda: CornerPosition(corner="lower_left"))
|
||||
size_method: SizeMethod = Field(default_factory=BoundingBoxSize)
|
||||
roi: Optional[ROIMapperConfig] = None
|
||||
|
||||
|
||||
def _get_default_classes() -> List[TargetClass]:
|
||||
@ -258,7 +234,7 @@ class ClassesConfig(BaseConfig):
|
||||
return v
|
||||
|
||||
|
||||
def _is_target_class(
|
||||
def is_target_class(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
tags: Set[data.Tag],
|
||||
match_all: bool = True,
|
||||
@ -373,7 +349,7 @@ def build_sound_event_encoder(
|
||||
(
|
||||
class_info.name,
|
||||
partial(
|
||||
_is_target_class,
|
||||
is_target_class,
|
||||
tags={
|
||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||
for tag_info in class_info.tags
|
||||
|
@ -34,7 +34,7 @@ from batdetect2.targets.types import Position, Size
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
"AnchorBBoxMapper",
|
||||
"BBoxAnchorMapperConfig",
|
||||
"AnchorBBoxMapperConfig",
|
||||
"DEFAULT_ANCHOR",
|
||||
"DEFAULT_FREQUENCY_SCALE",
|
||||
"DEFAULT_TIME_SCALE",
|
||||
@ -148,7 +148,7 @@ class ROITargetMapper(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class BBoxAnchorMapperConfig(BaseConfig):
|
||||
class AnchorBBoxMapperConfig(BaseConfig):
|
||||
"""Configuration for `AnchorBBoxMapper`.
|
||||
|
||||
Defines parameters for converting ROIs into targets using a fixed anchor
|
||||
@ -470,7 +470,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
||||
|
||||
|
||||
ROIMapperConfig = Annotated[
|
||||
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
"""A discriminated union of all supported ROI mapper configurations.
|
||||
@ -480,7 +480,9 @@ implementations by using the `name` field as a discriminator.
|
||||
"""
|
||||
|
||||
|
||||
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
||||
def build_roi_mapper(
|
||||
config: Optional[ROIMapperConfig] = None,
|
||||
) -> ROITargetMapper:
|
||||
"""Factory function to create an ROITargetMapper from a config object.
|
||||
|
||||
Parameters
|
||||
@ -498,6 +500,8 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
||||
NotImplementedError
|
||||
If the `name` in the config does not correspond to a known mapper.
|
||||
"""
|
||||
config = config or AnchorBBoxMapperConfig()
|
||||
|
||||
if config.name == "anchor_bbox":
|
||||
return AnchorBBoxMapper(
|
||||
anchor=config.anchor,
|
||||
|
@ -235,6 +235,7 @@ default_term_registry = TermRegistry(
|
||||
[
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
("event", call_type),
|
||||
("species", terms.scientific_name),
|
||||
("individual", individual),
|
||||
("data_source", data_source),
|
||||
(GENERIC_CLASS_KEY, generic_class),
|
||||
|
@ -181,7 +181,13 @@ class TargetProtocol(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
|
||||
# TODO: Update docstrings
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover the ROI geometry from a position and dimensions.
|
||||
|
||||
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
||||
@ -195,6 +201,8 @@ class TargetProtocol(Protocol):
|
||||
dims : np.ndarray
|
||||
The NumPy array containing the dimensions (e.g., predicted
|
||||
by the model), corresponding to the order in `dimension_names`.
|
||||
class_name: str
|
||||
class
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@ -447,3 +448,15 @@ def example_annotations(
|
||||
annotations = load_dataset(example_dataset)
|
||||
assert len(annotations) == 3
|
||||
return annotations
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
||||
"""Create a temporary YAML file with the given content."""
|
||||
|
||||
def factory(content: str) -> Path:
|
||||
temp_file = tmp_path / f"{uuid4()}.yaml"
|
||||
temp_file.write_text(content)
|
||||
return temp_file
|
||||
|
||||
return factory
|
||||
|
@ -12,11 +12,11 @@ from batdetect2.targets.classes import (
|
||||
TargetClass,
|
||||
_get_default_class_name,
|
||||
_get_default_classes,
|
||||
_is_target_class,
|
||||
build_generic_class_tags,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
is_target_class,
|
||||
load_classes_config,
|
||||
load_decoder_from_config,
|
||||
load_encoder_from_config,
|
||||
@ -145,7 +145,7 @@ def test_is_target_class_match_all(
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
}
|
||||
assert _is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
@ -153,14 +153,14 @@ def test_is_target_class_match_all(
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
assert _is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
assert _is_target_class(sample_annotation, tags, match_all=True) is False
|
||||
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||
|
||||
|
||||
def test_is_target_class_match_any(
|
||||
@ -174,7 +174,7 @@ def test_is_target_class_match_any(
|
||||
),
|
||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||
}
|
||||
assert _is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
@ -182,14 +182,14 @@ def test_is_target_class_match_any(
|
||||
value="Pipistrellus pipistrellus",
|
||||
)
|
||||
}
|
||||
assert _is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||
|
||||
tags = {
|
||||
data.Tag(
|
||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||
)
|
||||
}
|
||||
assert _is_target_class(sample_annotation, tags, match_all=False) is False
|
||||
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||
|
||||
|
||||
def test_get_class_names_from_config():
|
||||
|
@ -11,7 +11,7 @@ from batdetect2.targets.rois import (
|
||||
SIZE_HEIGHT,
|
||||
SIZE_WIDTH,
|
||||
AnchorBBoxMapper,
|
||||
BBoxAnchorMapperConfig,
|
||||
AnchorBBoxMapperConfig,
|
||||
PeakEnergyBBoxMapper,
|
||||
PeakEnergyBBoxMapperConfig,
|
||||
_build_bounding_box,
|
||||
@ -243,7 +243,7 @@ def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper):
|
||||
|
||||
def test_build_roi_mapper():
|
||||
"""Test build_roi_mapper creates a configured BBoxEncoder."""
|
||||
config = BBoxAnchorMapperConfig(
|
||||
config = AnchorBBoxMapperConfig(
|
||||
anchor="top-right", time_scale=2.0, frequency_scale=20.0
|
||||
)
|
||||
mapper = build_roi_mapper(config)
|
||||
@ -571,7 +571,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
||||
|
||||
def test_build_roi_mapper_for_anchor_bbox():
|
||||
# Given
|
||||
config = BBoxAnchorMapperConfig(
|
||||
config = AnchorBBoxMapperConfig(
|
||||
anchor="center",
|
||||
time_scale=123.0,
|
||||
frequency_scale=456.0,
|
||||
|
117
tests/test_targets/test_targets.py
Normal file
117
tests/test_targets/test_targets.py
Normal file
@ -0,0 +1,117 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
|
||||
|
||||
def test_can_override_default_roi_mapper_per_class(
|
||||
create_temp_yaml: Callable[..., Path],
|
||||
recording: data.Recording,
|
||||
sample_term_registry,
|
||||
):
|
||||
yaml_content = """
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
"""
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = load_target_config(config_path)
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
|
||||
se2 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||
)
|
||||
|
||||
(time1, freq1), _ = targets.encode_roi(se1)
|
||||
(time2, freq2), _ = targets.encode_roi(se2)
|
||||
|
||||
assert time1 == time2 == 0.1
|
||||
assert freq1 == 12_000
|
||||
assert freq2 == 18_000
|
||||
|
||||
|
||||
# TODO: rename this test function
|
||||
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||
create_temp_yaml,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
"""
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = load_target_config(config_path)
|
||||
targets = build_targets(config, term_registry=sample_term_registry)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
|
||||
se2 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||
)
|
||||
|
||||
position1, size1 = targets.encode_roi(se1)
|
||||
position2, size2 = targets.encode_roi(se2)
|
||||
|
||||
class_name1 = targets.encode_class(se1)
|
||||
class_name2 = targets.encode_class(se2)
|
||||
|
||||
recovered1 = targets.decode_roi(position1, size1, class_name=class_name1)
|
||||
recovered2 = targets.decode_roi(position2, size2, class_name=class_name2)
|
||||
|
||||
assert recovered1 == geometry
|
||||
assert recovered2 == geometry
|
@ -5,7 +5,7 @@ import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.targets.rois import BBoxAnchorMapperConfig
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
|
||||
@ -85,7 +85,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
):
|
||||
config = sample_target_config.model_copy(
|
||||
update=dict(
|
||||
roi=BBoxAnchorMapperConfig(
|
||||
roi=AnchorBBoxMapperConfig(
|
||||
time_scale=1,
|
||||
frequency_scale=1,
|
||||
)
|
||||
|
271
tests/test_train/test_preprocessing.py
Normal file
271
tests/test_train/test_preprocessing.py
Normal file
@ -0,0 +1,271 @@
|
||||
import pytest
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def build_from_config(
|
||||
create_temp_yaml,
|
||||
sample_term_registry,
|
||||
):
|
||||
def build(yaml_content):
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
targets_config = load_target_config(config_path, field="targets")
|
||||
preprocessing_config = load_preprocessing_config(
|
||||
config_path,
|
||||
field="preprocessing",
|
||||
)
|
||||
labels_config = load_label_config(config_path, field="labels")
|
||||
postprocessing_config = load_postprocess_config(
|
||||
config_path,
|
||||
field="postprocessing",
|
||||
)
|
||||
|
||||
targets = build_targets(
|
||||
targets_config, term_registry=sample_term_registry
|
||||
)
|
||||
preprocessor = build_preprocessor(preprocessing_config)
|
||||
labeller = build_clip_labeler(
|
||||
targets=targets,
|
||||
config=labels_config,
|
||||
)
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
config=postprocessing_config,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
return targets, preprocessor, labeller, postprocessor
|
||||
|
||||
return build
|
||||
|
||||
|
||||
# TODO: better name
|
||||
def test_generated_train_example_has_expected_outputs(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
postprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
clip_annotation = data.ClipAnnotation(
|
||||
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
||||
sound_events=[se1],
|
||||
)
|
||||
|
||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||
|
||||
assert isinstance(encoded, xr.Dataset)
|
||||
assert "audio" in encoded
|
||||
assert "spectrogram" in encoded
|
||||
assert "detection" in encoded
|
||||
assert "class" in encoded
|
||||
assert "size" in encoded
|
||||
|
||||
spec_shape = encoded["spectrogram"].shape
|
||||
assert len(spec_shape) == 2
|
||||
|
||||
height, width = spec_shape
|
||||
assert encoded["detection"].shape == (height, width)
|
||||
assert encoded["class"].shape == (1, height, width)
|
||||
assert encoded["size"].shape == (2, height, width)
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
||||
size_preds=torch.tensor([encoded["size"].data]),
|
||||
class_probs=torch.tensor([encoded["class"].data]),
|
||||
features=torch.tensor([[encoded["spectrogram"].data]]),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
|
||||
assert isinstance(predictions, data.ClipPrediction)
|
||||
assert len(predictions.sound_events) == 1
|
||||
|
||||
recovered = predictions.sound_events[0]
|
||||
assert recovered.sound_event.geometry is not None
|
||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||
recovered.sound_event.geometry.coordinates
|
||||
)
|
||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||
geometry.coordinates
|
||||
)
|
||||
|
||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||
|
||||
|
||||
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||
build_from_config,
|
||||
sample_term_registry,
|
||||
recording,
|
||||
):
|
||||
yaml_content = """
|
||||
labels:
|
||||
targets:
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
classes:
|
||||
classes:
|
||||
- name: pippip
|
||||
tags:
|
||||
- key: species
|
||||
value: Pipistrellus pipistrellus
|
||||
- name: myomyo
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
generic_class:
|
||||
- key: order
|
||||
value: Chiroptera
|
||||
preprocessing:
|
||||
"""
|
||||
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||
)
|
||||
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||
|
||||
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||
predictions = postprocessor.get_predictions(
|
||||
ModelOutput(
|
||||
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
||||
size_preds=torch.tensor([encoded["size"].data]),
|
||||
class_probs=torch.tensor([encoded["class"].data]),
|
||||
features=torch.tensor([[encoded["spectrogram"].data]]),
|
||||
),
|
||||
[clip],
|
||||
)[0]
|
||||
|
||||
assert isinstance(predictions, data.ClipPrediction)
|
||||
assert len(predictions.sound_events) == 1
|
||||
|
||||
recovered = predictions.sound_events[0]
|
||||
assert recovered.sound_event.geometry is not None
|
||||
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||
recovered.sound_event.geometry.coordinates
|
||||
)
|
||||
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||
geometry.coordinates
|
||||
)
|
||||
|
||||
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||
|
||||
assert len(recovered.tags) == 2
|
||||
|
||||
predicted_species_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||
)
|
||||
assert predicted_species_tag is not None
|
||||
assert predicted_species_tag.score == 1
|
||||
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||
|
||||
predicted_order_tag = next(
|
||||
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||
)
|
||||
assert predicted_order_tag is not None
|
||||
assert predicted_order_tag.score == 1
|
||||
assert predicted_order_tag.tag.value == "Chiroptera"
|
Loading…
Reference in New Issue
Block a user