diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 43c2c3f..10e5169 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -6,24 +6,35 @@ import torch from soundevent import data from soundevent.audio.files import get_audio_files -from batdetect2.audio import AudioLoader, build_audio_loader +from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader from batdetect2.config import BatDetect2Config from batdetect2.core import merge_configs from batdetect2.data import Dataset, load_dataset_from_config from batdetect2.evaluate import ( DEFAULT_EVAL_DIR, + EvaluationConfig, EvaluatorProtocol, build_evaluator, run_evaluate, save_evaluation_results, ) -from batdetect2.inference import process_file_list, run_batch_inference +from batdetect2.inference import ( + InferenceConfig, + process_file_list, + run_batch_inference, +) from batdetect2.logging import DEFAULT_LOGS_DIR -from batdetect2.models import Model, build_model, build_model_with_new_targets +from batdetect2.models import ( + Model, + ModelConfig, + build_model, + build_model_with_new_targets, +) from batdetect2.models.detectors import Detector from batdetect2.outputs import ( OutputFormatConfig, OutputFormatterProtocol, + OutputsConfig, OutputTransformProtocol, build_output_formatter, build_output_transform, @@ -39,6 +50,7 @@ from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.targets import TargetProtocol, build_targets from batdetect2.train import ( DEFAULT_CHECKPOINT_DIR, + TrainingConfig, load_model_from_checkpoint, run_train, ) @@ -47,7 +59,12 @@ from batdetect2.train import ( class BatDetect2API: def __init__( self, - config: BatDetect2Config, + model_config: ModelConfig, + audio_config: AudioConfig, + train_config: TrainingConfig, + evaluation_config: EvaluationConfig, + inference_config: InferenceConfig, + outputs_config: OutputsConfig, targets: TargetProtocol, audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, @@ -57,7 +74,12 @@ class BatDetect2API: output_transform: OutputTransformProtocol, model: Model, ): - self.config = config + self.model_config = model_config + self.audio_config = audio_config + self.train_config = train_config + self.evaluation_config = evaluation_config + self.inference_config = inference_config + self.outputs_config = outputs_config self.targets = targets self.audio_loader = audio_loader self.preprocessor = preprocessor @@ -88,15 +110,15 @@ class BatDetect2API: num_epochs: int | None = None, run_name: str | None = None, seed: int | None = None, + audio_config: AudioConfig | None = None, + train_config: TrainingConfig | None = None, ): run_train( train_annotations=train_annotations, val_annotations=val_annotations, model=self.model, targets=self.targets, - model_config=self.config.model, - train_config=self.config.train, - audio_config=self.config.audio, + model_config=self.model_config, audio_loader=self.audio_loader, preprocessor=self.preprocessor, train_workers=train_workers, @@ -107,6 +129,8 @@ class BatDetect2API: experiment_name=experiment_name, run_name=run_name, seed=seed, + train_config=train_config or self.train_config, + audio_config=audio_config or self.audio_config, ) return self @@ -125,6 +149,8 @@ class BatDetect2API: num_epochs: int | None = None, run_name: str | None = None, seed: int | None = None, + audio_config: AudioConfig | None = None, + train_config: TrainingConfig | None = None, ) -> "BatDetect2API": """Fine-tune the model with trainable-parameter selection.""" @@ -135,8 +161,7 @@ class BatDetect2API: val_annotations=val_annotations, model=self.model, targets=self.targets, - model_config=self.config.model, - train_config=self.config.train, + model_config=self.model_config, preprocessor=self.preprocessor, audio_loader=self.audio_loader, train_workers=train_workers, @@ -147,6 +172,8 @@ class BatDetect2API: num_epochs=num_epochs, run_name=run_name, seed=seed, + audio_config=audio_config or self.audio_config, + train_config=train_config or self.train_config, ) return self @@ -158,6 +185,9 @@ class BatDetect2API: experiment_name: str | None = None, run_name: str | None = None, save_predictions: bool = True, + audio_config: AudioConfig | None = None, + evaluation_config: EvaluationConfig | None = None, + outputs_config: OutputsConfig | None = None, ) -> tuple[dict[str, float], list[ClipDetections]]: return run_evaluate( self.model, @@ -165,9 +195,9 @@ class BatDetect2API: targets=self.targets, audio_loader=self.audio_loader, preprocessor=self.preprocessor, - audio_config=self.config.audio, - evaluation_config=self.config.evaluation, - output_config=self.config.outputs, + audio_config=audio_config or self.audio_config, + evaluation_config=evaluation_config or self.evaluation_config, + output_config=outputs_config or self.outputs_config, num_workers=num_workers, output_dir=output_dir, experiment_name=experiment_name, @@ -256,12 +286,20 @@ class BatDetect2API: tensor = torch.tensor(audio).unsqueeze(0) return self.preprocessor(tensor) - def process_file(self, audio_file: data.PathLike) -> ClipDetections: + def process_file( + self, + audio_file: data.PathLike, + batch_size: int | None = None, + ) -> ClipDetections: recording = data.Recording.from_file(audio_file, compute_hash=False) predictions = self.process_files( [audio_file], - batch_size=self.config.inference.loader.batch_size, + batch_size=( + batch_size + if batch_size is not None + else self.inference_config.loader.batch_size + ), ) detections = [ detection @@ -319,19 +357,22 @@ class BatDetect2API: audio_files: Sequence[data.PathLike], batch_size: int | None = None, num_workers: int = 0, + audio_config: AudioConfig | None = None, + inference_config: InferenceConfig | None = None, + output_config: OutputsConfig | None = None, ) -> list[ClipDetections]: return process_file_list( self.model, audio_files, targets=self.targets, audio_loader=self.audio_loader, - audio_config=self.config.audio, preprocessor=self.preprocessor, - inference_config=self.config.inference, - output_config=self.config.outputs, output_transform=self.output_transform, batch_size=batch_size, num_workers=num_workers, + audio_config=audio_config or self.audio_config, + inference_config=inference_config or self.inference_config, + output_config=output_config or self.outputs_config, ) def process_clips( @@ -339,19 +380,22 @@ class BatDetect2API: clips: Sequence[data.Clip], batch_size: int | None = None, num_workers: int = 0, + audio_config: AudioConfig | None = None, + inference_config: InferenceConfig | None = None, + output_config: OutputsConfig | None = None, ) -> list[ClipDetections]: return run_batch_inference( self.model, clips, targets=self.targets, audio_loader=self.audio_loader, - audio_config=self.config.audio, preprocessor=self.preprocessor, - inference_config=self.config.inference, - output_config=self.config.outputs, output_transform=self.output_transform, batch_size=batch_size, num_workers=num_workers, + audio_config=audio_config or self.audio_config, + inference_config=inference_config or self.inference_config, + output_config=output_config or self.outputs_config, ) def save_predictions( @@ -435,7 +479,12 @@ class BatDetect2API: ) return cls( - config=config, + model_config=config.model, + audio_config=config.audio, + train_config=config.train, + evaluation_config=config.evaluation, + inference_config=config.inference, + outputs_config=config.outputs, targets=targets, audio_loader=audio_loader, preprocessor=preprocessor, @@ -514,7 +563,12 @@ class BatDetect2API: model.targets = targets return cls( - config=config, + model_config=config.model, + audio_config=config.audio, + train_config=config.train, + evaluation_config=config.evaluation, + inference_config=config.inference, + outputs_config=config.outputs, targets=targets, audio_loader=audio_loader, preprocessor=preprocessor, diff --git a/src/batdetect2/evaluate/tasks/__init__.py b/src/batdetect2/evaluate/tasks/__init__.py index c35cdf6..08a173f 100644 --- a/src/batdetect2/evaluate/tasks/__init__.py +++ b/src/batdetect2/evaluate/tasks/__init__.py @@ -1,4 +1,4 @@ -from typing import Annotated, Optional, Sequence +from typing import Annotated, Sequence from pydantic import Field from soundevent import data @@ -44,7 +44,7 @@ def build_task( def evaluate_task( clip_annotations: Sequence[data.ClipAnnotation], predictions: Sequence[ClipDetections], - task: Optional["str"] = None, + task: str | None = None, targets: TargetProtocol | None = None, config: TaskConfig | dict | None = None, ): diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index e4c2d25..4b3f71e 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -8,10 +8,12 @@ import torch from soundevent.geometry import compute_bounds from batdetect2.api_v2 import BatDetect2API +from batdetect2.audio import AudioConfig from batdetect2.config import BatDetect2Config +from batdetect2.inference import InferenceConfig from batdetect2.models.detectors import Detector from batdetect2.models.heads import ClassifierHead -from batdetect2.train import load_model_from_checkpoint +from batdetect2.train import TrainingConfig, load_model_from_checkpoint from batdetect2.train.lightning import build_training_module @@ -328,3 +330,78 @@ def test_user_can_save_evaluation_results_to_disk( assert isinstance(metrics, dict) assert (tmp_path / "metrics.json").exists() + + +def test_process_file_uses_resolved_batch_size_by_default( + api_v2: BatDetect2API, + example_audio_files: list[Path], + monkeypatch, +) -> None: + """User story: process_file falls back to resolved inference config.""" + + captured: dict[str, object] = {} + + def fake_process_files( + audio_files, + batch_size=None, + **kwargs, + ): + captured["audio_files"] = audio_files + captured["batch_size"] = batch_size + captured["kwargs"] = kwargs + return [] + + monkeypatch.setattr(api_v2, "process_files", fake_process_files) + + api_v2.process_file(example_audio_files[0]) + + assert captured["audio_files"] == [example_audio_files[0]] + assert captured["batch_size"] == api_v2.inference_config.loader.batch_size + + +def test_per_call_overrides_are_ephemeral(monkeypatch) -> None: + """User story: call-level overrides do not mutate resolved defaults.""" + + api = BatDetect2API.from_config(BatDetect2Config()) + + override_inference = InferenceConfig.model_validate( + {"loader": {"batch_size": 7}} + ) + override_audio = AudioConfig.model_validate({"samplerate": 384000}) + override_train = TrainingConfig.model_validate( + {"trainer": {"max_epochs": 2}} + ) + + captured_process: dict[str, object] = {} + captured_train: dict[str, object] = {} + + def fake_process_file_list(*args, **kwargs): + captured_process["inference_config"] = kwargs["inference_config"] + captured_process["audio_config"] = kwargs["audio_config"] + return [] + + def fake_run_train(*args, **kwargs): + captured_train["train_config"] = kwargs["train_config"] + captured_train["audio_config"] = kwargs["audio_config"] + captured_train["model_config"] = kwargs["model_config"] + return None + + monkeypatch.setattr( + "batdetect2.api_v2.process_file_list", fake_process_file_list + ) + monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train) + + api.process_files( + [], inference_config=override_inference, audio_config=override_audio + ) + api.train([], train_config=override_train, audio_config=override_audio) + + assert captured_process["inference_config"] is override_inference + assert captured_process["audio_config"] is override_audio + assert captured_train["train_config"] is override_train + assert captured_train["audio_config"] is override_audio + assert captured_train["model_config"] is api.model_config + + assert api.inference_config.loader.batch_size != 7 + assert api.audio_config.samplerate != 384000 + assert api.train_config.trainer.max_epochs != 2