mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Make sure api_v2 loads fast
This commit is contained in:
parent
6d09133dca
commit
c24056214c
@ -1,68 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Sequence, cast
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import Dataset, load_dataset_from_config
|
||||
from batdetect2.evaluate import (
|
||||
DEFAULT_EVAL_DIR,
|
||||
EvaluationConfig,
|
||||
EvaluatorProtocol,
|
||||
build_evaluator,
|
||||
run_evaluate,
|
||||
save_evaluation_results,
|
||||
)
|
||||
from batdetect2.inference import (
|
||||
InferenceConfig,
|
||||
process_file_list,
|
||||
run_batch_inference,
|
||||
)
|
||||
from batdetect2.logging import (
|
||||
DEFAULT_LOGS_DIR,
|
||||
AppLoggingConfig,
|
||||
LoggerConfig,
|
||||
)
|
||||
from batdetect2.models import (
|
||||
Model,
|
||||
ModelConfig,
|
||||
build_model,
|
||||
build_model_with_new_targets,
|
||||
)
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.postprocess import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
PostprocessorProtocol,
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetConfig,
|
||||
TargetProtocol,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
TrainingConfig,
|
||||
load_model_from_checkpoint,
|
||||
run_train,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.audio import AudioConfig, AudioLoader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import Dataset
|
||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
||||
from batdetect2.models import Model, ModelConfig
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
)
|
||||
from batdetect2.postprocess import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetConfig,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.train import TrainingConfig
|
||||
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
@ -109,6 +87,8 @@ class BatDetect2API:
|
||||
path: data.PathLike,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
return load_dataset_from_config(path, base_dir=base_dir)
|
||||
|
||||
def train(
|
||||
@ -128,6 +108,8 @@ class BatDetect2API:
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
):
|
||||
from batdetect2.train import run_train
|
||||
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
@ -172,6 +154,7 @@ class BatDetect2API:
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune the model with trainable-parameter selection."""
|
||||
from batdetect2.train import run_train
|
||||
|
||||
self._set_trainable_parameters(trainable)
|
||||
|
||||
@ -211,6 +194,8 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
from batdetect2.evaluate import run_evaluate
|
||||
|
||||
return run_evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
@ -235,6 +220,8 @@ class BatDetect2API:
|
||||
predictions: Sequence[ClipDetections],
|
||||
output_dir: data.PathLike | None = None,
|
||||
):
|
||||
from batdetect2.evaluate import save_evaluation_results
|
||||
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
annotations,
|
||||
predictions,
|
||||
@ -307,6 +294,8 @@ class BatDetect2API:
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
import torch
|
||||
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
@ -316,6 +305,8 @@ class BatDetect2API:
|
||||
batch_size: int | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> ClipDetections:
|
||||
from batdetect2.postprocess import ClipDetections
|
||||
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
|
||||
predictions = self.process_files(
|
||||
@ -382,6 +373,8 @@ class BatDetect2API:
|
||||
audio_dir: data.PathLike,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(
|
||||
files,
|
||||
@ -398,6 +391,8 @@ class BatDetect2API:
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
from batdetect2.inference import process_file_list
|
||||
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
@ -424,6 +419,8 @@ class BatDetect2API:
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
from batdetect2.inference import run_batch_inference
|
||||
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
@ -448,6 +445,8 @@ class BatDetect2API:
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
):
|
||||
from batdetect2.outputs import get_output_formatter
|
||||
|
||||
formatter = self.formatter
|
||||
|
||||
if format is not None or config is not None:
|
||||
@ -467,6 +466,8 @@ class BatDetect2API:
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> list[object]:
|
||||
from batdetect2.outputs import get_output_formatter
|
||||
|
||||
formatter = self.formatter
|
||||
|
||||
if format is not None or config is not None:
|
||||
@ -484,6 +485,17 @@ class BatDetect2API:
|
||||
cls,
|
||||
config: BatDetect2Config,
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.outputs import (
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
|
||||
targets = build_targets(config=config.model.targets)
|
||||
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||
|
||||
@ -563,6 +575,21 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logging_config: AppLoggingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import build_model_with_new_targets
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
|
||||
audio_config = audio_config or AudioConfig(
|
||||
@ -645,7 +672,7 @@ class BatDetect2API:
|
||||
self,
|
||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||
) -> None:
|
||||
detector = cast(Detector, self.model.detector)
|
||||
detector = self.model.detector
|
||||
|
||||
for parameter in detector.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
@ -8,12 +8,10 @@ import torch
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.models.heads import ClassifierHead
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
|
||||
@ -452,51 +450,3 @@ def test_detection_threshold_override_changes_spectrogram_results(
|
||||
)
|
||||
|
||||
assert len(strict_detections) <= len(default_detections)
|
||||
|
||||
|
||||
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
||||
"""User story: call-level overrides do not mutate resolved defaults."""
|
||||
|
||||
api = BatDetect2API.from_config(BatDetect2Config())
|
||||
|
||||
override_inference = InferenceConfig.model_validate(
|
||||
{"loader": {"batch_size": 7}}
|
||||
)
|
||||
override_audio = AudioConfig.model_validate({"samplerate": 384000})
|
||||
override_train = TrainingConfig.model_validate(
|
||||
{"trainer": {"max_epochs": 2}}
|
||||
)
|
||||
|
||||
captured_process: dict[str, object] = {}
|
||||
captured_train: dict[str, object] = {}
|
||||
|
||||
def fake_process_file_list(*args, **kwargs):
|
||||
captured_process["inference_config"] = kwargs["inference_config"]
|
||||
captured_process["audio_config"] = kwargs["audio_config"]
|
||||
return []
|
||||
|
||||
def fake_run_train(*args, **kwargs):
|
||||
captured_train["train_config"] = kwargs["train_config"]
|
||||
captured_train["audio_config"] = kwargs["audio_config"]
|
||||
captured_train["model_config"] = kwargs["model_config"]
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.api_v2.process_file_list", fake_process_file_list
|
||||
)
|
||||
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
|
||||
|
||||
api.process_files(
|
||||
[], inference_config=override_inference, audio_config=override_audio
|
||||
)
|
||||
api.train([], train_config=override_train, audio_config=override_audio)
|
||||
|
||||
assert captured_process["inference_config"] is override_inference
|
||||
assert captured_process["audio_config"] is override_audio
|
||||
assert captured_train["train_config"] is override_train
|
||||
assert captured_train["audio_config"] is override_audio
|
||||
assert captured_train["model_config"] is api.model_config
|
||||
|
||||
assert api.inference_config.loader.batch_size != 7
|
||||
assert api.audio_config.samplerate != 384000
|
||||
assert api.train_config.trainer.max_epochs != 2
|
||||
|
||||
Loading…
Reference in New Issue
Block a user