mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Separate targets and ROI
This commit is contained in:
parent
716b3a3778
commit
2b235e28bb
@ -32,5 +32,6 @@ classification_targets:
|
||||
value: Rhinolophus ferrumequinum
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
default:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
|
||||
2
justfile
2
justfile
@ -20,7 +20,7 @@ install:
|
||||
# Testing & Coverage
|
||||
# Run tests using pytest.
|
||||
test:
|
||||
uv run pytest {{TESTS_DIR}}
|
||||
uv run pytest -n auto {{TESTS_DIR}}
|
||||
|
||||
# Run tests and generate coverage data.
|
||||
coverage:
|
||||
|
||||
@ -88,6 +88,7 @@ dev = [
|
||||
"pandas-stubs>=2.2.2.240807",
|
||||
"python-lsp-server>=1.13.0",
|
||||
"deepdiff>=8.6.1",
|
||||
"pytest-xdist[psutil]>=3.8.0",
|
||||
]
|
||||
dvclive = ["dvclive>=3.48.2"]
|
||||
mlflow = ["mlflow>=3.1.1"]
|
||||
|
||||
@ -50,7 +50,13 @@ from batdetect2.postprocess import (
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetConfig,
|
||||
TargetProtocol,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
TrainingConfig,
|
||||
@ -70,6 +76,7 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig,
|
||||
logging_config: AppLoggingConfig,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
@ -86,6 +93,7 @@ class BatDetect2API:
|
||||
self.outputs_config = outputs_config
|
||||
self.logging_config = logging_config
|
||||
self.targets = targets
|
||||
self.roi_mapper = roi_mapper
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
@ -125,6 +133,7 @@ class BatDetect2API:
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
model_config=model_config or self.model_config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
@ -171,6 +180,7 @@ class BatDetect2API:
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
model_config=model_config or self.model_config,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
@ -205,6 +215,7 @@ class BatDetect2API:
|
||||
self.model,
|
||||
test_annotations,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
@ -391,6 +402,7 @@ class BatDetect2API:
|
||||
self.model,
|
||||
audio_files,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
output_transform=self.output_transform,
|
||||
@ -416,6 +428,7 @@ class BatDetect2API:
|
||||
self.model,
|
||||
clips,
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
output_transform=self.output_transform,
|
||||
@ -472,6 +485,7 @@ class BatDetect2API:
|
||||
config: BatDetect2Config,
|
||||
) -> "BatDetect2API":
|
||||
targets = build_targets(config=config.model.targets)
|
||||
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
@ -492,11 +506,13 @@ class BatDetect2API:
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
@ -504,7 +520,8 @@ class BatDetect2API:
|
||||
# to avoid device mismatch errors
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=build_targets(config=config.model.targets),
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
preprocessor=build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.model.preprocess,
|
||||
@ -524,6 +541,7 @@ class BatDetect2API:
|
||||
outputs_config=config.outputs,
|
||||
logging_config=config.logging,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
@ -561,15 +579,18 @@ class BatDetect2API:
|
||||
and targets_config != model_config.targets
|
||||
):
|
||||
targets = build_targets(config=targets_config)
|
||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||
model = build_model_with_new_targets(
|
||||
model=model,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
model_config = model_config.model_copy(
|
||||
update={"targets": targets_config}
|
||||
)
|
||||
|
||||
targets = build_targets(config=model_config.targets)
|
||||
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
|
||||
|
||||
audio_loader = build_audio_loader(config=audio_config)
|
||||
|
||||
@ -591,11 +612,13 @@ class BatDetect2API:
|
||||
output_transform = build_output_transform(
|
||||
config=outputs_config.transform,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=evaluation_config,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
@ -608,6 +631,7 @@ class BatDetect2API:
|
||||
outputs_config=outputs_config,
|
||||
logging_config=logging_config,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
|
||||
@ -16,7 +16,7 @@ from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
@ -25,6 +25,7 @@ def run_evaluate(
|
||||
model: Model,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
@ -46,6 +47,7 @@ def run_evaluate(
|
||||
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
roi_mapper = roi_mapper or model.roi_mapper
|
||||
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
@ -57,6 +59,7 @@ def run_evaluate(
|
||||
output_transform = build_output_transform(
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
evaluator = build_evaluator(
|
||||
config=evaluation_config,
|
||||
|
||||
@ -8,8 +8,8 @@ from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.evaluate.types import EvaluationTaskProtocol, EvaluatorProtocol
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections, ClipDetectionsTensor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
@ -67,17 +67,23 @@ class Evaluator:
|
||||
def build_evaluator(
|
||||
config: EvaluationConfig | dict | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
transform: OutputTransformProtocol | None = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
roi_mapper = roi_mapper or build_roi_mapping()
|
||||
|
||||
if config is None:
|
||||
config = EvaluationConfig()
|
||||
|
||||
if not isinstance(config, EvaluationConfig):
|
||||
config = EvaluationConfig.model_validate(config)
|
||||
|
||||
transform = transform or build_output_transform(targets=targets)
|
||||
transform = transform or build_output_transform(
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
return Evaluator(
|
||||
targets=targets,
|
||||
|
||||
@ -18,13 +18,14 @@ from batdetect2.outputs import (
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
model: Model,
|
||||
clips: Sequence[data.Clip],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
@ -45,10 +46,12 @@ def run_batch_inference(
|
||||
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
roi_mapper = roi_mapper or model.roi_mapper
|
||||
|
||||
output_transform = output_transform or build_output_transform(
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
loader = build_inference_loader(
|
||||
@ -78,6 +81,7 @@ def process_file_list(
|
||||
model: Model,
|
||||
paths: Sequence[data.PathLike],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
@ -101,6 +105,7 @@ def process_file_list(
|
||||
model,
|
||||
clips,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
batch_size=batch_size,
|
||||
|
||||
@ -20,7 +20,8 @@ class InferenceModule(LightningModule):
|
||||
self.model = model
|
||||
self.detection_threshold = detection_threshold
|
||||
self.output_transform = output_transform or build_output_transform(
|
||||
targets=model.targets
|
||||
targets=model.targets,
|
||||
roi_mapper=model.roi_mapper,
|
||||
)
|
||||
|
||||
def predict_step(
|
||||
|
||||
@ -74,7 +74,7 @@ from batdetect2.postprocess.types import (
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
@ -186,12 +186,15 @@ class Model(torch.nn.Module):
|
||||
targets : TargetProtocol
|
||||
Describes the set of target classes; used when building heads and
|
||||
during training target construction.
|
||||
roi_mapper : ROIMapperProtocol
|
||||
Maps geometries to target-size channels and back.
|
||||
"""
|
||||
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
targets: TargetProtocol
|
||||
roi_mapper: ROIMapperProtocol
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -199,12 +202,14 @@ class Model(torch.nn.Module):
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
self.roi_mapper = roi_mapper
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||
"""Run the full detection pipeline on a waveform tensor.
|
||||
@ -234,6 +239,7 @@ class Model(torch.nn.Module):
|
||||
def build_model(
|
||||
config: ModelConfig | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
postprocessor: PostprocessorProtocol | None = None,
|
||||
) -> Model:
|
||||
@ -272,10 +278,19 @@ def build_model(
|
||||
"""
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
|
||||
config = config or ModelConfig()
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
|
||||
targets_config = getattr(targets, "config", None)
|
||||
roi_config = (
|
||||
targets_config.roi
|
||||
if isinstance(targets_config, TargetConfig)
|
||||
else config.targets.roi
|
||||
)
|
||||
|
||||
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=config.samplerate,
|
||||
@ -286,6 +301,7 @@ def build_model(
|
||||
)
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
num_sizes=len(roi_mapper.dimension_names),
|
||||
config=config.architecture,
|
||||
)
|
||||
return Model(
|
||||
@ -293,16 +309,19 @@ def build_model(
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
|
||||
def build_model_with_new_targets(
|
||||
model: Model,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
) -> Model:
|
||||
"""Build a new model with a different target set."""
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
num_sizes=len(roi_mapper.dimension_names),
|
||||
backbone=model.detector.backbone,
|
||||
)
|
||||
|
||||
@ -311,4 +330,5 @@ def build_model_with_new_targets(
|
||||
postprocessor=model.postprocessor,
|
||||
preprocessor=model.preprocessor,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
@ -136,6 +136,7 @@ class Detector(DetectionModel):
|
||||
|
||||
def build_detector(
|
||||
num_classes: int,
|
||||
num_sizes: int = 2,
|
||||
config: BackboneConfig | None = None,
|
||||
backbone: BackboneModel | None = None,
|
||||
) -> DetectionModel:
|
||||
@ -181,6 +182,7 @@ def build_detector(
|
||||
)
|
||||
bbox_head = BBoxHead(
|
||||
in_channels=backbone.out_channels,
|
||||
num_sizes=num_sizes,
|
||||
)
|
||||
return Detector(
|
||||
backbone=backbone,
|
||||
|
||||
@ -165,14 +165,15 @@ class BBoxHead(nn.Module):
|
||||
1×1 convolution with 2 output channels (duration, bandwidth).
|
||||
"""
|
||||
|
||||
def __init__(self, in_channels: int):
|
||||
def __init__(self, in_channels: int, num_sizes: int = 2):
|
||||
"""Initialise the BBoxHead."""
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.num_sizes = num_sizes
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=2,
|
||||
out_channels=self.num_sizes,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
@ -28,7 +28,7 @@ from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionsTransformConfig",
|
||||
@ -55,10 +55,12 @@ class OutputTransform(OutputTransformProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
detection_transform_steps: Sequence[DetectionTransform] = (),
|
||||
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
|
||||
):
|
||||
self.targets = targets
|
||||
self.roi_mapper = roi_mapper
|
||||
self.detection_transform_steps = list(detection_transform_steps)
|
||||
self.clip_transform_steps = list(clip_transform_steps)
|
||||
|
||||
@ -89,7 +91,11 @@ class OutputTransform(OutputTransformProtocol):
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]:
|
||||
decoded = to_detections(detections.numpy(), targets=self.targets)
|
||||
decoded = to_detections(
|
||||
detections.numpy(),
|
||||
targets=self.targets,
|
||||
roi_mapper=self.roi_mapper,
|
||||
)
|
||||
shifted = shift_detections_to_start_time(
|
||||
decoded,
|
||||
start_time=start_time,
|
||||
@ -151,8 +157,9 @@ class OutputTransform(OutputTransformProtocol):
|
||||
def build_output_transform(
|
||||
config: OutputTransformConfig | dict | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
) -> OutputTransformProtocol:
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
|
||||
if config is None:
|
||||
config = OutputTransformConfig()
|
||||
@ -161,9 +168,11 @@ def build_output_transform(
|
||||
config = OutputTransformConfig.model_validate(config)
|
||||
|
||||
targets = targets or build_targets()
|
||||
roi_mapper = roi_mapper or build_roi_mapping()
|
||||
|
||||
return OutputTransform(
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
detection_transform_steps=[
|
||||
detection_transform_registry.build(transform_config)
|
||||
for transform_config in config.detection_transforms
|
||||
|
||||
@ -6,7 +6,7 @@ import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import ClipDetectionsArray, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
@ -25,6 +25,7 @@ DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||
def to_detections(
|
||||
detections: ClipDetectionsArray,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
) -> List[Detection]:
|
||||
predictions = []
|
||||
|
||||
@ -39,7 +40,7 @@ def to_detections(
|
||||
):
|
||||
highest_scoring_class = targets.class_names[class_scores.argmax()]
|
||||
|
||||
geom = targets.decode_roi(
|
||||
geom = roi_mapper.decode(
|
||||
(time, freq),
|
||||
dims,
|
||||
class_name=highest_scoring_class,
|
||||
|
||||
@ -4,7 +4,7 @@ from soundevent import data, plot
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.common import create_ax
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_annotation",
|
||||
@ -48,6 +48,7 @@ def plot_clip_annotation(
|
||||
def plot_anchor_points(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
figsize: tuple[int, int] | None = None,
|
||||
ax: Axes | None = None,
|
||||
size: int = 1,
|
||||
@ -63,7 +64,11 @@ def plot_anchor_points(
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
position, _ = targets.encode_roi(sound_event)
|
||||
class_name = targets.encode_class(sound_event)
|
||||
position, _ = roi_mapper.encode(
|
||||
sound_event.sound_event,
|
||||
class_name=class_name,
|
||||
)
|
||||
positions.append(position)
|
||||
|
||||
X, Y = zip(*positions, strict=False)
|
||||
|
||||
@ -10,7 +10,10 @@ from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
ROIMapperConfig,
|
||||
ROIMapperProtocol,
|
||||
ROIMappingConfig,
|
||||
build_roi_mapper,
|
||||
build_roi_mapping,
|
||||
)
|
||||
from batdetect2.targets.targets import (
|
||||
Targets,
|
||||
@ -30,12 +33,15 @@ from batdetect2.targets.types import (
|
||||
Size,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
SoundEventFilter,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AnchorBBoxMapperConfig",
|
||||
"Position",
|
||||
"ROIMappingConfig",
|
||||
"ROIMapperProtocol",
|
||||
"ROIMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"Size",
|
||||
@ -46,6 +52,7 @@ __all__ = [
|
||||
"TargetConfig",
|
||||
"TargetProtocol",
|
||||
"Targets",
|
||||
"build_roi_mapping",
|
||||
"build_roi_mapper",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
|
||||
@ -14,7 +14,6 @@ from batdetect2.data.conditions import (
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
from batdetect2.targets.rois import ROIMapperConfig
|
||||
from batdetect2.targets.terms import call_type, generic_class
|
||||
from batdetect2.targets.types import SoundEventDecoder, SoundEventEncoder
|
||||
|
||||
@ -39,8 +38,6 @@ class TargetClassConfig(BaseConfig):
|
||||
|
||||
assign_tags: List[data.Tag] = Field(default_factory=list)
|
||||
|
||||
roi: ROIMapperConfig | None = None
|
||||
|
||||
_match_if: SoundEventConditionConfig = PrivateAttr()
|
||||
|
||||
@model_validator(mode="after")
|
||||
|
||||
@ -9,7 +9,7 @@ from batdetect2.targets.classes import (
|
||||
DEFAULT_DETECTION_CLASS,
|
||||
TargetClassConfig,
|
||||
)
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
||||
from batdetect2.targets.rois import ROIMappingConfig
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
@ -25,7 +25,7 @@ class TargetConfig(BaseConfig):
|
||||
default_factory=lambda: DEFAULT_CLASSES
|
||||
)
|
||||
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
roi: ROIMappingConfig = Field(default_factory=ROIMappingConfig)
|
||||
|
||||
@field_validator("classification_targets")
|
||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||
|
||||
@ -29,7 +29,12 @@ from batdetect2.core.arrays import spec_to_xarray
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import Position, ROITargetMapper, Size
|
||||
from batdetect2.targets.types import (
|
||||
Position,
|
||||
ROIMapperProtocol,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
@ -40,12 +45,15 @@ __all__ = [
|
||||
"DEFAULT_TIME_SCALE",
|
||||
"PeakEnergyBBoxMapper",
|
||||
"PeakEnergyBBoxMapperConfig",
|
||||
"ROIMappingConfig",
|
||||
"ROIMapperProtocol",
|
||||
"ROIMapperConfig",
|
||||
"ROIMapperImportConfig",
|
||||
"ROITargetMapper",
|
||||
"SIZE_HEIGHT",
|
||||
"SIZE_ORDER",
|
||||
"SIZE_WIDTH",
|
||||
"build_roi_mapping",
|
||||
"build_roi_mapper",
|
||||
]
|
||||
|
||||
@ -456,6 +464,59 @@ implementations by using the `name` field as a discriminator.
|
||||
"""
|
||||
|
||||
|
||||
class ROIMappingConfig(BaseConfig):
|
||||
"""Configuration for class-aware ROI mapping.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
default : ROIMapperConfig
|
||||
Default mapper used when no class-specific override exists.
|
||||
overrides : dict[str, ROIMapperConfig]
|
||||
Optional class-specific mapper overrides by class name.
|
||||
"""
|
||||
|
||||
default: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
overrides: dict[str, ROIMapperConfig] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ClassAwareROIMapper(ROIMapperProtocol):
|
||||
"""Apply a default ROI mapper with optional per-class overrides."""
|
||||
|
||||
dimension_names: list[str]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default_mapper: ROITargetMapper,
|
||||
overrides: dict[str, ROITargetMapper] | None = None,
|
||||
):
|
||||
self.default_mapper = default_mapper
|
||||
self.overrides = overrides or {}
|
||||
self.dimension_names = list(default_mapper.dimension_names)
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sound_event: data.SoundEvent,
|
||||
class_name: str | None = None,
|
||||
) -> tuple[Position, Size]:
|
||||
mapper = self._select_mapper(class_name)
|
||||
return mapper.encode(sound_event)
|
||||
|
||||
def decode(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: str | None = None,
|
||||
) -> data.Geometry:
|
||||
mapper = self._select_mapper(class_name)
|
||||
return mapper.decode(position, size)
|
||||
|
||||
def _select_mapper(self, class_name: str | None = None) -> ROITargetMapper:
|
||||
if class_name is not None and class_name in self.overrides:
|
||||
return self.overrides[class_name]
|
||||
|
||||
return self.default_mapper
|
||||
|
||||
|
||||
def build_roi_mapper(
|
||||
config: ROIMapperConfig | None = None,
|
||||
) -> ROITargetMapper:
|
||||
@ -480,6 +541,36 @@ def build_roi_mapper(
|
||||
return roi_mapper_registry.build(config)
|
||||
|
||||
|
||||
def build_roi_mapping(
|
||||
config: ROIMappingConfig | None = None,
|
||||
) -> ROIMapperProtocol:
|
||||
"""Build a class-aware ROI mapper and validate consistency."""
|
||||
config = config or ROIMappingConfig()
|
||||
|
||||
default_mapper = build_roi_mapper(config.default)
|
||||
overrides = {
|
||||
class_name: build_roi_mapper(mapper_config)
|
||||
for class_name, mapper_config in config.overrides.items()
|
||||
}
|
||||
|
||||
expected = list(default_mapper.dimension_names)
|
||||
|
||||
for class_name, mapper in overrides.items():
|
||||
actual = list(mapper.dimension_names)
|
||||
|
||||
if actual != expected:
|
||||
raise ValueError(
|
||||
"All ROI mappers must share the same dimension order. "
|
||||
f"Default dimensions: {expected}, "
|
||||
f"class '{class_name}' dimensions: {actual}."
|
||||
)
|
||||
|
||||
return ClassAwareROIMapper(
|
||||
default_mapper=default_mapper,
|
||||
overrides=overrides,
|
||||
)
|
||||
|
||||
|
||||
VALID_ANCHORS = [
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
|
||||
@ -12,21 +12,21 @@ from batdetect2.targets.classes import (
|
||||
get_class_names_from_config,
|
||||
)
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
build_roi_mapper,
|
||||
from batdetect2.targets.types import (
|
||||
Position,
|
||||
ROIMapperProtocol,
|
||||
Size,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.targets.types import Position, Size, TargetProtocol
|
||||
|
||||
|
||||
class Targets(TargetProtocol):
|
||||
"""Encapsulates the complete configured target definition pipeline.
|
||||
"""Encapsulates the configured target class definition pipeline.
|
||||
|
||||
This class implements the `TargetProtocol`, holding the configured
|
||||
functions for filtering, transforming, encoding (tags to class name),
|
||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||
and back). It provides a high-level interface to apply these steps and
|
||||
access relevant metadata like class names and dimension names.
|
||||
functions for filtering, encoding (tags to class name), and decoding
|
||||
(class name to tags). Geometry ROI mapping is handled separately by
|
||||
``ROIMapperProtocol``.
|
||||
|
||||
Instances are typically created using the `build_targets` factory function
|
||||
or the `load_targets` convenience loader.
|
||||
@ -39,14 +39,10 @@ class Targets(TargetProtocol):
|
||||
generic_class_tags
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: list[str]
|
||||
detection_class_tags: list[data.Tag]
|
||||
dimension_names: list[str]
|
||||
detection_class_name: str
|
||||
|
||||
def __init__(self, config: TargetConfig):
|
||||
@ -63,10 +59,6 @@ class Targets(TargetProtocol):
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self._roi_mapper = build_roi_mapper(config.roi)
|
||||
|
||||
self.dimension_names = self._roi_mapper.dimension_names
|
||||
|
||||
self.class_names = get_class_names_from_config(
|
||||
config.classification_targets
|
||||
)
|
||||
@ -74,21 +66,6 @@ class Targets(TargetProtocol):
|
||||
self.detection_class_name = config.detection_target.name
|
||||
self.detection_class_tags = config.detection_target.assign_tags
|
||||
|
||||
self._roi_mapper_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classification_targets
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
if class_name not in self.class_names:
|
||||
# TODO: improve this warning
|
||||
logger.warning(
|
||||
"The ROI mapper overrides contains a class ({class_name}) "
|
||||
"not present in the class names.",
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
|
||||
@ -147,75 +124,10 @@ class Targets(TargetProtocol):
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
class_name = self.encode_class(sound_event)
|
||||
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].encode(
|
||||
sound_event.sound_event
|
||||
)
|
||||
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: str | None = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
un-scales the dimensions and reconstructs the geometry (typically a
|
||||
`BoundingBox`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos
|
||||
The reference position `(time, frequency)`.
|
||||
dims
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].decode(
|
||||
position,
|
||||
size,
|
||||
)
|
||||
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
classification_targets=DEFAULT_CLASSES,
|
||||
detection_target=DEFAULT_DETECTION_CLASS,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
@ -292,6 +204,7 @@ def load_targets(
|
||||
def iterate_encoded_sound_events(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
) -> Iterable[tuple[str | None, Position, Size]]:
|
||||
for sound_event in sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
@ -303,6 +216,9 @@ def iterate_encoded_sound_events(
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event)
|
||||
position, size = targets.encode_roi(sound_event)
|
||||
position, size = roi_mapper.encode(
|
||||
sound_event.sound_event,
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
yield class_name, position, size
|
||||
|
||||
@ -6,6 +6,7 @@ from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"Position",
|
||||
"ROIMapperProtocol",
|
||||
"ROITargetMapper",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
@ -26,7 +27,6 @@ class TargetProtocol(Protocol):
|
||||
class_names: list[str]
|
||||
detection_class_tags: list[data.Tag]
|
||||
detection_class_name: str
|
||||
dimension_names: list[str]
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool: ...
|
||||
|
||||
@ -37,6 +37,23 @@ class TargetProtocol(Protocol):
|
||||
|
||||
def decode_class(self, class_label: str) -> list[data.Tag]: ...
|
||||
|
||||
|
||||
class ROIMapperProtocol(Protocol):
|
||||
dimension_names: list[str]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sound_event: data.SoundEvent,
|
||||
class_name: str | None = None,
|
||||
) -> tuple[Position, Size]: ...
|
||||
|
||||
def decode(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: str | None = None,
|
||||
) -> data.Geometry: ...
|
||||
|
||||
def encode_roi(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
|
||||
@ -93,7 +93,8 @@ class ValidationMetrics(Callback):
|
||||
model = pl_module.model
|
||||
if self.output_transform is None:
|
||||
self.output_transform = build_output_transform(
|
||||
targets=model.targets
|
||||
targets=model.targets,
|
||||
roi_mapper=model.roi_mapper,
|
||||
)
|
||||
|
||||
output_transform = self.output_transform
|
||||
|
||||
@ -40,7 +40,7 @@ def build_checkpoint_callback(
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
||||
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
|
||||
@ -14,8 +14,12 @@ from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.targets import (
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
iterate_encoded_sound_events,
|
||||
)
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
from batdetect2.train.types import ClipLabeller, Heatmaps
|
||||
|
||||
__all__ = [
|
||||
@ -42,6 +46,7 @@ class LabelConfig(BaseConfig):
|
||||
|
||||
def build_clip_labeler(
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
config: LabelConfig | None = None,
|
||||
@ -53,12 +58,13 @@ def build_clip_labeler(
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
targets = targets or build_targets()
|
||||
roi_mapper = roi_mapper or build_roi_mapping()
|
||||
|
||||
return partial(
|
||||
generate_heatmaps,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
target_sigma=config.sigma,
|
||||
@ -73,6 +79,7 @@ def generate_heatmaps(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: torch.Tensor,
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
target_sigma: float = 3.0,
|
||||
@ -89,7 +96,7 @@ def generate_heatmaps(
|
||||
height = spec.shape[-2]
|
||||
width = spec.shape[-1]
|
||||
num_classes = len(targets.class_names)
|
||||
num_dims = len(targets.dimension_names)
|
||||
num_dims = len(roi_mapper.dimension_names)
|
||||
clip = clip_annotation.clip
|
||||
|
||||
# Initialize heatmaps
|
||||
@ -109,6 +116,7 @@ def generate_heatmaps(
|
||||
for class_name, (time, frequency), size in iterate_encoded_sound_events(
|
||||
clip_annotation.sound_events,
|
||||
targets,
|
||||
roi_mapper,
|
||||
):
|
||||
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
|
||||
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
|
||||
|
||||
@ -6,23 +6,24 @@ from lightning import Trainer, seed_everything
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
||||
from batdetect2.logging import (
|
||||
LoggerConfig,
|
||||
TensorBoardLoggerConfig,
|
||||
build_logger,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetProtocol,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
@ -39,6 +40,7 @@ def run_train(
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
model: Model | None = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
@ -69,8 +71,15 @@ def run_train(
|
||||
if model is not None:
|
||||
targets = targets or model.targets
|
||||
|
||||
if roi_mapper is None and targets is model.targets:
|
||||
roi_mapper = model.roi_mapper
|
||||
|
||||
targets = targets or build_targets(config=model_config.targets)
|
||||
|
||||
roi_mapper = roi_mapper or build_roi_mapping(
|
||||
config=model_config.targets.roi
|
||||
)
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
@ -80,6 +89,7 @@ def run_train(
|
||||
|
||||
labeller = labeller or build_clip_labeler(
|
||||
targets,
|
||||
roi_mapper,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=train_config.labels,
|
||||
@ -119,6 +129,7 @@ def run_train(
|
||||
evaluator=build_evaluator(
|
||||
train_config.validation,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
num_epochs=num_epochs,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from hypothesis import given
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
@ -9,6 +9,7 @@ from batdetect2.utils import audio_utils, detector_utils
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=1))
|
||||
@settings(deadline=None)
|
||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
@ -87,6 +88,7 @@ def test_pad_audio_without_fixed_size(duration: float):
|
||||
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
@settings(deadline=None)
|
||||
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||
duration: float,
|
||||
):
|
||||
|
||||
@ -8,15 +8,6 @@ from click.testing import CliRunner
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
def test_cli_detect_help() -> None:
|
||||
"""User story: get usage help for legacy detect command."""
|
||||
|
||||
result = CliRunner().invoke(cli, ["detect", "--help"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Detect bat calls in files in AUDIO_DIR" in result.output
|
||||
|
||||
|
||||
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
||||
"""User story: run legacy detect on example audio directory."""
|
||||
|
||||
|
||||
@ -41,6 +41,22 @@ def test_build_detector_custom_config():
|
||||
assert model.backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
def test_build_detector_custom_size_channels():
|
||||
num_classes = 3
|
||||
num_sizes = 4
|
||||
config = UNetBackboneConfig(in_channels=1, input_height=128)
|
||||
|
||||
model = build_detector(
|
||||
num_classes=num_classes,
|
||||
num_sizes=num_sizes,
|
||||
config=config,
|
||||
)
|
||||
|
||||
dummy = torch.randn(1, 1, 128, 64)
|
||||
output = model(dummy)
|
||||
assert output.size_preds.shape[1] == num_sizes
|
||||
|
||||
|
||||
def test_detector_forward_pass_shapes(dummy_spectrogram):
|
||||
"""Test that the forward pass produces correctly shaped outputs."""
|
||||
num_classes = 4
|
||||
|
||||
@ -6,6 +6,7 @@ from soundevent.geometry import compute_bounds
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.targets import build_roi_mapping
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
|
||||
@ -37,7 +38,9 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
|
||||
width = int(duration * sample_preprocessor.output_samplerate)
|
||||
spec = torch.zeros((1, height, width), dtype=torch.float32)
|
||||
|
||||
labeler = build_clip_labeler(targets=sample_targets)
|
||||
roi_mapper = build_roi_mapping()
|
||||
|
||||
labeler = build_clip_labeler(targets=sample_targets, roi_mapper=roi_mapper)
|
||||
heatmaps = labeler(clip_annotation, spec)
|
||||
|
||||
output = ModelOutput(
|
||||
@ -51,7 +54,10 @@ def test_annotation_roundtrip_through_postprocess_and_output_transform(
|
||||
clip_detection_tensors = postprocessor(output)
|
||||
assert len(clip_detection_tensors) == 1
|
||||
|
||||
transform = build_output_transform(targets=sample_targets)
|
||||
transform = build_output_transform(
|
||||
targets=sample_targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
clip_detections = transform.to_clip_detections(
|
||||
detections=clip_detection_tensors[0],
|
||||
clip=clip,
|
||||
|
||||
@ -12,6 +12,7 @@ from batdetect2.postprocess.types import (
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.targets import TargetConfig, build_roi_mapping
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
@ -27,9 +28,22 @@ def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
||||
)
|
||||
|
||||
|
||||
def _build_roi_mapper(targets: TargetProtocol):
|
||||
config_obj = getattr(targets, "config", None)
|
||||
target_config = (
|
||||
config_obj if isinstance(config_obj, TargetConfig) else None
|
||||
)
|
||||
return build_roi_mapping(
|
||||
config=(target_config.roi if target_config is not None else None),
|
||||
)
|
||||
|
||||
|
||||
def test_shift_time_to_clip_start(sample_targets: TargetProtocol):
|
||||
raw = _mock_clip_detections_tensor()
|
||||
transform = build_output_transform(targets=sample_targets)
|
||||
transform = build_output_transform(
|
||||
targets=sample_targets,
|
||||
roi_mapper=_build_roi_mapper(sample_targets),
|
||||
)
|
||||
|
||||
transformed = transform.to_detections(raw, start_time=2.5)
|
||||
start_time, _, end_time, _ = compute_bounds(transformed[0].geometry)
|
||||
@ -43,7 +57,10 @@ def test_to_clip_detections_shifts_by_clip_start(
|
||||
sample_targets: TargetProtocol,
|
||||
):
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
transform = build_output_transform(targets=sample_targets)
|
||||
transform = build_output_transform(
|
||||
targets=sample_targets,
|
||||
roi_mapper=_build_roi_mapper(sample_targets),
|
||||
)
|
||||
raw = _mock_clip_detections_tensor()
|
||||
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
||||
start_time, _, end_time, _ = compute_bounds(shifted.detections[0].geometry)
|
||||
@ -90,6 +107,7 @@ def test_detection_and_clip_transforms_applied_in_order(
|
||||
|
||||
transform = OutputTransform(
|
||||
targets=sample_targets,
|
||||
roi_mapper=_build_roi_mapper(sample_targets),
|
||||
detection_transform_steps=[boost_score, keep_high_score],
|
||||
clip_transform_steps=[tag_clip_transform],
|
||||
)
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
@ -22,8 +24,10 @@ from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
PeakEnergyBBoxMapper,
|
||||
PeakEnergyBBoxMapperConfig,
|
||||
ROIMappingConfig,
|
||||
_build_bounding_box,
|
||||
build_roi_mapper,
|
||||
build_roi_mapping,
|
||||
get_peak_energy_coordinates,
|
||||
)
|
||||
|
||||
@ -630,3 +634,43 @@ def test_build_roi_mapper_raises_error_for_unknown_name():
|
||||
# Then
|
||||
with pytest.raises(NotImplementedError):
|
||||
build_roi_mapper(DummyConfig()) # type: ignore
|
||||
|
||||
|
||||
def test_build_roi_mapping_applies_class_override():
|
||||
config = ROIMappingConfig(
|
||||
default=AnchorBBoxMapperConfig(anchor="bottom-left"),
|
||||
overrides={
|
||||
"myomyo": AnchorBBoxMapperConfig(anchor="top-left"),
|
||||
},
|
||||
)
|
||||
|
||||
mapper = build_roi_mapping(config=config)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
sound_event = data.SoundEvent(
|
||||
recording=data.Recording(
|
||||
path=Path("x.wav"),
|
||||
samplerate=256_000,
|
||||
channels=1,
|
||||
duration=1.0,
|
||||
),
|
||||
geometry=geometry,
|
||||
)
|
||||
|
||||
default_position, _ = mapper.encode(sound_event, class_name="pippip")
|
||||
override_position, _ = mapper.encode(sound_event, class_name="myomyo")
|
||||
|
||||
assert default_position == pytest.approx((0.1, 12_000))
|
||||
assert override_position == pytest.approx((0.1, 18_000))
|
||||
|
||||
|
||||
def test_build_roi_mapping_rejects_dimension_mismatch():
|
||||
config = ROIMappingConfig(
|
||||
default=AnchorBBoxMapperConfig(),
|
||||
overrides={
|
||||
"myomyo": PeakEnergyBBoxMapperConfig(),
|
||||
},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="same dimension order"):
|
||||
build_roi_mapping(config=config)
|
||||
|
||||
@ -1,9 +1,10 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||
|
||||
|
||||
def test_can_override_default_roi_mapper_per_class(
|
||||
@ -32,18 +33,21 @@ def test_can_override_default_roi_mapper_per_class(
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
default:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
overrides:
|
||||
myomyo:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
"""
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = TargetConfig.load(config_path)
|
||||
targets = build_targets(config)
|
||||
roi_mapper = build_roi_mapping(config=config.roi)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
@ -60,8 +64,17 @@ def test_can_override_default_roi_mapper_per_class(
|
||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||
)
|
||||
|
||||
(time1, freq1), _ = targets.encode_roi(se1)
|
||||
(time2, freq2), _ = targets.encode_roi(se2)
|
||||
class_name1 = targets.encode_class(se1)
|
||||
class_name2 = targets.encode_class(se2)
|
||||
|
||||
(time1, freq1), _ = roi_mapper.encode(
|
||||
se1.sound_event,
|
||||
class_name=class_name1,
|
||||
)
|
||||
(time2, freq2), _ = roi_mapper.encode(
|
||||
se2.sound_event,
|
||||
class_name=class_name2,
|
||||
)
|
||||
|
||||
assert time1 == time2 == 0.1
|
||||
assert freq1 == 12_000
|
||||
@ -95,18 +108,21 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
|
||||
roi:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
default:
|
||||
name: anchor_bbox
|
||||
anchor: bottom-left
|
||||
overrides:
|
||||
myomyo:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
"""
|
||||
config_path = create_temp_yaml(yaml_content)
|
||||
|
||||
config = TargetConfig.load(config_path)
|
||||
targets = build_targets(config)
|
||||
roi_mapper = build_roi_mapping(config=config.roi)
|
||||
|
||||
geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000])
|
||||
|
||||
@ -122,14 +138,14 @@ def test_roi_is_recovered_roundtrip_even_with_overriders(
|
||||
tags=[data.Tag(term=species, value="Myotis myotis")],
|
||||
)
|
||||
|
||||
position1, size1 = targets.encode_roi(se1)
|
||||
position2, size2 = targets.encode_roi(se2)
|
||||
position1, size1 = roi_mapper.encode(se1.sound_event, class_name="pippip")
|
||||
position2, size2 = roi_mapper.encode(se2.sound_event, class_name="myomyo")
|
||||
|
||||
class_name1 = targets.encode_class(se1)
|
||||
class_name2 = targets.encode_class(se2)
|
||||
|
||||
recovered1 = targets.decode_roi(position1, size1, class_name=class_name1)
|
||||
recovered2 = targets.decode_roi(position2, size2, class_name=class_name2)
|
||||
recovered1 = roi_mapper.decode(position1, size1, class_name=class_name1)
|
||||
recovered2 = roi_mapper.decode(position2, size2, class_name=class_name2)
|
||||
|
||||
assert recovered1 == geometry
|
||||
assert recovered2 == geometry
|
||||
|
||||
@ -42,28 +42,6 @@ def test_train_saves_checkpoint_in_requested_experiment_run_dir(
|
||||
assert checkpoints
|
||||
|
||||
|
||||
def test_train_without_validation_does_not_save_default_monitored_checkpoint(
|
||||
tmp_path: Path,
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
config = _build_fast_train_config()
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=None,
|
||||
train_config=config.train,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=tmp_path,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
assert not list(tmp_path.rglob("*.ckpt"))
|
||||
|
||||
|
||||
def test_train_without_validation_can_still_save_last_checkpoint(
|
||||
tmp_path: Path,
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
|
||||
@ -3,8 +3,8 @@ from pathlib import Path
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig
|
||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMappingConfig
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
|
||||
recording = data.Recording(
|
||||
@ -30,14 +30,17 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
):
|
||||
config = sample_target_config.model_copy(
|
||||
update=dict(
|
||||
roi=AnchorBBoxMapperConfig(
|
||||
time_scale=1,
|
||||
frequency_scale=1,
|
||||
roi=ROIMappingConfig(
|
||||
default=AnchorBBoxMapperConfig(
|
||||
time_scale=1,
|
||||
frequency_scale=1,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
targets = build_targets(config)
|
||||
roi_mapper = build_roi_mapping(config=config.roi)
|
||||
|
||||
clip_annotation = data.ClipAnnotation(
|
||||
clip=clip,
|
||||
@ -60,6 +63,7 @@ def test_generated_heatmap_are_non_zero_at_correct_positions(
|
||||
min_freq=0,
|
||||
max_freq=100,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
pippip_index = targets.class_names.index("pippip")
|
||||
myomyo_index = targets.class_names.index("myomyo")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user