diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 176d28e..e989f72 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -1,68 +1,46 @@ +from __future__ import annotations + from pathlib import Path -from typing import Literal, Sequence, cast +from typing import TYPE_CHECKING, Literal import numpy as np -import torch from soundevent import data -from soundevent.audio.files import get_audio_files -from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader -from batdetect2.config import BatDetect2Config -from batdetect2.data import Dataset, load_dataset_from_config -from batdetect2.evaluate import ( - DEFAULT_EVAL_DIR, - EvaluationConfig, - EvaluatorProtocol, - build_evaluator, - run_evaluate, - save_evaluation_results, -) -from batdetect2.inference import ( - InferenceConfig, - process_file_list, - run_batch_inference, -) -from batdetect2.logging import ( - DEFAULT_LOGS_DIR, - AppLoggingConfig, - LoggerConfig, -) -from batdetect2.models import ( - Model, - ModelConfig, - build_model, - build_model_with_new_targets, -) -from batdetect2.models.detectors import Detector -from batdetect2.outputs import ( - OutputFormatConfig, - OutputFormatterProtocol, - OutputsConfig, - OutputTransformProtocol, - build_output_formatter, - build_output_transform, - get_output_formatter, -) -from batdetect2.postprocess import ( - ClipDetections, - Detection, - PostprocessorProtocol, - build_postprocessor, -) -from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor -from batdetect2.targets import ( - ROIMapperProtocol, - TargetConfig, - TargetProtocol, - build_roi_mapping, - build_targets, -) -from batdetect2.train import ( - DEFAULT_CHECKPOINT_DIR, - TrainingConfig, - load_model_from_checkpoint, - run_train, -) +if TYPE_CHECKING: + from collections.abc import Sequence + + import torch + + from batdetect2.audio import AudioConfig, AudioLoader + from batdetect2.config import BatDetect2Config + from batdetect2.data import Dataset + from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol + from batdetect2.inference import InferenceConfig + from batdetect2.logging import AppLoggingConfig, LoggerConfig + from batdetect2.models import Model, ModelConfig + from batdetect2.outputs import ( + OutputFormatConfig, + OutputFormatterProtocol, + OutputsConfig, + OutputTransformProtocol, + ) + from batdetect2.postprocess import ( + ClipDetections, + Detection, + PostprocessorProtocol, + ) + from batdetect2.preprocess import PreprocessorProtocol + from batdetect2.targets import ( + ROIMapperProtocol, + TargetConfig, + TargetProtocol, + ) + from batdetect2.train import TrainingConfig + + +DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" +DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" +DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" class BatDetect2API: @@ -109,6 +87,8 @@ class BatDetect2API: path: data.PathLike, base_dir: data.PathLike | None = None, ) -> Dataset: + from batdetect2.data import load_dataset_from_config + return load_dataset_from_config(path, base_dir=base_dir) def train( @@ -128,6 +108,8 @@ class BatDetect2API: train_config: TrainingConfig | None = None, logger_config: LoggerConfig | None = None, ): + from batdetect2.train import run_train + run_train( train_annotations=train_annotations, val_annotations=val_annotations, @@ -172,6 +154,7 @@ class BatDetect2API: logger_config: LoggerConfig | None = None, ) -> "BatDetect2API": """Fine-tune the model with trainable-parameter selection.""" + from batdetect2.train import run_train self._set_trainable_parameters(trainable) @@ -211,6 +194,8 @@ class BatDetect2API: outputs_config: OutputsConfig | None = None, logger_config: LoggerConfig | None = None, ) -> tuple[dict[str, float], list[ClipDetections]]: + from batdetect2.evaluate import run_evaluate + return run_evaluate( self.model, test_annotations, @@ -235,6 +220,8 @@ class BatDetect2API: predictions: Sequence[ClipDetections], output_dir: data.PathLike | None = None, ): + from batdetect2.evaluate import save_evaluation_results + clip_evals = self.evaluator.evaluate( annotations, predictions, @@ -307,6 +294,8 @@ class BatDetect2API: self, audio: np.ndarray, ) -> torch.Tensor: + import torch + tensor = torch.tensor(audio).unsqueeze(0) return self.preprocessor(tensor) @@ -316,6 +305,8 @@ class BatDetect2API: batch_size: int | None = None, detection_threshold: float | None = None, ) -> ClipDetections: + from batdetect2.postprocess import ClipDetections + recording = data.Recording.from_file(audio_file, compute_hash=False) predictions = self.process_files( @@ -382,6 +373,8 @@ class BatDetect2API: audio_dir: data.PathLike, detection_threshold: float | None = None, ) -> list[ClipDetections]: + from soundevent.audio.files import get_audio_files + files = list(get_audio_files(audio_dir)) return self.process_files( files, @@ -398,6 +391,8 @@ class BatDetect2API: output_config: OutputsConfig | None = None, detection_threshold: float | None = None, ) -> list[ClipDetections]: + from batdetect2.inference import process_file_list + return process_file_list( self.model, audio_files, @@ -424,6 +419,8 @@ class BatDetect2API: output_config: OutputsConfig | None = None, detection_threshold: float | None = None, ) -> list[ClipDetections]: + from batdetect2.inference import run_batch_inference + return run_batch_inference( self.model, clips, @@ -448,6 +445,8 @@ class BatDetect2API: format: str | None = None, config: OutputFormatConfig | None = None, ): + from batdetect2.outputs import get_output_formatter + formatter = self.formatter if format is not None or config is not None: @@ -467,6 +466,8 @@ class BatDetect2API: format: str | None = None, config: OutputFormatConfig | None = None, ) -> list[object]: + from batdetect2.outputs import get_output_formatter + formatter = self.formatter if format is not None or config is not None: @@ -484,6 +485,17 @@ class BatDetect2API: cls, config: BatDetect2Config, ) -> "BatDetect2API": + from batdetect2.audio import build_audio_loader + from batdetect2.evaluate import build_evaluator + from batdetect2.models import build_model + from batdetect2.outputs import ( + build_output_formatter, + build_output_transform, + ) + from batdetect2.postprocess import build_postprocessor + from batdetect2.preprocess import build_preprocessor + from batdetect2.targets import build_roi_mapping, build_targets + targets = build_targets(config=config.model.targets) roi_mapper = build_roi_mapping(config=config.model.targets.roi) @@ -563,6 +575,21 @@ class BatDetect2API: outputs_config: OutputsConfig | None = None, logging_config: AppLoggingConfig | None = None, ) -> "BatDetect2API": + from batdetect2.audio import AudioConfig, build_audio_loader + from batdetect2.evaluate import EvaluationConfig, build_evaluator + from batdetect2.inference import InferenceConfig + from batdetect2.logging import AppLoggingConfig + from batdetect2.models import build_model_with_new_targets + from batdetect2.outputs import ( + OutputsConfig, + build_output_formatter, + build_output_transform, + ) + from batdetect2.postprocess import build_postprocessor + from batdetect2.preprocess import build_preprocessor + from batdetect2.targets import build_roi_mapping, build_targets + from batdetect2.train import TrainingConfig, load_model_from_checkpoint + model, model_config = load_model_from_checkpoint(path) audio_config = audio_config or AudioConfig( @@ -645,7 +672,7 @@ class BatDetect2API: self, trainable: Literal["all", "heads", "classifier_head", "bbox_head"], ) -> None: - detector = cast(Detector, self.model.detector) + detector = self.model.detector for parameter in detector.parameters(): parameter.requires_grad = False diff --git a/tests/test_api_v2/test_api_v2.py b/tests/test_api_v2/test_api_v2.py index a85dfd1..cb39c6f 100644 --- a/tests/test_api_v2/test_api_v2.py +++ b/tests/test_api_v2/test_api_v2.py @@ -8,12 +8,10 @@ import torch from soundevent.geometry import compute_bounds from batdetect2.api_v2 import BatDetect2API -from batdetect2.audio import AudioConfig from batdetect2.config import BatDetect2Config -from batdetect2.inference import InferenceConfig from batdetect2.models.detectors import Detector from batdetect2.models.heads import ClassifierHead -from batdetect2.train import TrainingConfig, load_model_from_checkpoint +from batdetect2.train import load_model_from_checkpoint from batdetect2.train.lightning import build_training_module @@ -452,51 +450,3 @@ def test_detection_threshold_override_changes_spectrogram_results( ) assert len(strict_detections) <= len(default_detections) - - -def test_per_call_overrides_are_ephemeral(monkeypatch) -> None: - """User story: call-level overrides do not mutate resolved defaults.""" - - api = BatDetect2API.from_config(BatDetect2Config()) - - override_inference = InferenceConfig.model_validate( - {"loader": {"batch_size": 7}} - ) - override_audio = AudioConfig.model_validate({"samplerate": 384000}) - override_train = TrainingConfig.model_validate( - {"trainer": {"max_epochs": 2}} - ) - - captured_process: dict[str, object] = {} - captured_train: dict[str, object] = {} - - def fake_process_file_list(*args, **kwargs): - captured_process["inference_config"] = kwargs["inference_config"] - captured_process["audio_config"] = kwargs["audio_config"] - return [] - - def fake_run_train(*args, **kwargs): - captured_train["train_config"] = kwargs["train_config"] - captured_train["audio_config"] = kwargs["audio_config"] - captured_train["model_config"] = kwargs["model_config"] - return None - - monkeypatch.setattr( - "batdetect2.api_v2.process_file_list", fake_process_file_list - ) - monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train) - - api.process_files( - [], inference_config=override_inference, audio_config=override_audio - ) - api.train([], train_config=override_train, audio_config=override_audio) - - assert captured_process["inference_config"] is override_inference - assert captured_process["audio_config"] is override_audio - assert captured_train["train_config"] is override_train - assert captured_train["audio_config"] is override_audio - assert captured_train["model_config"] is api.model_config - - assert api.inference_config.loader.batch_size != 7 - assert api.audio_config.samplerate != 384000 - assert api.train_config.trainer.max_epochs != 2