Ensure config is source of truth

This commit is contained in:
mbsantiago 2026-03-18 18:33:51 +00:00
parent ebe7e134e9
commit f9056eb19a
3 changed files with 157 additions and 26 deletions

View File

@ -6,24 +6,35 @@ import torch
from soundevent import data from soundevent import data
from soundevent.audio.files import get_audio_files from soundevent.audio.files import get_audio_files
from batdetect2.audio import AudioLoader, build_audio_loader from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs from batdetect2.core import merge_configs
from batdetect2.data import Dataset, load_dataset_from_config from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import ( from batdetect2.evaluate import (
DEFAULT_EVAL_DIR, DEFAULT_EVAL_DIR,
EvaluationConfig,
EvaluatorProtocol, EvaluatorProtocol,
build_evaluator, build_evaluator,
run_evaluate, run_evaluate,
save_evaluation_results, save_evaluation_results,
) )
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import (
InferenceConfig,
process_file_list,
run_batch_inference,
)
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model, build_model_with_new_targets from batdetect2.models import (
Model,
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.models.detectors import Detector from batdetect2.models.detectors import Detector
from batdetect2.outputs import ( from batdetect2.outputs import (
OutputFormatConfig, OutputFormatConfig,
OutputFormatterProtocol, OutputFormatterProtocol,
OutputsConfig,
OutputTransformProtocol, OutputTransformProtocol,
build_output_formatter, build_output_formatter,
build_output_transform, build_output_transform,
@ -39,6 +50,7 @@ from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetProtocol, build_targets from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.train import ( from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR, DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
load_model_from_checkpoint, load_model_from_checkpoint,
run_train, run_train,
) )
@ -47,7 +59,12 @@ from batdetect2.train import (
class BatDetect2API: class BatDetect2API:
def __init__( def __init__(
self, self,
config: BatDetect2Config, model_config: ModelConfig,
audio_config: AudioConfig,
train_config: TrainingConfig,
evaluation_config: EvaluationConfig,
inference_config: InferenceConfig,
outputs_config: OutputsConfig,
targets: TargetProtocol, targets: TargetProtocol,
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
@ -57,7 +74,12 @@ class BatDetect2API:
output_transform: OutputTransformProtocol, output_transform: OutputTransformProtocol,
model: Model, model: Model,
): ):
self.config = config self.model_config = model_config
self.audio_config = audio_config
self.train_config = train_config
self.evaluation_config = evaluation_config
self.inference_config = inference_config
self.outputs_config = outputs_config
self.targets = targets self.targets = targets
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.preprocessor = preprocessor self.preprocessor = preprocessor
@ -88,15 +110,15 @@ class BatDetect2API:
num_epochs: int | None = None, num_epochs: int | None = None,
run_name: str | None = None, run_name: str | None = None,
seed: int | None = None, seed: int | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
): ):
run_train( run_train(
train_annotations=train_annotations, train_annotations=train_annotations,
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model, model=self.model,
targets=self.targets, targets=self.targets,
model_config=self.config.model, model_config=self.model_config,
train_config=self.config.train,
audio_config=self.config.audio,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
train_workers=train_workers, train_workers=train_workers,
@ -107,6 +129,8 @@ class BatDetect2API:
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,
train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config,
) )
return self return self
@ -125,6 +149,8 @@ class BatDetect2API:
num_epochs: int | None = None, num_epochs: int | None = None,
run_name: str | None = None, run_name: str | None = None,
seed: int | None = None, seed: int | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection.""" """Fine-tune the model with trainable-parameter selection."""
@ -135,8 +161,7 @@ class BatDetect2API:
val_annotations=val_annotations, val_annotations=val_annotations,
model=self.model, model=self.model,
targets=self.targets, targets=self.targets,
model_config=self.config.model, model_config=self.model_config,
train_config=self.config.train,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
train_workers=train_workers, train_workers=train_workers,
@ -147,6 +172,8 @@ class BatDetect2API:
num_epochs=num_epochs, num_epochs=num_epochs,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
) )
return self return self
@ -158,6 +185,9 @@ class BatDetect2API:
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
save_predictions: bool = True, save_predictions: bool = True,
audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
outputs_config: OutputsConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]: ) -> tuple[dict[str, float], list[ClipDetections]]:
return run_evaluate( return run_evaluate(
self.model, self.model,
@ -165,9 +195,9 @@ class BatDetect2API:
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
audio_config=self.config.audio, audio_config=audio_config or self.audio_config,
evaluation_config=self.config.evaluation, evaluation_config=evaluation_config or self.evaluation_config,
output_config=self.config.outputs, output_config=outputs_config or self.outputs_config,
num_workers=num_workers, num_workers=num_workers,
output_dir=output_dir, output_dir=output_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
@ -256,12 +286,20 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0) tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor) return self.preprocessor(tensor)
def process_file(self, audio_file: data.PathLike) -> ClipDetections: def process_file(
self,
audio_file: data.PathLike,
batch_size: int | None = None,
) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False) recording = data.Recording.from_file(audio_file, compute_hash=False)
predictions = self.process_files( predictions = self.process_files(
[audio_file], [audio_file],
batch_size=self.config.inference.loader.batch_size, batch_size=(
batch_size
if batch_size is not None
else self.inference_config.loader.batch_size
),
) )
detections = [ detections = [
detection detection
@ -319,19 +357,22 @@ class BatDetect2API:
audio_files: Sequence[data.PathLike], audio_files: Sequence[data.PathLike],
batch_size: int | None = None, batch_size: int | None = None,
num_workers: int = 0, num_workers: int = 0,
audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
return process_file_list( return process_file_list(
self.model, self.model,
audio_files, audio_files,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
audio_config=self.config.audio,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
inference_config=self.config.inference,
output_config=self.config.outputs,
output_transform=self.output_transform, output_transform=self.output_transform,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
) )
def process_clips( def process_clips(
@ -339,19 +380,22 @@ class BatDetect2API:
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
batch_size: int | None = None, batch_size: int | None = None,
num_workers: int = 0, num_workers: int = 0,
audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
return run_batch_inference( return run_batch_inference(
self.model, self.model,
clips, clips,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
audio_config=self.config.audio,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
inference_config=self.config.inference,
output_config=self.config.outputs,
output_transform=self.output_transform, output_transform=self.output_transform,
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
) )
def save_predictions( def save_predictions(
@ -435,7 +479,12 @@ class BatDetect2API:
) )
return cls( return cls(
config=config, model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
@ -514,7 +563,12 @@ class BatDetect2API:
model.targets = targets model.targets = targets
return cls( return cls(
config=config, model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Sequence from typing import Annotated, Sequence
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -44,7 +44,7 @@ def build_task(
def evaluate_task( def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections], predictions: Sequence[ClipDetections],
task: Optional["str"] = None, task: str | None = None,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
config: TaskConfig | dict | None = None, config: TaskConfig | dict | None = None,
): ):

View File

@ -8,10 +8,12 @@ import torch
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.inference import InferenceConfig
from batdetect2.models.detectors import Detector from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead from batdetect2.models.heads import ClassifierHead
from batdetect2.train import load_model_from_checkpoint from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
@ -328,3 +330,78 @@ def test_user_can_save_evaluation_results_to_disk(
assert isinstance(metrics, dict) assert isinstance(metrics, dict)
assert (tmp_path / "metrics.json").exists() assert (tmp_path / "metrics.json").exists()
def test_process_file_uses_resolved_batch_size_by_default(
api_v2: BatDetect2API,
example_audio_files: list[Path],
monkeypatch,
) -> None:
"""User story: process_file falls back to resolved inference config."""
captured: dict[str, object] = {}
def fake_process_files(
audio_files,
batch_size=None,
**kwargs,
):
captured["audio_files"] = audio_files
captured["batch_size"] = batch_size
captured["kwargs"] = kwargs
return []
monkeypatch.setattr(api_v2, "process_files", fake_process_files)
api_v2.process_file(example_audio_files[0])
assert captured["audio_files"] == [example_audio_files[0]]
assert captured["batch_size"] == api_v2.inference_config.loader.batch_size
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
"""User story: call-level overrides do not mutate resolved defaults."""
api = BatDetect2API.from_config(BatDetect2Config())
override_inference = InferenceConfig.model_validate(
{"loader": {"batch_size": 7}}
)
override_audio = AudioConfig.model_validate({"samplerate": 384000})
override_train = TrainingConfig.model_validate(
{"trainer": {"max_epochs": 2}}
)
captured_process: dict[str, object] = {}
captured_train: dict[str, object] = {}
def fake_process_file_list(*args, **kwargs):
captured_process["inference_config"] = kwargs["inference_config"]
captured_process["audio_config"] = kwargs["audio_config"]
return []
def fake_run_train(*args, **kwargs):
captured_train["train_config"] = kwargs["train_config"]
captured_train["audio_config"] = kwargs["audio_config"]
captured_train["model_config"] = kwargs["model_config"]
return None
monkeypatch.setattr(
"batdetect2.api_v2.process_file_list", fake_process_file_list
)
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
api.process_files(
[], inference_config=override_inference, audio_config=override_audio
)
api.train([], train_config=override_train, audio_config=override_audio)
assert captured_process["inference_config"] is override_inference
assert captured_process["audio_config"] is override_audio
assert captured_train["train_config"] is override_train
assert captured_train["audio_config"] is override_audio
assert captured_train["model_config"] is api.model_config
assert api.inference_config.loader.batch_size != 7
assert api.audio_config.samplerate != 384000
assert api.train_config.trainer.max_epochs != 2