mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Make sure api_v2 loads fast
This commit is contained in:
parent
6d09133dca
commit
c24056214c
@ -1,68 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, Sequence, cast
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.audio.files import get_audio_files
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
from collections.abc import Sequence
|
||||||
from batdetect2.data import Dataset, load_dataset_from_config
|
|
||||||
from batdetect2.evaluate import (
|
import torch
|
||||||
DEFAULT_EVAL_DIR,
|
|
||||||
EvaluationConfig,
|
from batdetect2.audio import AudioConfig, AudioLoader
|
||||||
EvaluatorProtocol,
|
from batdetect2.config import BatDetect2Config
|
||||||
build_evaluator,
|
from batdetect2.data import Dataset
|
||||||
run_evaluate,
|
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||||
save_evaluation_results,
|
from batdetect2.inference import InferenceConfig
|
||||||
)
|
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
||||||
from batdetect2.inference import (
|
from batdetect2.models import Model, ModelConfig
|
||||||
InferenceConfig,
|
from batdetect2.outputs import (
|
||||||
process_file_list,
|
|
||||||
run_batch_inference,
|
|
||||||
)
|
|
||||||
from batdetect2.logging import (
|
|
||||||
DEFAULT_LOGS_DIR,
|
|
||||||
AppLoggingConfig,
|
|
||||||
LoggerConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.models import (
|
|
||||||
Model,
|
|
||||||
ModelConfig,
|
|
||||||
build_model,
|
|
||||||
build_model_with_new_targets,
|
|
||||||
)
|
|
||||||
from batdetect2.models.detectors import Detector
|
|
||||||
from batdetect2.outputs import (
|
|
||||||
OutputFormatConfig,
|
OutputFormatConfig,
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
OutputsConfig,
|
OutputsConfig,
|
||||||
OutputTransformProtocol,
|
OutputTransformProtocol,
|
||||||
build_output_formatter,
|
)
|
||||||
build_output_transform,
|
from batdetect2.postprocess import (
|
||||||
get_output_formatter,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess import (
|
|
||||||
ClipDetections,
|
ClipDetections,
|
||||||
Detection,
|
Detection,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
build_postprocessor,
|
)
|
||||||
)
|
from batdetect2.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.targets import (
|
||||||
from batdetect2.targets import (
|
|
||||||
ROIMapperProtocol,
|
ROIMapperProtocol,
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
build_roi_mapping,
|
)
|
||||||
build_targets,
|
from batdetect2.train import TrainingConfig
|
||||||
)
|
|
||||||
from batdetect2.train import (
|
|
||||||
DEFAULT_CHECKPOINT_DIR,
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
TrainingConfig,
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
load_model_from_checkpoint,
|
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||||
run_train,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
class BatDetect2API:
|
||||||
@ -109,6 +87,8 @@ class BatDetect2API:
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
|
||||||
return load_dataset_from_config(path, base_dir=base_dir)
|
return load_dataset_from_config(path, base_dir=base_dir)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -128,6 +108,8 @@ class BatDetect2API:
|
|||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
@ -172,6 +154,7 @@ class BatDetect2API:
|
|||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
"""Fine-tune the model with trainable-parameter selection."""
|
"""Fine-tune the model with trainable-parameter selection."""
|
||||||
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
self._set_trainable_parameters(trainable)
|
self._set_trainable_parameters(trainable)
|
||||||
|
|
||||||
@ -211,6 +194,8 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig | None = None,
|
outputs_config: OutputsConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||||
|
from batdetect2.evaluate import run_evaluate
|
||||||
|
|
||||||
return run_evaluate(
|
return run_evaluate(
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -235,6 +220,8 @@ class BatDetect2API:
|
|||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[ClipDetections],
|
||||||
output_dir: data.PathLike | None = None,
|
output_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.evaluate import save_evaluation_results
|
||||||
|
|
||||||
clip_evals = self.evaluator.evaluate(
|
clip_evals = self.evaluator.evaluate(
|
||||||
annotations,
|
annotations,
|
||||||
predictions,
|
predictions,
|
||||||
@ -307,6 +294,8 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
import torch
|
||||||
|
|
||||||
tensor = torch.tensor(audio).unsqueeze(0)
|
tensor = torch.tensor(audio).unsqueeze(0)
|
||||||
return self.preprocessor(tensor)
|
return self.preprocessor(tensor)
|
||||||
|
|
||||||
@ -316,6 +305,8 @@ class BatDetect2API:
|
|||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> ClipDetections:
|
) -> ClipDetections:
|
||||||
|
from batdetect2.postprocess import ClipDetections
|
||||||
|
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
|
|
||||||
predictions = self.process_files(
|
predictions = self.process_files(
|
||||||
@ -382,6 +373,8 @@ class BatDetect2API:
|
|||||||
audio_dir: data.PathLike,
|
audio_dir: data.PathLike,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
files = list(get_audio_files(audio_dir))
|
files = list(get_audio_files(audio_dir))
|
||||||
return self.process_files(
|
return self.process_files(
|
||||||
files,
|
files,
|
||||||
@ -398,6 +391,8 @@ class BatDetect2API:
|
|||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
|
from batdetect2.inference import process_file_list
|
||||||
|
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
audio_files,
|
audio_files,
|
||||||
@ -424,6 +419,8 @@ class BatDetect2API:
|
|||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
|
from batdetect2.inference import run_batch_inference
|
||||||
|
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
clips,
|
clips,
|
||||||
@ -448,6 +445,8 @@ class BatDetect2API:
|
|||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: OutputFormatConfig | None = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.outputs import get_output_formatter
|
||||||
|
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
|
|
||||||
if format is not None or config is not None:
|
if format is not None or config is not None:
|
||||||
@ -467,6 +466,8 @@ class BatDetect2API:
|
|||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: OutputFormatConfig | None = None,
|
||||||
) -> list[object]:
|
) -> list[object]:
|
||||||
|
from batdetect2.outputs import get_output_formatter
|
||||||
|
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
|
|
||||||
if format is not None or config is not None:
|
if format is not None or config is not None:
|
||||||
@ -484,6 +485,17 @@ class BatDetect2API:
|
|||||||
cls,
|
cls,
|
||||||
config: BatDetect2Config,
|
config: BatDetect2Config,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.evaluate import build_evaluator
|
||||||
|
from batdetect2.models import build_model
|
||||||
|
from batdetect2.outputs import (
|
||||||
|
build_output_formatter,
|
||||||
|
build_output_transform,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
|
|
||||||
targets = build_targets(config=config.model.targets)
|
targets = build_targets(config=config.model.targets)
|
||||||
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||||
|
|
||||||
@ -563,6 +575,21 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig | None = None,
|
outputs_config: OutputsConfig | None = None,
|
||||||
logging_config: AppLoggingConfig | None = None,
|
logging_config: AppLoggingConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||||
|
from batdetect2.inference import InferenceConfig
|
||||||
|
from batdetect2.logging import AppLoggingConfig
|
||||||
|
from batdetect2.models import build_model_with_new_targets
|
||||||
|
from batdetect2.outputs import (
|
||||||
|
OutputsConfig,
|
||||||
|
build_output_formatter,
|
||||||
|
build_output_transform,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_roi_mapping, build_targets
|
||||||
|
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig(
|
audio_config = audio_config or AudioConfig(
|
||||||
@ -645,7 +672,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||||
) -> None:
|
) -> None:
|
||||||
detector = cast(Detector, self.model.detector)
|
detector = self.model.detector
|
||||||
|
|
||||||
for parameter in detector.parameters():
|
for parameter in detector.parameters():
|
||||||
parameter.requires_grad = False
|
parameter.requires_grad = False
|
||||||
|
|||||||
@ -8,12 +8,10 @@ import torch
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.inference import InferenceConfig
|
|
||||||
from batdetect2.models.detectors import Detector
|
from batdetect2.models.detectors import Detector
|
||||||
from batdetect2.models.heads import ClassifierHead
|
from batdetect2.models.heads import ClassifierHead
|
||||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
from batdetect2.train import load_model_from_checkpoint
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
|
|
||||||
@ -452,51 +450,3 @@ def test_detection_threshold_override_changes_spectrogram_results(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert len(strict_detections) <= len(default_detections)
|
assert len(strict_detections) <= len(default_detections)
|
||||||
|
|
||||||
|
|
||||||
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
|
||||||
"""User story: call-level overrides do not mutate resolved defaults."""
|
|
||||||
|
|
||||||
api = BatDetect2API.from_config(BatDetect2Config())
|
|
||||||
|
|
||||||
override_inference = InferenceConfig.model_validate(
|
|
||||||
{"loader": {"batch_size": 7}}
|
|
||||||
)
|
|
||||||
override_audio = AudioConfig.model_validate({"samplerate": 384000})
|
|
||||||
override_train = TrainingConfig.model_validate(
|
|
||||||
{"trainer": {"max_epochs": 2}}
|
|
||||||
)
|
|
||||||
|
|
||||||
captured_process: dict[str, object] = {}
|
|
||||||
captured_train: dict[str, object] = {}
|
|
||||||
|
|
||||||
def fake_process_file_list(*args, **kwargs):
|
|
||||||
captured_process["inference_config"] = kwargs["inference_config"]
|
|
||||||
captured_process["audio_config"] = kwargs["audio_config"]
|
|
||||||
return []
|
|
||||||
|
|
||||||
def fake_run_train(*args, **kwargs):
|
|
||||||
captured_train["train_config"] = kwargs["train_config"]
|
|
||||||
captured_train["audio_config"] = kwargs["audio_config"]
|
|
||||||
captured_train["model_config"] = kwargs["model_config"]
|
|
||||||
return None
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
"batdetect2.api_v2.process_file_list", fake_process_file_list
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
|
|
||||||
|
|
||||||
api.process_files(
|
|
||||||
[], inference_config=override_inference, audio_config=override_audio
|
|
||||||
)
|
|
||||||
api.train([], train_config=override_train, audio_config=override_audio)
|
|
||||||
|
|
||||||
assert captured_process["inference_config"] is override_inference
|
|
||||||
assert captured_process["audio_config"] is override_audio
|
|
||||||
assert captured_train["train_config"] is override_train
|
|
||||||
assert captured_train["audio_config"] is override_audio
|
|
||||||
assert captured_train["model_config"] is api.model_config
|
|
||||||
|
|
||||||
assert api.inference_config.loader.batch_size != 7
|
|
||||||
assert api.audio_config.samplerate != 384000
|
|
||||||
assert api.train_config.trainer.max_epochs != 2
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user