mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
366 lines
11 KiB
Python
366 lines
11 KiB
Python
import json
|
|
from pathlib import Path
|
|
from typing import Sequence
|
|
|
|
import numpy as np
|
|
import torch
|
|
from soundevent import data
|
|
from soundevent.audio.files import get_audio_files
|
|
|
|
from batdetect2.audio import build_audio_loader
|
|
from batdetect2.config import BatDetect2Config
|
|
from batdetect2.core import merge_configs
|
|
from batdetect2.data import (
|
|
OutputFormatConfig,
|
|
build_output_formatter,
|
|
get_output_formatter,
|
|
load_dataset_from_config,
|
|
)
|
|
from batdetect2.data.datasets import Dataset
|
|
from batdetect2.data.predictions.base import OutputFormatterProtocol
|
|
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
|
from batdetect2.inference import process_file_list, run_batch_inference
|
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
|
from batdetect2.models import Model, build_model
|
|
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
|
from batdetect2.preprocess import build_preprocessor
|
|
from batdetect2.targets import build_targets
|
|
from batdetect2.train import (
|
|
DEFAULT_CHECKPOINT_DIR,
|
|
load_model_from_checkpoint,
|
|
run_train,
|
|
)
|
|
from batdetect2.typing import (
|
|
AudioLoader,
|
|
ClipDetections,
|
|
Detection,
|
|
EvaluatorProtocol,
|
|
PostprocessorProtocol,
|
|
PreprocessorProtocol,
|
|
TargetProtocol,
|
|
)
|
|
|
|
|
|
class BatDetect2API:
|
|
def __init__(
|
|
self,
|
|
config: BatDetect2Config,
|
|
targets: TargetProtocol,
|
|
audio_loader: AudioLoader,
|
|
preprocessor: PreprocessorProtocol,
|
|
postprocessor: PostprocessorProtocol,
|
|
evaluator: EvaluatorProtocol,
|
|
formatter: OutputFormatterProtocol,
|
|
model: Model,
|
|
):
|
|
self.config = config
|
|
self.targets = targets
|
|
self.audio_loader = audio_loader
|
|
self.preprocessor = preprocessor
|
|
self.postprocessor = postprocessor
|
|
self.evaluator = evaluator
|
|
self.model = model
|
|
self.formatter = formatter
|
|
|
|
self.model.eval()
|
|
|
|
def load_annotations(
|
|
self,
|
|
path: data.PathLike,
|
|
base_dir: data.PathLike | None = None,
|
|
) -> Dataset:
|
|
return load_dataset_from_config(path, base_dir=base_dir)
|
|
|
|
def train(
|
|
self,
|
|
train_annotations: Sequence[data.ClipAnnotation],
|
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
|
train_workers: int | None = None,
|
|
val_workers: int | None = None,
|
|
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
|
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
|
experiment_name: str | None = None,
|
|
num_epochs: int | None = None,
|
|
run_name: str | None = None,
|
|
seed: int | None = None,
|
|
):
|
|
run_train(
|
|
train_annotations=train_annotations,
|
|
val_annotations=val_annotations,
|
|
targets=self.targets,
|
|
model_config=self.config.model,
|
|
train_config=self.config.train,
|
|
audio_config=self.config.audio,
|
|
audio_loader=self.audio_loader,
|
|
preprocessor=self.preprocessor,
|
|
train_workers=train_workers,
|
|
val_workers=val_workers,
|
|
checkpoint_dir=checkpoint_dir,
|
|
log_dir=log_dir,
|
|
num_epochs=num_epochs,
|
|
experiment_name=experiment_name,
|
|
run_name=run_name,
|
|
seed=seed,
|
|
)
|
|
return self
|
|
|
|
def evaluate(
|
|
self,
|
|
test_annotations: Sequence[data.ClipAnnotation],
|
|
num_workers: int | None = None,
|
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
|
experiment_name: str | None = None,
|
|
run_name: str | None = None,
|
|
save_predictions: bool = True,
|
|
) -> tuple[dict[str, float], list[list[Detection]]]:
|
|
return evaluate(
|
|
self.model,
|
|
test_annotations,
|
|
targets=self.targets,
|
|
audio_loader=self.audio_loader,
|
|
preprocessor=self.preprocessor,
|
|
config=self.config,
|
|
num_workers=num_workers,
|
|
output_dir=output_dir,
|
|
experiment_name=experiment_name,
|
|
run_name=run_name,
|
|
formatter=self.formatter if save_predictions else None,
|
|
)
|
|
|
|
def evaluate_predictions(
|
|
self,
|
|
annotations: Sequence[data.ClipAnnotation],
|
|
predictions: Sequence[ClipDetections],
|
|
output_dir: data.PathLike | None = None,
|
|
):
|
|
clip_evals = self.evaluator.evaluate(
|
|
annotations,
|
|
predictions,
|
|
)
|
|
|
|
metrics = self.evaluator.compute_metrics(clip_evals)
|
|
|
|
if output_dir is not None:
|
|
output_dir = Path(output_dir)
|
|
|
|
if not output_dir.is_dir():
|
|
output_dir.mkdir(parents=True)
|
|
|
|
metrics_path = output_dir / "metrics.json"
|
|
metrics_path.write_text(json.dumps(metrics))
|
|
|
|
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
|
|
fig_path = output_dir / figure_name
|
|
|
|
if not fig_path.parent.is_dir():
|
|
fig_path.parent.mkdir(parents=True)
|
|
|
|
fig.savefig(fig_path)
|
|
|
|
return metrics
|
|
|
|
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
|
return self.audio_loader.load_file(path)
|
|
|
|
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
|
return self.audio_loader.load_clip(clip)
|
|
|
|
def generate_spectrogram(
|
|
self,
|
|
audio: np.ndarray,
|
|
) -> torch.Tensor:
|
|
tensor = torch.tensor(audio).unsqueeze(0)
|
|
return self.preprocessor(tensor)
|
|
|
|
def process_file(self, audio_file: str) -> ClipDetections:
|
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
|
wav = self.audio_loader.load_recording(recording)
|
|
detections = self.process_audio(wav)
|
|
return ClipDetections(
|
|
clip=data.Clip(
|
|
uuid=recording.uuid,
|
|
recording=recording,
|
|
start_time=0,
|
|
end_time=recording.duration,
|
|
),
|
|
detections=detections,
|
|
)
|
|
|
|
def process_audio(
|
|
self,
|
|
audio: np.ndarray,
|
|
) -> list[Detection]:
|
|
spec = self.generate_spectrogram(audio)
|
|
return self.process_spectrogram(spec)
|
|
|
|
def process_spectrogram(
|
|
self,
|
|
spec: torch.Tensor,
|
|
start_time: float = 0,
|
|
) -> list[Detection]:
|
|
if spec.ndim == 4 and spec.shape[0] > 1:
|
|
raise ValueError("Batched spectrograms not supported.")
|
|
|
|
if spec.ndim == 3:
|
|
spec = spec.unsqueeze(0)
|
|
|
|
outputs = self.model.detector(spec)
|
|
|
|
detections = self.model.postprocessor(
|
|
outputs,
|
|
start_times=[start_time],
|
|
)[0]
|
|
|
|
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
|
|
|
def process_directory(
|
|
self,
|
|
audio_dir: data.PathLike,
|
|
) -> list[ClipDetections]:
|
|
files = list(get_audio_files(audio_dir))
|
|
return self.process_files(files)
|
|
|
|
def process_files(
|
|
self,
|
|
audio_files: Sequence[data.PathLike],
|
|
num_workers: int | None = None,
|
|
) -> list[ClipDetections]:
|
|
return process_file_list(
|
|
self.model,
|
|
audio_files,
|
|
config=self.config,
|
|
targets=self.targets,
|
|
audio_loader=self.audio_loader,
|
|
preprocessor=self.preprocessor,
|
|
num_workers=num_workers,
|
|
)
|
|
|
|
def process_clips(
|
|
self,
|
|
clips: Sequence[data.Clip],
|
|
batch_size: int | None = None,
|
|
num_workers: int | None = None,
|
|
) -> list[ClipDetections]:
|
|
return run_batch_inference(
|
|
self.model,
|
|
clips,
|
|
targets=self.targets,
|
|
audio_loader=self.audio_loader,
|
|
preprocessor=self.preprocessor,
|
|
config=self.config,
|
|
batch_size=batch_size,
|
|
num_workers=num_workers,
|
|
)
|
|
|
|
def save_predictions(
|
|
self,
|
|
predictions: Sequence[ClipDetections],
|
|
path: data.PathLike,
|
|
audio_dir: data.PathLike | None = None,
|
|
format: str | None = None,
|
|
config: OutputFormatConfig | None = None,
|
|
):
|
|
formatter = self.formatter
|
|
|
|
if format is not None or config is not None:
|
|
format = format or config.name # type: ignore
|
|
formatter = get_output_formatter(
|
|
name=format,
|
|
targets=self.targets,
|
|
config=config,
|
|
)
|
|
|
|
outs = formatter.format(predictions)
|
|
formatter.save(outs, audio_dir=audio_dir, path=path)
|
|
|
|
def load_predictions(
|
|
self,
|
|
path: data.PathLike,
|
|
) -> list[ClipDetections]:
|
|
return self.formatter.load(path)
|
|
|
|
@classmethod
|
|
def from_config(
|
|
cls,
|
|
config: BatDetect2Config,
|
|
):
|
|
targets = build_targets(config=config.model.targets)
|
|
|
|
audio_loader = build_audio_loader(config=config.audio)
|
|
|
|
preprocessor = build_preprocessor(
|
|
input_samplerate=audio_loader.samplerate,
|
|
config=config.model.preprocess,
|
|
)
|
|
|
|
postprocessor = build_postprocessor(
|
|
preprocessor,
|
|
config=config.model.postprocess,
|
|
)
|
|
|
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
|
|
|
# NOTE: Better to have a separate instance of preprocessor and
|
|
# postprocessor as these may be moved to another device.
|
|
model = build_model(config=config.model)
|
|
|
|
formatter = build_output_formatter(targets, config=config.output)
|
|
|
|
return cls(
|
|
config=config,
|
|
targets=targets,
|
|
audio_loader=audio_loader,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
evaluator=evaluator,
|
|
model=model,
|
|
formatter=formatter,
|
|
)
|
|
|
|
@classmethod
|
|
def from_checkpoint(
|
|
cls,
|
|
path: data.PathLike,
|
|
config: BatDetect2Config | None = None,
|
|
):
|
|
from batdetect2.audio import AudioConfig
|
|
|
|
model, model_config = load_model_from_checkpoint(path)
|
|
|
|
# Reconstruct a full BatDetect2Config from the checkpoint's
|
|
# ModelConfig, then overlay any caller-supplied overrides.
|
|
base = BatDetect2Config(
|
|
model=model_config,
|
|
audio=AudioConfig(samplerate=model_config.samplerate),
|
|
)
|
|
config = merge_configs(base, config) if config else base
|
|
|
|
targets = build_targets(config=config.model.targets)
|
|
|
|
audio_loader = build_audio_loader(config=config.audio)
|
|
|
|
preprocessor = build_preprocessor(
|
|
input_samplerate=audio_loader.samplerate,
|
|
config=config.model.preprocess,
|
|
)
|
|
|
|
postprocessor = build_postprocessor(
|
|
preprocessor,
|
|
config=config.model.postprocess,
|
|
)
|
|
|
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
|
|
|
formatter = build_output_formatter(targets, config=config.output)
|
|
|
|
return cls(
|
|
config=config,
|
|
targets=targets,
|
|
audio_loader=audio_loader,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
evaluator=evaluator,
|
|
model=model,
|
|
formatter=formatter,
|
|
)
|