mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Allow training on existing model
This commit is contained in:
parent
0bf809e376
commit
9fa703b34b
@ -91,6 +91,7 @@ class BatDetect2API:
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
model_config=self.config.model,
|
||||
train_config=self.config.train,
|
||||
@ -116,7 +117,7 @@ class BatDetect2API:
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
save_predictions: bool = True,
|
||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
return run_evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
@ -183,8 +184,17 @@ class BatDetect2API:
|
||||
|
||||
def process_file(self, audio_file: data.PathLike) -> ClipDetections:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
wav = self.audio_loader.load_recording(recording)
|
||||
detections = self.process_audio(wav)
|
||||
|
||||
predictions = self.process_files(
|
||||
[audio_file],
|
||||
batch_size=self.config.inference.loader.batch_size,
|
||||
)
|
||||
detections = [
|
||||
detection
|
||||
for prediction in predictions
|
||||
for detection in prediction.detections
|
||||
]
|
||||
|
||||
return ClipDetections(
|
||||
clip=data.Clip(
|
||||
uuid=recording.uuid,
|
||||
@ -215,7 +225,7 @@ class BatDetect2API:
|
||||
|
||||
outputs = self.model.detector(spec)
|
||||
|
||||
detections = self.model.postprocessor(
|
||||
detections = self.postprocessor(
|
||||
outputs,
|
||||
)[0]
|
||||
return self.output_transform.to_detections(
|
||||
@ -239,10 +249,14 @@ class BatDetect2API:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
config=self.config,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
audio_config=self.config.audio,
|
||||
preprocessor=self.preprocessor,
|
||||
inference_config=self.config.inference,
|
||||
output_config=self.config.outputs,
|
||||
output_transform=self.output_transform,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
@ -257,8 +271,11 @@ class BatDetect2API:
|
||||
clips,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
audio_config=self.config.audio,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
inference_config=self.config.inference,
|
||||
output_config=self.config.outputs,
|
||||
output_transform=self.output_transform,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
@ -294,7 +311,7 @@ class BatDetect2API:
|
||||
def from_config(
|
||||
cls,
|
||||
config: BatDetect2Config,
|
||||
):
|
||||
) -> "BatDetect2API":
|
||||
targets = build_targets(config=config.model.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
@ -309,12 +326,6 @@ class BatDetect2API:
|
||||
config=config.model.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
# NOTE: Better to have a separate instance of preprocessor and
|
||||
# postprocessor as these may be moved to another device.
|
||||
model = build_model(config=config.model)
|
||||
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=config.outputs.format,
|
||||
@ -324,6 +335,19 @@ class BatDetect2API:
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
@ -341,7 +365,7 @@ class BatDetect2API:
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: BatDetect2Config | None = None,
|
||||
):
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import AudioConfig
|
||||
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
@ -368,8 +392,6 @@ class BatDetect2API:
|
||||
config=config.model.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=config.outputs.format,
|
||||
@ -379,6 +401,16 @@ class BatDetect2API:
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
model.preprocessor = preprocessor
|
||||
model.postprocessor = postprocessor
|
||||
model.targets = targets
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
|
||||
@ -26,7 +26,7 @@ def evaluate_command(
|
||||
base_dir: Path,
|
||||
config_path: Path | None,
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
|
||||
@ -14,7 +14,7 @@ from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
@ -35,7 +35,7 @@ def run_evaluate(
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
|
||||
audio_config = audio_config or AudioConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
|
||||
@ -6,8 +6,8 @@ from soundevent import data
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.audio.loader import build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
|
||||
@ -38,10 +38,10 @@ def get_recording_clips(
|
||||
discard_empty: bool = True,
|
||||
) -> Sequence[data.Clip]:
|
||||
start_time = 0
|
||||
duration = recording.duration
|
||||
recording_duration = recording.duration
|
||||
hop = duration * (1 - overlap)
|
||||
|
||||
num_clips = int(np.ceil(duration / hop))
|
||||
num_clips = int(np.ceil(recording_duration / hop))
|
||||
|
||||
if num_clips == 0:
|
||||
# This should only happen if the clip's duration is zero,
|
||||
@ -53,8 +53,8 @@ def get_recording_clips(
|
||||
start = start_time + i * hop
|
||||
end = start + duration
|
||||
|
||||
if end > duration:
|
||||
empty_duration = end - duration
|
||||
if end > recording_duration:
|
||||
empty_duration = end - recording_duration
|
||||
|
||||
if empty_duration > max_empty and discard_empty:
|
||||
# Discard clips that contain too much empty space
|
||||
|
||||
75
src/batdetect2/postprocess/remapping.py
Normal file
75
src/batdetect2/postprocess/remapping.py
Normal file
@ -0,0 +1,75 @@
|
||||
"""Remaps raw model output tensors to coordinate-aware xarray DataArrays.
|
||||
|
||||
This module provides utility functions to convert the raw numerical outputs
|
||||
(typically PyTorch tensors) from the BatDetect2 DNN model into
|
||||
`xarray.DataArray` objects. This step adds coordinate information
|
||||
(time in seconds, frequency in Hz) back to the model's predictions, making them
|
||||
interpretable in the context of the original audio signal and facilitating
|
||||
subsequent processing steps.
|
||||
|
||||
Functions are provided for common BatDetect2 output types: detection heatmaps,
|
||||
classification probability maps, size prediction maps, and potentially
|
||||
intermediate features.
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
|
||||
__all__ = [
|
||||
"to_xarray",
|
||||
]
|
||||
|
||||
|
||||
def to_xarray(
|
||||
array: torch.Tensor | np.ndarray,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
name: str = "xarray",
|
||||
extra_dims: List[str] | None = None,
|
||||
extra_coords: Dict[str, np.ndarray] | None = None,
|
||||
) -> xr.DataArray:
|
||||
if isinstance(array, torch.Tensor):
|
||||
array = array.detach().cpu().numpy()
|
||||
|
||||
extra_ndims = array.ndim - 2
|
||||
|
||||
if extra_ndims < 0:
|
||||
raise ValueError(
|
||||
"Input array must have at least 2 dimensions, "
|
||||
f"got shape {array.shape}"
|
||||
)
|
||||
|
||||
width = array.shape[-1]
|
||||
height = array.shape[-2]
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
if extra_dims is None:
|
||||
extra_dims = [f"dim_{i}" for i in range(extra_ndims)]
|
||||
|
||||
if extra_coords is None:
|
||||
extra_coords = {}
|
||||
|
||||
return xr.DataArray(
|
||||
data=array,
|
||||
dims=[
|
||||
*extra_dims,
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
**extra_coords,
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name=name,
|
||||
)
|
||||
@ -111,6 +111,7 @@ def load_model_from_checkpoint(
|
||||
def build_training_module(
|
||||
model_config: ModelConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
model: Model | None = None,
|
||||
) -> TrainingModule:
|
||||
if model_config is None:
|
||||
model_config = ModelConfig()
|
||||
@ -121,4 +122,5 @@ def build_training_module(
|
||||
return TrainingModule(
|
||||
model_config=model_config.model_dump(mode="json"),
|
||||
train_config=train_config.model_dump(mode="json"),
|
||||
model=model,
|
||||
)
|
||||
|
||||
@ -11,7 +11,7 @@ from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import build_targets
|
||||
@ -33,6 +33,7 @@ __all__ = [
|
||||
def run_train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
model: Model | None = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
@ -57,10 +58,19 @@ def run_train(
|
||||
audio_config = audio_config or AudioConfig()
|
||||
train_config = train_config or TrainingConfig()
|
||||
|
||||
if model is not None:
|
||||
_validate_model_compatibility(model=model, model_config=model_config)
|
||||
|
||||
if model is not None:
|
||||
targets = targets or model.targets
|
||||
|
||||
targets = targets or build_targets(config=model_config.targets)
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
if model is not None:
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=model_config.preprocess,
|
||||
@ -98,6 +108,7 @@ def run_train(
|
||||
module = build_training_module(
|
||||
model_config=model_config,
|
||||
train_config=train_config,
|
||||
model=model,
|
||||
)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
@ -124,6 +135,49 @@ def run_train(
|
||||
return module
|
||||
|
||||
|
||||
def _validate_model_compatibility(
|
||||
model: Model,
|
||||
model_config: ModelConfig,
|
||||
) -> None:
|
||||
reference_model = build_model(config=model_config)
|
||||
|
||||
expected_shapes = {
|
||||
key: tuple(value.shape)
|
||||
for key, value in reference_model.state_dict().items()
|
||||
}
|
||||
actual_shapes = {
|
||||
key: tuple(value.shape) for key, value in model.state_dict().items()
|
||||
}
|
||||
|
||||
expected_keys = set(expected_shapes)
|
||||
actual_keys = set(actual_shapes)
|
||||
|
||||
missing_keys = sorted(expected_keys - actual_keys)
|
||||
if missing_keys:
|
||||
key = missing_keys[0]
|
||||
raise ValueError(
|
||||
"Provided model is incompatible with model_config: "
|
||||
f"missing state key '{key}'."
|
||||
)
|
||||
|
||||
extra_keys = sorted(actual_keys - expected_keys)
|
||||
if extra_keys:
|
||||
key = extra_keys[0]
|
||||
raise ValueError(
|
||||
"Provided model is incompatible with model_config: "
|
||||
f"unexpected state key '{key}'."
|
||||
)
|
||||
|
||||
for key, expected_shape in expected_shapes.items():
|
||||
actual_shape = actual_shapes[key]
|
||||
if actual_shape != expected_shape:
|
||||
raise ValueError(
|
||||
"Provided model is incompatible with model_config: "
|
||||
f"shape mismatch for '{key}' (expected {expected_shape}, "
|
||||
f"got {actual_shape})."
|
||||
)
|
||||
|
||||
|
||||
def build_trainer(
|
||||
config: TrainingConfig,
|
||||
evaluator: "EvaluatorProtocol",
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
|
||||
import lightning as L
|
||||
import pytest
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
from soundevent import data
|
||||
@ -10,7 +11,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.models import ModelConfig, build_model
|
||||
from batdetect2.targets.classes import TargetClassConfig
|
||||
from batdetect2.train import (
|
||||
TrainingConfig,
|
||||
TrainingModule,
|
||||
@ -223,3 +225,40 @@ def test_train_smoke_produces_loadable_checkpoint(
|
||||
).unsqueeze(0)
|
||||
outputs = model(wav.unsqueeze(0))
|
||||
assert outputs is not None
|
||||
|
||||
|
||||
def test_build_training_module_uses_provided_model() -> None:
|
||||
model = build_model(ModelConfig())
|
||||
|
||||
module = build_training_module(
|
||||
model_config=ModelConfig(),
|
||||
train_config=TrainingConfig(),
|
||||
model=model,
|
||||
)
|
||||
|
||||
assert module.model is model
|
||||
|
||||
|
||||
def test_run_train_rejects_incompatible_model_config(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
model = build_model(ModelConfig())
|
||||
incompatible_config = ModelConfig()
|
||||
incompatible_config.targets.classification_targets.append(
|
||||
TargetClassConfig(
|
||||
name="dummy_class",
|
||||
tags=[data.Tag(key="class", value="Dummy class")],
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Provided model is incompatible with model_config",
|
||||
):
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=None,
|
||||
model=model,
|
||||
model_config=incompatible_config,
|
||||
train_config=TrainingConfig(),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user