From 9fa703b34b5ba265925bd62ccaf35f8227928faf Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 18 Mar 2026 13:58:52 +0000 Subject: [PATCH] Allow training on existing model --- src/batdetect2/api_v2.py | 64 +++++++++++++++------ src/batdetect2/cli/evaluate.py | 2 +- src/batdetect2/evaluate/evaluate.py | 4 +- src/batdetect2/inference/batch.py | 2 +- src/batdetect2/inference/clips.py | 8 +-- src/batdetect2/postprocess/remapping.py | 75 +++++++++++++++++++++++++ src/batdetect2/train/lightning.py | 2 + src/batdetect2/train/train.py | 56 +++++++++++++++++- tests/test_train/test_lightning.py | 41 +++++++++++++- 9 files changed, 228 insertions(+), 26 deletions(-) create mode 100644 src/batdetect2/postprocess/remapping.py diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 96006eb..5ff689e 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -91,6 +91,7 @@ class BatDetect2API: run_train( train_annotations=train_annotations, val_annotations=val_annotations, + model=self.model, targets=self.targets, model_config=self.config.model, train_config=self.config.train, @@ -116,7 +117,7 @@ class BatDetect2API: experiment_name: str | None = None, run_name: str | None = None, save_predictions: bool = True, - ) -> tuple[dict[str, float], list[list[Detection]]]: + ) -> tuple[dict[str, float], list[ClipDetections]]: return run_evaluate( self.model, test_annotations, @@ -183,8 +184,17 @@ class BatDetect2API: def process_file(self, audio_file: data.PathLike) -> ClipDetections: recording = data.Recording.from_file(audio_file, compute_hash=False) - wav = self.audio_loader.load_recording(recording) - detections = self.process_audio(wav) + + predictions = self.process_files( + [audio_file], + batch_size=self.config.inference.loader.batch_size, + ) + detections = [ + detection + for prediction in predictions + for detection in prediction.detections + ] + return ClipDetections( clip=data.Clip( uuid=recording.uuid, @@ -215,7 +225,7 @@ class BatDetect2API: outputs = self.model.detector(spec) - detections = self.model.postprocessor( + detections = self.postprocessor( outputs, )[0] return self.output_transform.to_detections( @@ -239,10 +249,14 @@ class BatDetect2API: return process_file_list( self.model, audio_files, - config=self.config, targets=self.targets, audio_loader=self.audio_loader, + audio_config=self.config.audio, preprocessor=self.preprocessor, + inference_config=self.config.inference, + output_config=self.config.outputs, + output_transform=self.output_transform, + batch_size=batch_size, num_workers=num_workers, ) @@ -257,8 +271,11 @@ class BatDetect2API: clips, targets=self.targets, audio_loader=self.audio_loader, + audio_config=self.config.audio, preprocessor=self.preprocessor, - config=self.config, + inference_config=self.config.inference, + output_config=self.config.outputs, + output_transform=self.output_transform, batch_size=batch_size, num_workers=num_workers, ) @@ -294,7 +311,7 @@ class BatDetect2API: def from_config( cls, config: BatDetect2Config, - ): + ) -> "BatDetect2API": targets = build_targets(config=config.model.targets) audio_loader = build_audio_loader(config=config.audio) @@ -309,12 +326,6 @@ class BatDetect2API: config=config.model.postprocess, ) - evaluator = build_evaluator(config=config.evaluation, targets=targets) - - # NOTE: Better to have a separate instance of preprocessor and - # postprocessor as these may be moved to another device. - model = build_model(config=config.model) - formatter = build_output_formatter( targets, config=config.outputs.format, @@ -324,6 +335,19 @@ class BatDetect2API: targets=targets, ) + evaluator = build_evaluator( + config=config.evaluation, + targets=targets, + transform=output_transform, + ) + + model = build_model( + config=config.model, + targets=targets, + preprocessor=preprocessor, + postprocessor=postprocessor, + ) + return cls( config=config, targets=targets, @@ -341,7 +365,7 @@ class BatDetect2API: cls, path: data.PathLike, config: BatDetect2Config | None = None, - ): + ) -> "BatDetect2API": from batdetect2.audio import AudioConfig model, model_config = load_model_from_checkpoint(path) @@ -368,8 +392,6 @@ class BatDetect2API: config=config.model.postprocess, ) - evaluator = build_evaluator(config=config.evaluation, targets=targets) - formatter = build_output_formatter( targets, config=config.outputs.format, @@ -379,6 +401,16 @@ class BatDetect2API: targets=targets, ) + evaluator = build_evaluator( + config=config.evaluation, + targets=targets, + transform=output_transform, + ) + + model.preprocessor = preprocessor + model.postprocessor = postprocessor + model.targets = targets + return cls( config=config, targets=targets, diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index 70458dc..f0d6517 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -26,7 +26,7 @@ def evaluate_command( base_dir: Path, config_path: Path | None, output_dir: Path = DEFAULT_OUTPUT_DIR, - num_workers: int | None = None, + num_workers: int = 0, experiment_name: str | None = None, run_name: str | None = None, ): diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index f5d7da9..1f74170 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -14,7 +14,7 @@ from batdetect2.logging import build_logger from batdetect2.models import Model from batdetect2.outputs import OutputsConfig, build_output_transform from batdetect2.outputs.types import OutputFormatterProtocol -from batdetect2.postprocess.types import Detection +from batdetect2.postprocess.types import ClipDetections from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets.types import TargetProtocol @@ -35,7 +35,7 @@ def run_evaluate( output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: str | None = None, run_name: str | None = None, -) -> tuple[dict[str, float], list[list[Detection]]]: +) -> tuple[dict[str, float], list[ClipDetections]]: audio_config = audio_config or AudioConfig() evaluation_config = evaluation_config or EvaluationConfig() diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 3ab2bb9..119b449 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -6,8 +6,8 @@ from soundevent import data from batdetect2.audio import AudioConfig from batdetect2.audio.loader import build_audio_loader from batdetect2.audio.types import AudioLoader -from batdetect2.inference import InferenceConfig from batdetect2.inference.clips import get_clips_from_files +from batdetect2.inference.config import InferenceConfig from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.lightning import InferenceModule from batdetect2.models import Model diff --git a/src/batdetect2/inference/clips.py b/src/batdetect2/inference/clips.py index b69e066..0f486d2 100644 --- a/src/batdetect2/inference/clips.py +++ b/src/batdetect2/inference/clips.py @@ -38,10 +38,10 @@ def get_recording_clips( discard_empty: bool = True, ) -> Sequence[data.Clip]: start_time = 0 - duration = recording.duration + recording_duration = recording.duration hop = duration * (1 - overlap) - num_clips = int(np.ceil(duration / hop)) + num_clips = int(np.ceil(recording_duration / hop)) if num_clips == 0: # This should only happen if the clip's duration is zero, @@ -53,8 +53,8 @@ def get_recording_clips( start = start_time + i * hop end = start + duration - if end > duration: - empty_duration = end - duration + if end > recording_duration: + empty_duration = end - recording_duration if empty_duration > max_empty and discard_empty: # Discard clips that contain too much empty space diff --git a/src/batdetect2/postprocess/remapping.py b/src/batdetect2/postprocess/remapping.py new file mode 100644 index 0000000..def0b04 --- /dev/null +++ b/src/batdetect2/postprocess/remapping.py @@ -0,0 +1,75 @@ +"""Remaps raw model output tensors to coordinate-aware xarray DataArrays. + +This module provides utility functions to convert the raw numerical outputs +(typically PyTorch tensors) from the BatDetect2 DNN model into +`xarray.DataArray` objects. This step adds coordinate information +(time in seconds, frequency in Hz) back to the model's predictions, making them +interpretable in the context of the original audio signal and facilitating +subsequent processing steps. + +Functions are provided for common BatDetect2 output types: detection heatmaps, +classification probability maps, size prediction maps, and potentially +intermediate features. +""" + +from typing import Dict, List + +import numpy as np +import torch +import xarray as xr +from soundevent.arrays import Dimensions + +from batdetect2.preprocess import MAX_FREQ, MIN_FREQ + +__all__ = [ + "to_xarray", +] + + +def to_xarray( + array: torch.Tensor | np.ndarray, + start_time: float, + end_time: float, + min_freq: float = MIN_FREQ, + max_freq: float = MAX_FREQ, + name: str = "xarray", + extra_dims: List[str] | None = None, + extra_coords: Dict[str, np.ndarray] | None = None, +) -> xr.DataArray: + if isinstance(array, torch.Tensor): + array = array.detach().cpu().numpy() + + extra_ndims = array.ndim - 2 + + if extra_ndims < 0: + raise ValueError( + "Input array must have at least 2 dimensions, " + f"got shape {array.shape}" + ) + + width = array.shape[-1] + height = array.shape[-2] + + times = np.linspace(start_time, end_time, width, endpoint=False) + freqs = np.linspace(min_freq, max_freq, height, endpoint=False) + + if extra_dims is None: + extra_dims = [f"dim_{i}" for i in range(extra_ndims)] + + if extra_coords is None: + extra_coords = {} + + return xr.DataArray( + data=array, + dims=[ + *extra_dims, + Dimensions.frequency.value, + Dimensions.time.value, + ], + coords={ + **extra_coords, + Dimensions.frequency.value: freqs, + Dimensions.time.value: times, + }, + name=name, + ) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 6e51ce1..0c4d6e2 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -111,6 +111,7 @@ def load_model_from_checkpoint( def build_training_module( model_config: ModelConfig | None = None, train_config: TrainingConfig | None = None, + model: Model | None = None, ) -> TrainingModule: if model_config is None: model_config = ModelConfig() @@ -121,4 +122,5 @@ def build_training_module( return TrainingModule( model_config=model_config.model_dump(mode="json"), train_config=train_config.model_dump(mode="json"), + model=model, ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 7b11c1b..8d59e17 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -11,7 +11,7 @@ from batdetect2.audio.types import AudioLoader from batdetect2.evaluate import build_evaluator from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import build_logger -from batdetect2.models import ModelConfig +from batdetect2.models import Model, ModelConfig, build_model from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import build_targets @@ -33,6 +33,7 @@ __all__ = [ def run_train( train_annotations: Sequence[data.ClipAnnotation], val_annotations: Sequence[data.ClipAnnotation] | None = None, + model: Model | None = None, targets: Optional["TargetProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, audio_loader: Optional["AudioLoader"] = None, @@ -57,10 +58,19 @@ def run_train( audio_config = audio_config or AudioConfig() train_config = train_config or TrainingConfig() + if model is not None: + _validate_model_compatibility(model=model, model_config=model_config) + + if model is not None: + targets = targets or model.targets + targets = targets or build_targets(config=model_config.targets) audio_loader = audio_loader or build_audio_loader(config=audio_config) + if model is not None: + preprocessor = preprocessor or model.preprocessor + preprocessor = preprocessor or build_preprocessor( input_samplerate=audio_loader.samplerate, config=model_config.preprocess, @@ -98,6 +108,7 @@ def run_train( module = build_training_module( model_config=model_config, train_config=train_config, + model=model, ) trainer = trainer or build_trainer( @@ -124,6 +135,49 @@ def run_train( return module +def _validate_model_compatibility( + model: Model, + model_config: ModelConfig, +) -> None: + reference_model = build_model(config=model_config) + + expected_shapes = { + key: tuple(value.shape) + for key, value in reference_model.state_dict().items() + } + actual_shapes = { + key: tuple(value.shape) for key, value in model.state_dict().items() + } + + expected_keys = set(expected_shapes) + actual_keys = set(actual_shapes) + + missing_keys = sorted(expected_keys - actual_keys) + if missing_keys: + key = missing_keys[0] + raise ValueError( + "Provided model is incompatible with model_config: " + f"missing state key '{key}'." + ) + + extra_keys = sorted(actual_keys - expected_keys) + if extra_keys: + key = extra_keys[0] + raise ValueError( + "Provided model is incompatible with model_config: " + f"unexpected state key '{key}'." + ) + + for key, expected_shape in expected_shapes.items(): + actual_shape = actual_shapes[key] + if actual_shape != expected_shape: + raise ValueError( + "Provided model is incompatible with model_config: " + f"shape mismatch for '{key}' (expected {expected_shape}, " + f"got {actual_shape})." + ) + + def build_trainer( config: TrainingConfig, evaluator: "EvaluatorProtocol", diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 2a16928..05f5fea 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -1,6 +1,7 @@ from pathlib import Path import lightning as L +import pytest import torch from deepdiff import DeepDiff from soundevent import data @@ -10,7 +11,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from batdetect2.api_v2 import BatDetect2API from batdetect2.audio.types import AudioLoader from batdetect2.config import BatDetect2Config -from batdetect2.models import ModelConfig +from batdetect2.models import ModelConfig, build_model +from batdetect2.targets.classes import TargetClassConfig from batdetect2.train import ( TrainingConfig, TrainingModule, @@ -223,3 +225,40 @@ def test_train_smoke_produces_loadable_checkpoint( ).unsqueeze(0) outputs = model(wav.unsqueeze(0)) assert outputs is not None + + +def test_build_training_module_uses_provided_model() -> None: + model = build_model(ModelConfig()) + + module = build_training_module( + model_config=ModelConfig(), + train_config=TrainingConfig(), + model=model, + ) + + assert module.model is model + + +def test_run_train_rejects_incompatible_model_config( + example_annotations: list[data.ClipAnnotation], +) -> None: + model = build_model(ModelConfig()) + incompatible_config = ModelConfig() + incompatible_config.targets.classification_targets.append( + TargetClassConfig( + name="dummy_class", + tags=[data.Tag(key="class", value="Dummy class")], + ) + ) + + with pytest.raises( + ValueError, + match="Provided model is incompatible with model_config", + ): + run_train( + train_annotations=example_annotations[:1], + val_annotations=None, + model=model, + model_config=incompatible_config, + train_config=TrainingConfig(), + )