Allow training on existing model

This commit is contained in:
mbsantiago 2026-03-18 13:58:52 +00:00
parent 0bf809e376
commit 9fa703b34b
9 changed files with 228 additions and 26 deletions

View File

@ -91,6 +91,7 @@ class BatDetect2API:
run_train( run_train(
train_annotations=train_annotations, train_annotations=train_annotations,
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model,
targets=self.targets, targets=self.targets,
model_config=self.config.model, model_config=self.config.model,
train_config=self.config.train, train_config=self.config.train,
@ -116,7 +117,7 @@ class BatDetect2API:
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
save_predictions: bool = True, save_predictions: bool = True,
) -> tuple[dict[str, float], list[list[Detection]]]: ) -> tuple[dict[str, float], list[ClipDetections]]:
return run_evaluate( return run_evaluate(
self.model, self.model,
test_annotations, test_annotations,
@ -183,8 +184,17 @@ class BatDetect2API:
def process_file(self, audio_file: data.PathLike) -> ClipDetections: def process_file(self, audio_file: data.PathLike) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False) recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav) predictions = self.process_files(
[audio_file],
batch_size=self.config.inference.loader.batch_size,
)
detections = [
detection
for prediction in predictions
for detection in prediction.detections
]
return ClipDetections( return ClipDetections(
clip=data.Clip( clip=data.Clip(
uuid=recording.uuid, uuid=recording.uuid,
@ -215,7 +225,7 @@ class BatDetect2API:
outputs = self.model.detector(spec) outputs = self.model.detector(spec)
detections = self.model.postprocessor( detections = self.postprocessor(
outputs, outputs,
)[0] )[0]
return self.output_transform.to_detections( return self.output_transform.to_detections(
@ -239,10 +249,14 @@ class BatDetect2API:
return process_file_list( return process_file_list(
self.model, self.model,
audio_files, audio_files,
config=self.config,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
audio_config=self.config.audio,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
inference_config=self.config.inference,
output_config=self.config.outputs,
output_transform=self.output_transform,
batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
) )
@ -257,8 +271,11 @@ class BatDetect2API:
clips, clips,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
audio_config=self.config.audio,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
config=self.config, inference_config=self.config.inference,
output_config=self.config.outputs,
output_transform=self.output_transform,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
) )
@ -294,7 +311,7 @@ class BatDetect2API:
def from_config( def from_config(
cls, cls,
config: BatDetect2Config, config: BatDetect2Config,
): ) -> "BatDetect2API":
targets = build_targets(config=config.model.targets) targets = build_targets(config=config.model.targets)
audio_loader = build_audio_loader(config=config.audio) audio_loader = build_audio_loader(config=config.audio)
@ -309,12 +326,6 @@ class BatDetect2API:
config=config.model.postprocess, config=config.model.postprocess,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets)
# NOTE: Better to have a separate instance of preprocessor and
# postprocessor as these may be moved to another device.
model = build_model(config=config.model)
formatter = build_output_formatter( formatter = build_output_formatter(
targets, targets,
config=config.outputs.format, config=config.outputs.format,
@ -324,6 +335,19 @@ class BatDetect2API:
targets=targets, targets=targets,
) )
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
transform=output_transform,
)
model = build_model(
config=config.model,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
return cls( return cls(
config=config, config=config,
targets=targets, targets=targets,
@ -341,7 +365,7 @@ class BatDetect2API:
cls, cls,
path: data.PathLike, path: data.PathLike,
config: BatDetect2Config | None = None, config: BatDetect2Config | None = None,
): ) -> "BatDetect2API":
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig
model, model_config = load_model_from_checkpoint(path) model, model_config = load_model_from_checkpoint(path)
@ -368,8 +392,6 @@ class BatDetect2API:
config=config.model.postprocess, config=config.model.postprocess,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets)
formatter = build_output_formatter( formatter = build_output_formatter(
targets, targets,
config=config.outputs.format, config=config.outputs.format,
@ -379,6 +401,16 @@ class BatDetect2API:
targets=targets, targets=targets,
) )
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
transform=output_transform,
)
model.preprocessor = preprocessor
model.postprocessor = postprocessor
model.targets = targets
return cls( return cls(
config=config, config=config,
targets=targets, targets=targets,

View File

@ -26,7 +26,7 @@ def evaluate_command(
base_dir: Path, base_dir: Path,
config_path: Path | None, config_path: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR, output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: int | None = None, num_workers: int = 0,
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
): ):

View File

@ -14,7 +14,7 @@ from batdetect2.logging import build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputsConfig, build_output_transform from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import Detection from batdetect2.postprocess.types import ClipDetections
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -35,7 +35,7 @@ def run_evaluate(
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
) -> tuple[dict[str, float], list[list[Detection]]]: ) -> tuple[dict[str, float], list[ClipDetections]]:
audio_config = audio_config or AudioConfig() audio_config = audio_config or AudioConfig()
evaluation_config = evaluation_config or EvaluationConfig() evaluation_config = evaluation_config or EvaluationConfig()

View File

@ -6,8 +6,8 @@ from soundevent import data
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig
from batdetect2.audio.loader import build_audio_loader from batdetect2.audio.loader import build_audio_loader
from batdetect2.audio.types import AudioLoader from batdetect2.audio.types import AudioLoader
from batdetect2.inference import InferenceConfig
from batdetect2.inference.clips import get_clips_from_files from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model from batdetect2.models import Model

View File

@ -38,10 +38,10 @@ def get_recording_clips(
discard_empty: bool = True, discard_empty: bool = True,
) -> Sequence[data.Clip]: ) -> Sequence[data.Clip]:
start_time = 0 start_time = 0
duration = recording.duration recording_duration = recording.duration
hop = duration * (1 - overlap) hop = duration * (1 - overlap)
num_clips = int(np.ceil(duration / hop)) num_clips = int(np.ceil(recording_duration / hop))
if num_clips == 0: if num_clips == 0:
# This should only happen if the clip's duration is zero, # This should only happen if the clip's duration is zero,
@ -53,8 +53,8 @@ def get_recording_clips(
start = start_time + i * hop start = start_time + i * hop
end = start + duration end = start + duration
if end > duration: if end > recording_duration:
empty_duration = end - duration empty_duration = end - recording_duration
if empty_duration > max_empty and discard_empty: if empty_duration > max_empty and discard_empty:
# Discard clips that contain too much empty space # Discard clips that contain too much empty space

View File

@ -0,0 +1,75 @@
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
This module provides utility functions to convert the raw numerical outputs
(typically PyTorch tensors) from the BatDetect2 DNN model into
`xarray.DataArray` objects. This step adds coordinate information
(time in seconds, frequency in Hz) back to the model's predictions, making them
interpretable in the context of the original audio signal and facilitating
subsequent processing steps.
Functions are provided for common BatDetect2 output types: detection heatmaps,
classification probability maps, size prediction maps, and potentially
intermediate features.
"""
from typing import Dict, List
import numpy as np
import torch
import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
__all__ = [
"to_xarray",
]
def to_xarray(
array: torch.Tensor | np.ndarray,
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
name: str = "xarray",
extra_dims: List[str] | None = None,
extra_coords: Dict[str, np.ndarray] | None = None,
) -> xr.DataArray:
if isinstance(array, torch.Tensor):
array = array.detach().cpu().numpy()
extra_ndims = array.ndim - 2
if extra_ndims < 0:
raise ValueError(
"Input array must have at least 2 dimensions, "
f"got shape {array.shape}"
)
width = array.shape[-1]
height = array.shape[-2]
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
if extra_dims is None:
extra_dims = [f"dim_{i}" for i in range(extra_ndims)]
if extra_coords is None:
extra_coords = {}
return xr.DataArray(
data=array,
dims=[
*extra_dims,
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
**extra_coords,
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name=name,
)

View File

@ -111,6 +111,7 @@ def load_model_from_checkpoint(
def build_training_module( def build_training_module(
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
model: Model | None = None,
) -> TrainingModule: ) -> TrainingModule:
if model_config is None: if model_config is None:
model_config = ModelConfig() model_config = ModelConfig()
@ -121,4 +122,5 @@ def build_training_module(
return TrainingModule( return TrainingModule(
model_config=model_config.model_dump(mode="json"), model_config=model_config.model_dump(mode="json"),
train_config=train_config.model_dump(mode="json"), train_config=train_config.model_dump(mode="json"),
model=model,
) )

View File

@ -11,7 +11,7 @@ from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import build_evaluator from batdetect2.evaluate import build_evaluator
from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import build_logger from batdetect2.logging import build_logger
from batdetect2.models import ModelConfig from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
@ -33,6 +33,7 @@ __all__ = [
def run_train( def run_train(
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
model: Model | None = None,
targets: Optional["TargetProtocol"] = None, targets: Optional["TargetProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
@ -57,10 +58,19 @@ def run_train(
audio_config = audio_config or AudioConfig() audio_config = audio_config or AudioConfig()
train_config = train_config or TrainingConfig() train_config = train_config or TrainingConfig()
if model is not None:
_validate_model_compatibility(model=model, model_config=model_config)
if model is not None:
targets = targets or model.targets
targets = targets or build_targets(config=model_config.targets) targets = targets or build_targets(config=model_config.targets)
audio_loader = audio_loader or build_audio_loader(config=audio_config) audio_loader = audio_loader or build_audio_loader(config=audio_config)
if model is not None:
preprocessor = preprocessor or model.preprocessor
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
config=model_config.preprocess, config=model_config.preprocess,
@ -98,6 +108,7 @@ def run_train(
module = build_training_module( module = build_training_module(
model_config=model_config, model_config=model_config,
train_config=train_config, train_config=train_config,
model=model,
) )
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
@ -124,6 +135,49 @@ def run_train(
return module return module
def _validate_model_compatibility(
model: Model,
model_config: ModelConfig,
) -> None:
reference_model = build_model(config=model_config)
expected_shapes = {
key: tuple(value.shape)
for key, value in reference_model.state_dict().items()
}
actual_shapes = {
key: tuple(value.shape) for key, value in model.state_dict().items()
}
expected_keys = set(expected_shapes)
actual_keys = set(actual_shapes)
missing_keys = sorted(expected_keys - actual_keys)
if missing_keys:
key = missing_keys[0]
raise ValueError(
"Provided model is incompatible with model_config: "
f"missing state key '{key}'."
)
extra_keys = sorted(actual_keys - expected_keys)
if extra_keys:
key = extra_keys[0]
raise ValueError(
"Provided model is incompatible with model_config: "
f"unexpected state key '{key}'."
)
for key, expected_shape in expected_shapes.items():
actual_shape = actual_shapes[key]
if actual_shape != expected_shape:
raise ValueError(
"Provided model is incompatible with model_config: "
f"shape mismatch for '{key}' (expected {expected_shape}, "
f"got {actual_shape})."
)
def build_trainer( def build_trainer(
config: TrainingConfig, config: TrainingConfig,
evaluator: "EvaluatorProtocol", evaluator: "EvaluatorProtocol",

View File

@ -1,6 +1,7 @@
from pathlib import Path from pathlib import Path
import lightning as L import lightning as L
import pytest
import torch import torch
from deepdiff import DeepDiff from deepdiff import DeepDiff
from soundevent import data from soundevent import data
@ -10,7 +11,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio.types import AudioLoader from batdetect2.audio.types import AudioLoader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig, build_model
from batdetect2.targets.classes import TargetClassConfig
from batdetect2.train import ( from batdetect2.train import (
TrainingConfig, TrainingConfig,
TrainingModule, TrainingModule,
@ -223,3 +225,40 @@ def test_train_smoke_produces_loadable_checkpoint(
).unsqueeze(0) ).unsqueeze(0)
outputs = model(wav.unsqueeze(0)) outputs = model(wav.unsqueeze(0))
assert outputs is not None assert outputs is not None
def test_build_training_module_uses_provided_model() -> None:
model = build_model(ModelConfig())
module = build_training_module(
model_config=ModelConfig(),
train_config=TrainingConfig(),
model=model,
)
assert module.model is model
def test_run_train_rejects_incompatible_model_config(
example_annotations: list[data.ClipAnnotation],
) -> None:
model = build_model(ModelConfig())
incompatible_config = ModelConfig()
incompatible_config.targets.classification_targets.append(
TargetClassConfig(
name="dummy_class",
tags=[data.Tag(key="class", value="Dummy class")],
)
)
with pytest.raises(
ValueError,
match="Provided model is incompatible with model_config",
):
run_train(
train_annotations=example_annotations[:1],
val_annotations=None,
model=model,
model_config=incompatible_config,
train_config=TrainingConfig(),
)