Separate targets and ROI

This commit is contained in:
mbsantiago 2026-03-29 13:11:42 +01:00
parent 716b3a3778
commit 2b235e28bb
33 changed files with 403 additions and 201 deletions

View File

@ -32,5 +32,6 @@ classification_targets:
value: Rhinolophus ferrumequinum value: Rhinolophus ferrumequinum
roi: roi:
default:
name: anchor_bbox name: anchor_bbox
anchor: top-left anchor: top-left

View File

@ -20,7 +20,7 @@ install:
# Testing & Coverage # Testing & Coverage
# Run tests using pytest. # Run tests using pytest.
test: test:
uv run pytest {{TESTS_DIR}} uv run pytest -n auto {{TESTS_DIR}}
# Run tests and generate coverage data. # Run tests and generate coverage data.
coverage: coverage:

View File

@ -88,6 +88,7 @@ dev = [
"pandas-stubs>=2.2.2.240807", "pandas-stubs>=2.2.2.240807",
"python-lsp-server>=1.13.0", "python-lsp-server>=1.13.0",
"deepdiff>=8.6.1", "deepdiff>=8.6.1",
"pytest-xdist[psutil]>=3.8.0",
] ]
dvclive = ["dvclive>=3.48.2"] dvclive = ["dvclive>=3.48.2"]
mlflow = ["mlflow>=3.1.1"] mlflow = ["mlflow>=3.1.1"]

View File

