Small fixes

This commit is contained in:
mbsantiago 2025-09-08 18:35:02 +01:00
parent d8d2e5a2c2
commit c73984b213
4 changed files with 74 additions and 89 deletions

View File

@ -1,7 +1,6 @@
from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.preprocess import preprocess
from batdetect2.cli.train import train_command
__all__ = [
@ -9,7 +8,6 @@ __all__ = [
"detect",
"data",
"train_command",
"preprocess",
]

View File

@ -20,7 +20,7 @@ selecting and configuring the desired mapper. This module separates the
*geometric* aspect of target definition from *semantic* classification.
"""
from typing import Annotated, List, Literal, Optional, Protocol, Tuple, Union
from typing import Annotated, Literal, Optional, Tuple, Union
import numpy as np
from pydantic import Field
@ -30,7 +30,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.targets import Position, Size
from batdetect2.typing.targets import Position, ROITargetMapper, Size
from batdetect2.utils.arrays import spec_to_xarray
__all__ = [
@ -83,73 +83,6 @@ DEFAULT_ANCHOR = "bottom-left"
"""Default reference position within the geometry ('bottom-left' corner)."""
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the `encode` and `decode` methods required for converting a
`soundevent.data.SoundEvent` into a target representation (a reference
position and a size vector) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions in the `Size` array
returned by `encode` and expected by `decode`.
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent's geometry into a position and size.
Parameters
----------
sound_event : data.SoundEvent
The input sound event, which must have a geometry attribute.
Returns
-------
Tuple[Position, Size]
A tuple containing:
- The reference position as (time, frequency) coordinates.
- A NumPy array with the calculated size dimensions.
Raises
------
ValueError
If the sound event does not have a geometry.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and size
dimensions and reconstructs a geometric representation.
Parameters
----------
position : Position
The reference position (time, frequency).
size : Size
NumPy array containing the size dimensions, matching the order
and meaning specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry, typically a `BoundingBox`.
Raises
------
ValueError
If the `size` array has an unexpected shape or if reconstruction
fails.
"""
...
class AnchorBBoxMapperConfig(BaseConfig):
"""Configuration for `AnchorBBoxMapper`.
@ -475,7 +408,10 @@ class PeakEnergyBBoxMapper(ROITargetMapper):
ROIMapperConfig = Annotated[
Union[AnchorBBoxMapperConfig, PeakEnergyBBoxMapperConfig],
Union[
AnchorBBoxMapperConfig,
PeakEnergyBBoxMapperConfig,
],
Field(discriminator="name"),
]
"""A discriminated union of all supported ROI mapper configurations.
@ -553,7 +489,7 @@ def _build_bounding_box(
) -> data.BoundingBox:
"""Construct a BoundingBox from a reference point, size, and position type.
Internal helper for `BBoxEncoder.recover_roi`. Calculates the box
Internal helper for `BBoxEncoder.decode`. Calculates the box
coordinates [start_time, low_freq, end_time, high_freq] based on where
the input `pos` (time, freq) is located relative to the box (e.g.,
center, corner).

View File

@ -227,3 +227,70 @@ class TargetProtocol(Protocol):
if reconstruction fails based on the configured position type.
"""
...
class ROITargetMapper(Protocol):
"""Protocol defining the interface for ROI-to-target mapping.
Specifies the `encode` and `decode` methods required for converting a
`soundevent.data.SoundEvent` into a target representation (a reference
position and a size vector) and for recovering an approximate ROI from that
representation.
Attributes
----------
dimension_names : List[str]
A list containing the names of the dimensions in the `Size` array
returned by `encode` and expected by `decode`.
"""
dimension_names: List[str]
def encode(self, sound_event: data.SoundEvent) -> tuple[Position, Size]:
"""Encode a SoundEvent's geometry into a position and size.
Parameters
----------
sound_event : data.SoundEvent
The input sound event, which must have a geometry attribute.
Returns
-------
Tuple[Position, Size]
A tuple containing:
- The reference position as (time, frequency) coordinates.
- A NumPy array with the calculated size dimensions.
Raises
------
ValueError
If the sound event does not have a geometry.
"""
...
def decode(self, position: Position, size: Size) -> data.Geometry:
"""Decode a position and size back into a geometric ROI.
Performs the inverse mapping: takes a reference position and size
dimensions and reconstructs a geometric representation.
Parameters
----------
position : Position
The reference position (time, frequency).
size : Size
NumPy array containing the size dimensions, matching the order
and meaning specified by `dimension_names`.
Returns
-------
soundevent.data.Geometry
The reconstructed geometry, typically a `BoundingBox`.
Raises
------
ValueError
If the `size` array has an unexpected shape or if reconstruction
fails.
"""
...

View File

@ -1,16 +0,0 @@
import pytest
from batdetect2.targets import terms
def test_tag_info_and_get_tag_from_info():
tag_info = TagInfo(value="Myotis myotis", key="event")
tag = terms.get_tag_from_info(tag_info)
assert tag.value == "Myotis myotis"
assert tag.term == terms.call_type
def test_get_tag_from_info_key_not_found():
tag_info = TagInfo(value="test", key="non_existent_key")
with pytest.raises(KeyError):
terms.get_tag_from_info(tag_info)