Ensure config is source of truth

This commit is contained in:
mbsantiago 2026-03-18 18:33:51 +00:00
parent ebe7e134e9
commit f9056eb19a
3 changed files with 157 additions and 26 deletions

View File

@ -6,24 +6,35 @@ import torch
from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import AudioLoader, build_audio_loader
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import (
DEFAULT_EVAL_DIR,
EvaluationConfig,
EvaluatorProtocol,
build_evaluator,
run_evaluate,
save_evaluation_results,
)
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.inference import (
InferenceConfig,
process_file_list,
run_batch_inference,
)
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model, build_model_with_new_targets
from batdetect2.models import (
Model,
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.models.detectors import Detector
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
OutputsConfig,
OutputTransformProtocol,
build_output_formatter,
build_output_transform,
@ -39,6 +50,7 @@ from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import TargetProtocol, build_targets
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
load_model_from_checkpoint,
run_train,
)
@ -47,7 +59,12 @@ from batdetect2.train import (
class BatDetect2API:
def __init__(
self,
config: BatDetect2Config,
model_config: ModelConfig,
audio_config: AudioConfig,
train_config: TrainingConfig,
evaluation_config: EvaluationConfig,
inference_config: InferenceConfig,
outputs_config: OutputsConfig,
targets: TargetProtocol,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
@ -57,7 +74,12 @@ class BatDetect2API:
output_transform: OutputTransformProtocol,
model: Model,
):
self.config = config
self.model_config = model_config
self.audio_config = audio_config
self.train_config = train_config
self.evaluation_config = evaluation_config
self.inference_config = inference_config
self.outputs_config = outputs_config
self.targets = targets
self.audio_loader = audio_loader
self.preprocessor = preprocessor
@ -88,15 +110,15 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
):
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
model_config=self.config.model,
train_config=self.config.train,
audio_config=self.config.audio,
model_config=self.model_config,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
train_workers=train_workers,
@ -107,6 +129,8 @@ class BatDetect2API:
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config,
)
return self
@ -125,6 +149,8 @@ class BatDetect2API:
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
@ -135,8 +161,7 @@ class BatDetect2API:
val_annotations=val_annotations,
model=self.model,
targets=self.targets,
model_config=self.config.model,
train_config=self.config.train,
model_config=self.model_config,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
train_workers=train_workers,
@ -147,6 +172,8 @@ class BatDetect2API:
num_epochs=num_epochs,
run_name=run_name,
seed=seed,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
)
return self
@ -158,6 +185,9 @@ class BatDetect2API:
experiment_name: str | None = None,
run_name: str | None = None,
save_predictions: bool = True,
audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
outputs_config: OutputsConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]:
return run_evaluate(
self.model,
@ -165,9 +195,9 @@ class BatDetect2API:
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
audio_config=self.config.audio,
evaluation_config=self.config.evaluation,
output_config=self.config.outputs,
audio_config=audio_config or self.audio_config,
evaluation_config=evaluation_config or self.evaluation_config,
output_config=outputs_config or self.outputs_config,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
@ -256,12 +286,20 @@ class BatDetect2API:
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
def process_file(self, audio_file: data.PathLike) -> ClipDetections:
def process_file(
self,
audio_file: data.PathLike,
batch_size: int | None = None,
) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False)
predictions = self.process_files(
[audio_file],
batch_size=self.config.inference.loader.batch_size,
batch_size=(
batch_size
if batch_size is not None
else self.inference_config.loader.batch_size
),
)
detections = [
detection
@ -319,19 +357,22 @@ class BatDetect2API:
audio_files: Sequence[data.PathLike],
batch_size: int | None = None,
num_workers: int = 0,
audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]:
return process_file_list(
self.model,
audio_files,
targets=self.targets,
audio_loader=self.audio_loader,
audio_config=self.config.audio,
preprocessor=self.preprocessor,
inference_config=self.config.inference,
output_config=self.config.outputs,
output_transform=self.output_transform,
batch_size=batch_size,
num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
)
def process_clips(
@ -339,19 +380,22 @@ class BatDetect2API:
clips: Sequence[data.Clip],
batch_size: int | None = None,
num_workers: int = 0,
audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None,
) -> list[ClipDetections]:
return run_batch_inference(
self.model,
clips,
targets=self.targets,
audio_loader=self.audio_loader,
audio_config=self.config.audio,
preprocessor=self.preprocessor,
inference_config=self.config.inference,
output_config=self.config.outputs,
output_transform=self.output_transform,
batch_size=batch_size,
num_workers=num_workers,
audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config,
)
def save_predictions(
@ -435,7 +479,12 @@ class BatDetect2API:
)
return cls(
config=config,
model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
@ -514,7 +563,12 @@ class BatDetect2API:
model.targets = targets
return cls(
config=config,
model_config=config.model,
audio_config=config.audio,
train_config=config.train,
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,

View File

@ -1,4 +1,4 @@
from typing import Annotated, Optional, Sequence
from typing import Annotated, Sequence
from pydantic import Field
from soundevent import data
@ -44,7 +44,7 @@ def build_task(
def evaluate_task(
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[ClipDetections],
task: Optional["str"] = None,
task: str | None = None,
targets: TargetProtocol | None = None,
config: TaskConfig | dict | None = None,
):

View File

@ -8,10 +8,12 @@ import torch
from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config
from batdetect2.inference import InferenceConfig
from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead
from batdetect2.train import load_model_from_checkpoint
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module
@ -328,3 +330,78 @@ def test_user_can_save_evaluation_results_to_disk(
assert isinstance(metrics, dict)
assert (tmp_path / "metrics.json").exists()
def test_process_file_uses_resolved_batch_size_by_default(
api_v2: BatDetect2API,
example_audio_files: list[Path],
monkeypatch,
) -> None:
"""User story: process_file falls back to resolved inference config."""
captured: dict[str, object] = {}
def fake_process_files(
audio_files,
batch_size=None,
**kwargs,
):
captured["audio_files"] = audio_files
captured["batch_size"] = batch_size
captured["kwargs"] = kwargs
return []
monkeypatch.setattr(api_v2, "process_files", fake_process_files)
api_v2.process_file(example_audio_files[0])
assert captured["audio_files"] == [example_audio_files[0]]
assert captured["batch_size"] == api_v2.inference_config.loader.batch_size
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
"""User story: call-level overrides do not mutate resolved defaults."""
api = BatDetect2API.from_config(BatDetect2Config())
override_inference = InferenceConfig.model_validate(
{"loader": {"batch_size": 7}}
)
override_audio = AudioConfig.model_validate({"samplerate": 384000})
override_train = TrainingConfig.model_validate(
{"trainer": {"max_epochs": 2}}
)
captured_process: dict[str, object] = {}
captured_train: dict[str, object] = {}
def fake_process_file_list(*args, **kwargs):
captured_process["inference_config"] = kwargs["inference_config"]
captured_process["audio_config"] = kwargs["audio_config"]
return []
def fake_run_train(*args, **kwargs):
captured_train["train_config"] = kwargs["train_config"]
captured_train["audio_config"] = kwargs["audio_config"]
captured_train["model_config"] = kwargs["model_config"]
return None
monkeypatch.setattr(
"batdetect2.api_v2.process_file_list", fake_process_file_list
)
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
api.process_files(
[], inference_config=override_inference, audio_config=override_audio
)
api.train([], train_config=override_train, audio_config=override_audio)
assert captured_process["inference_config"] is override_inference
assert captured_process["audio_config"] is override_audio
assert captured_train["train_config"] is override_train
assert captured_train["audio_config"] is override_audio
assert captured_train["model_config"] is api.model_config
assert api.inference_config.loader.batch_size != 7
assert api.audio_config.samplerate != 384000
assert api.train_config.trainer.max_epochs != 2