mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Ensure config is source of truth
This commit is contained in:
parent
ebe7e134e9
commit
f9056eb19a
@ -6,24 +6,35 @@ import torch
|
||||
from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import AudioLoader, build_audio_loader
|
||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.data import Dataset, load_dataset_from_config
|
||||
from batdetect2.evaluate import (
|
||||
DEFAULT_EVAL_DIR,
|
||||
EvaluationConfig,
|
||||
EvaluatorProtocol,
|
||||
build_evaluator,
|
||||
run_evaluate,
|
||||
save_evaluation_results,
|
||||
)
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.inference import (
|
||||
InferenceConfig,
|
||||
process_file_list,
|
||||
run_batch_inference,
|
||||
)
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model, build_model_with_new_targets
|
||||
from batdetect2.models import (
|
||||
Model,
|
||||
ModelConfig,
|
||||
build_model,
|
||||
build_model_with_new_targets,
|
||||
)
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
@ -39,6 +50,7 @@ from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import TargetProtocol, build_targets
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
TrainingConfig,
|
||||
load_model_from_checkpoint,
|
||||
run_train,
|
||||
)
|
||||
@ -47,7 +59,12 @@ from batdetect2.train import (
|
||||
class BatDetect2API:
|
||||
def __init__(
|
||||
self,
|
||||
config: BatDetect2Config,
|
||||
model_config: ModelConfig,
|
||||
audio_config: AudioConfig,
|
||||
train_config: TrainingConfig,
|
||||
evaluation_config: EvaluationConfig,
|
||||
inference_config: InferenceConfig,
|
||||
outputs_config: OutputsConfig,
|
||||
targets: TargetProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
@ -57,7 +74,12 @@ class BatDetect2API:
|
||||
output_transform: OutputTransformProtocol,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
self.model_config = model_config
|
||||
self.audio_config = audio_config
|
||||
self.train_config = train_config
|
||||
self.evaluation_config = evaluation_config
|
||||
self.inference_config = inference_config
|
||||
self.outputs_config = outputs_config
|
||||
self.targets = targets
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
@ -88,15 +110,15 @@ class BatDetect2API:
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
):
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
model_config=self.config.model,
|
||||
train_config=self.config.train,
|
||||
audio_config=self.config.audio,
|
||||
model_config=self.model_config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
@ -107,6 +129,8 @@ class BatDetect2API:
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
train_config=train_config or self.train_config,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
)
|
||||
return self
|
||||
|
||||
@ -125,6 +149,8 @@ class BatDetect2API:
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune the model with trainable-parameter selection."""
|
||||
|
||||
@ -135,8 +161,7 @@ class BatDetect2API:
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
model_config=self.config.model,
|
||||
train_config=self.config.train,
|
||||
model_config=self.model_config,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
train_workers=train_workers,
|
||||
@ -147,6 +172,8 @@ class BatDetect2API:
|
||||
num_epochs=num_epochs,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
train_config=train_config or self.train_config,
|
||||
)
|
||||
return self
|
||||
|
||||
@ -158,6 +185,9 @@ class BatDetect2API:
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
save_predictions: bool = True,
|
||||
audio_config: AudioConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
return run_evaluate(
|
||||
self.model,
|
||||
@ -165,9 +195,9 @@ class BatDetect2API:
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_config=self.config.audio,
|
||||
evaluation_config=self.config.evaluation,
|
||||
output_config=self.config.outputs,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
evaluation_config=evaluation_config or self.evaluation_config,
|
||||
output_config=outputs_config or self.outputs_config,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
@ -256,12 +286,20 @@ class BatDetect2API:
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
def process_file(self, audio_file: data.PathLike) -> ClipDetections:
|
||||
def process_file(
|
||||
self,
|
||||
audio_file: data.PathLike,
|
||||
batch_size: int | None = None,
|
||||
) -> ClipDetections:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
|
||||
predictions = self.process_files(
|
||||
[audio_file],
|
||||
batch_size=self.config.inference.loader.batch_size,
|
||||
batch_size=(
|
||||
batch_size
|
||||
if batch_size is not None
|
||||
else self.inference_config.loader.batch_size
|
||||
),
|
||||
)
|
||||
detections = [
|
||||
detection
|
||||
@ -319,19 +357,22 @@ class BatDetect2API:
|
||||
audio_files: Sequence[data.PathLike],
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
audio_config: AudioConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
audio_config=self.config.audio,
|
||||
preprocessor=self.preprocessor,
|
||||
inference_config=self.config.inference,
|
||||
output_config=self.config.outputs,
|
||||
output_transform=self.output_transform,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
inference_config=inference_config or self.inference_config,
|
||||
output_config=output_config or self.outputs_config,
|
||||
)
|
||||
|
||||
def process_clips(
|
||||
@ -339,19 +380,22 @@ class BatDetect2API:
|
||||
clips: Sequence[data.Clip],
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
audio_config: AudioConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
audio_config=self.config.audio,
|
||||
preprocessor=self.preprocessor,
|
||||
inference_config=self.config.inference,
|
||||
output_config=self.config.outputs,
|
||||
output_transform=self.output_transform,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
inference_config=inference_config or self.inference_config,
|
||||
output_config=output_config or self.outputs_config,
|
||||
)
|
||||
|
||||
def save_predictions(
|
||||
@ -435,7 +479,12 @@ class BatDetect2API:
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config.train,
|
||||
evaluation_config=config.evaluation,
|
||||
inference_config=config.inference,
|
||||
outputs_config=config.outputs,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
@ -514,7 +563,12 @@ class BatDetect2API:
|
||||
model.targets = targets
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config.train,
|
||||
evaluation_config=config.evaluation,
|
||||
inference_config=config.inference,
|
||||
outputs_config=config.outputs,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Optional, Sequence
|
||||
from typing import Annotated, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -44,7 +44,7 @@ def build_task(
|
||||
def evaluate_task(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[ClipDetections],
|
||||
task: Optional["str"] = None,
|
||||
task: str | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
config: TaskConfig | dict | None = None,
|
||||
):
|
||||
|
||||
@ -8,10 +8,12 @@ import torch
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.models.heads import ClassifierHead
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
|
||||
@ -328,3 +330,78 @@ def test_user_can_save_evaluation_results_to_disk(
|
||||
|
||||
assert isinstance(metrics, dict)
|
||||
assert (tmp_path / "metrics.json").exists()
|
||||
|
||||
|
||||
def test_process_file_uses_resolved_batch_size_by_default(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
monkeypatch,
|
||||
) -> None:
|
||||
"""User story: process_file falls back to resolved inference config."""
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
def fake_process_files(
|
||||
audio_files,
|
||||
batch_size=None,
|
||||
**kwargs,
|
||||
):
|
||||
captured["audio_files"] = audio_files
|
||||
captured["batch_size"] = batch_size
|
||||
captured["kwargs"] = kwargs
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(api_v2, "process_files", fake_process_files)
|
||||
|
||||
api_v2.process_file(example_audio_files[0])
|
||||
|
||||
assert captured["audio_files"] == [example_audio_files[0]]
|
||||
assert captured["batch_size"] == api_v2.inference_config.loader.batch_size
|
||||
|
||||
|
||||
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
||||
"""User story: call-level overrides do not mutate resolved defaults."""
|
||||
|
||||
api = BatDetect2API.from_config(BatDetect2Config())
|
||||
|
||||
override_inference = InferenceConfig.model_validate(
|
||||
{"loader": {"batch_size": 7}}
|
||||
)
|
||||
override_audio = AudioConfig.model_validate({"samplerate": 384000})
|
||||
override_train = TrainingConfig.model_validate(
|
||||
{"trainer": {"max_epochs": 2}}
|
||||
)
|
||||
|
||||
captured_process: dict[str, object] = {}
|
||||
captured_train: dict[str, object] = {}
|
||||
|
||||
def fake_process_file_list(*args, **kwargs):
|
||||
captured_process["inference_config"] = kwargs["inference_config"]
|
||||
captured_process["audio_config"] = kwargs["audio_config"]
|
||||
return []
|
||||
|
||||
def fake_run_train(*args, **kwargs):
|
||||
captured_train["train_config"] = kwargs["train_config"]
|
||||
captured_train["audio_config"] = kwargs["audio_config"]
|
||||
captured_train["model_config"] = kwargs["model_config"]
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.api_v2.process_file_list", fake_process_file_list
|
||||
)
|
||||
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
|
||||
|
||||
api.process_files(
|
||||
[], inference_config=override_inference, audio_config=override_audio
|
||||
)
|
||||
api.train([], train_config=override_train, audio_config=override_audio)
|
||||
|
||||
assert captured_process["inference_config"] is override_inference
|
||||
assert captured_process["audio_config"] is override_audio
|
||||
assert captured_train["train_config"] is override_train
|
||||
assert captured_train["audio_config"] is override_audio
|
||||
assert captured_train["model_config"] is api.model_config
|
||||
|
||||
assert api.inference_config.loader.batch_size != 7
|
||||
assert api.audio_config.samplerate != 384000
|
||||
assert api.train_config.trainer.max_epochs != 2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user