@ -50,7 +50,13 @@ from batdetect2.postprocess import (
build_postprocessor, build_postprocessor,
) )
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
TargetProtocol,
build_roi_mapping,
build_targets,
)
from batdetect2.train import ( from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR, DEFAULT_CHECKPOINT_DIR,
TrainingConfig, TrainingConfig,
@ -70,6 +76,7 @@ class BatDetect2API:
outputs_config: OutputsConfig, outputs_config: OutputsConfig,
logging_config: AppLoggingConfig, logging_config: AppLoggingConfig,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
@ -86,6 +93,7 @@ class BatDetect2API:
self.outputs_config = outputs_config self.outputs_config = outputs_config
self.logging_config = logging_config self.logging_config = logging_config
self.targets = targets self.targets = targets
self.roi_mapper = roi_mapper
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
@ -125,6 +133,7 @@ class BatDetect2API:
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model, model=self.model,
targets=self.targets, targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config, model_config=model_config or self.model_config,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
@ -171,6 +180,7 @@ class BatDetect2API:
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model, model=self.model,
targets=self.targets, targets=self.targets,
roi_mapper=self.roi_mapper,
model_config=model_config or self.model_config, model_config=model_config or self.model_config,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
@ -205,6 +215,7 @@ class BatDetect2API:
self.model, self.model,
test_annotations, test_annotations,
targets=self.targets, targets=self.targets,
roi_mapper=self.roi_mapper,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
audio_config=audio_config or self.audio_config, audio_config=audio_config or self.audio_config,
@ -391,6 +402,7 @@ class BatDetect2API:
self.model, self.model,
audio_files, audio_files,
targets=self.targets, targets=self.targets,
roi_mapper=self.roi_mapper,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
output_transform=self.output_transform, output_transform=self.output_transform,
@ -416,6 +428,7 @@ class BatDetect2API:
self.model, self.model,
clips, clips,
targets=self.targets, targets=self.targets,
roi_mapper=self.roi_mapper,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
output_transform=self.output_transform, output_transform=self.output_transform,
@ -472,6 +485,7 @@ class BatDetect2API:
config: BatDetect2Config, config: BatDetect2Config,
) -> "BatDetect2API": ) -> "BatDetect2API":
targets = build_targets(config=config.model.targets) targets = build_targets(config=config.model.targets)
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
audio_loader = build_audio_loader(config=config.audio) audio_loader = build_audio_loader(config=config.audio)
@ -492,11 +506,13 @@ class BatDetect2API:
output_transform = build_output_transform( output_transform = build_output_transform(
config=config.outputs.transform, config=config.outputs.transform,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )
evaluator = build_evaluator( evaluator = build_evaluator(
config=config.evaluation, config=config.evaluation,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
transform=output_transform, transform=output_transform,
) )
@ -504,7 +520,8 @@ class BatDetect2API:
# to avoid device mismatch errors # to avoid device mismatch errors
model = build_model( model = build_model(
config=config.model, config=config.model,
targets=build_targets(config=config.model.targets), targets=targets,
roi_mapper=roi_mapper,
preprocessor=build_preprocessor( preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=config.model.preprocess, config=config.model.preprocess,
@ -524,6 +541,7 @@ class BatDetect2API:
outputs_config=config.outputs, outputs_config=config.outputs,
logging_config=config.logging, logging_config=config.logging,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
@ -561,15 +579,18 @@ class BatDetect2API:
and targets_config != model_config.targets and targets_config != model_config.targets
): ):
targets = build_targets(config=targets_config) targets = build_targets(config=targets_config)
roi_mapper = build_roi_mapping(config=targets_config.roi)
model = build_model_with_new_targets( model = build_model_with_new_targets(
model=model, model=model,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )
model_config = model_config.model_copy( model_config = model_config.model_copy(
update={"targets": targets_config} update={"targets": targets_config}
) )
targets = build_targets(config=model_config.targets) targets = build_targets(config=model_config.targets)
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
audio_loader = build_audio_loader(config=audio_config) audio_loader = build_audio_loader(config=audio_config)
@ -591,11 +612,13 @@ class BatDetect2API:
output_transform = build_output_transform( output_transform = build_output_transform(
config=outputs_config.transform, config=outputs_config.transform,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )
evaluator = build_evaluator( evaluator = build_evaluator(
config=evaluation_config, config=evaluation_config,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
transform=output_transform, transform=output_transform,
) )
@ -608,6 +631,7 @@ class BatDetect2API:
outputs_config=outputs_config, outputs_config=outputs_config,
logging_config=logging_config, logging_config=logging_config,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,

View File

@ -16,7 +16,7 @@ from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
@ -25,6 +25,7 @@ def run_evaluate(
model: Model, model: Model,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
@ -46,6 +47,7 @@ def run_evaluate(
preprocessor = preprocessor or model.preprocessor preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
loader = build_test_loader( loader = build_test_loader(
test_annotations, test_annotations,
@ -57,6 +59,7 @@ def run_evaluate(
output_transform = build_output_transform( output_transform = build_output_transform(
config=output_config.transform, config=output_config.transform,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )
evaluator = build_evaluator( evaluator = build_evaluator(
config=evaluation_config, config=evaluation_config,

View File

@ -8,8 +8,8 @@ from batdetect2.evaluate.tasks import build_task
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
from batdetect2.targets import build_targets from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [ __all__ = [
"Evaluator", "Evaluator",
@ -67,17 +67,23 @@ class Evaluator:
def build_evaluator( def build_evaluator(
config: EvaluationConfig | dict | None = None, config: EvaluationConfig | dict | None = None,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
transform: OutputTransformProtocol | None = None, transform: OutputTransformProtocol | None = None,
) -> EvaluatorProtocol: ) -> EvaluatorProtocol:
targets = targets or build_targets() targets = targets or build_targets()
roi_mapper = roi_mapper or build_roi_mapping()
if config is None: if config is None:
config = EvaluationConfig() config = EvaluationConfig()
if not isinstance(config, EvaluationConfig): if not isinstance(config, EvaluationConfig):
config = EvaluationConfig.model_validate(config) config = EvaluationConfig.model_validate(config)
transform = transform or build_output_transform(targets=targets) transform = transform or build_output_transform(
targets=targets,
roi_mapper=roi_mapper,
)
return Evaluator( return Evaluator(
targets=targets, targets=targets,

View File

@ -18,13 +18,14 @@ from batdetect2.outputs import (
) )
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
def run_batch_inference( def run_batch_inference(
model: Model, model: Model,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
@ -45,10 +46,12 @@ def run_batch_inference(
preprocessor = preprocessor or model.preprocessor preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets targets = targets or model.targets
roi_mapper = roi_mapper or model.roi_mapper
output_transform = output_transform or build_output_transform( output_transform = output_transform or build_output_transform(
config=output_config.transform, config=output_config.transform,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )
loader = build_inference_loader( loader = build_inference_loader(
@ -78,6 +81,7 @@ def process_file_list(
model: Model, model: Model,
paths: Sequence[data.PathLike], paths: Sequence[data.PathLike],
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
@ -101,6 +105,7 @@ def process_file_list(
model, model,
clips, clips,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
batch_size=batch_size, batch_size=batch_size,

View File

@ -20,7 +20,8 @@ class InferenceModule(LightningModule):
self.model = model self.model = model
self.detection_threshold = detection_threshold self.detection_threshold = detection_threshold
self.output_transform = output_transform or build_output_transform( self.output_transform = output_transform or build_output_transform(
targets=model.targets targets=model.targets,
roi_mapper=model.roi_mapper,
) )
def predict_step( def predict_step(

View File

@ -74,7 +74,7 @@ from batdetect2.postprocess.types import (
from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.config import TargetConfig from batdetect2.targets.config import TargetConfig
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [ __all__ = [
"BBoxHead", "BBoxHead",
@ -186,12 +186,15 @@ class Model(torch.nn.Module):
targets : TargetProtocol targets : TargetProtocol
Describes the set of target classes; used when building heads and Describes the set of target classes; used when building heads and
during training target construction. during training target construction.
roi_mapper : ROIMapperProtocol
Maps geometries to target-size channels and back.
""" """
detector: DetectionModel detector: DetectionModel
preprocessor: PreprocessorProtocol preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol postprocessor: PostprocessorProtocol
targets: TargetProtocol targets: TargetProtocol
roi_mapper: ROIMapperProtocol
def __init__( def __init__(
self, self,
@ -199,12 +202,14 @@ class Model(torch.nn.Module):
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
): ):
super().__init__() super().__init__()
self.detector = detector self.detector = detector
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.targets = targets
self.roi_mapper = roi_mapper
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]: def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
"""Run the full detection pipeline on a waveform tensor. """Run the full detection pipeline on a waveform tensor.
@ -234,6 +239,7 @@ class Model(torch.nn.Module):
def build_model( def build_model(
config: ModelConfig | None = None, config: ModelConfig | None = None,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
postprocessor: PostprocessorProtocol | None = None, postprocessor: PostprocessorProtocol | None = None,
) -> Model: ) -> Model:
@ -272,10 +278,19 @@ def build_model(
""" """
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_roi_mapping, build_targets
config = config or ModelConfig() config = config or ModelConfig()
targets = targets or build_targets(config=config.targets) targets = targets or build_targets(config=config.targets)
targets_config = getattr(targets, "config", None)
roi_config = (
targets_config.roi
if isinstance(targets_config, TargetConfig)
else config.targets.roi
)
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
config=config.preprocess, config=config.preprocess,
input_samplerate=config.samplerate, input_samplerate=config.samplerate,
@ -286,6 +301,7 @@ def build_model(
) )
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
config=config.architecture, config=config.architecture,
) )
return Model( return Model(
@ -293,16 +309,19 @@ def build_model(
postprocessor=postprocessor, postprocessor=postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )
def build_model_with_new_targets( def build_model_with_new_targets(
model: Model, model: Model,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> Model: ) -> Model:
"""Build a new model with a different target set.""" """Build a new model with a different target set."""
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), num_classes=len(targets.class_names),
num_sizes=len(roi_mapper.dimension_names),
backbone=model.detector.backbone, backbone=model.detector.backbone,
) )
@ -311,4 +330,5 @@ def build_model_with_new_targets(
postprocessor=model.postprocessor, postprocessor=model.postprocessor,
preprocessor=model.preprocessor, preprocessor=model.preprocessor,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )

View File

@ -136,6 +136,7 @@ class Detector(DetectionModel):
def build_detector( def build_detector(
num_classes: int, num_classes: int,
num_sizes: int = 2,
config: BackboneConfig | None = None, config: BackboneConfig | None = None,
backbone: BackboneModel | None = None, backbone: BackboneModel | None = None,
) -> DetectionModel: ) -> DetectionModel:
@ -181,6 +182,7 @@ def build_detector(
) )
bbox_head = BBoxHead( bbox_head = BBoxHead(
in_channels=backbone.out_channels, in_channels=backbone.out_channels,
num_sizes=num_sizes,
) )
return Detector( return Detector(
backbone=backbone, backbone=backbone,

View File

@ -165,14 +165,15 @@ class BBoxHead(nn.Module):
1×1 convolution with 2 output channels (duration, bandwidth). 1×1 convolution with 2 output channels (duration, bandwidth).
""" """
def __init__(self, in_channels: int): def __init__(self, in_channels: int, num_sizes: int = 2):
"""Initialise the BBoxHead.""" """Initialise the BBoxHead."""
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.num_sizes = num_sizes
self.bbox = nn.Conv2d( self.bbox = nn.Conv2d(
in_channels=self.in_channels, in_channels=self.in_channels,
out_channels=2, out_channels=self.num_sizes,
kernel_size=1, kernel_size=1,
padding=0, padding=0,
) )

View File

@ -28,7 +28,7 @@ from batdetect2.postprocess.types import (
ClipDetectionsTensor, ClipDetectionsTensor,
Detection, Detection,
) )
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [ __all__ = [
"ClipDetectionsTransformConfig", "ClipDetectionsTransformConfig",
@ -55,10 +55,12 @@ class OutputTransform(OutputTransformProtocol):
def __init__( def __init__(
self, self,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
detection_transform_steps: Sequence[DetectionTransform] = (), detection_transform_steps: Sequence[DetectionTransform] = (),
clip_transform_steps: Sequence[ClipDetectionsTransform] = (), clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
): ):
self.targets = targets self.targets = targets
self.roi_mapper = roi_mapper
self.detection_transform_steps = list(detection_transform_steps) self.detection_transform_steps = list(detection_transform_steps)
self.clip_transform_steps = list(clip_transform_steps) self.clip_transform_steps = list(clip_transform_steps)
@ -89,7 +91,11 @@ class OutputTransform(OutputTransformProtocol):
detections: ClipDetectionsTensor, detections: ClipDetectionsTensor,
start_time: float = 0, start_time: float = 0,
) -> list[Detection]: ) -> list[Detection]:
decoded = to_detections(detections.numpy(), targets=self.targets) decoded = to_detections(
detections.numpy(),
targets=self.targets,
roi_mapper=self.roi_mapper,
)
shifted = shift_detections_to_start_time( shifted = shift_detections_to_start_time(
decoded, decoded,
start_time=start_time, start_time=start_time,
@ -151,8 +157,9 @@ class OutputTransform(OutputTransformProtocol):
def build_output_transform( def build_output_transform(
config: OutputTransformConfig | dict | None = None, config: OutputTransformConfig | dict | None = None,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
) -> OutputTransformProtocol: ) -> OutputTransformProtocol:
from batdetect2.targets import build_targets from batdetect2.targets import build_roi_mapping, build_targets
if config is None: if config is None:
config = OutputTransformConfig() config = OutputTransformConfig()
@ -161,9 +168,11 @@ def build_output_transform(
config = OutputTransformConfig.model_validate(config) config = OutputTransformConfig.model_validate(config)
targets = targets or build_targets() targets = targets or build_targets()
roi_mapper = roi_mapper or build_roi_mapping()
return OutputTransform( return OutputTransform(
targets=targets, targets=targets,
roi_mapper=roi_mapper,
detection_transform_steps=[ detection_transform_steps=[
detection_transform_registry.build(transform_config) detection_transform_registry.build(transform_config)
for transform_config in config.detection_transforms for transform_config in config.detection_transforms

View File

@ -6,7 +6,7 @@ import numpy as np
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import ClipDetectionsArray, Detection from batdetect2.postprocess.types import ClipDetectionsArray, Detection
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [ __all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
@ -25,6 +25,7 @@ DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
def to_detections( def to_detections(
detections: ClipDetectionsArray, detections: ClipDetectionsArray,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> List[Detection]: ) -> List[Detection]:
predictions = [] predictions = []
@ -39,7 +40,7 @@ def to_detections(
): ):
highest_scoring_class = targets.class_names[class_scores.argmax()] highest_scoring_class = targets.class_names[class_scores.argmax()]
geom = targets.decode_roi( geom = roi_mapper.decode(
(time, freq), (time, freq),
dims, dims,
class_name=highest_scoring_class, class_name=highest_scoring_class,

View File

@ -4,7 +4,7 @@ from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import create_ax from batdetect2.plotting.common import create_ax
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
__all__ = [ __all__ = [
"plot_clip_annotation", "plot_clip_annotation",
@ -48,6 +48,7 @@ def plot_clip_annotation(
def plot_anchor_points( def plot_anchor_points(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
figsize: tuple[int, int] | None = None, figsize: tuple[int, int] | None = None,
ax: Axes | None = None, ax: Axes | None = None,
size: int = 1, size: int = 1,
@ -63,7 +64,11 @@ def plot_anchor_points(
if not targets.filter(sound_event): if not targets.filter(sound_event):
continue continue
position, _ = targets.encode_roi(sound_event) class_name = targets.encode_class(sound_event)
position, _ = roi_mapper.encode(
sound_event.sound_event,
class_name=class_name,
)
positions.append(position) positions.append(position)
X, Y = zip(*positions, strict=False) X, Y = zip(*positions, strict=False)

View File

@ -10,7 +10,10 @@ from batdetect2.targets.config import TargetConfig
from batdetect2.targets.rois import ( from batdetect2.targets.rois import (
AnchorBBoxMapperConfig, AnchorBBoxMapperConfig,
ROIMapperConfig, ROIMapperConfig,
ROIMapperProtocol,
ROIMappingConfig,
build_roi_mapper, build_roi_mapper,
build_roi_mapping,
) )
from batdetect2.targets.targets import ( from batdetect2.targets.targets import (
Targets, Targets,
@ -30,12 +33,15 @@ from batdetect2.targets.types import (
Size, Size,
SoundEventDecoder, SoundEventDecoder,
SoundEventEncoder, SoundEventEncoder,
SoundEventFilter,
TargetProtocol, TargetProtocol,
) )
__all__ = [ __all__ = [
"AnchorBBoxMapperConfig", "AnchorBBoxMapperConfig",
"Position", "Position",
"ROIMappingConfig",
"ROIMapperProtocol",
"ROIMapperConfig", "ROIMapperConfig",
"ROITargetMapper", "ROITargetMapper",
"Size", "Size",
@ -46,6 +52,7 @@ __all__ = [
"TargetConfig", "TargetConfig",
"TargetProtocol", "TargetProtocol",
"Targets", "Targets",
"build_roi_mapping",
"build_roi_mapper", "build_roi_mapper",
"build_sound_event_decoder", "build_sound_event_decoder",
"build_sound_event_encoder", "build_sound_event_encoder",

View File

@ -14,7 +14,6 @@ from batdetect2.data.conditions import (
SoundEventConditionConfig, SoundEventConditionConfig,
build_sound_event_condition, build_sound_event_condition,
) )
from batdetect2.targets.rois import ROIMapperConfig
from batdetect2.targets.terms import call_type, generic_class from batdetect2.targets.terms import call_type, generic_class
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
@ -39,8 +38,6 @@ class TargetClassConfig(BaseConfig):
assign_tags: List[data.Tag] = Field(default_factory=list) assign_tags: List[data.Tag] = Field(default_factory=list)
roi: ROIMapperConfig | None = None
_match_if: SoundEventConditionConfig = PrivateAttr() _match_if: SoundEventConditionConfig = PrivateAttr()
@model_validator(mode="after") @model_validator(mode="after")

View File

@ -9,7 +9,7 @@ from batdetect2.targets.classes import (
DEFAULT_DETECTION_CLASS, DEFAULT_DETECTION_CLASS,
TargetClassConfig, TargetClassConfig,
) )
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig from batdetect2.targets.rois import ROIMappingConfig
__all__ = [ __all__ = [
"TargetConfig", "TargetConfig",
@ -25,7 +25,7 @@ class TargetConfig(BaseConfig):
default_factory=lambda: DEFAULT_CLASSES default_factory=lambda: DEFAULT_CLASSES
) )
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig) roi: ROIMappingConfig = Field(default_factory=ROIMappingConfig)
@field_validator("classification_targets") @field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]): def check_unique_class_names(cls, v: List[TargetClassConfig]):

View File

@ -29,7 +29,12 @@ from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import Position, ROITargetMapper, Size from batdetect2.targets.types import (
Position,
ROIMapperProtocol,
ROITargetMapper,
Size,
)
__all__ = [ __all__ = [
"Anchor", "Anchor",
@ -40,12 +45,15 @@ __all__ = [
"DEFAULT_TIME_SCALE", "DEFAULT_TIME_SCALE",
"PeakEnergyBBoxMapper", "PeakEnergyBBoxMapper",
"PeakEnergyBBoxMapperConfig", "PeakEnergyBBoxMapperConfig",
"ROIMappingConfig",
"ROIMapperProtocol",
"ROIMapperConfig", "ROIMapperConfig",
"ROIMapperImportConfig", "ROIMapperImportConfig",
"ROITargetMapper", "ROITargetMapper",
"SIZE_HEIGHT", "SIZE_HEIGHT",
"SIZE_ORDER", "SIZE_ORDER",
"SIZE_WIDTH", "SIZE_WIDTH",
"build_roi_mapping",
"build_roi_mapper", "build_roi_mapper",
] ]
@ -456,6 +464,59 @@ implementations by using the `name` field as a discriminator.
""" """
class ROIMappingConfig(BaseConfig):
"""Configuration for class-aware ROI mapping.
Attributes
----------
default : ROIMapperConfig
Default mapper used when no class-specific override exists.
overrides : dict[str, ROIMapperConfig]
Optional class-specific mapper overrides by class name.
"""
default: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
overrides: dict[str, ROIMapperConfig] = Field(default_factory=dict)
class ClassAwareROIMapper(ROIMapperProtocol):
"""Apply a default ROI mapper with optional per-class overrides."""
dimension_names: list[str]
def __init__(
self,
default_mapper: ROITargetMapper,
overrides: dict[str, ROITargetMapper] | None = None,
):
self.default_mapper = default_mapper
self.overrides = overrides or {}
self.dimension_names = list(default_mapper.dimension_names)
def encode(
self,
sound_event: data.SoundEvent,
class_name: str | None = None,
) -> tuple[Position, Size]:
mapper = self._select_mapper(class_name)
return mapper.encode(sound_event)
def decode(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry:
mapper = self._select_mapper(class_name)
return mapper.decode(position, size)
def _select_mapper(self, class_name: str | None = None) -> ROITargetMapper:
if class_name is not None and class_name in self.overrides:
return self.overrides[class_name]
return self.default_mapper
def build_roi_mapper( def build_roi_mapper(
config: ROIMapperConfig | None = None, config: ROIMapperConfig | None = None,
) -> ROITargetMapper: ) -> ROITargetMapper:
@ -480,6 +541,36 @@ def build_roi_mapper(
return roi_mapper_registry.build(config) return roi_mapper_registry.build(config)
def build_roi_mapping(
config: ROIMappingConfig | None = None,
) -> ROIMapperProtocol:
"""Build a class-aware ROI mapper and validate consistency."""
config = config or ROIMappingConfig()
default_mapper = build_roi_mapper(config.default)
overrides = {
class_name: build_roi_mapper(mapper_config)
for class_name, mapper_config in config.overrides.items()
}
expected = list(default_mapper.dimension_names)
for class_name, mapper in overrides.items():
actual = list(mapper.dimension_names)
if actual != expected:
raise ValueError(
"All ROI mappers must share the same dimension order. "
f"Default dimensions: {expected}, "
f"class '{class_name}' dimensions: {actual}."
)
return ClassAwareROIMapper(
default_mapper=default_mapper,
overrides=overrides,
)
VALID_ANCHORS = [ VALID_ANCHORS = [
"bottom-left", "bottom-left",
"bottom-right", "bottom-right",

View File

@ -12,21 +12,21 @@ from batdetect2.targets.classes import (
get_class_names_from_config, get_class_names_from_config,
) )
from batdetect2.targets.config import TargetConfig from batdetect2.targets.config import TargetConfig
from batdetect2.targets.rois import ( from batdetect2.targets.types import (
AnchorBBoxMapperConfig, Position,
build_roi_mapper, ROIMapperProtocol,
Size,
TargetProtocol,
) )
from batdetect2.targets.types import Position, Size, TargetProtocol
class Targets(TargetProtocol): class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline. """Encapsulates the configured target class definition pipeline.
This class implements the `TargetProtocol`, holding the configured This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name), functions for filtering, encoding (tags to class name), and decoding
decoding (class name to tags), and mapping ROIs (geometry to position/size (class name to tags). Geometry ROI mapping is handled separately by
and back). It provides a high-level interface to apply these steps and ``ROIMapperProtocol``.
access relevant metadata like class names and dimension names.
Instances are typically created using the `build_targets` factory function Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader. or the `load_targets` convenience loader.
@ -39,14 +39,10 @@ class Targets(TargetProtocol):
generic_class_tags generic_class_tags
A list of `soundevent.data.Tag` objects representing the configured A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches). generic class category (used when no specific class matches).
dimension_names
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
""" """
class_names: list[str] class_names: list[str]
detection_class_tags: list[data.Tag] detection_class_tags: list[data.Tag]
dimension_names: list[str]
detection_class_name: str detection_class_name: str
def __init__(self, config: TargetConfig): def __init__(self, config: TargetConfig):
@ -63,10 +59,6 @@ class Targets(TargetProtocol):
config.classification_targets config.classification_targets
) )
self._roi_mapper = build_roi_mapper(config.roi)
self.dimension_names = self._roi_mapper.dimension_names
self.class_names = get_class_names_from_config( self.class_names = get_class_names_from_config(
config.classification_targets config.classification_targets
) )
@ -74,21 +66,6 @@ class Targets(TargetProtocol):
self.detection_class_name = config.detection_target.name self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation. """Apply the configured filter to a sound event annotation.
@ -147,75 +124,10 @@ class Targets(TargetProtocol):
""" """
return self._decode_fn(class_label) return self._decode_fn(class_label)
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry.
"""
class_name = self.encode_class(sound_event)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
)
return self._roi_mapper.encode(sound_event.sound_event)
def decode_roi(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos
The reference position `(time, frequency)`.
dims
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES, classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_DETECTION_CLASS, detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(),
) )
@ -292,6 +204,7 @@ def load_targets(
def iterate_encoded_sound_events( def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation], sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
) -> Iterable[tuple[str | None, Position, Size]]: ) -> Iterable[tuple[str | None, Position, Size]]:
for sound_event in sound_events: for sound_event in sound_events:
if not targets.filter(sound_event): if not targets.filter(sound_event):
@ -303,6 +216,9 @@ def iterate_encoded_sound_events(
continue continue
class_name = targets.encode_class(sound_event) class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event) position, size = roi_mapper.encode(
sound_event.sound_event,
class_name=class_name,
)
yield class_name, position, size yield class_name, position, size

View File

@ -6,6 +6,7 @@ from soundevent import data
__all__ = [ __all__ = [
"Position", "Position",
"ROIMapperProtocol",
"ROITargetMapper", "ROITargetMapper",
"Size", "Size",
"SoundEventDecoder", "SoundEventDecoder",
@ -26,7 +27,6 @@ class TargetProtocol(Protocol):
class_names: list[str] class_names: list[str]
detection_class_tags: list[data.Tag] detection_class_tags: list[data.Tag]
detection_class_name: str detection_class_name: str
dimension_names: list[str]
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ... def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
@ -37,6 +37,23 @@ class TargetProtocol(Protocol):
def decode_class(self, class_label: str) -> list[data.Tag]: ... def decode_class(self, class_label: str) -> list[data.Tag]: ...
class ROIMapperProtocol(Protocol):
dimension_names: list[str]
def encode(
self,
sound_event: data.SoundEvent,
class_name: str | None = None,
) -> tuple[Position, Size]: ...
def decode(
self,
position: Position,
size: Size,
class_name: str | None = None,
) -> data.Geometry: ...
def encode_roi( def encode_roi(
self, self,
sound_event: data.SoundEventAnnotation, sound_event: data.SoundEventAnnotation,

View File

@ -93,7 +93,8 @@ class ValidationMetrics(Callback):
model = pl_module.model model = pl_module.model
if self.output_transform is None: if self.output_transform is None:
self.output_transform = build_output_transform( self.output_transform = build_output_transform(
targets=model.targets targets=model.targets,
roi_mapper=model.roi_mapper,
) )
output_transform = self.output_transform output_transform = self.output_transform

View File

@ -40,7 +40,7 @@ def build_checkpoint_callback(
if run_name is not None: if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name checkpoint_dir = checkpoint_dir / run_name
checkpoint_dir.mkdir(parents=True, exist_ok=True) Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
return ModelCheckpoint( return ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),

View File

@ -14,8 +14,12 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events from batdetect2.targets import (
from batdetect2.targets.types import TargetProtocol build_roi_mapping,
build_targets,
iterate_encoded_sound_events,
)
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
from batdetect2.train.types import ClipLabeller, Heatmaps from batdetect2.train.types import ClipLabeller, Heatmaps
__all__ = [ __all__ = [
@ -42,6 +46,7 @@ class LabelConfig(BaseConfig):
def build_clip_labeler( def build_clip_labeler(
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
min_freq: float = MIN_FREQ, min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ, max_freq: float = MAX_FREQ,
config: LabelConfig | None = None, config: LabelConfig | None = None,
@ -53,12 +58,13 @@ def build_clip_labeler(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
if targets is None: targets = targets or build_targets()
targets = build_targets() roi_mapper = roi_mapper or build_roi_mapping()
return partial( return partial(
generate_heatmaps, generate_heatmaps,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
min_freq=min_freq, min_freq=min_freq,
max_freq=max_freq, max_freq=max_freq,
target_sigma=config.sigma, target_sigma=config.sigma,
@ -73,6 +79,7 @@ def generate_heatmaps(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
spec: torch.Tensor, spec: torch.Tensor,
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,
min_freq: float, min_freq: float,
max_freq: float, max_freq: float,
target_sigma: float = 3.0, target_sigma: float = 3.0,
@ -89,7 +96,7 @@ def generate_heatmaps(
height = spec.shape[-2] height = spec.shape[-2]
width = spec.shape[-1] width = spec.shape[-1]
num_classes = len(targets.class_names) num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names) num_dims = len(roi_mapper.dimension_names)
clip = clip_annotation.clip clip = clip_annotation.clip
# Initialize heatmaps # Initialize heatmaps
@ -109,6 +116,7 @@ def generate_heatmaps(
for class_name, (time, frequency), size in iterate_encoded_sound_events( for class_name, (time, frequency), size in iterate_encoded_sound_events(
clip_annotation.sound_events, clip_annotation.sound_events,
targets, targets,
roi_mapper,
): ):
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time) time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
freq_index = map_to_pixels(frequency, height, min_freq, max_freq) freq_index = map_to_pixels(frequency, height, min_freq, max_freq)

View File

@ -6,23 +6,24 @@ from lightning import Trainer, seed_everything
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.audio.types import AudioLoader from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
from batdetect2.evaluate import build_evaluator
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import ( from batdetect2.logging import (
LoggerConfig, LoggerConfig,
TensorBoardLoggerConfig, TensorBoardLoggerConfig,
build_logger, build_logger,
) )
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import (
from batdetect2.targets import build_targets ROIMapperProtocol,
from batdetect2.targets.types import TargetProtocol TargetProtocol,
from batdetect2.train import TrainingConfig build_roi_mapping,
build_targets,
)
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.checkpoints import build_checkpoint_callback from batdetect2.train.checkpoints import build_checkpoint_callback
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
@ -39,6 +40,7 @@ def run_train(
val_annotations: Sequence[data.ClipAnnotation] | None = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
model: Model | None = None, model: Model | None = None,
targets: Optional["TargetProtocol"] = None, targets: Optional["TargetProtocol"] = None,
roi_mapper: Optional["ROIMapperProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None, labeller: Optional["ClipLabeller"] = None,
@ -69,8 +71,15 @@ def run_train(
if model is not None: if model is not None:
targets = targets or model.targets targets = targets or model.targets
if roi_mapper is None and targets is model.targets:
roi_mapper = model.roi_mapper
targets = targets or build_targets(config=model_config.targets) targets = targets or build_targets(config=model_config.targets)
roi_mapper = roi_mapper or build_roi_mapping(
config=model_config.targets.roi
)
audio_loader = audio_loader or build_audio_loader(config=audio_config) audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
@ -80,6 +89,7 @@ def run_train(
labeller = labeller or build_clip_labeler( labeller = labeller or build_clip_labeler(
targets, targets,
roi_mapper,
min_freq=preprocessor.min_freq, min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq, max_freq=preprocessor.max_freq,
config=train_config.labels, config=train_config.labels,
@ -119,6 +129,7 @@ def run_train(
evaluator=build_evaluator( evaluator=build_evaluator(
train_config.validation, train_config.validation,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
), ),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs, num_epochs=num_epochs,

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from hypothesis import given from hypothesis import given, settings
from hypothesis import strategies as st from hypothesis import strategies as st
from batdetect2.detector import parameters from batdetect2.detector import parameters
@ -9,6 +9,7 @@ from batdetect2.utils import audio_utils, detector_utils
@given(duration=st.floats(min_value=0.1, max_value=1)) @given(duration=st.floats(min_value=0.1, max_value=1))
@settings(deadline=None)
def test_can_compute_correct_spectrogram_width(duration: float): def test_can_compute_correct_spectrogram_width(duration: float):
samplerate = parameters.TARGET_SAMPLERATE_HZ samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
@ -87,6 +88,7 @@ def test_pad_audio_without_fixed_size(duration: float):
@given(duration=st.floats(min_value=0.1, max_value=2)) @given(duration=st.floats(min_value=0.1, max_value=2))
@settings(deadline=None)
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor( def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
duration: float, duration: float,
): ):

View File

@ -8,15 +8,6 @@ from click.testing import CliRunner
from batdetect2.cli import cli from batdetect2.cli import cli
def test_cli_detect_help() -> None:
"""User story: get usage help for legacy detect command."""
result = CliRunner().invoke(cli, ["detect", "--help"])
assert result.exit_code == 0
assert "Detect bat calls in files in AUDIO_DIR" in result.output
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None: def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
"""User story: run legacy detect on example audio directory.""" """User story: run legacy detect on example audio directory."""

View File

@ -41,6 +41,22 @@ def test_build_detector_custom_config():
assert model.backbone.encoder.in_channels == 2 assert model.backbone.encoder.in_channels == 2
def test_build_detector_custom_size_channels():
num_classes = 3
num_sizes = 4
config = UNetBackboneConfig(in_channels=1, input_height=128)
model = build_detector(
num_classes=num_classes,
num_sizes=num_sizes,
config=config,
)
dummy = torch.randn(1, 1, 128, 64)
output = model(dummy)
assert output.size_preds.shape[1] == num_sizes
def test_detector_forward_pass_shapes(dummy_spectrogram): def test_detector_forward_pass_shapes(dummy_spectrogram):
"""Test that the forward pass produces correctly shaped outputs.""" """Test that the forward pass produces correctly shaped outputs."""
num_classes = 4 num_classes = 4

View File

@ -6,6 +6,7 @@ from soundevent.geometry import compute_bounds
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.outputs import build_output_transform from batdetect2.outputs import build_output_transform
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.targets import build_roi_mapping
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
@ -37,7 +38,9 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
width = int(duration * sample_preprocessor.output_samplerate) width = int(duration * sample_preprocessor.output_samplerate)
spec = torch.zeros((1, height, width), dtype=torch.float32) spec = torch.zeros((1, height, width), dtype=torch.float32)
labeler = build_clip_labeler(targets=sample_targets) roi_mapper = build_roi_mapping()
labeler = build_clip_labeler(targets=sample_targets, roi_mapper=roi_mapper)
heatmaps = labeler(clip_annotation, spec) heatmaps = labeler(clip_annotation, spec)
output = ModelOutput( output = ModelOutput(
@ -51,7 +54,10 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
clip_detection_tensors = postprocessor(output) clip_detection_tensors = postprocessor(output)
assert len(clip_detection_tensors) == 1 assert len(clip_detection_tensors) == 1
transform = build_output_transform(targets=sample_targets) transform = build_output_transform(
targets=sample_targets,
roi_mapper=roi_mapper,
)
clip_detections = transform.to_clip_detections( clip_detections = transform.to_clip_detections(
detections=clip_detection_tensors[0], detections=clip_detection_tensors[0],
clip=clip, clip=clip,

View File

@ -12,6 +12,7 @@ from batdetect2.postprocess.types import (
ClipDetectionsTensor, ClipDetectionsTensor,
Detection, Detection,
) )
from batdetect2.targets import TargetConfig, build_roi_mapping
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -27,9 +28,22 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
) )
def _build_roi_mapper(targets: TargetProtocol):
config_obj = getattr(targets, "config", None)
target_config = (
config_obj if isinstance(config_obj, TargetConfig) else None
)
return build_roi_mapping(
config=(target_config.roi if target_config is not None else None),
)
def test_shift_time_to_clip_start(sample_targets: TargetProtocol): def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
raw = _mock_clip_detections_tensor() raw = _mock_clip_detections_tensor()
transform = build_output_transform(targets=sample_targets) transform = build_output_transform(
targets=sample_targets,
roi_mapper=_build_roi_mapper(sample_targets),
)
transformed = transform.to_detections(raw, start_time=2.5) transformed = transform.to_detections(raw, start_time=2.5)
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry) start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
@ -43,7 +57,10 @@ def test_to_clip_detections_shifts_by_clip_start(
sample_targets: TargetProtocol, sample_targets: TargetProtocol,
): ):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
transform = build_output_transform(targets=sample_targets) transform = build_output_transform(
targets=sample_targets,
roi_mapper=_build_roi_mapper(sample_targets),
)
raw = _mock_clip_detections_tensor() raw = _mock_clip_detections_tensor()
shifted = transform.to_clip_detections(detections=raw, clip=clip) shifted = transform.to_clip_detections(detections=raw, clip=clip)
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry) start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
@ -90,6 +107,7 @@ def test_detection_and_clip_transforms_applied_in_order(
transform = OutputTransform( transform = OutputTransform(
targets=sample_targets, targets=sample_targets,
roi_mapper=_build_roi_mapper(sample_targets),
detection_transform_steps=[boost_score, keep_high_score], detection_transform_steps=[boost_score, keep_high_score],
clip_transform_steps=[tag_clip_transform], clip_transform_steps=[tag_clip_transform],
) )

View File

@ -1,3 +1,5 @@
from pathlib import Path
import numpy as np import numpy as np
import pytest import pytest
import soundfile as sf import soundfile as sf
@ -22,8 +24,10 @@ from batdetect2.targets.rois import (
AnchorBBoxMapperConfig, AnchorBBoxMapperConfig,
PeakEnergyBBoxMapper, PeakEnergyBBoxMapper,
PeakEnergyBBoxMapperConfig, PeakEnergyBBoxMapperConfig,
ROIMappingConfig,
_build_bounding_box, _build_bounding_box,
build_roi_mapper, build_roi_mapper,
build_roi_mapping,
get_peak_energy_coordinates, get_peak_energy_coordinates,
) )
@ -630,3 +634,43 @@ def test_build_roi_mapper_raises_error_for_unknown_name():
# Then # Then
with pytest.raises(NotImplementedError): with pytest.raises(NotImplementedError):
build_roi_mapper(DummyConfig()) # type: ignore build_roi_mapper(DummyConfig()) # type: ignore
def test_build_roi_mapping_applies_class_override():
config = ROIMappingConfig(
default=AnchorBBoxMapperConfig(anchor="bottom-left"),
overrides={
"myomyo": AnchorBBoxMapperConfig(anchor="top-left"),
},
)
mapper = build_roi_mapping(config=config)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
sound_event = data.SoundEvent(
recording=data.Recording(
path=Path("x.wav"),
samplerate=256_000,
channels=1,
duration=1.0,
),
geometry=geometry,
)
default_position, _ = mapper.encode(sound_event, class_name="pippip")
override_position, _ = mapper.encode(sound_event, class_name="myomyo")
assert default_position == pytest.approx((0.1, 12_000))
assert override_position == pytest.approx((0.1, 18_000))
def test_build_roi_mapping_rejects_dimension_mismatch():
config = ROIMappingConfig(
default=AnchorBBoxMapperConfig(),
overrides={
"myomyo": PeakEnergyBBoxMapperConfig(),
},
)
with pytest.raises(ValueError, match="same dimension order"):
build_roi_mapping(config=config)

View File

@ -1,9 +1,10 @@
from collections.abc import Callable from collections.abc import Callable
from pathlib import Path from pathlib import Path
import pytest
from soundevent import data, terms from soundevent import data, terms
from batdetect2.targets import TargetConfig, build_targets from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
def test_can_override_default_roi_mapper_per_class( def test_can_override_default_roi_mapper_per_class(
@ -32,18 +33,21 @@ def test_can_override_default_roi_mapper_per_class(
tags: tags:
- key: species - key: species
value: Myotis myotis value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
roi: roi:
default:
name: anchor_bbox name: anchor_bbox
anchor: bottom-left anchor: bottom-left
overrides:
myomyo:
name: anchor_bbox
anchor: top-left
""" """
config_path = create_temp_yaml(yaml_content) config_path = create_temp_yaml(yaml_content)
config = TargetConfig.load(config_path) config = TargetConfig.load(config_path)
targets = build_targets(config) targets = build_targets(config)
roi_mapper = build_roi_mapping(config=config.roi)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
@ -60,8 +64,17 @@ def test_can_override_default_roi_mapper_per_class(
tags=[data.Tag(term=species, value="Myotis myotis")], tags=[data.Tag(term=species, value="Myotis myotis")],
) )
(time1, freq1), _ = targets.encode_roi(se1) class_name1 = targets.encode_class(se1)
(time2, freq2), _ = targets.encode_roi(se2) class_name2 = targets.encode_class(se2)
(time1, freq1), _ = roi_mapper.encode(
se1.sound_event,
class_name=class_name1,
)
(time2, freq2), _ = roi_mapper.encode(
se2.sound_event,
class_name=class_name2,
)
assert time1 == time2 == 0.1 assert time1 == time2 == 0.1
assert freq1 == 12_000 assert freq1 == 12_000
@ -95,18 +108,21 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
tags: tags:
- key: species - key: species
value: Myotis myotis value: Myotis myotis
roi:
name: anchor_bbox
anchor: top-left
roi: roi:
default:
name: anchor_bbox name: anchor_bbox
anchor: bottom-left anchor: bottom-left
overrides:
myomyo:
name: anchor_bbox
anchor: top-left
""" """
config_path = create_temp_yaml(yaml_content) config_path = create_temp_yaml(yaml_content)
config = TargetConfig.load(config_path) config = TargetConfig.load(config_path)
targets = build_targets(config) targets = build_targets(config)
roi_mapper = build_roi_mapping(config=config.roi)
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
@ -122,14 +138,14 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
tags=[data.Tag(term=species, value="Myotis myotis")], tags=[data.Tag(term=species, value="Myotis myotis")],
) )
position1, size1 = targets.encode_roi(se1) position1, size1 = roi_mapper.encode(se1.sound_event, class_name="pippip")
position2, size2 = targets.encode_roi(se2) position2, size2 = roi_mapper.encode(se2.sound_event, class_name="myomyo")
class_name1 = targets.encode_class(se1) class_name1 = targets.encode_class(se1)
class_name2 = targets.encode_class(se2) class_name2 = targets.encode_class(se2)
recovered1 = targets.decode_roi(position1, size1, class_name=class_name1) recovered1 = roi_mapper.decode(position1, size1, class_name=class_name1)
recovered2 = targets.decode_roi(position2, size2, class_name=class_name2) recovered2 = roi_mapper.decode(position2, size2, class_name=class_name2)
assert recovered1 == geometry assert recovered1 == geometry
assert recovered2 == geometry assert recovered2 == geometry

View File

@ -42,28 +42,6 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
assert checkpoints assert checkpoints
def test_train_without_validation_does_not_save_default_monitored_checkpoint(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],
) -> None:
config = _build_fast_train_config()
run_train(
train_annotations=example_annotations[:1],
val_annotations=None,
train_config=config.train,
model_config=config.model,
audio_config=config.audio,
num_epochs=1,
train_workers=0,
val_workers=0,
checkpoint_dir=tmp_path,
seed=0,
)
assert not list(tmp_path.rglob("*.ckpt"))
def test_train_without_validation_can_still_save_last_checkpoint( def test_train_without_validation_can_still_save_last_checkpoint(
tmp_path: Path, tmp_path: Path,
example_annotations: list[data.ClipAnnotation], example_annotations: list[data.ClipAnnotation],

View File

@ -3,8 +3,8 @@ from pathlib import Path
import torch import torch
from soundevent import data from soundevent import data
from batdetect2.targets import TargetConfig, build_targets from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
from batdetect2.targets.rois import AnchorBBoxMapperConfig from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMappingConfig
from batdetect2.train.labels import generate_heatmaps from batdetect2.train.labels import generate_heatmaps
recording = data.Recording( recording = data.Recording(
@ -30,14 +30,17 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
): ):
config = sample_target_config.model_copy( config = sample_target_config.model_copy(
update=dict( update=dict(
roi=AnchorBBoxMapperConfig( roi=ROIMappingConfig(
default=AnchorBBoxMapperConfig(
time_scale=1, time_scale=1,
frequency_scale=1, frequency_scale=1,
) )
) )
) )
)
targets = build_targets(config) targets = build_targets(config)
roi_mapper = build_roi_mapping(config=config.roi)
clip_annotation = data.ClipAnnotation( clip_annotation = data.ClipAnnotation(
clip=clip, clip=clip,
@ -60,6 +63,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
min_freq=0, min_freq=0,
max_freq=100, max_freq=100,
targets=targets, targets=targets,
roi_mapper=roi_mapper,
) )
pippip_index = targets.class_names.index("pippip") pippip_index = targets.class_names.index("pippip")
myomyo_index = targets.class_names.index("myomyo") myomyo_index = targets.class_names.index("myomyo")