Make sure api_v2 loads fast

This commit is contained in:
mbsantiago 2026-03-29 15:10:18 +01:00
parent 6d09133dca
commit c24056214c
2 changed files with 89 additions and 112 deletions

View File

@ -1,68 +1,46 @@
from __future__ import annotations
from pathlib import Path from pathlib import Path
from typing import Literal, Sequence, cast from typing import TYPE_CHECKING, Literal
import numpy as np import numpy as np
import torch
from soundevent import data from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config from collections.abc import Sequence
from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import ( import torch
DEFAULT_EVAL_DIR,
EvaluationConfig, from batdetect2.audio import AudioConfig, AudioLoader
EvaluatorProtocol, from batdetect2.config import BatDetect2Config
build_evaluator, from batdetect2.data import Dataset
run_evaluate, from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
save_evaluation_results, from batdetect2.inference import InferenceConfig
) from batdetect2.logging import AppLoggingConfig, LoggerConfig
from batdetect2.inference import ( from batdetect2.models import Model, ModelConfig
InferenceConfig, from batdetect2.outputs import (
process_file_list, OutputFormatConfig,
run_batch_inference, OutputFormatterProtocol,
) OutputsConfig,
from batdetect2.logging import ( OutputTransformProtocol,
DEFAULT_LOGS_DIR, )
AppLoggingConfig, from batdetect2.postprocess import (
LoggerConfig, ClipDetections,
) Detection,
from batdetect2.models import ( PostprocessorProtocol,
Model, )
ModelConfig, from batdetect2.preprocess import PreprocessorProtocol
build_model, from batdetect2.targets import (
build_model_with_new_targets, ROIMapperProtocol,
) TargetConfig,
from batdetect2.models.detectors import Detector TargetProtocol,
from batdetect2.outputs import ( )
OutputFormatConfig, from batdetect2.train import TrainingConfig
OutputFormatterProtocol,
OutputsConfig,
OutputTransformProtocol, DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
build_output_formatter, DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
build_output_transform, DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
get_output_formatter,
)
from batdetect2.postprocess import (
ClipDetections,
Detection,
PostprocessorProtocol,
build_postprocessor,
)
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
TargetProtocol,
build_roi_mapping,
build_targets,
)
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
load_model_from_checkpoint,
run_train,
)
class BatDetect2API: class BatDetect2API:
@ -109,6 +87,8 @@ class BatDetect2API:
path: data.PathLike, path: data.PathLike,
base_dir: data.PathLike | None = None, base_dir: data.PathLike | None = None,
) -> Dataset: ) -> Dataset:
from batdetect2.data import load_dataset_from_config
return load_dataset_from_config(path, base_dir=base_dir) return load_dataset_from_config(path, base_dir=base_dir)
def train( def train(
@ -128,6 +108,8 @@ class BatDetect2API:
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None, logger_config: LoggerConfig | None = None,
): ):
from batdetect2.train import run_train
run_train( run_train(
train_annotations=train_annotations, train_annotations=train_annotations,
val_annotations=val_annotations, val_annotations=val_annotations,
@ -172,6 +154,7 @@ class BatDetect2API:
logger_config: LoggerConfig | None = None, logger_config: LoggerConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection.""" """Fine-tune the model with trainable-parameter selection."""
from batdetect2.train import run_train
self._set_trainable_parameters(trainable) self._set_trainable_parameters(trainable)
@ -211,6 +194,8 @@ class BatDetect2API:
outputs_config: OutputsConfig | None = None, outputs_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None, logger_config: LoggerConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]: ) -> tuple[dict[str, float], list[ClipDetections]]:
from batdetect2.evaluate import run_evaluate
return run_evaluate( return run_evaluate(
self.model, self.model,
test_annotations, test_annotations,
@ -235,6 +220,8 @@ class BatDetect2API:
predictions: Sequence[ClipDetections], predictions: Sequence[ClipDetections],
output_dir: data.PathLike | None = None, output_dir: data.PathLike | None = None,
): ):
from batdetect2.evaluate import save_evaluation_results
clip_evals = self.evaluator.evaluate( clip_evals = self.evaluator.evaluate(
annotations, annotations,
predictions, predictions,
@ -307,6 +294,8 @@ class BatDetect2API:
self, self,
audio: np.ndarray, audio: np.ndarray,
) -> torch.Tensor: ) -> torch.Tensor:
import torch
tensor = torch.tensor(audio).unsqueeze(0) tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor) return self.preprocessor(tensor)
@ -316,6 +305,8 @@ class BatDetect2API:
batch_size: int | None = None, batch_size: int | None = None,
detection_threshold: float | None = None, detection_threshold: float | None = None,
) -> ClipDetections: ) -> ClipDetections:
from batdetect2.postprocess import ClipDetections
recording = data.Recording.from_file(audio_file, compute_hash=False) recording = data.Recording.from_file(audio_file, compute_hash=False)
predictions = self.process_files( predictions = self.process_files(
@ -382,6 +373,8 @@ class BatDetect2API:
audio_dir: data.PathLike, audio_dir: data.PathLike,
detection_threshold: float | None = None, detection_threshold: float | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
from soundevent.audio.files import get_audio_files
files = list(get_audio_files(audio_dir)) files = list(get_audio_files(audio_dir))
return self.process_files( return self.process_files(
files, files,
@ -398,6 +391,8 @@ class BatDetect2API:
output_config: OutputsConfig | None = None, output_config: OutputsConfig | None = None,
detection_threshold: float | None = None, detection_threshold: float | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
from batdetect2.inference import process_file_list
return process_file_list( return process_file_list(
self.model, self.model,
audio_files, audio_files,
@ -424,6 +419,8 @@ class BatDetect2API:
output_config: OutputsConfig | None = None, output_config: OutputsConfig | None = None,
detection_threshold: float | None = None, detection_threshold: float | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
from batdetect2.inference import run_batch_inference
return run_batch_inference( return run_batch_inference(
self.model, self.model,
clips, clips,
@ -448,6 +445,8 @@ class BatDetect2API:
format: str | None = None, format: str | None = None,
config: OutputFormatConfig | None = None, config: OutputFormatConfig | None = None,
): ):
from batdetect2.outputs import get_output_formatter
formatter = self.formatter formatter = self.formatter
if format is not None or config is not None: if format is not None or config is not None:
@ -467,6 +466,8 @@ class BatDetect2API:
format: str | None = None, format: str | None = None,
config: OutputFormatConfig | None = None, config: OutputFormatConfig | None = None,
) -> list[object]: ) -> list[object]:
from batdetect2.outputs import get_output_formatter
formatter = self.formatter formatter = self.formatter
if format is not None or config is not None: if format is not None or config is not None:
@ -484,6 +485,17 @@ class BatDetect2API:
cls, cls,
config: BatDetect2Config, config: BatDetect2Config,
) -> "BatDetect2API": ) -> "BatDetect2API":
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model
from batdetect2.outputs import (
build_output_formatter,
build_output_transform,
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
targets = build_targets(config=config.model.targets) targets = build_targets(config=config.model.targets)
roi_mapper = build_roi_mapping(config=config.model.targets.roi) roi_mapper = build_roi_mapping(config=config.model.targets.roi)
@ -563,6 +575,21 @@ class BatDetect2API:
outputs_config: OutputsConfig | None = None, outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None, logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
OutputsConfig,
build_output_formatter,
build_output_transform,
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
model, model_config = load_model_from_checkpoint(path) model, model_config = load_model_from_checkpoint(path)
audio_config = audio_config or AudioConfig( audio_config = audio_config or AudioConfig(
@ -645,7 +672,7 @@ class BatDetect2API:
self, self,
trainable: Literal["all", "heads", "classifier_head", "bbox_head"], trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
) -> None: ) -> None:
detector = cast(Detector, self.model.detector) detector = self.model.detector
for parameter in detector.parameters(): for parameter in detector.parameters():
parameter.requires_grad = False parameter.requires_grad = False

View File

@ -8,12 +8,10 @@ import torch
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.inference import InferenceConfig
from batdetect2.models.detectors import Detector from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead from batdetect2.models.heads import ClassifierHead
from batdetect2.train import TrainingConfig, load_model_from_checkpoint from batdetect2.train import load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
@ -452,51 +450,3 @@ def test_detection_threshold_override_changes_spectrogram_results(
) )
assert len(strict_detections) <= len(default_detections) assert len(strict_detections) <= len(default_detections)
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
"""User story: call-level overrides do not mutate resolved defaults."""
api = BatDetect2API.from_config(BatDetect2Config())
override_inference = InferenceConfig.model_validate(
{"loader": {"batch_size": 7}}
)
override_audio = AudioConfig.model_validate({"samplerate": 384000})
override_train = TrainingConfig.model_validate(
{"trainer": {"max_epochs": 2}}
)
captured_process: dict[str, object] = {}
captured_train: dict[str, object] = {}
def fake_process_file_list(*args, **kwargs):
captured_process["inference_config"] = kwargs["inference_config"]
captured_process["audio_config"] = kwargs["audio_config"]
return []
def fake_run_train(*args, **kwargs):
captured_train["train_config"] = kwargs["train_config"]
captured_train["audio_config"] = kwargs["audio_config"]
captured_train["model_config"] = kwargs["model_config"]
return None
monkeypatch.setattr(
"batdetect2.api_v2.process_file_list", fake_process_file_list
)
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
api.process_files(
[], inference_config=override_inference, audio_config=override_audio
)
api.train([], train_config=override_train, audio_config=override_audio)
assert captured_process["inference_config"] is override_inference
assert captured_process["audio_config"] is override_audio
assert captured_train["train_config"] is override_train
assert captured_train["audio_config"] is override_audio
assert captured_train["model_config"] is api.model_config
assert api.inference_config.loader.batch_size != 7
assert api.audio_config.samplerate != 384000
assert api.train_config.trainer.max_epochs != 2