Separate targets and ROI

This commit is contained in:
mbsantiago 2026-03-29 13:11:42 +01:00
parent 716b3a3778
commit 2b235e28bb
33 changed files with 403 additions and 201 deletions

View File

@ -32,5 +32,6 @@ classification_targets:
value: Rhinolophus ferrumequinum
roi:
default:
name: anchor_bbox
anchor: top-left

View File

@ -20,7 +20,7 @@ install:
# Testing & Coverage
# Run tests using pytest.
test:
uv run pytest {{TESTS_DIR}}
uv run pytest -n auto {{TESTS_DIR}}
# Run tests and generate coverage data.
coverage:

View File

@ -88,6 +88,7 @@ dev = [
"pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0",
"deepdiff>=8.6.1",
"pytest-xdist[psutil]>=3.8.0",
]
dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"]

View File

@ -50,7 +50,13 @@ from batdetect2.postprocess import (
build_postprocessor,
)
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
TargetProtocol,
build_roi_mapping,
build_targets,
)
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
@ -70,6 +76,7 @@ class BatDetect2API:
outputs_config: OutputsConfig,
logging_config: AppLoggingConfig,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
@ -86,6 +93,7 @@ class BatDetect2API:
self.outputs_config = outputs_config
self.logging_config = logging_config
self.targets = targets
self.roi_mapper = roi_mapper
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.postprocessor = postprocessor
@ -125,6 +133,7 @@ class BatDetect2API:
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
@ -171,6 +180,7 @@ class BatDetect2API:
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
@ -205,6 +215,7 @@ class BatDetect2API:
self.model,
test_annotations,
targets=self.targets,
roi_mapper=self.roi_mapper,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
audio_config=audio_config or self.audio_config,
@ -391,6 +402,7 @@ class BatDetect2API:
self.model,
audio_files,
targets=self.targets,
roi_mapper=self.roi_mapper,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
output_transform=self.output_transform,
@ -416,6 +428,7 @@ class BatDetect2API:
self.model,
clips,
targets=self.targets,
roi_mapper=self.roi_mapper,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
output_transform=self.output_transform,
@ -472,6 +485,7 @@ class BatDetect2API:
config: BatDetect2Config,
) -> "BatDetect2API":
targets = build_targets(config=config.model.targets)
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
audio_loader = build_audio_loader(config=config.audio)
@ -492,11 +506,13 @@ class BatDetect2API:
output_transform = build_output_transform(
config=config.outputs.transform,
targets=targets,
roi_mapper=roi_mapper,
)
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
roi_mapper=roi_mapper,
transform=output_transform,
)
@ -504,7 +520,8 @@ class BatDetect2API:
# to avoid device mismatch errors
model = build_model(
config=config.model,
targets=build_targets(config=config.model.targets),
targets=targets,
roi_mapper=roi_mapper,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.model.preprocess,
@ -524,6 +541,7 @@ class BatDetect2API:
outputs_config=config.outputs,
logging_config=config.logging,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,
@ -561,15 +579,18 @@ class BatDetect2API:
and targets_config != model_config.targets
):
targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
model = build_model_with_new_targets(
model=model,
targets=targets,
roi_mapper=roi_mapper,
)
model_config = model_config.model_copy(
update={"targets": targets_config}
)
targets = build_targets(config=model_config.targets)
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
audio_loader = build_audio_loader(config=audio_config)
@ -591,11 +612,13 @@ class BatDetect2API:
output_transform = build_output_transform(
config=outputs_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
evaluator = build_evaluator(
config=evaluation_config,
targets=targets,
roi_mapper=roi_mapper,
transform=output_transform,
)
@ -608,6 +631,7 @@ class BatDetect2API:
outputs_config=outputs_config,
logging_config=logging_config,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,

View File

@ -16,7 +16,7 @@ from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
@ -25,6 +25,7 @@ def run_evaluate(
model: Model,
test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
@ -46,6 +47,7 @@ def run_evaluate(
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
loader = build_test_loader(
test_annotations,
@ -57,6 +59,7 @@ def run_evaluate(
output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
evaluator = build_evaluator(
config=evaluation_config,

View File

@ -8,8 +8,8 @@ from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [
"Evaluator",
@ -67,17 +67,23 @@ class Evaluator:
def build_evaluator(
config: EvaluationConfig | dict | None = None,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
transform: OutputTransformProtocol | None = None,
) -> EvaluatorProtocol:
targets = targets or build_targets()
roi_mapper = roi_mapper or build_roi_mapping()
if config is None:
config = EvaluationConfig()
if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config)
transform = transform or build_output_transform(targets=targets)
transform = transform or build_output_transform(
targets=targets,
roi_mapper=roi_mapper,
)
return Evaluator(
targets=targets,

View File

@ -18,13 +18,14 @@ from batdetect2.outputs import (
)
from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
def run_batch_inference(
model: Model,
clips: Sequence[data.Clip],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
@ -45,10 +46,12 @@ def run_batch_inference(
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
output_transform = output_transform or build_output_transform(
config=output_config.transform,
targets=targets,
roi_mapper=roi_mapper,
)
loader = build_inference_loader(
@ -78,6 +81,7 @@ def process_file_list(
model: Model,
paths: Sequence[data.PathLike],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
audio_loader: AudioLoader | None = None,
audio_config: AudioConfig | None = None,
preprocessor: PreprocessorProtocol | None = None,
@ -101,6 +105,7 @@ def process_file_list(
model,
clips,
targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader,
preprocessor=preprocessor,
batch_size=batch_size,

View File

@ -20,7 +20,8 @@ class InferenceModule(LightningModule):
self.model = model
self.detection_threshold = detection_threshold
self.output_transform = output_transform or build_output_transform(
targets=model.targets
targets=model.targets,
roi_mapper=model.roi_mapper,
)
def predict_step(

View File

@ -74,7 +74,7 @@ from batdetect2.postprocess.types import (
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [
"BBoxHead",
@ -186,12 +186,15 @@ class Model(torch.nn.Module):
targets : TargetProtocol
Describes the set of target classes; used when building heads and
during training target construction.
roi_mapper : ROIMapperProtocol
Maps geometries to target-size channels and back.
"""
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
targets: TargetProtocol
roi_mapper: ROIMapperProtocol
def __init__(
self,
@ -199,12 +202,14 @@ class Model(torch.nn.Module):
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.roi_mapper = roi_mapper
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor.
@ -234,6 +239,7 @@ class Model(torch.nn.Module):
def build_model(
config: ModelConfig | None = None,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None,
) -> Model:
@ -272,10 +278,19 @@ def build_model(
"""
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.targets import build_roi_mapping, build_targets
config = config or ModelConfig()
targets = targets or build_targets(config=config.targets)
targets_config = getattr(targets, "config", None)
roi_config = (
targets_config.roi
if isinstance(targets_config, TargetConfig)
else config.targets.roi
)
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=config.samplerate,
@ -286,6 +301,7 @@ def build_model(
)
detector = build_detector(
num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
config=config.architecture,
)
return Model(
@ -293,16 +309,19 @@ def build_model(
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
roi_mapper=roi_mapper,
)
def build_model_with_new_targets(
model: Model,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> Model:
"""Build a new model with a different target set."""
detector = build_detector(
num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
backbone=model.detector.backbone,
)
@ -311,4 +330,5 @@ def build_model_with_new_targets(
postprocessor=model.postprocessor,
preprocessor=model.preprocessor,
targets=targets,
roi_mapper=roi_mapper,
)

View File

@ -136,6 +136,7 @@ class Detector(DetectionModel):
def build_detector(
num_classes: int,
num_sizes: int = 2,
config: BackboneConfig | None = None,
backbone: BackboneModel | None = None,
) -> DetectionModel:
@ -181,6 +182,7 @@ def build_detector(
)
bbox_head = BBoxHead(
in_channels=backbone.out_channels,
num_sizes=num_sizes,
)
return Detector(
backbone=backbone,

View File

@ -165,14 +165,15 @@ class BBoxHead(nn.Module):
1×1 convolution with 2 output channels (duration, bandwidth).
"""
def __init__(self, in_channels: int):
def __init__(self, in_channels: int, num_sizes: int = 2):
"""Initialise the BBoxHead."""
super().__init__()
self.in_channels = in_channels
self.num_sizes = num_sizes
self.bbox = nn.Conv2d(
in_channels=self.in_channels,
out_channels=2,
out_channels=self.num_sizes,
kernel_size=1,
padding=0,
)

View File

@ -28,7 +28,7 @@ from batdetect2.postprocess.types import (
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [
"ClipDetectionsTransformConfig",
@ -55,10 +55,12 @@ class OutputTransform(OutputTransformProtocol):
def __init__(
self,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
detection_transform_steps: Sequence[DetectionTransform] = (),
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
):
self.targets = targets
self.roi_mapper = roi_mapper
self.detection_transform_steps = list(detection_transform_steps)
self.clip_transform_steps = list(clip_transform_steps)
@ -89,7 +91,11 @@ class OutputTransform(OutputTransformProtocol):
detections: ClipDetectionsTensor,
start_time: float = 0,
) -> list[Detection]:
decoded = to_detections(detections.numpy(), targets=self.targets)
decoded = to_detections(
detections.numpy(),
targets=self.targets,
roi_mapper=self.roi_mapper,
)
shifted = shift_detections_to_start_time(
decoded,
start_time=start_time,
@ -151,8 +157,9 @@ class OutputTransform(OutputTransformProtocol):
def build_output_transform(
config: OutputTransformConfig | dict | None = None,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
) -> OutputTransformProtocol:
from batdetect2.targets import build_targets
from batdetect2.targets import build_roi_mapping, build_targets
if config is None:
config = OutputTransformConfig()
@ -161,9 +168,11 @@ def build_output_transform(
config = OutputTransformConfig.model_validate(config)
targets = targets or build_targets()
roi_mapper = roi_mapper or build_roi_mapping()
return OutputTransform(
targets=targets,
roi_mapper=roi_mapper,
detection_transform_steps=[
detection_transform_registry.build(transform_config)
for transform_config in config.detection_transforms

View File

@ -6,7 +6,7 @@ import numpy as np
from soundevent import data
from batdetect2.postprocess.types import ClipDetectionsArray, Detection
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
@ -25,6 +25,7 @@ DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
def to_detections(
detections: ClipDetectionsArray,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> List[Detection]:
predictions = []
@ -39,7 +40,7 @@ def to_detections(
):
highest_scoring_class = targets.class_names[class_scores.argmax()]
geom = targets.decode_roi(
geom = roi_mapper.decode(
(time, freq),
dims,
class_name=highest_scoring_class,

View File

@ -4,7 +4,7 @@ from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import create_ax
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [
"plot_clip_annotation",
@ -48,6 +48,7 @@ def plot_clip_annotation(
def plot_anchor_points(
clip_annotation: data.ClipAnnotation,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
figsize: tuple[int, int] | None = None,
ax: Axes | None = None,
size: int = 1,
@ -63,7 +64,11 @@ def plot_anchor_points(
if not targets.filter(sound_event):
continue
position, _ = targets.encode_roi(sound_event)
class_name = targets.encode_class(sound_event)
position, _ = roi_mapper.encode(
sound_event.sound_event,
class_name=class_name,
)
positions.append(position)
X, Y = zip(*positions, strict=False)

View File

@ -10,7 +10,10 @@ from batdetect2.targets.config import TargetConfig
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
ROIMapperConfig,
ROIMapperProtocol,
ROIMappingConfig,
build_roi_mapper,
build_roi_mapping,
)
from batdetect2.targets.targets import (
Targets,
@ -30,12 +33,15 @@ from batdetect2.targets.types import (
Size,
SoundEventDecoder,
SoundEventEncoder,
SoundEventFilter,
TargetProtocol,
)
__all__ = [
"AnchorBBoxMapperConfig",
"Position",
"ROIMappingConfig",
"ROIMapperProtocol",
"ROIMapperConfig",
"ROITargetMapper",
"Size",
@ -46,6 +52,7 @@ __all__ = [
"TargetConfig",
"TargetProtocol",
"Targets",
"build_roi_mapping",
"build_roi_mapper",
"build_sound_event_decoder",
"build_sound_event_encoder",

View File

@ -14,7 +14,6 @@ from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
)
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import call_type, generic_class
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
@ -39,8 +38,6 @@ class TargetClassConfig(BaseConfig):
assign_tags: List[data.Tag] = Field(default_factory=list)
roi: ROIMapperConfig | None = None
_match_if: SoundEventConditionConfig = PrivateAttr()
@model_validator(mode="after")

View File

@ -9,7 +9,7 @@ from batdetect2.targets.classes import (
DEFAULT_DETECTION_CLASS,
TargetClassConfig,
)
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
from batdetect2.targets.rois import ROIMappingConfig
__all__ = [
"TargetConfig",
@ -25,7 +25,7 @@ class TargetConfig(BaseConfig):
default_factory=lambda: DEFAULT_CLASSES
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
roi: ROIMappingConfig = Field(default_factory=ROIMappingConfig)
@field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]):

View File

@ -29,7 +29,12 @@ from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import Position, ROITargetMapper, Size
from batdetect2.targets.types import (
Position,
ROIMapperProtocol,
ROITargetMapper,
Size,
)
__all__ = [
"Anchor",
@ -40,12 +45,15 @@ __all__ = [
"DEFAULT_TIME_SCALE",
"PeakEnergyBBoxMapper",
"PeakEnergyBBoxMapperConfig",
"ROIMappingConfig",
"ROIMapperProtocol",
"ROIMapperConfig",
"ROIMapperImportConfig",
"ROITargetMapper",
"SIZE_HEIGHT",
"SIZE_ORDER",
"SIZE_WIDTH",
"build_roi_mapping",
"build_roi_mapper",
]
@ -456,6 +464,59 @@ implementations by using the `name` field as a discriminator.
"""
class ROIMappingConfig(BaseConfig):
"""Configuration for class-aware ROI mapping.
Attributes
----------
default : ROIMapperConfig
Default mapper used when no class-specific override exists.
overrides : dict[str, ROIMapperConfig]
Optional class-specific mapper overrides by class name.
"""
default: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
overrides: dict[str, ROIMapperConfig] = Field(default_factory=dict)
class ClassAwareROIMapper(ROIMapperProtocol):
"""Apply a default ROI mapper with optional per-class overrides."""
dimension_names: list[str]
def __init__(
self,
default_mapper: ROITargetMapper,
overrides: dict[str, ROITargetMapper] | None = None,
):
self.default_mapper = default_mapper
self.overrides = overrides or {}
self.dimension_names = list(default_mapper.dimension_names)
def encode(
self,
sound_event: data.SoundEvent,
class_name: str | None = None,
) -> tuple[Position, Size]:
mapper = self._select_mapper(class_name)
return mapper.encode(sound_event)
def decode(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry:
mapper = self._select_mapper(class_name)
return mapper.decode(position, size)
def _select_mapper(self, class_name: str | None = None) -> ROITargetMapper:
if class_name is not None and class_name in self.overrides:
return self.overrides[class_name]
return self.default_mapper
def build_roi_mapper(
config: ROIMapperConfig | None = None,
) -> ROITargetMapper:
@ -480,6 +541,36 @@ def build_roi_mapper(
return roi_mapper_registry.build(config)
def build_roi_mapping(
config: ROIMappingConfig | None = None,
) -> ROIMapperProtocol:
"""Build a class-aware ROI mapper and validate consistency."""
config = config or ROIMappingConfig()
default_mapper = build_roi_mapper(config.default)
overrides = {
class_name: build_roi_mapper(mapper_config)
for class_name, mapper_config in config.overrides.items()
}
expected = list(default_mapper.dimension_names)
for class_name, mapper in overrides.items():
actual = list(mapper.dimension_names)
if actual != expected:
raise ValueError(
"All ROI mappers must share the same dimension order. "
f"Default dimensions: {expected}, "
f"class '{class_name}' dimensions: {actual}."
)
return ClassAwareROIMapper(
default_mapper=default_mapper,
overrides=overrides,
)
VALID_ANCHORS = [
"bottom-left",
"bottom-right",

View File

@ -12,21 +12,21 @@ from batdetect2.targets.classes import (
get_class_names_from_config,
)
from batdetect2.targets.config import TargetConfig
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
build_roi_mapper,
from batdetect2.targets.types import (
Position,
ROIMapperProtocol,
Size,
TargetProtocol,
)
from batdetect2.targets.types import Position, Size, TargetProtocol
class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline.
"""Encapsulates the configured target class definition pipeline.
This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name),
decoding (class name to tags), and mapping ROIs (geometry to position/size
and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
functions for filtering, encoding (tags to class name), and decoding
(class name to tags). Geometry ROI mapping is handled separately by
``ROIMapperProtocol``.
Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader.
@ -39,14 +39,10 @@ class Targets(TargetProtocol):
generic_class_tags
A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches).
dimension_names
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: list[str]
detection_class_tags: list[data.Tag]
dimension_names: list[str]
detection_class_name: str
def __init__(self, config: TargetConfig):
@ -63,10 +59,6 @@ class Targets(TargetProtocol):
config.classification_targets
)
self._roi_mapper = build_roi_mapper(config.roi)
self.dimension_names = self._roi_mapper.dimension_names
self.class_names = get_class_names_from_config(
config.classification_targets
)
@ -74,21 +66,6 @@ class Targets(TargetProtocol):
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
@ -147,75 +124,10 @@ class Targets(TargetProtocol):
"""
return self._decode_fn(class_label)
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry.
"""
class_name = self.encode_class(sound_event)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
)
return self._roi_mapper.encode(sound_event.sound_event)
def decode_roi(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos
The reference position `(time, frequency)`.
dims
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(),
)
@ -292,6 +204,7 @@ def load_targets(
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> Iterable[tuple[str | None, Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
@ -303,6 +216,9 @@ def iterate_encoded_sound_events(
continue
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)
position, size = roi_mapper.encode(
sound_event.sound_event,
class_name=class_name,
)
yield class_name, position, size

View File

@ -6,6 +6,7 @@ from soundevent import data
__all__ = [
"Position",
"ROIMapperProtocol",
"ROITargetMapper",
"Size",
"SoundEventDecoder",
@ -26,7 +27,6 @@ class TargetProtocol(Protocol):
class_names: list[str]
detection_class_tags: list[data.Tag]
detection_class_name: str
dimension_names: list[str]
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
@ -37,6 +37,23 @@ class TargetProtocol(Protocol):
def decode_class(self, class_label: str) -> list[data.Tag]: ...
class ROIMapperProtocol(Protocol):
dimension_names: list[str]
def encode(
self,
sound_event: data.SoundEvent,
class_name: str | None = None,
) -> tuple[Position, Size]: ...
def decode(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry: ...
def encode_roi(
self,
sound_event: data.SoundEventAnnotation,

View File

@ -93,7 +93,8 @@ class ValidationMetrics(Callback):
model = pl_module.model
if self.output_transform is None:
self.output_transform = build_output_transform(
targets=model.targets
targets=model.targets,
roi_mapper=model.roi_mapper,
)
output_transform = self.output_transform

View File

@ -40,7 +40,7 @@ def build_checkpoint_callback(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
checkpoint_dir.mkdir(parents=True, exist_ok=True)
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
return ModelCheckpoint(
dirpath=str(checkpoint_dir),

View File

@ -14,8 +14,12 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.targets.types import TargetProtocol
from batdetect2.targets import (
build_roi_mapping,
build_targets,
iterate_encoded_sound_events,
)
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
from batdetect2.train.types import ClipLabeller, Heatmaps
__all__ = [
@ -42,6 +46,7 @@ class LabelConfig(BaseConfig):
def build_clip_labeler(
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
config: LabelConfig | None = None,
@ -53,12 +58,13 @@ def build_clip_labeler(
lambda: config.to_yaml_string(),
)
if targets is None:
targets = build_targets()
targets = targets or build_targets()
roi_mapper = roi_mapper or build_roi_mapping()
return partial(
generate_heatmaps,
targets=targets,
roi_mapper=roi_mapper,
min_freq=min_freq,
max_freq=max_freq,
target_sigma=config.sigma,
@ -73,6 +79,7 @@ def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
@ -89,7 +96,7 @@ def generate_heatmaps(
height = spec.shape[-2]
width = spec.shape[-1]
num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names)
num_dims = len(roi_mapper.dimension_names)
clip = clip_annotation.clip
# Initialize heatmaps
@ -109,6 +116,7 @@ def generate_heatmaps(
for class_name, (time, frequency), size in iterate_encoded_sound_events(
clip_annotation.sound_events,
targets,
roi_mapper,
):
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)

View File

@ -6,23 +6,24 @@ from lightning import Trainer, seed_everything
from loguru import logger
from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import build_evaluator
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
from batdetect2.logging import (
LoggerConfig,
TensorBoardLoggerConfig,
build_logger,
)
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train import TrainingConfig
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import (
ROIMapperProtocol,
TargetProtocol,
build_roi_mapping,
build_targets,
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.checkpoints import build_checkpoint_callback
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
@ -39,6 +40,7 @@ def run_train(
val_annotations: Sequence[data.ClipAnnotation] | None = None,
model: Model | None = None,
targets: Optional["TargetProtocol"] = None,
roi_mapper: Optional["ROIMapperProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None,
@ -69,8 +71,15 @@ def run_train(
if model is not None:
targets = targets or model.targets
if roi_mapper is None and targets is model.targets:
roi_mapper = model.roi_mapper
targets = targets or build_targets(config=model_config.targets)
roi_mapper = roi_mapper or build_roi_mapping(
config=model_config.targets.roi
)
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or build_preprocessor(
@ -80,6 +89,7 @@ def run_train(
labeller = labeller or build_clip_labeler(
targets,
roi_mapper,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=train_config.labels,
@ -119,6 +129,7 @@ def run_train(
evaluator=build_evaluator(
train_config.validation,
targets=targets,
roi_mapper=roi_mapper,
),
checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,

View File

@ -1,7 +1,7 @@
import numpy as np
import torch
import torch.nn.functional as F
from hypothesis import given
from hypothesis import given, settings
from hypothesis import strategies as st
from batdetect2.detector import parameters
@ -9,6 +9,7 @@ from batdetect2.utils import audio_utils, detector_utils
@given(duration=st.floats(min_value=0.1, max_value=1))
@settings(deadline=None)
def test_can_compute_correct_spectrogram_width(duration: float):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
@ -87,6 +88,7 @@ def test_pad_audio_without_fixed_size(duration: float):
@given(duration=st.floats(min_value=0.1, max_value=2))
@settings(deadline=None)
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
duration: float,
):

View File

@ -8,15 +8,6 @@ from click.testing import CliRunner
from batdetect2.cli import cli
def test_cli_detect_help() -> None:
"""User story: get usage help for legacy detect command."""
result = CliRunner().invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
"""User story: run legacy detect on example audio directory."""

View File

@ -41,6 +41,22 @@ def test_build_detector_custom_config():
assert model.backbone.encoder.in_channels == 2
def test_build_detector_custom_size_channels():
num_classes = 3
num_sizes = 4
config = UNetBackboneConfig(in_channels=1, input_height=128)
model = build_detector(
num_classes=num_classes,
num_sizes=num_sizes,
config=config,
)
dummy = torch.randn(1, 1, 128, 64)
output = model(dummy)
assert output.size_preds.shape[1] == num_sizes
def test_detector_forward_pass_shapes(dummy_spectrogram):
"""Test that the forward pass produces correctly shaped outputs."""
num_classes = 4

View File

@ -6,6 +6,7 @@ from soundevent.geometry import compute_bounds
from batdetect2.models.types import ModelOutput
from batdetect2.outputs import build_output_transform
from batdetect2.postprocess import build_postprocessor
from batdetect2.targets import build_roi_mapping
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler
@ -37,7 +38,9 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
width = int(duration * sample_preprocessor.output_samplerate)
spec = torch.zeros((1, height, width), dtype=torch.float32)
labeler = build_clip_labeler(targets=sample_targets)
roi_mapper = build_roi_mapping()
labeler = build_clip_labeler(targets=sample_targets, roi_mapper=roi_mapper)
heatmaps = labeler(clip_annotation, spec)
output = ModelOutput(
@ -51,7 +54,10 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
clip_detection_tensors = postprocessor(output)
assert len(clip_detection_tensors) == 1
transform = build_output_transform(targets=sample_targets)
transform = build_output_transform(
targets=sample_targets,
roi_mapper=roi_mapper,
)
clip_detections = transform.to_clip_detections(
detections=clip_detection_tensors[0],
clip=clip,

View File

@ -12,6 +12,7 @@ from batdetect2.postprocess.types import (
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets import TargetConfig, build_roi_mapping
from batdetect2.targets.types import TargetProtocol
@ -27,9 +28,22 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
)
def _build_roi_mapper(targets: TargetProtocol):
config_obj = getattr(targets, "config", None)
target_config = (
config_obj if isinstance(config_obj, TargetConfig) else None
)
return build_roi_mapping(
config=(target_config.roi if target_config is not None else None),
)
def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
raw = _mock_clip_detections_tensor()
transform = build_output_transform(targets=sample_targets)
transform = build_output_transform(
targets=sample_targets,
roi_mapper=_build_roi_mapper(sample_targets),
)
transformed = transform.to_detections(raw, start_time=2.5)
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
@ -43,7 +57,10 @@ def test_to_clip_detections_shifts_by_clip_start(
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
transform = build_output_transform(targets=sample_targets)
transform = build_output_transform(
targets=sample_targets,
roi_mapper=_build_roi_mapper(sample_targets),
)
raw = _mock_clip_detections_tensor()
shifted = transform.to_clip_detections(detections=raw, clip=clip)
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
@ -90,6 +107,7 @@ def test_detection_and_clip_transforms_applied_in_order(
transform = OutputTransform(
targets=sample_targets,
roi_mapper=_build_roi_mapper(sample_targets),
detection_transform_steps=[boost_score, keep_high_score],
clip_transform_steps=[tag_clip_transform],
)

View File

@ -1,3 +1,5 @@
from pathlib import Path
import numpy as np
import pytest
import soundfile as sf
@ -22,8 +24,10 @@ from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
PeakEnergyBBoxMapper,
PeakEnergyBBoxMapperConfig,
ROIMappingConfig,
_build_bounding_box,
build_roi_mapper,
build_roi_mapping,
get_peak_energy_coordinates,
)
@ -630,3 +634,43 @@ def test_build_roi_mapper_raises_error_for_unknown_name():
# Then
with pytest.raises(NotImplementedError):
build_roi_mapper(DummyConfig()) # type: ignore
def test_build_roi_mapping_applies_class_override():
config = ROIMappingConfig(
default=AnchorBBoxMapperConfig(anchor="bottom-left"),
overrides={
"myomyo": AnchorBBoxMapperConfig(anchor="top-left"),
},
)
mapper = build_roi_mapping(config=config)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
sound_event = data.SoundEvent(
recording=data.Recording(
path=Path("x.wav"),
samplerate=256_000,
channels=1,
duration=1.0,
),
geometry=geometry,
)
default_position, _ = mapper.encode(sound_event, class_name="pippip")
override_position, _ = mapper.encode(sound_event, class_name="myomyo")
assert default_position == pytest.approx((0.1, 12_000))
assert override_position == pytest.approx((0.1, 18_000))
def test_build_roi_mapping_rejects_dimension_mismatch():
config = ROIMappingConfig(
default=AnchorBBoxMapperConfig(),
overrides={
"myomyo": PeakEnergyBBoxMapperConfig(),
},
)
with pytest.raises(ValueError, match="same dimension order"):
build_roi_mapping(config=config)

View File

@ -1,9 +1,10 @@
from collections.abc import Callable
from pathlib import Path
import pytest
from soundevent import data, terms
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
def test_can_override_default_roi_mapper_per_class(
@ -32,18 +33,21 @@ def test_can_override_default_roi_mapper_per_class(
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
roi:
default:
name: anchor_bbox
anchor: bottom-left
overrides:
myomyo:
name: anchor_bbox
anchor: top-left
"""
config_path = create_temp_yaml(yaml_content)
config = TargetConfig.load(config_path)
targets = build_targets(config)
roi_mapper = build_roi_mapping(config=config.roi)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
@ -60,8 +64,17 @@ def test_can_override_default_roi_mapper_per_class(
tags=[data.Tag(term=species, value="Myotis myotis")],
)
(time1, freq1), _ = targets.encode_roi(se1)
(time2, freq2), _ = targets.encode_roi(se2)
class_name1 = targets.encode_class(se1)
class_name2 = targets.encode_class(se2)
(time1, freq1), _ = roi_mapper.encode(
se1.sound_event,
class_name=class_name1,
)
(time2, freq2), _ = roi_mapper.encode(
se2.sound_event,
class_name=class_name2,
)
assert time1 == time2 == 0.1
assert freq1 == 12_000
@ -95,18 +108,21 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
tags:
- key: species
value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
roi:
default:
name: anchor_bbox
anchor: bottom-left
overrides:
myomyo:
name: anchor_bbox
anchor: top-left
"""
config_path = create_temp_yaml(yaml_content)
config = TargetConfig.load(config_path)
targets = build_targets(config)
roi_mapper = build_roi_mapping(config=config.roi)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
@ -122,14 +138,14 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
tags=[data.Tag(term=species, value="Myotis myotis")],
)
position1, size1 = targets.encode_roi(se1)
position2, size2 = targets.encode_roi(se2)
position1, size1 = roi_mapper.encode(se1.sound_event, class_name="pippip")
position2, size2 = roi_mapper.encode(se2.sound_event, class_name="myomyo")
class_name1 = targets.encode_class(se1)
class_name2 = targets.encode_class(se2)
recovered1 = targets.decode_roi(position1, size1, class_name=class_name1)
recovered2 = targets.decode_roi(position2, size2, class_name=class_name2)
recovered1 = roi_mapper.decode(position1, size1, class_name=class_name1)
recovered2 = roi_mapper.decode(position2, size2, class_name=class_name2)
assert recovered1 == geometry
assert recovered2 == geometry

View File

@ -42,28 +42,6 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
assert checkpoints
def test_train_without_validation_does_not_save_default_monitored_checkpoint(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],
) -> None:
config = _build_fast_train_config()
run_train(
train_annotations=example_annotations[:1],
val_annotations=None,
train_config=config.train,
model_config=config.model,
audio_config=config.audio,
num_epochs=1,
train_workers=0,
val_workers=0,
checkpoint_dir=tmp_path,
seed=0,
)
assert not list(tmp_path.rglob("*.ckpt"))
def test_train_without_validation_can_still_save_last_checkpoint(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],

View File

@ -3,8 +3,8 @@ from pathlib import Path
import torch
from soundevent import data
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMappingConfig
from batdetect2.train.labels import generate_heatmaps
recording = data.Recording(
@ -30,14 +30,17 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
):
config = sample_target_config.model_copy(
update=dict(
roi=AnchorBBoxMapperConfig(
roi=ROIMappingConfig(
default=AnchorBBoxMapperConfig(
time_scale=1,
frequency_scale=1,
)
)
)
)
targets = build_targets(config)
roi_mapper = build_roi_mapping(config=config.roi)
clip_annotation = data.ClipAnnotation(
clip=clip,
@ -60,6 +63,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
min_freq=0,
max_freq=100,
targets=targets,
roi_mapper=roi_mapper,
)
pippip_index = targets.class_names.index("pippip")
myomyo_index = targets.class_names.index("myomyo")