mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Implement changes needed to make roi encode/decode class dependent
This commit is contained in:
parent
3407e1b5f0
commit
c7ea361cf4
@ -157,4 +157,4 @@ def load_config(
|
|||||||
if field:
|
if field:
|
||||||
config = get_object_field(config, field)
|
config = get_object_field(config, field)
|
||||||
|
|
||||||
return schema.model_validate(config)
|
return schema.model_validate(config or {})
|
||||||
|
@ -7,8 +7,8 @@ containing detected sound events with associated class tags and geometry.
|
|||||||
|
|
||||||
The pipeline involves several configurable steps, implemented in submodules:
|
The pipeline involves several configurable steps, implemented in submodules:
|
||||||
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
1. Non-Maximum Suppression (`.nms`): Isolates distinct detection peaks.
|
||||||
2. Coordinate Remapping (`.remapping`): Adds real-world time/frequency
|
2. Coordinate Remapping (`.remapping`): Adds time/frequency coordinates to raw
|
||||||
coordinates to raw model output arrays.
|
model output arrays.
|
||||||
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
3. Detection Extraction (`.detection`): Identifies candidate detection points
|
||||||
(location and score) based on thresholds and score ranking (top-k).
|
(location and score) based on thresholds and score ranking (top-k).
|
||||||
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
4. Data Extraction (`.extraction`): Gathers associated model outputs (size,
|
||||||
|
@ -4,8 +4,7 @@ This module handles the final stages of the BatDetect2 postprocessing pipeline.
|
|||||||
It takes the structured detection data extracted by the `extraction` module
|
It takes the structured detection data extracted by the `extraction` module
|
||||||
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
(typically an `xarray.Dataset` containing scores, positions, predicted sizes,
|
||||||
class probabilities, and features for each detection point) and converts it
|
class probabilities, and features for each detection point) and converts it
|
||||||
into meaningful, standardized prediction objects based on the `soundevent` data
|
into standardized prediction objects based on the `soundevent` data model.
|
||||||
model.
|
|
||||||
|
|
||||||
The process involves:
|
The process involves:
|
||||||
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
1. Converting the `xarray.Dataset` into a list of intermediate `RawPrediction`
|
||||||
@ -33,7 +32,7 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import GeometryBuilder, RawPrediction
|
from batdetect2.postprocess.types import GeometryDecoder, RawPrediction
|
||||||
from batdetect2.targets.classes import SoundEventDecoder
|
from batdetect2.targets.classes import SoundEventDecoder
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
from batdetect2.utils.arrays import iterate_over_array
|
||||||
|
|
||||||
@ -55,7 +54,7 @@ decoding.
|
|||||||
|
|
||||||
def convert_xr_dataset_to_raw_prediction(
|
def convert_xr_dataset_to_raw_prediction(
|
||||||
detection_dataset: xr.Dataset,
|
detection_dataset: xr.Dataset,
|
||||||
geometry_builder: GeometryBuilder,
|
geometry_decoder: GeometryDecoder,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
"""Convert an xarray.Dataset of detections to RawPrediction objects.
|
||||||
|
|
||||||
@ -72,7 +71,7 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
output by `extract_detection_xr_dataset`. Expected variables include
|
output by `extract_detection_xr_dataset`. Expected variables include
|
||||||
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
'scores' (with time/freq coords), 'dimensions', 'classes', 'features'.
|
||||||
Must have a 'detection' dimension.
|
Must have a 'detection' dimension.
|
||||||
geometry_builder : GeometryBuilder
|
geometry_decoder : GeometryDecoder
|
||||||
A function that takes a position tuple `(time, freq)` and a NumPy array
|
A function that takes a position tuple `(time, freq)` and a NumPy array
|
||||||
of dimensions, and returns the corresponding reconstructed
|
of dimensions, and returns the corresponding reconstructed
|
||||||
`soundevent.data.Geometry`.
|
`soundevent.data.Geometry`.
|
||||||
@ -96,14 +95,20 @@ def convert_xr_dataset_to_raw_prediction(
|
|||||||
for det_num in range(detection_dataset.sizes["detection"]):
|
for det_num in range(detection_dataset.sizes["detection"]):
|
||||||
det_info = detection_dataset.sel(detection=det_num)
|
det_info = detection_dataset.sel(detection=det_num)
|
||||||
|
|
||||||
geom = geometry_builder(
|
# TODO: Maybe clean this up
|
||||||
|
highest_scoring_class = det_info.coords["category"][
|
||||||
|
det_info["classes"].argmax()
|
||||||
|
].item()
|
||||||
|
|
||||||
|
geom = geometry_decoder(
|
||||||
(det_info.time, det_info.frequency),
|
(det_info.time, det_info.frequency),
|
||||||
det_info.dimensions,
|
det_info.dimensions,
|
||||||
|
class_name=highest_scoring_class,
|
||||||
)
|
)
|
||||||
|
|
||||||
detections.append(
|
detections.append(
|
||||||
RawPrediction(
|
RawPrediction(
|
||||||
detection_score=det_info.score,
|
detection_score=det_info.scores,
|
||||||
geometry=geom,
|
geometry=geom,
|
||||||
class_scores=det_info.classes,
|
class_scores=det_info.classes,
|
||||||
features=det_info.features,
|
features=det_info.features,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
"""Extracts candidate detection points from a model output heatmap.
|
"""Extracts candidate detection points from a model output heatmap.
|
||||||
|
|
||||||
This module implements a specific step within the BatDetect2 postprocessing
|
This module implements Step 3 within the BatDetect2 postprocessing
|
||||||
pipeline. Its primary function is to identify potential sound event locations
|
pipeline. Its primary function is to identify potential sound event locations
|
||||||
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
by finding peaks (local maxima or high-scoring points) in the detection heatmap
|
||||||
produced by the neural network (usually after Non-Maximum Suppression and
|
produced by the neural network (usually after Non-Maximum Suppression and
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
"""Extracts associated data for detected points from model output arrays.
|
"""Extracts associated data for detected points from model output arrays.
|
||||||
|
|
||||||
This module implements a key step (Step 4) in the BatDetect2 postprocessing
|
This module implements a Step 4 in the BatDetect2 postprocessing pipeline.
|
||||||
pipeline. After candidate detection points (time, frequency, score) have been
|
After candidate detection points (time, frequency, score) have been identified,
|
||||||
identified, this module extracts the corresponding values from other raw model
|
this module extracts the corresponding values from other raw model output
|
||||||
output arrays, such as:
|
arrays, such as:
|
||||||
|
|
||||||
- Predicted bounding box sizes (width, height).
|
- Predicted bounding box sizes (width, height).
|
||||||
- Class probability scores for each defined target class.
|
- Class probability scores for each defined target class.
|
||||||
|
@ -11,30 +11,37 @@ modularity and consistent interaction between different parts of the BatDetect2
|
|||||||
system that deal with model predictions.
|
system that deal with model predictions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Callable, List, NamedTuple, Protocol
|
from typing import List, NamedTuple, Optional, Protocol
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
|
from batdetect2.targets.types import Position, Size
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"RawPrediction",
|
"RawPrediction",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"GeometryBuilder",
|
"GeometryDecoder",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
GeometryBuilder = Callable[[tuple[float, float], np.ndarray], data.Geometry]
|
# TODO: update the docstring
|
||||||
"""Type alias for a function that recovers geometry from position and size.
|
class GeometryDecoder(Protocol):
|
||||||
|
"""Type alias for a function that recovers geometry from position and size.
|
||||||
|
|
||||||
This callable takes:
|
This callable takes:
|
||||||
1. A position tuple `(time, frequency)`.
|
1. A position tuple `(time, frequency)`.
|
||||||
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
2. A NumPy array of size dimensions (e.g., `[width, height]`).
|
||||||
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
3. Optionally a class name of the highest scoring class. This is to accomodate
|
||||||
`BoundingBox`).
|
different ways of decoding geometry that depend on the predicted class.
|
||||||
"""
|
It should return the reconstructed `soundevent.data.Geometry` (typically a
|
||||||
|
`BoundingBox`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, position: Position, size: Size, class_name: Optional[str] = None
|
||||||
|
) -> data.Geometry: ...
|
||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
class RawPrediction(NamedTuple):
|
||||||
|
@ -23,6 +23,7 @@ object is via the `build_targets` or `load_targets` functions.
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -49,7 +50,8 @@ from batdetect2.targets.filtering import (
|
|||||||
load_filter_from_config,
|
load_filter_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
BBoxAnchorMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
|
ROIMapperConfig,
|
||||||
ROITargetMapper,
|
ROITargetMapper,
|
||||||
build_roi_mapper,
|
build_roi_mapper,
|
||||||
)
|
)
|
||||||
@ -58,11 +60,11 @@ from batdetect2.targets.terms import (
|
|||||||
TermInfo,
|
TermInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
call_type,
|
call_type,
|
||||||
|
default_term_registry,
|
||||||
get_tag_from_info,
|
get_tag_from_info,
|
||||||
get_term_from_key,
|
get_term_from_key,
|
||||||
individual,
|
individual,
|
||||||
register_term,
|
register_term,
|
||||||
default_term_registry,
|
|
||||||
)
|
)
|
||||||
from batdetect2.targets.transform import (
|
from batdetect2.targets.transform import (
|
||||||
DerivationRegistry,
|
DerivationRegistry,
|
||||||
@ -87,7 +89,7 @@ __all__ = [
|
|||||||
"FilterConfig",
|
"FilterConfig",
|
||||||
"FilterRule",
|
"FilterRule",
|
||||||
"MapValueRule",
|
"MapValueRule",
|
||||||
"BBoxAnchorMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"ReplaceRule",
|
"ReplaceRule",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
@ -160,7 +162,7 @@ class TargetConfig(BaseConfig):
|
|||||||
classes: ClassesConfig = Field(
|
classes: ClassesConfig = Field(
|
||||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||||
)
|
)
|
||||||
roi: Optional[BBoxAnchorMapperConfig] = None
|
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_target_config(
|
def load_target_config(
|
||||||
@ -239,6 +241,7 @@ class Targets(TargetProtocol):
|
|||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
filter_fn: Optional[SoundEventFilter] = None,
|
filter_fn: Optional[SoundEventFilter] = None,
|
||||||
transform_fn: Optional[SoundEventTransformation] = None,
|
transform_fn: Optional[SoundEventTransformation] = None,
|
||||||
|
roi_mapper_overrides: Optional[dict[str, ROITargetMapper]] = None,
|
||||||
):
|
):
|
||||||
"""Initialize the Targets object.
|
"""Initialize the Targets object.
|
||||||
|
|
||||||
@ -271,6 +274,16 @@ class Targets(TargetProtocol):
|
|||||||
self._encode_fn = encode_fn
|
self._encode_fn = encode_fn
|
||||||
self._decode_fn = decode_fn
|
self._decode_fn = decode_fn
|
||||||
self._transform_fn = transform_fn
|
self._transform_fn = transform_fn
|
||||||
|
self._roi_mapper_overrides = roi_mapper_overrides or {}
|
||||||
|
|
||||||
|
for class_name in self._roi_mapper_overrides:
|
||||||
|
if class_name not in self.class_names:
|
||||||
|
# TODO: improve this warning
|
||||||
|
logger.warning(
|
||||||
|
"The ROI mapper overrides contains a class ({class_name}) "
|
||||||
|
"not present in the class names.",
|
||||||
|
class_name=class_name,
|
||||||
|
)
|
||||||
|
|
||||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||||
"""Apply the configured filter to a sound event annotation.
|
"""Apply the configured filter to a sound event annotation.
|
||||||
@ -375,9 +388,21 @@ class Targets(TargetProtocol):
|
|||||||
ValueError
|
ValueError
|
||||||
If the annotation lacks geometry.
|
If the annotation lacks geometry.
|
||||||
"""
|
"""
|
||||||
|
class_name = self.encode_class(sound_event)
|
||||||
|
|
||||||
|
if class_name in self._roi_mapper_overrides:
|
||||||
|
return self._roi_mapper_overrides[class_name].encode(
|
||||||
|
sound_event.sound_event
|
||||||
|
)
|
||||||
|
|
||||||
return self._roi_mapper.encode(sound_event.sound_event)
|
return self._roi_mapper.encode(sound_event.sound_event)
|
||||||
|
|
||||||
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
|
def decode_roi(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
class_name: Optional[str] = None,
|
||||||
|
) -> data.Geometry:
|
||||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||||
|
|
||||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||||
@ -397,6 +422,13 @@ class Targets(TargetProtocol):
|
|||||||
data.Geometry
|
data.Geometry
|
||||||
The reconstructed geometry (typically `BoundingBox`).
|
The reconstructed geometry (typically `BoundingBox`).
|
||||||
"""
|
"""
|
||||||
|
print(class_name)
|
||||||
|
if class_name in self._roi_mapper_overrides:
|
||||||
|
return self._roi_mapper_overrides[class_name].decode(
|
||||||
|
position,
|
||||||
|
size,
|
||||||
|
)
|
||||||
|
|
||||||
return self._roi_mapper.decode(position, size)
|
return self._roi_mapper.decode(position, size)
|
||||||
|
|
||||||
|
|
||||||
@ -452,10 +484,12 @@ DEFAULT_CLASSES = [
|
|||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||||
name="nyclei",
|
name="nyclei",
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
),
|
),
|
||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||||
name="rhifer",
|
name="rhifer",
|
||||||
|
roi=AnchorBBoxMapperConfig(anchor="top-left"),
|
||||||
),
|
),
|
||||||
TargetClass(
|
TargetClass(
|
||||||
tags=[TagInfo(value="Plecotus auritus")],
|
tags=[TagInfo(value="Plecotus auritus")],
|
||||||
@ -496,6 +530,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
]
|
]
|
||||||
),
|
),
|
||||||
classes=DEFAULT_CLASSES_CONFIG,
|
classes=DEFAULT_CLASSES_CONFIG,
|
||||||
|
roi=AnchorBBoxMapperConfig(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -565,12 +600,17 @@ def build_targets(
|
|||||||
if config.transforms
|
if config.transforms
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
roi_mapper = build_roi_mapper(config.roi or BBoxAnchorMapperConfig())
|
roi_mapper = build_roi_mapper(config.roi)
|
||||||
class_names = get_class_names_from_config(config.classes)
|
class_names = get_class_names_from_config(config.classes)
|
||||||
generic_class_tags = build_generic_class_tags(
|
generic_class_tags = build_generic_class_tags(
|
||||||
config.classes,
|
config.classes,
|
||||||
term_registry=term_registry,
|
term_registry=term_registry,
|
||||||
)
|
)
|
||||||
|
roi_overrides = {
|
||||||
|
class_config.name: build_roi_mapper(class_config.roi)
|
||||||
|
for class_config in config.classes.classes
|
||||||
|
if class_config.roi is not None
|
||||||
|
}
|
||||||
|
|
||||||
return Targets(
|
return Targets(
|
||||||
filter_fn=filter_fn,
|
filter_fn=filter_fn,
|
||||||
@ -580,6 +620,7 @@ def build_targets(
|
|||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
generic_class_tags=generic_class_tags,
|
generic_class_tags=generic_class_tags,
|
||||||
transform_fn=transform_fn,
|
transform_fn=transform_fn,
|
||||||
|
roi_mapper_overrides=roi_overrides,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -6,49 +6,26 @@ from pydantic import Field, field_validator
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.targets.rois import ROIMapperConfig
|
||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
GENERIC_CLASS_KEY,
|
GENERIC_CLASS_KEY,
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
get_tag_from_info,
|
|
||||||
default_term_registry,
|
default_term_registry,
|
||||||
|
get_tag_from_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SoundEventEncoder",
|
|
||||||
"SoundEventDecoder",
|
|
||||||
"TargetClass",
|
|
||||||
"ClassesConfig",
|
|
||||||
"load_classes_config",
|
|
||||||
"load_encoder_from_config",
|
|
||||||
"load_decoder_from_config",
|
|
||||||
"build_sound_event_encoder",
|
|
||||||
"build_sound_event_decoder",
|
|
||||||
"build_generic_class_tags",
|
|
||||||
"get_class_names_from_config",
|
|
||||||
"DEFAULT_SPECIES_LIST",
|
"DEFAULT_SPECIES_LIST",
|
||||||
"PositionMethod",
|
"build_generic_class_tags",
|
||||||
"CornerPosition",
|
"build_sound_event_decoder",
|
||||||
"SizeMethod",
|
"build_sound_event_encoder",
|
||||||
"BoundingBoxSize",
|
"get_class_names_from_config",
|
||||||
|
"load_classes_config",
|
||||||
|
"load_decoder_from_config",
|
||||||
|
"load_encoder_from_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
class PositionMethod(BaseConfig):
|
|
||||||
"""Base class for defining how to select a position from a geometry."""
|
|
||||||
method_type: str
|
|
||||||
|
|
||||||
class CornerPosition(PositionMethod):
|
|
||||||
"""Selects a position based on a corner or center of the bounding box."""
|
|
||||||
method_type: Literal["corner"] = "corner"
|
|
||||||
corner: Literal["upper_left", "lower_left", "center"] = "lower_left"
|
|
||||||
|
|
||||||
class SizeMethod(BaseConfig):
|
|
||||||
"""Base class for defining how to select a size from a geometry."""
|
|
||||||
method_type: str
|
|
||||||
|
|
||||||
class BoundingBoxSize(SizeMethod):
|
|
||||||
"""Uses the width and height of the bounding box as the size."""
|
|
||||||
method_type: Literal["bounding_box"] = "bounding_box"
|
|
||||||
|
|
||||||
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
SoundEventEncoder = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
"""Type alias for a sound event class encoder function.
|
"""Type alias for a sound event class encoder function.
|
||||||
@ -134,8 +111,7 @@ class TargetClass(BaseConfig):
|
|||||||
tags: List[TagInfo] = Field(min_length=1)
|
tags: List[TagInfo] = Field(min_length=1)
|
||||||
match_type: Literal["all", "any"] = Field(default="all")
|
match_type: Literal["all", "any"] = Field(default="all")
|
||||||
output_tags: Optional[List[TagInfo]] = None
|
output_tags: Optional[List[TagInfo]] = None
|
||||||
position_method: PositionMethod = Field(default_factory=lambda: CornerPosition(corner="lower_left"))
|
roi: Optional[ROIMapperConfig] = None
|
||||||
size_method: SizeMethod = Field(default_factory=BoundingBoxSize)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_classes() -> List[TargetClass]:
|
def _get_default_classes() -> List[TargetClass]:
|
||||||
@ -258,7 +234,7 @@ class ClassesConfig(BaseConfig):
|
|||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
def _is_target_class(
|
def is_target_class(
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
tags: Set[data.Tag],
|
tags: Set[data.Tag],
|
||||||
match_all: bool = True,
|
match_all: bool = True,
|
||||||
@ -373,7 +349,7 @@ def build_sound_event_encoder(
|
|||||||
(
|
(
|
||||||
class_info.name,
|
class_info.name,
|
||||||
partial(
|
partial(
|
||||||
_is_target_class,
|
is_target_class,
|
||||||
tags={
|
tags={
|
||||||
get_tag_from_info(tag_info, term_registry=term_registry)
|
get_tag_from_info(tag_info, term_registry=term_registry)
|
||||||
for tag_info in class_info.tags
|
for tag_info in class_info.tags
|
||||||
|
@ -34,7 +34,7 @@ from batdetect2.targets.types import Position, Size
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"Anchor",
|
"Anchor",
|
||||||
"AnchorBBoxMapper",
|
"AnchorBBoxMapper",
|
||||||
"BBoxAnchorMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"DEFAULT_ANCHOR",
|
"DEFAULT_ANCHOR",
|
||||||
"DEFAULT_FREQUENCY_SCALE",
|
"DEFAULT_FREQUENCY_SCALE",
|
||||||
"DEFAULT_TIME_SCALE",
|
"DEFAULT_TIME_SCALE",
|
||||||
@ -148,7 +148,7 @@ class ROITargetMapper(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class BBoxAnchorMapperConfig(BaseConfig):
|
class AnchorBBoxMapperConfig(BaseConfig):
|
||||||
"""Configuration for `AnchorBBoxMapper`.
|
"""Configuration for `AnchorBBoxMapper`.
|
||||||
|
|
||||||
Defines parameters for converting ROIs into targets using a fixed anchor
|
Defines parameters for converting ROIs into targets using a fixed anchor
|
||||||
@ -470,7 +470,7 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
|
|||||||
|
|
||||||
|
|
||||||
ROIMapperConfig = Annotated[
|
ROIMapperConfig = Annotated[
|
||||||
Union[BBoxAnchorMapperConfig, PeakEnergyBBoxMapperConfig],
|
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
"""A discriminated union of all supported ROI mapper configurations.
|
"""A discriminated union of all supported ROI mapper configurations.
|
||||||
@ -480,7 +480,9 @@ implementations by using the `name` field as a discriminator.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
def build_roi_mapper(
|
||||||
|
config: Optional[ROIMapperConfig] = None,
|
||||||
|
) -> ROITargetMapper:
|
||||||
"""Factory function to create an ROITargetMapper from a config object.
|
"""Factory function to create an ROITargetMapper from a config object.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -498,6 +500,8 @@ def build_roi_mapper(config: ROIMapperConfig) -> ROITargetMapper:
|
|||||||
NotImplementedError
|
NotImplementedError
|
||||||
If the `name` in the config does not correspond to a known mapper.
|
If the `name` in the config does not correspond to a known mapper.
|
||||||
"""
|
"""
|
||||||
|
config = config or AnchorBBoxMapperConfig()
|
||||||
|
|
||||||
if config.name == "anchor_bbox":
|
if config.name == "anchor_bbox":
|
||||||
return AnchorBBoxMapper(
|
return AnchorBBoxMapper(
|
||||||
anchor=config.anchor,
|
anchor=config.anchor,
|
||||||
|
@ -235,6 +235,7 @@ default_term_registry = TermRegistry(
|
|||||||
[
|
[
|
||||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
("event", call_type),
|
("event", call_type),
|
||||||
|
("species", terms.scientific_name),
|
||||||
("individual", individual),
|
("individual", individual),
|
||||||
("data_source", data_source),
|
("data_source", data_source),
|
||||||
(GENERIC_CLASS_KEY, generic_class),
|
(GENERIC_CLASS_KEY, generic_class),
|
||||||
|
@ -181,7 +181,13 @@ class TargetProtocol(Protocol):
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def decode_roi(self, position: Position, size: Size) -> data.Geometry:
|
# TODO: Update docstrings
|
||||||
|
def decode_roi(
|
||||||
|
self,
|
||||||
|
position: Position,
|
||||||
|
size: Size,
|
||||||
|
class_name: Optional[str] = None,
|
||||||
|
) -> data.Geometry:
|
||||||
"""Recover the ROI geometry from a position and dimensions.
|
"""Recover the ROI geometry from a position and dimensions.
|
||||||
|
|
||||||
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
Performs the inverse mapping of `get_position` and `get_size`. It takes
|
||||||
@ -195,6 +201,8 @@ class TargetProtocol(Protocol):
|
|||||||
dims : np.ndarray
|
dims : np.ndarray
|
||||||
The NumPy array containing the dimensions (e.g., predicted
|
The NumPy array containing the dimensions (e.g., predicted
|
||||||
by the model), corresponding to the order in `dimension_names`.
|
by the model), corresponding to the order in `dimension_names`.
|
||||||
|
class_name: str
|
||||||
|
class
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional
|
from typing import Callable, List, Optional
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
@ -447,3 +448,15 @@ def example_annotations(
|
|||||||
annotations = load_dataset(example_dataset)
|
annotations = load_dataset(example_dataset)
|
||||||
assert len(annotations) == 3
|
assert len(annotations) == 3
|
||||||
return annotations
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
||||||
|
"""Create a temporary YAML file with the given content."""
|
||||||
|
|
||||||
|
def factory(content: str) -> Path:
|
||||||
|
temp_file = tmp_path / f"{uuid4()}.yaml"
|
||||||
|
temp_file.write_text(content)
|
||||||
|
return temp_file
|
||||||
|
|
||||||
|
return factory
|
||||||
|
@ -12,11 +12,11 @@ from batdetect2.targets.classes import (
|
|||||||
TargetClass,
|
TargetClass,
|
||||||
_get_default_class_name,
|
_get_default_class_name,
|
||||||
_get_default_classes,
|
_get_default_classes,
|
||||||
_is_target_class,
|
|
||||||
build_generic_class_tags,
|
build_generic_class_tags,
|
||||||
build_sound_event_decoder,
|
build_sound_event_decoder,
|
||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
|
is_target_class,
|
||||||
load_classes_config,
|
load_classes_config,
|
||||||
load_decoder_from_config,
|
load_decoder_from_config,
|
||||||
load_encoder_from_config,
|
load_encoder_from_config,
|
||||||
@ -145,7 +145,7 @@ def test_is_target_class_match_all(
|
|||||||
),
|
),
|
||||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||||
}
|
}
|
||||||
assert _is_target_class(sample_annotation, tags, match_all=True) is True
|
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||||
|
|
||||||
tags = {
|
tags = {
|
||||||
data.Tag(
|
data.Tag(
|
||||||
@ -153,14 +153,14 @@ def test_is_target_class_match_all(
|
|||||||
value="Pipistrellus pipistrellus",
|
value="Pipistrellus pipistrellus",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
assert _is_target_class(sample_annotation, tags, match_all=True) is True
|
assert is_target_class(sample_annotation, tags, match_all=True) is True
|
||||||
|
|
||||||
tags = {
|
tags = {
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
assert _is_target_class(sample_annotation, tags, match_all=True) is False
|
assert is_target_class(sample_annotation, tags, match_all=True) is False
|
||||||
|
|
||||||
|
|
||||||
def test_is_target_class_match_any(
|
def test_is_target_class_match_any(
|
||||||
@ -174,7 +174,7 @@ def test_is_target_class_match_any(
|
|||||||
),
|
),
|
||||||
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
data.Tag(term=sample_term_registry["quality"], value="Good"),
|
||||||
}
|
}
|
||||||
assert _is_target_class(sample_annotation, tags, match_all=False) is True
|
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||||
|
|
||||||
tags = {
|
tags = {
|
||||||
data.Tag(
|
data.Tag(
|
||||||
@ -182,14 +182,14 @@ def test_is_target_class_match_any(
|
|||||||
value="Pipistrellus pipistrellus",
|
value="Pipistrellus pipistrellus",
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
assert _is_target_class(sample_annotation, tags, match_all=False) is True
|
assert is_target_class(sample_annotation, tags, match_all=False) is True
|
||||||
|
|
||||||
tags = {
|
tags = {
|
||||||
data.Tag(
|
data.Tag(
|
||||||
term=sample_term_registry["species"], value="Myotis daubentonii"
|
term=sample_term_registry["species"], value="Myotis daubentonii"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
assert _is_target_class(sample_annotation, tags, match_all=False) is False
|
assert is_target_class(sample_annotation, tags, match_all=False) is False
|
||||||
|
|
||||||
|
|
||||||
def test_get_class_names_from_config():
|
def test_get_class_names_from_config():
|
||||||
|
@ -11,7 +11,7 @@ from batdetect2.targets.rois import (
|
|||||||
SIZE_HEIGHT,
|
SIZE_HEIGHT,
|
||||||
SIZE_WIDTH,
|
SIZE_WIDTH,
|
||||||
AnchorBBoxMapper,
|
AnchorBBoxMapper,
|
||||||
BBoxAnchorMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
PeakEnergyBBoxMapper,
|
PeakEnergyBBoxMapper,
|
||||||
PeakEnergyBBoxMapperConfig,
|
PeakEnergyBBoxMapperConfig,
|
||||||
_build_bounding_box,
|
_build_bounding_box,
|
||||||
@ -243,7 +243,7 @@ def test_anchor_bbox_mapper_decode_invalid_size_shape(default_mapper):
|
|||||||
|
|
||||||
def test_build_roi_mapper():
|
def test_build_roi_mapper():
|
||||||
"""Test build_roi_mapper creates a configured BBoxEncoder."""
|
"""Test build_roi_mapper creates a configured BBoxEncoder."""
|
||||||
config = BBoxAnchorMapperConfig(
|
config = AnchorBBoxMapperConfig(
|
||||||
anchor="top-right", time_scale=2.0, frequency_scale=20.0
|
anchor="top-right", time_scale=2.0, frequency_scale=20.0
|
||||||
)
|
)
|
||||||
mapper = build_roi_mapper(config)
|
mapper = build_roi_mapper(config)
|
||||||
@ -571,7 +571,7 @@ def test_peak_energy_bbox_mapper_encode_decode_roundtrip(generate_whistle):
|
|||||||
|
|
||||||
def test_build_roi_mapper_for_anchor_bbox():
|
def test_build_roi_mapper_for_anchor_bbox():
|
||||||
# Given
|
# Given
|
||||||
config = BBoxAnchorMapperConfig(
|
config = AnchorBBoxMapperConfig(
|
||||||
anchor="center",
|
anchor="center",
|
||||||
time_scale=123.0,
|
time_scale=123.0,
|
||||||
frequency_scale=456.0,
|
frequency_scale=456.0,
|
||||||
|
117
tests/test_targets/test_targets.py
Normal file
117
tests/test_targets/test_targets.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
|
from batdetect2.targets.terms import get_term_from_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_override_default_roi_mapper_per_class(
|
||||||
|
create_temp_yaml: Callable[..., Path],
|
||||||
|
recording: data.Recording,
|
||||||
|
sample_term_registry,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
"""
|
||||||
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
|
config = load_target_config(config_path)
|
||||||
|
targets = build_targets(config, term_registry=sample_term_registry)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
|
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
|
||||||
|
se2 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
|
||||||
|
(time1, freq1), _ = targets.encode_roi(se1)
|
||||||
|
(time2, freq2), _ = targets.encode_roi(se2)
|
||||||
|
|
||||||
|
assert time1 == time2 == 0.1
|
||||||
|
assert freq1 == 12_000
|
||||||
|
assert freq2 == 18_000
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: rename this test function
|
||||||
|
def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||||
|
create_temp_yaml,
|
||||||
|
sample_term_registry,
|
||||||
|
recording,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
"""
|
||||||
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
|
config = load_target_config(config_path)
|
||||||
|
targets = build_targets(config, term_registry=sample_term_registry)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
|
||||||
|
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
|
||||||
|
se2 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
|
||||||
|
position1, size1 = targets.encode_roi(se1)
|
||||||
|
position2, size2 = targets.encode_roi(se2)
|
||||||
|
|
||||||
|
class_name1 = targets.encode_class(se1)
|
||||||
|
class_name2 = targets.encode_class(se2)
|
||||||
|
|
||||||
|
recovered1 = targets.decode_roi(position1, size1, class_name=class_name1)
|
||||||
|
recovered2 = targets.decode_roi(position2, size2, class_name=class_name2)
|
||||||
|
|
||||||
|
assert recovered1 == geometry
|
||||||
|
assert recovered2 == geometry
|
@ -5,7 +5,7 @@ import xarray as xr
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||||
from batdetect2.targets.rois import BBoxAnchorMapperConfig
|
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||||
from batdetect2.targets.terms import TagInfo, TermRegistry
|
from batdetect2.targets.terms import TagInfo, TermRegistry
|
||||||
from batdetect2.train.labels import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
@ -85,7 +85,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
|||||||
):
|
):
|
||||||
config = sample_target_config.model_copy(
|
config = sample_target_config.model_copy(
|
||||||
update=dict(
|
update=dict(
|
||||||
roi=BBoxAnchorMapperConfig(
|
roi=AnchorBBoxMapperConfig(
|
||||||
time_scale=1,
|
time_scale=1,
|
||||||
frequency_scale=1,
|
frequency_scale=1,
|
||||||
)
|
)
|
||||||
|
271
tests/test_train/test_preprocessing.py
Normal file
271
tests/test_train/test_preprocessing.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.models.types import ModelOutput
|
||||||
|
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||||
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
|
from batdetect2.targets.terms import get_term_from_key
|
||||||
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def build_from_config(
|
||||||
|
create_temp_yaml,
|
||||||
|
sample_term_registry,
|
||||||
|
):
|
||||||
|
def build(yaml_content):
|
||||||
|
config_path = create_temp_yaml(yaml_content)
|
||||||
|
|
||||||
|
targets_config = load_target_config(config_path, field="targets")
|
||||||
|
preprocessing_config = load_preprocessing_config(
|
||||||
|
config_path,
|
||||||
|
field="preprocessing",
|
||||||
|
)
|
||||||
|
labels_config = load_label_config(config_path, field="labels")
|
||||||
|
postprocessing_config = load_postprocess_config(
|
||||||
|
config_path,
|
||||||
|
field="postprocessing",
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = build_targets(
|
||||||
|
targets_config, term_registry=sample_term_registry
|
||||||
|
)
|
||||||
|
preprocessor = build_preprocessor(preprocessing_config)
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
targets=targets,
|
||||||
|
config=labels_config,
|
||||||
|
)
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
targets,
|
||||||
|
config=postprocessing_config,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
return targets, preprocessor, labeller, postprocessor
|
||||||
|
|
||||||
|
return build
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: better name
|
||||||
|
def test_generated_train_example_has_expected_outputs(
|
||||||
|
build_from_config,
|
||||||
|
sample_term_registry,
|
||||||
|
recording,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
labels:
|
||||||
|
targets:
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
preprocessing:
|
||||||
|
postprocessing:
|
||||||
|
"""
|
||||||
|
_, preprocessor, labeller, _ = build_from_config(yaml_content)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||||
|
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
clip_annotation = data.ClipAnnotation(
|
||||||
|
clip=data.Clip(start_time=0, end_time=0.5, recording=recording),
|
||||||
|
sound_events=[se1],
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||||
|
|
||||||
|
assert isinstance(encoded, xr.Dataset)
|
||||||
|
assert "audio" in encoded
|
||||||
|
assert "spectrogram" in encoded
|
||||||
|
assert "detection" in encoded
|
||||||
|
assert "class" in encoded
|
||||||
|
assert "size" in encoded
|
||||||
|
|
||||||
|
spec_shape = encoded["spectrogram"].shape
|
||||||
|
assert len(spec_shape) == 2
|
||||||
|
|
||||||
|
height, width = spec_shape
|
||||||
|
assert encoded["detection"].shape == (height, width)
|
||||||
|
assert encoded["class"].shape == (1, height, width)
|
||||||
|
assert encoded["size"].shape == (2, height, width)
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoding_decoding_roundtrip_recovers_object(
|
||||||
|
build_from_config,
|
||||||
|
sample_term_registry,
|
||||||
|
recording,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
labels:
|
||||||
|
targets:
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
preprocessing:
|
||||||
|
"""
|
||||||
|
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||||
|
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(term=species, value="Pipistrellus pipistrellus")],
|
||||||
|
)
|
||||||
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
|
|
||||||
|
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||||
|
predictions = postprocessor.get_predictions(
|
||||||
|
ModelOutput(
|
||||||
|
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
||||||
|
size_preds=torch.tensor([encoded["size"].data]),
|
||||||
|
class_probs=torch.tensor([encoded["class"].data]),
|
||||||
|
features=torch.tensor([[encoded["spectrogram"].data]]),
|
||||||
|
),
|
||||||
|
[clip],
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert isinstance(predictions, data.ClipPrediction)
|
||||||
|
assert len(predictions.sound_events) == 1
|
||||||
|
|
||||||
|
recovered = predictions.sound_events[0]
|
||||||
|
assert recovered.sound_event.geometry is not None
|
||||||
|
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||||
|
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||||
|
recovered.sound_event.geometry.coordinates
|
||||||
|
)
|
||||||
|
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||||
|
geometry.coordinates
|
||||||
|
)
|
||||||
|
|
||||||
|
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||||
|
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||||
|
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||||
|
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||||
|
|
||||||
|
assert len(recovered.tags) == 2
|
||||||
|
|
||||||
|
predicted_species_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||||
|
)
|
||||||
|
assert predicted_species_tag is not None
|
||||||
|
assert predicted_species_tag.score == 1
|
||||||
|
assert predicted_species_tag.tag.value == "Pipistrellus pipistrellus"
|
||||||
|
|
||||||
|
predicted_order_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||||
|
)
|
||||||
|
assert predicted_order_tag is not None
|
||||||
|
assert predicted_order_tag.score == 1
|
||||||
|
assert predicted_order_tag.tag.value == "Chiroptera"
|
||||||
|
|
||||||
|
|
||||||
|
def test_encoding_decoding_roundtrip_recovers_object_with_roi_override(
|
||||||
|
build_from_config,
|
||||||
|
sample_term_registry,
|
||||||
|
recording,
|
||||||
|
):
|
||||||
|
yaml_content = """
|
||||||
|
labels:
|
||||||
|
targets:
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: bottom-left
|
||||||
|
classes:
|
||||||
|
classes:
|
||||||
|
- name: pippip
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Pipistrellus pipistrellus
|
||||||
|
- name: myomyo
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
roi:
|
||||||
|
name: anchor_bbox
|
||||||
|
anchor: top-left
|
||||||
|
generic_class:
|
||||||
|
- key: order
|
||||||
|
value: Chiroptera
|
||||||
|
preprocessing:
|
||||||
|
"""
|
||||||
|
_, preprocessor, labeller, postprocessor = build_from_config(yaml_content)
|
||||||
|
|
||||||
|
geometry = data.BoundingBox(coordinates=[0.1, 40_000, 0.2, 80_000])
|
||||||
|
species = get_term_from_key("species", term_registry=sample_term_registry)
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(recording=recording, geometry=geometry),
|
||||||
|
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
clip = data.Clip(start_time=0, end_time=0.5, recording=recording)
|
||||||
|
clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1])
|
||||||
|
|
||||||
|
encoded = generate_train_example(clip_annotation, preprocessor, labeller)
|
||||||
|
predictions = postprocessor.get_predictions(
|
||||||
|
ModelOutput(
|
||||||
|
detection_probs=torch.tensor([[encoded["detection"].data]]),
|
||||||
|
size_preds=torch.tensor([encoded["size"].data]),
|
||||||
|
class_probs=torch.tensor([encoded["class"].data]),
|
||||||
|
features=torch.tensor([[encoded["spectrogram"].data]]),
|
||||||
|
),
|
||||||
|
[clip],
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert isinstance(predictions, data.ClipPrediction)
|
||||||
|
assert len(predictions.sound_events) == 1
|
||||||
|
|
||||||
|
recovered = predictions.sound_events[0]
|
||||||
|
assert recovered.sound_event.geometry is not None
|
||||||
|
assert isinstance(recovered.sound_event.geometry, data.BoundingBox)
|
||||||
|
start_time_rec, low_freq_rec, end_time_rec, high_freq_rec = (
|
||||||
|
recovered.sound_event.geometry.coordinates
|
||||||
|
)
|
||||||
|
start_time_or, low_freq_or, end_time_or, high_freq_or = (
|
||||||
|
geometry.coordinates
|
||||||
|
)
|
||||||
|
|
||||||
|
assert start_time_rec == pytest.approx(start_time_or, abs=0.01)
|
||||||
|
assert low_freq_rec == pytest.approx(low_freq_or, abs=1_000)
|
||||||
|
assert end_time_rec == pytest.approx(end_time_or, abs=0.01)
|
||||||
|
assert high_freq_rec == pytest.approx(high_freq_or, abs=1_000)
|
||||||
|
|
||||||
|
assert len(recovered.tags) == 2
|
||||||
|
|
||||||
|
predicted_species_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term == species), None
|
||||||
|
)
|
||||||
|
assert predicted_species_tag is not None
|
||||||
|
assert predicted_species_tag.score == 1
|
||||||
|
assert predicted_species_tag.tag.value == "Myotis myotis"
|
||||||
|
|
||||||
|
predicted_order_tag = next(
|
||||||
|
iter(t for t in recovered.tags if t.tag.term.label == "order"), None
|
||||||
|
)
|
||||||
|
assert predicted_order_tag is not None
|
||||||
|
assert predicted_order_tag.score == 1
|
||||||
|
assert predicted_order_tag.tag.value == "Chiroptera"
|
Loading…
Reference in New Issue
Block a user