mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Config restructuring
This commit is contained in:
parent
7d6cba5465
commit
bbb96b33a2
@ -36,14 +36,13 @@ targets:
|
||||
name: anchor_bbox
|
||||
anchor: top-left
|
||||
|
||||
preprocess:
|
||||
audio:
|
||||
audio:
|
||||
samplerate: 256000
|
||||
resample:
|
||||
enabled: True
|
||||
method: "poly"
|
||||
|
||||
spectrogram:
|
||||
preprocess:
|
||||
stft:
|
||||
window_duration: 0.002
|
||||
window_overlap: 0.75
|
||||
@ -54,7 +53,7 @@ preprocess:
|
||||
size:
|
||||
height: 128
|
||||
resize_factor: 0.5
|
||||
transforms:
|
||||
spectrogram_transforms:
|
||||
- name: pcen
|
||||
time_constant: 0.1
|
||||
gain: 0.98
|
||||
@ -113,35 +112,15 @@ train:
|
||||
|
||||
train_loader:
|
||||
batch_size: 8
|
||||
|
||||
num_workers: 2
|
||||
|
||||
shuffle: True
|
||||
|
||||
clipping_strategy:
|
||||
name: random_subclip
|
||||
duration: 0.256
|
||||
|
||||
val_loader:
|
||||
num_workers: 2
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
classification:
|
||||
weight: 2.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
logger:
|
||||
name: csv
|
||||
|
||||
augmentations:
|
||||
enabled: true
|
||||
audio:
|
||||
@ -170,3 +149,26 @@ train:
|
||||
probability: 0.2
|
||||
max_perc: 0.10
|
||||
max_masks: 3
|
||||
|
||||
val_loader:
|
||||
num_workers: 2
|
||||
clipping_strategy:
|
||||
name: whole_audio_padded
|
||||
chunk_size: 0.256
|
||||
|
||||
loss:
|
||||
detection:
|
||||
weight: 1.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
classification:
|
||||
weight: 2.0
|
||||
focal:
|
||||
beta: 4
|
||||
alpha: 2
|
||||
size:
|
||||
weight: 0.1
|
||||
|
||||
logger:
|
||||
name: csv
|
||||
|
||||
295
src/batdetect2/audio.py
Normal file
295
src/batdetect2/audio.py
Normal file
@ -0,0 +1,295 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"resample_audio",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: ResampleConfig = Field(default_factory=ResampleConfig)
|
||||
|
||||
|
||||
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
samplerate=config.samplerate,
|
||||
config=config.resample,
|
||||
)
|
||||
|
||||
|
||||
class SoundEventAudioLoader(AudioLoader):
|
||||
"""Concrete implementation of the `AudioLoader`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
path,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {path}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
.sel(channel=0)
|
||||
.astype(dtype)
|
||||
)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {clip.recording.path}. "
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if not config or not config.enabled or samplerate is None:
|
||||
return wav.data.astype(dtype)
|
||||
|
||||
sr = int(1 / wav.time.attrs["step"])
|
||||
return resample_audio(
|
||||
wav.data,
|
||||
sr=sr,
|
||||
samplerate=samplerate,
|
||||
method=config.method,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: np.ndarray,
|
||||
sr: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
method: str = "poly",
|
||||
) -> np.ndarray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||
if sr == samplerate:
|
||||
return wav
|
||||
|
||||
if method == "poly":
|
||||
return resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
elif method == "fourier":
|
||||
return resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Resampling method '{method}' not implemented"
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
||||
|
||||
This method is often preferred for signals when the ratio of new
|
||||
to old sample rates can be expressed as a rational number. It uses
|
||||
polyphase filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
sr_orig : int
|
||||
The original sample rate in Hz.
|
||||
sr_new : int
|
||||
The target sample rate in Hz.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to the target sample rate.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rates are not positive.
|
||||
"""
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample`.
|
||||
|
||||
This method uses FFTs to resample the signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
num : int
|
||||
The desired number of samples in the output array along `axis`.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to have `num` samples along `axis`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample( # type: ignore
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
)
|
||||
@ -14,7 +14,7 @@ __all__ = ["train_command"]
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--targets", type=click.Path(exists=True))
|
||||
@click.option("--targets-config", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@ -37,7 +37,7 @@ def train_command(
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
targets: Optional[Path] = None,
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
@ -46,13 +46,13 @@ def train_command(
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.config import (
|
||||
BatDetect2Config,
|
||||
load_full_config,
|
||||
)
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
from batdetect2.train import train
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
@ -68,15 +68,16 @@ def train_command(
|
||||
logger.info("Loading training configuration...")
|
||||
|
||||
conf = (
|
||||
load_full_training_config(config, field=config_field)
|
||||
load_full_config(config, field=config_field)
|
||||
if config is not None
|
||||
else FullTrainingConfig()
|
||||
else BatDetect2Config()
|
||||
)
|
||||
|
||||
if targets is not None:
|
||||
if targets_config is not None:
|
||||
logger.info("Loading targets configuration...")
|
||||
targets_config = load_target_config(targets)
|
||||
conf = conf.model_copy(update=dict(targets=targets_config))
|
||||
conf = conf.model_copy(
|
||||
update=dict(targets=load_target_config(targets_config))
|
||||
)
|
||||
|
||||
logger.info("Loading training dataset...")
|
||||
train_annotations = load_dataset_from_config(train_dataset)
|
||||
@ -96,6 +97,7 @@ def train_command(
|
||||
logger.debug("No validation directory provided.")
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
|
||||
train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
|
||||
@ -1,16 +1,40 @@
|
||||
from typing import Literal
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.configs import load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.models.backbones import BackboneConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.targets.config import TargetConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2Config",
|
||||
"load_full_config",
|
||||
]
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
config_version: Literal["v1"] = "v1"
|
||||
|
||||
train: TrainingConfig
|
||||
evaluation: EvaluationConfig
|
||||
model: BackboneConfig
|
||||
preprocess: PreprocessingConfig
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
def load_full_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BatDetect2Config:
|
||||
return load_config(path, schema=BatDetect2Config, field=field)
|
||||
|
||||
@ -3,39 +3,46 @@ from typing import List, Optional, Tuple
|
||||
import pandas as pd
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.plotting.clips import build_audio_loader
|
||||
from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.train import build_val_loader
|
||||
from batdetect2.typing import ClipLabeller, TargetProtocol
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: Model,
|
||||
test_annotations: List[data.ClipAnnotation],
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
config: Optional[EvaluationConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> Tuple[pd.DataFrame, dict]:
|
||||
config = config or FullTrainingConfig()
|
||||
config = config or EvaluationConfig()
|
||||
|
||||
audio_loader = build_audio_loader(config.preprocess.audio)
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
targets = build_targets(config.targets)
|
||||
targets = targets or build_targets()
|
||||
|
||||
labeller = build_clip_labeler(
|
||||
labeller = labeller or build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.train.labels,
|
||||
)
|
||||
|
||||
loader = build_val_loader(
|
||||
@ -43,7 +50,6 @@ def evaluate(
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train.val_loader,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
@ -52,7 +58,7 @@ def evaluate(
|
||||
clip_annotations = []
|
||||
predictions = []
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation)
|
||||
evaluator = build_evaluator(config=config)
|
||||
|
||||
for batch in loader:
|
||||
outputs = model.detector(batch.spec)
|
||||
|
||||
@ -89,7 +89,7 @@ class ExampleGallery(PlotterProtocol):
|
||||
@classmethod
|
||||
def from_config(cls, config: ExampleGalleryConfig):
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio_transforms)
|
||||
return cls(
|
||||
examples_per_class=config.examples_per_class,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@ -29,15 +29,10 @@ provided here.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
Backbone,
|
||||
BackboneConfig,
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
@ -51,6 +46,10 @@ from batdetect2.models.bottleneck import (
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.config import (
|
||||
BackboneConfig,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
@ -63,9 +62,9 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.models import DetectionModel
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsTensor,
|
||||
@ -99,20 +98,10 @@ __all__ = [
|
||||
"build_detector",
|
||||
"load_backbone_config",
|
||||
"Model",
|
||||
"ModelConfig",
|
||||
"build_model",
|
||||
]
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
@ -125,14 +114,12 @@ class Model(torch.nn.Module):
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
config: ModelConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
self.config = config
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||
spec = self.preprocessor(wav)
|
||||
@ -140,32 +127,25 @@ class Model(torch.nn.Module):
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
|
||||
def build_model(config: Optional[ModelConfig] = None):
|
||||
config = config or ModelConfig()
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
def build_model(
|
||||
config: Optional[BackboneConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
):
|
||||
config = config or BackboneConfig()
|
||||
targets = targets or build_targets()
|
||||
preprocessor = preprocessor or build_preprocessor()
|
||||
postprocessor = postprocessor or build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config.model,
|
||||
config=config,
|
||||
)
|
||||
return Model(
|
||||
config=config,
|
||||
detector=detector,
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
def load_model_config(
|
||||
path: PathLike, field: Optional[str] = None
|
||||
) -> ModelConfig:
|
||||
return load_config(path, schema=ModelConfig, field=field)
|
||||
|
||||
@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
|
||||
network's total downsampling factor.
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
build_bottleneck,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
Decoder,
|
||||
DecoderConfig,
|
||||
build_decoder,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
Encoder,
|
||||
EncoderConfig,
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.bottleneck import build_bottleneck
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.models.decoder import Decoder, build_decoder
|
||||
from batdetect2.models.encoder import Encoder, build_encoder
|
||||
from batdetect2.typing.models import BackboneModel
|
||||
|
||||
__all__ = [
|
||||
"Backbone",
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
"build_backbone",
|
||||
]
|
||||
|
||||
@ -161,82 +144,6 @@ class Backbone(BackboneModel):
|
||||
return x
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
"""Factory function to build a Backbone from configuration.
|
||||
|
||||
|
||||
98
src/batdetect2/models/config.py
Normal file
98
src/batdetect2/models/config.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Optional
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
)
|
||||
from batdetect2.models.decoder import (
|
||||
DEFAULT_DECODER_CONFIG,
|
||||
DecoderConfig,
|
||||
)
|
||||
from batdetect2.models.encoder import (
|
||||
DEFAULT_ENCODER_CONFIG,
|
||||
EncoderConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BackboneConfig",
|
||||
"load_backbone_config",
|
||||
]
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
"""Configuration for the Encoder-Decoder Backbone network.
|
||||
|
||||
Aggregates configurations for the encoder, bottleneck, and decoder
|
||||
components, along with defining the input and final output dimensions
|
||||
for the complete backbone.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
input_height : int, default=128
|
||||
Expected height (frequency bins) of the input spectrograms to the
|
||||
backbone. Must be positive.
|
||||
in_channels : int, default=1
|
||||
Expected number of channels in the input spectrograms (e.g., 1 for
|
||||
mono). Must be positive.
|
||||
encoder : EncoderConfig, optional
|
||||
Configuration for the encoder. If None or omitted,
|
||||
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
|
||||
encoder module) will be used.
|
||||
bottleneck : BottleneckConfig, optional
|
||||
Configuration for the bottleneck layer connecting encoder and decoder.
|
||||
If None or omitted, the default bottleneck configuration will be used.
|
||||
decoder : DecoderConfig, optional
|
||||
Configuration for the decoder. If None or omitted,
|
||||
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
|
||||
decoder module) will be used.
|
||||
out_channels : int, default=32
|
||||
Desired number of channels in the final feature map output by the
|
||||
backbone. Must be positive.
|
||||
"""
|
||||
|
||||
input_height: int = 128
|
||||
in_channels: int = 1
|
||||
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
|
||||
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
out_channels: int = 32
|
||||
|
||||
|
||||
def load_backbone_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> BackboneConfig:
|
||||
"""Load the backbone configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`BackboneConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
backbone configuration (e.g., "model.backbone"). If None, the entire
|
||||
file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneConfig
|
||||
The loaded and validated backbone configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `BackboneConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=BackboneConfig, field=field)
|
||||
@ -5,8 +5,9 @@ import torch
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.plotting.common import plot_spectrogram
|
||||
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -1,307 +1,29 @@
|
||||
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
load_postprocess_config,
|
||||
)
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
to_raw_predictions,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
||||
from batdetect2.postprocess.nms import (
|
||||
NMS_KERNEL_SIZE,
|
||||
non_max_suppression,
|
||||
)
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
DetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
from batdetect2.postprocess.postprocessor import (
|
||||
Postprocessor,
|
||||
build_postprocessor,
|
||||
get_raw_predictions,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
"DEFAULT_DETECTION_THRESHOLD",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"ModelOutput",
|
||||
"NMS_KERNEL_SIZE",
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"TOP_K_PER_SEC",
|
||||
"build_postprocessor",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"to_raw_predictions",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
"get_raw_predictions",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 100
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor."""
|
||||
config = config or PostprocessConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building postprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Postprocessor(
|
||||
samplerate=preprocessor.output_samplerate,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
top_k_per_sec=config.top_k_per_sec,
|
||||
detection_threshold=config.detection_threshold,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
self.samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
self.top_k_per_sec = top_k_per_sec
|
||||
self.detection_threshold = detection_threshold
|
||||
|
||||
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=0,
|
||||
end_time=duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
return detections
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection, start_time in zip(detections, start_times)
|
||||
]
|
||||
|
||||
|
||||
def get_raw_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch."""
|
||||
detections = postprocessor.get_detections(output, start_times)
|
||||
return [
|
||||
to_raw_predictions(detection.numpy(), targets=targets)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
|
||||
def get_sound_event_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: List[data.Clip],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
[
|
||||
BatDetect2Prediction(
|
||||
raw=raw,
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
),
|
||||
)
|
||||
for raw in predictions
|
||||
]
|
||||
for predictions, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
|
||||
def get_predictions(
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output and corresponding clips, applies the entire
|
||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
List containing one `ClipPrediction` object for each input clip,
|
||||
populated with `SoundEventPrediction` objects.
|
||||
"""
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
94
src/batdetect2/postprocess/config.py
Normal file
94
src/batdetect2/postprocess/config.py
Normal file
@ -0,0 +1,94 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
||||
|
||||
__all__ = [
|
||||
"PostprocessConfig",
|
||||
"load_postprocess_config",
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 100
|
||||
|
||||
|
||||
class PostprocessConfig(BaseConfig):
|
||||
"""Configuration settings for the postprocessing pipeline.
|
||||
|
||||
Defines tunable parameters that control how raw model outputs are
|
||||
converted into final detections.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
nms_kernel_size : int, default=NMS_KERNEL_SIZE
|
||||
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
|
||||
Used to suppress weaker detections near stronger peaks. Must be
|
||||
positive.
|
||||
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
|
||||
Minimum confidence score from the detection heatmap required to
|
||||
consider a point as a potential detection. Must be >= 0.
|
||||
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
Minimum confidence score for a specific class prediction to be included
|
||||
in the decoded tags for a detection. Must be >= 0.
|
||||
top_k_per_sec : int, default=TOP_K_PER_SEC
|
||||
Desired maximum number of detections per second of audio. Used by
|
||||
`get_max_detections` to calculate an absolute limit based on clip
|
||||
duration before applying `extract_detections_from_array`. Must be
|
||||
positive.
|
||||
"""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(
|
||||
default=DEFAULT_DETECTION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
classification_threshold: float = Field(
|
||||
default=DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
ge=0,
|
||||
)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
def load_postprocess_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PostprocessConfig:
|
||||
"""Load the postprocessing configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`PostprocessConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
postprocessing configuration (e.g., "inference.postprocessing").
|
||||
If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PostprocessConfig
|
||||
The loaded and validated postprocessing configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`PostprocessConfig` schema.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path, schema=PostprocessConfig, field=field)
|
||||
208
src/batdetect2/postprocess/postprocessor.py
Normal file
208
src/batdetect2/postprocess/postprocessor.py
Normal file
@ -0,0 +1,208 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
)
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
to_raw_predictions,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
DetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_postprocessor",
|
||||
"Postprocessor",
|
||||
]
|
||||
|
||||
|
||||
def build_postprocessor(
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[PostprocessConfig] = None,
|
||||
) -> PostprocessorProtocol:
|
||||
"""Factory function to build the standard postprocessor."""
|
||||
config = config or PostprocessConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building postprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Postprocessor(
|
||||
samplerate=preprocessor.output_samplerate,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
top_k_per_sec=config.top_k_per_sec,
|
||||
detection_threshold=config.detection_threshold,
|
||||
)
|
||||
|
||||
|
||||
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
"""Standard implementation of the postprocessing pipeline."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
self.samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
self.top_k_per_sec = top_k_per_sec
|
||||
self.detection_threshold = detection_threshold
|
||||
|
||||
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=0,
|
||||
end_time=duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
return detections
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection, start_time in zip(detections, start_times)
|
||||
]
|
||||
|
||||
|
||||
def get_raw_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch."""
|
||||
detections = postprocessor.get_detections(output, start_times)
|
||||
return [
|
||||
to_raw_predictions(detection.numpy(), targets=targets)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
|
||||
def get_sound_event_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: List[data.Clip],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
[
|
||||
BatDetect2Prediction(
|
||||
raw=raw,
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
),
|
||||
)
|
||||
for raw in predictions
|
||||
]
|
||||
for predictions, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
|
||||
def get_predictions(
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output and corresponding clips, applies the entire
|
||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
List containing one `ClipPrediction` object for each input clip,
|
||||
populated with `SoundEventPrediction` objects.
|
||||
"""
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
@ -1,21 +1,19 @@
|
||||
"""Main entry point for the BatDetect2 preprocessing subsystem."""
|
||||
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.preprocess.config import (
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
PreprocessingConfig,
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
|
||||
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
|
||||
|
||||
__all__ = [
|
||||
"MIN_FREQ",
|
||||
"MAX_FREQ",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"MIN_FREQ",
|
||||
"PreprocessingConfig",
|
||||
"load_preprocessing_config",
|
||||
"Preprocessor",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"build_preprocessor",
|
||||
"build_audio_loader",
|
||||
"load_preprocessing_config",
|
||||
]
|
||||
|
||||
@ -1,267 +1,60 @@
|
||||
"""Handles loading and initial preprocessing of audio waveforms."""
|
||||
from typing import Annotated, Literal, Union
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import DTypeLike
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
||||
from batdetect2.preprocess.config import (
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
AudioTransform,
|
||||
ResampleConfig,
|
||||
)
|
||||
from batdetect2.typing import AudioLoader
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.preprocess.common import center_tensor, peak_normalize
|
||||
|
||||
__all__ = [
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"resample_audio",
|
||||
"CenterAudioConfig",
|
||||
"ScaleAudioConfig",
|
||||
"FixDurationConfig",
|
||||
"build_audio_transform",
|
||||
]
|
||||
|
||||
|
||||
class SoundEventAudioLoader(AudioLoader):
|
||||
"""Concrete implementation of the `AudioLoader`."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
):
|
||||
self.samplerate = samplerate
|
||||
self.config = config or ResampleConfig()
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio directly from a file path."""
|
||||
return load_file_audio(
|
||||
path,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_recording(
|
||||
self,
|
||||
recording: data.Recording,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio for a Recording object."""
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
|
||||
def load_clip(
|
||||
self,
|
||||
clip: data.Clip,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the audio segment defined by a Clip object."""
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=self.samplerate,
|
||||
config=self.config,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
"audio_transform"
|
||||
)
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess audio from a file path using specified config."""
|
||||
try:
|
||||
recording = data.Recording.from_file(path)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {path}. Error: {e}"
|
||||
) from e
|
||||
|
||||
return load_recording_audio(
|
||||
recording,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
name: Literal["center_audio"] = "center_audio"
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess the entire audio content of a recording using config."""
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(
|
||||
clip,
|
||||
samplerate=samplerate,
|
||||
config=config,
|
||||
dtype=dtype,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
class CenterAudio(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return center_tensor(wav)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: CenterAudioConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
samplerate: Optional[int] = None,
|
||||
config: Optional[ResampleConfig] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
dtype: DTypeLike = np.float32, # type: ignore
|
||||
) -> np.ndarray:
|
||||
"""Load and preprocess a specific audio clip segment based on config."""
|
||||
try:
|
||||
wav = (
|
||||
audio.load_clip(clip, audio_dir=audio_dir)
|
||||
.sel(channel=0)
|
||||
.astype(dtype)
|
||||
)
|
||||
except LibsndfileError as e:
|
||||
raise FileNotFoundError(
|
||||
f"Could not load the recording at path: {clip.recording.path}. "
|
||||
f"Error: {e}"
|
||||
) from e
|
||||
|
||||
if not config or not config.enabled or samplerate is None:
|
||||
return wav.data.astype(dtype)
|
||||
|
||||
sr = int(1 / wav.time.attrs["step"])
|
||||
return resample_audio(
|
||||
wav.data,
|
||||
sr=sr,
|
||||
samplerate=samplerate,
|
||||
method=config.method,
|
||||
)
|
||||
audio_transforms.register(CenterAudioConfig, CenterAudio)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: np.ndarray,
|
||||
sr: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
method: str = "poly",
|
||||
) -> np.ndarray:
|
||||
"""Resample an audio waveform DataArray to a target sample rate."""
|
||||
if sr == samplerate:
|
||||
return wav
|
||||
|
||||
if method == "poly":
|
||||
return resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
elif method == "fourier":
|
||||
return resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=sr,
|
||||
sr_new=samplerate,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Resampling method '{method}' not implemented"
|
||||
)
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
name: Literal["scale_audio"] = "scale_audio"
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample_poly`.
|
||||
class ScaleAudio(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return peak_normalize(wav)
|
||||
|
||||
This method is often preferred for signals when the ratio of new
|
||||
to old sample rates can be expressed as a rational number. It uses
|
||||
polyphase filtering.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
sr_orig : int
|
||||
The original sample rate in Hz.
|
||||
sr_new : int
|
||||
The target sample rate in Hz.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to the target sample rate.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If sample rates are not positive.
|
||||
"""
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
@classmethod
|
||||
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: np.ndarray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
"""Resample a numpy array using `scipy.signal.resample`.
|
||||
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
|
||||
|
||||
This method uses FFTs to resample the signal.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : np.ndarray
|
||||
The input array to resample.
|
||||
num : int
|
||||
The desired number of samples in the output array along `axis`.
|
||||
axis : int, default=-1
|
||||
The axis of `array` along which to resample.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The array resampled to have `num` samples along `axis`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `num` is negative.
|
||||
"""
|
||||
ratio = sr_new / sr_orig
|
||||
return resample( # type: ignore
|
||||
array,
|
||||
int(array.shape[axis] * ratio),
|
||||
axis=axis,
|
||||
)
|
||||
class FixDurationConfig(BaseConfig):
|
||||
name: Literal["fix_duration"] = "fix_duration"
|
||||
duration: float = 0.5
|
||||
|
||||
|
||||
class FixDuration(torch.nn.Module):
|
||||
@ -282,40 +75,25 @@ class FixDuration(torch.nn.Module):
|
||||
|
||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||
|
||||
|
||||
def build_audio_loader(
|
||||
config: Optional[AudioConfig] = None,
|
||||
) -> AudioLoader:
|
||||
"""Factory function to create an AudioLoader based on configuration."""
|
||||
config = config or AudioConfig()
|
||||
return SoundEventAudioLoader(
|
||||
samplerate=config.samplerate,
|
||||
config=config.resample,
|
||||
)
|
||||
@classmethod
|
||||
def from_config(cls, config: FixDurationConfig, samplerate: int):
|
||||
return cls(samplerate=samplerate, duration=config.duration)
|
||||
|
||||
|
||||
def build_audio_transform_step(
|
||||
audio_transforms.register(FixDurationConfig, FixDuration)
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
FixDurationConfig,
|
||||
ScaleAudioConfig,
|
||||
CenterAudioConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_audio_transform(
|
||||
config: AudioTransform,
|
||||
samplerate: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
if config.name == "fix_duration":
|
||||
return FixDuration(samplerate=samplerate, duration=config.duration)
|
||||
|
||||
if config.name == "scale_audio":
|
||||
return PeakNormalize()
|
||||
|
||||
if config.name == "center_audio":
|
||||
return CenterTensor()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Audio preprocessing step {config.name} not implemented"
|
||||
)
|
||||
|
||||
|
||||
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
|
||||
return torch.nn.Sequential(
|
||||
*[
|
||||
build_audio_transform_step(step, samplerate=config.samplerate)
|
||||
for step in config.transforms
|
||||
]
|
||||
)
|
||||
return audio_transforms.build(config, samplerate)
|
||||
|
||||
@ -1,24 +1,22 @@
|
||||
import torch
|
||||
|
||||
__all__ = [
|
||||
"CenterTensor",
|
||||
"PeakNormalize",
|
||||
"center_tensor",
|
||||
"peak_normalize",
|
||||
]
|
||||
|
||||
|
||||
class CenterTensor(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor):
|
||||
return wav - wav.mean()
|
||||
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return tensor - tensor.mean()
|
||||
|
||||
|
||||
class PeakNormalize(torch.nn.Module):
|
||||
def forward(self, wav: torch.Tensor):
|
||||
max_value = wav.abs().min()
|
||||
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
|
||||
max_value = tensor.abs().min()
|
||||
|
||||
denominator = torch.where(
|
||||
max_value == 0,
|
||||
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
|
||||
torch.tensor(1.0, device=tensor.device, dtype=tensor.dtype),
|
||||
max_value,
|
||||
)
|
||||
|
||||
return wav / denominator
|
||||
return tensor / denominator
|
||||
|
||||
@ -1,187 +1,25 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import AudioTransform
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
ResizeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectrogramTransform,
|
||||
STFTConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"load_preprocessing_config",
|
||||
"CenterAudioConfig",
|
||||
"ScaleAudioConfig",
|
||||
"FixDurationConfig",
|
||||
"ResampleConfig",
|
||||
"AudioTransform",
|
||||
"AudioConfig",
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"PcenConfig",
|
||||
"ScaleAmplitudeConfig",
|
||||
"SpectralMeanSubstractionConfig",
|
||||
"ResizeConfig",
|
||||
"PeakNormalizeConfig",
|
||||
"SpectrogramTransform",
|
||||
"SpectrogramConfig",
|
||||
"PreprocessingConfig",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"MIN_FREQ",
|
||||
"MAX_FREQ",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
MIN_FREQ = 10_000
|
||||
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
MAX_FREQ = 120_000
|
||||
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
name: Literal["center_audio"] = "center_audio"
|
||||
|
||||
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
name: Literal["scale_audio"] = "scale_audio"
|
||||
|
||||
|
||||
class FixDurationConfig(BaseConfig):
|
||||
name: Literal["fix_duration"] = "fix_duration"
|
||||
duration: float = 0.5
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
FixDurationConfig,
|
||||
ScaleAudioConfig,
|
||||
CenterAudioConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
"""Configuration for the Short-Time Fourier Transform (STFT).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
window_duration : float, default=0.002
|
||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
||||
> 0. Determines frequency resolution (longer window = finer frequency
|
||||
resolution).
|
||||
window_overlap : float, default=0.75
|
||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
||||
(higher overlap = finer time resolution).
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function to apply before FFT calculation. Common
|
||||
options include "hann", "hamming", "blackman". See
|
||||
`scipy.signal.get_window`.
|
||||
"""
|
||||
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
"""Configuration for frequency axis parameters.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
max_freq : int, default=120000
|
||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies above this value will be cropped. Must be > 0.
|
||||
min_freq : int, default=10000
|
||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=120_000, ge=0)
|
||||
min_freq: int = Field(default=10_000, ge=0)
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class ScaleAmplitudeConfig(BaseConfig):
|
||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||
scale: Literal["power", "db"] = "db"
|
||||
|
||||
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||
|
||||
|
||||
class ResizeConfig(BaseConfig):
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
|
||||
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
name: Literal["peak_normalize"] = "peak_normalize"
|
||||
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
PeakNormalizeConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
transforms: Sequence[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
@ -201,8 +39,20 @@ class PreprocessingConfig(BaseConfig):
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
"""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
audio_transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
spectrogram_transforms: List[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
|
||||
@ -3,21 +3,25 @@ from typing import Optional
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.preprocess.audio import build_audio_pipeline
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.preprocess.audio import build_audio_transform
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
_spec_params_from_config,
|
||||
build_spectrogram_pipeline,
|
||||
build_spectrogram_builder,
|
||||
build_spectrogram_crop,
|
||||
build_spectrogram_resizer,
|
||||
build_spectrogram_transform,
|
||||
)
|
||||
from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline
|
||||
from batdetect2.typing import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"StandardPreprocessor",
|
||||
"Preprocessor",
|
||||
"build_preprocessor",
|
||||
]
|
||||
|
||||
|
||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
|
||||
input_samplerate: int
|
||||
@ -28,37 +32,78 @@ class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_pipeline: torch.nn.Module,
|
||||
spectrogram_pipeline: SpectrogramPipeline,
|
||||
config: PreprocessingConfig,
|
||||
input_samplerate: int,
|
||||
output_samplerate: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.audio_pipeline = audio_pipeline
|
||||
self.spectrogram_pipeline = spectrogram_pipeline
|
||||
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
self.audio_transforms = torch.nn.Sequential(
|
||||
*[
|
||||
build_audio_transform(step, samplerate=input_samplerate)
|
||||
for step in config.audio_transforms
|
||||
]
|
||||
)
|
||||
|
||||
self.spectrogram_transforms = torch.nn.Sequential(
|
||||
*[
|
||||
build_spectrogram_transform(step, samplerate=input_samplerate)
|
||||
for step in config.spectrogram_transforms
|
||||
]
|
||||
)
|
||||
|
||||
self.spectrogram_builder = build_spectrogram_builder(
|
||||
config.stft,
|
||||
samplerate=input_samplerate,
|
||||
)
|
||||
|
||||
self.spectrogram_crop = build_spectrogram_crop(
|
||||
config.frequencies,
|
||||
stft=config.stft,
|
||||
samplerate=input_samplerate,
|
||||
)
|
||||
|
||||
self.spectrogram_resizer = build_spectrogram_resizer(config.size)
|
||||
|
||||
self.min_freq = config.frequencies.min_freq
|
||||
self.max_freq = config.frequencies.max_freq
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.output_samplerate = output_samplerate
|
||||
self.output_samplerate = compute_output_samplerate(
|
||||
config,
|
||||
input_samplerate=input_samplerate,
|
||||
)
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = self.audio_pipeline(wav)
|
||||
return self.spectrogram_pipeline(wav)
|
||||
wav = self.audio_transforms(wav)
|
||||
spec = self.spectrogram_builder(wav)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.spectrogram_builder(wav)
|
||||
|
||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self(wav)
|
||||
|
||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
spec = self.spectrogram_crop(spec)
|
||||
spec = self.spectrogram_transforms(spec)
|
||||
return self.spectrogram_resizer(spec)
|
||||
|
||||
|
||||
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
||||
samplerate = config.audio.samplerate
|
||||
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
||||
factor = config.spectrogram.size.resize_factor
|
||||
return samplerate * factor / hop_size
|
||||
def compute_output_samplerate(
|
||||
config: PreprocessingConfig,
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> float:
|
||||
_, hop_size = _spec_params_from_config(
|
||||
config.stft, samplerate=input_samplerate
|
||||
)
|
||||
factor = config.size.resize_factor
|
||||
return input_samplerate * factor / hop_size
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
config = config or PreprocessingConfig()
|
||||
@ -66,21 +111,4 @@ def build_preprocessor(
|
||||
"Building preprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
samplerate = config.audio.samplerate
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
output_samplerate = compute_output_samplerate(config)
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_pipeline=build_audio_pipeline(config.audio),
|
||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||
samplerate, config.spectrogram
|
||||
),
|
||||
input_samplerate=samplerate,
|
||||
output_samplerate=output_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
return Preprocessor(config=config, input_samplerate=input_samplerate)
|
||||
|
||||
@ -1,34 +1,64 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||
|
||||
from typing import Callable, Optional
|
||||
from typing import Annotated, Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.preprocess.common import PeakNormalize
|
||||
from batdetect2.preprocess.config import (
|
||||
ScaleAmplitudeConfig,
|
||||
SpectrogramConfig,
|
||||
SpectrogramTransform,
|
||||
STFTConfig,
|
||||
)
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.preprocess.common import peak_normalize
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
"build_spectrogram_transform",
|
||||
"build_spectrogram_builder",
|
||||
"build_spectrogram_pipeline",
|
||||
]
|
||||
|
||||
|
||||
MIN_FREQ = 10_000
|
||||
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
MAX_FREQ = 120_000
|
||||
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
"""Configuration for the Short-Time Fourier Transform (STFT).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
window_duration : float, default=0.002
|
||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
||||
> 0. Determines frequency resolution (longer window = finer frequency
|
||||
resolution).
|
||||
window_overlap : float, default=0.75
|
||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
||||
(higher overlap = finer time resolution).
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function to apply before FFT calculation. Common
|
||||
options include "hann", "hamming", "blackman". See
|
||||
`scipy.signal.get_window`.
|
||||
"""
|
||||
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
def build_spectrogram_builder(
|
||||
samplerate: int,
|
||||
conf: STFTConfig,
|
||||
config: STFTConfig,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
window_fn=get_spectrogram_window(conf.window_fn),
|
||||
window_fn=get_spectrogram_window(config.window_fn),
|
||||
center=True,
|
||||
power=1,
|
||||
)
|
||||
@ -55,16 +85,19 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||
)
|
||||
|
||||
|
||||
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
||||
n_fft = int(samplerate * conf.window_duration)
|
||||
hop_length = int(n_fft * (1 - conf.window_overlap))
|
||||
def _spec_params_from_config(
|
||||
config: STFTConfig,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
):
|
||||
n_fft = int(samplerate * config.window_duration)
|
||||
hop_length = int(n_fft * (1 - config.window_overlap))
|
||||
return n_fft, hop_length
|
||||
|
||||
|
||||
def _frequency_to_index(
|
||||
freq: float,
|
||||
samplerate: int,
|
||||
n_fft: int,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> Optional[int]:
|
||||
alpha = freq * 2 / samplerate
|
||||
height = np.floor(n_fft / 2) + 1
|
||||
@ -79,14 +112,49 @@ def _frequency_to_index(
|
||||
return index
|
||||
|
||||
|
||||
class FrequencyClip(torch.nn.Module):
|
||||
class FrequencyConfig(BaseConfig):
|
||||
"""Configuration for frequency axis parameters.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
max_freq : int, default=120000
|
||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies above this value will be cropped. Must be > 0.
|
||||
min_freq : int, default=10000
|
||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=MAX_FREQ, ge=0)
|
||||
min_freq: int = Field(default=MIN_FREQ, ge=0)
|
||||
|
||||
|
||||
class FrequencyCrop(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
low_index: Optional[int] = None,
|
||||
high_index: Optional[int] = None,
|
||||
samplerate: int,
|
||||
n_fft: int,
|
||||
min_freq: Optional[int] = None,
|
||||
max_freq: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_fft = n_fft
|
||||
self.samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
|
||||
low_index = None
|
||||
if min_freq is not None:
|
||||
low_index = _frequency_to_index(
|
||||
min_freq, self.samplerate, self.n_fft
|
||||
)
|
||||
self.low_index = low_index
|
||||
|
||||
high_index = None
|
||||
if max_freq is not None:
|
||||
high_index = _frequency_to_index(
|
||||
max_freq, self.samplerate, self.n_fft
|
||||
)
|
||||
self.high_index = high_index
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
@ -107,6 +175,72 @@ class FrequencyClip(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
def build_spectrogram_crop(
|
||||
config: FrequencyConfig,
|
||||
stft: Optional[STFTConfig] = None,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
) -> torch.nn.Module:
|
||||
stft = stft or STFTConfig()
|
||||
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
|
||||
return FrequencyCrop(
|
||||
samplerate=samplerate,
|
||||
n_fft=n_fft,
|
||||
min_freq=config.min_freq,
|
||||
max_freq=config.max_freq,
|
||||
)
|
||||
|
||||
|
||||
class ResizeConfig(BaseConfig):
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
|
||||
|
||||
class ResizeSpec(torch.nn.Module):
|
||||
def __init__(self, height: int, time_factor: float):
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.time_factor = time_factor
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
current_length = spec.shape[-1]
|
||||
target_length = int(self.time_factor * current_length)
|
||||
|
||||
original_ndim = spec.ndim
|
||||
while spec.ndim < 4:
|
||||
spec = spec.unsqueeze(0)
|
||||
|
||||
resized = torch.nn.functional.interpolate(
|
||||
spec,
|
||||
size=(self.height, target_length),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
while resized.ndim != original_ndim:
|
||||
resized = resized.squeeze(0)
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
|
||||
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
|
||||
|
||||
|
||||
spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
|
||||
"spectrogram_transform"
|
||||
)
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class PCEN(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -115,7 +249,7 @@ class PCEN(torch.nn.Module):
|
||||
bias: float = 2.0,
|
||||
power: float = 0.5,
|
||||
eps: float = 1e-6,
|
||||
dtype=torch.float64,
|
||||
dtype=torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.smoothing_constant = smoothing_constant
|
||||
@ -151,6 +285,19 @@ class PCEN(torch.nn.Module):
|
||||
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
|
||||
).to(spec.dtype)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PcenConfig, samplerate: int):
|
||||
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
|
||||
return cls(
|
||||
smoothing_constant=smooth,
|
||||
gain=config.gain,
|
||||
bias=config.bias,
|
||||
power=config.power,
|
||||
)
|
||||
|
||||
|
||||
spectrogram_transforms.register(PcenConfig, PCEN)
|
||||
|
||||
|
||||
def _compute_smoothing_constant(
|
||||
samplerate: int,
|
||||
@ -164,21 +311,40 @@ def _compute_smoothing_constant(
|
||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
|
||||
|
||||
class ScaleAmplitudeConfig(BaseConfig):
|
||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||
scale: Literal["power", "db"] = "db"
|
||||
|
||||
|
||||
class ToPower(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return spec**2
|
||||
|
||||
|
||||
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
|
||||
if conf.scale == "db":
|
||||
return torchaudio.transforms.AmplitudeToDB()
|
||||
_scalers = {
|
||||
"db": torchaudio.transforms.AmplitudeToDB,
|
||||
"power": ToPower,
|
||||
}
|
||||
|
||||
if conf.scale == "power":
|
||||
return ToPower()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Amplitude scaling {conf.scale} not implemented"
|
||||
)
|
||||
class ScaleAmplitude(torch.nn.Module):
|
||||
def __init__(self, scale: Literal["power", "db"]):
|
||||
self.scale = scale
|
||||
self.scaler = _scalers[scale]()
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.scaler(spec)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
|
||||
return cls(scale=config.scale)
|
||||
|
||||
|
||||
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
|
||||
|
||||
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||
|
||||
|
||||
class SpectralMeanSubstraction(torch.nn.Module):
|
||||
@ -186,129 +352,49 @@ class SpectralMeanSubstraction(torch.nn.Module):
|
||||
mean = spec.mean(-1, keepdim=True)
|
||||
return (spec - mean).clamp(min=0)
|
||||
|
||||
|
||||
class ResizeSpec(torch.nn.Module):
|
||||
def __init__(self, height: int, time_factor: float):
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.time_factor = time_factor
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
current_length = spec.shape[-1]
|
||||
target_length = int(self.time_factor * current_length)
|
||||
|
||||
original_ndim = spec.ndim
|
||||
while spec.ndim < 4:
|
||||
spec = spec.unsqueeze(0)
|
||||
|
||||
resized = torch.nn.functional.interpolate(
|
||||
spec,
|
||||
size=(self.height, target_length),
|
||||
mode="bilinear",
|
||||
)
|
||||
|
||||
while resized.ndim != original_ndim:
|
||||
resized = resized.squeeze(0)
|
||||
|
||||
return resized
|
||||
|
||||
|
||||
def _build_spectrogram_transform_step(
|
||||
step: SpectrogramTransform,
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: SpectralMeanSubstractionConfig,
|
||||
samplerate: int,
|
||||
) -> torch.nn.Module:
|
||||
if step.name == "pcen":
|
||||
return PCEN(
|
||||
smoothing_constant=_compute_smoothing_constant(
|
||||
samplerate=samplerate,
|
||||
time_constant=step.time_constant,
|
||||
),
|
||||
gain=step.gain,
|
||||
bias=step.bias,
|
||||
power=step.power,
|
||||
)
|
||||
):
|
||||
return cls()
|
||||
|
||||
if step.name == "scale_amplitude":
|
||||
return _build_amplitude_scaler(step)
|
||||
|
||||
if step.name == "spectral_mean_substraction":
|
||||
return SpectralMeanSubstraction()
|
||||
spectrogram_transforms.register(
|
||||
SpectralMeanSubstractionConfig,
|
||||
SpectralMeanSubstraction,
|
||||
)
|
||||
|
||||
if step.name == "peak_normalize":
|
||||
return PeakNormalize()
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Spectrogram preprocessing step {step.name} not implemented"
|
||||
)
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
name: Literal["peak_normalize"] = "peak_normalize"
|
||||
|
||||
|
||||
class PeakNormalize(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return peak_normalize(spec)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
|
||||
return cls()
|
||||
|
||||
|
||||
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
PeakNormalizeConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_spectrogram_transform(
|
||||
config: SpectrogramTransform,
|
||||
samplerate: int,
|
||||
conf: SpectrogramConfig,
|
||||
) -> torch.nn.Module:
|
||||
return torch.nn.Sequential(
|
||||
*[
|
||||
_build_spectrogram_transform_step(step, samplerate=samplerate)
|
||||
for step in conf.transforms
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class SpectrogramPipeline(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
spec_builder: torch.nn.Module,
|
||||
freq_cutter: torch.nn.Module,
|
||||
transforms: torch.nn.Module,
|
||||
resizer: torch.nn.Module,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_builder = spec_builder
|
||||
self.freq_cutter = freq_cutter
|
||||
self.transforms = transforms
|
||||
self.resizer = resizer
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
spec = self.spec_builder(wav)
|
||||
spec = self.freq_cutter(spec)
|
||||
spec = self.transforms(spec)
|
||||
return self.resizer(spec)
|
||||
|
||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
return self.spec_builder(wav)
|
||||
|
||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.freq_cutter(spec)
|
||||
|
||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.transforms(spec)
|
||||
|
||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return self.resizer(spec)
|
||||
|
||||
|
||||
def build_spectrogram_pipeline(
|
||||
samplerate: int,
|
||||
conf: SpectrogramConfig,
|
||||
) -> SpectrogramPipeline:
|
||||
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
|
||||
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
|
||||
cutter = FrequencyClip(
|
||||
low_index=_frequency_to_index(
|
||||
conf.frequencies.min_freq, samplerate, n_fft
|
||||
),
|
||||
high_index=_frequency_to_index(
|
||||
conf.frequencies.max_freq, samplerate, n_fft
|
||||
),
|
||||
)
|
||||
transforms = build_spectrogram_transform(samplerate, conf)
|
||||
resizer = ResizeSpec(
|
||||
height=conf.size.height,
|
||||
time_factor=conf.size.resize_factor,
|
||||
)
|
||||
return SpectrogramPipeline(
|
||||
spec_builder=spec_builder,
|
||||
freq_cutter=cutter,
|
||||
transforms=transforms,
|
||||
resizer=resizer,
|
||||
)
|
||||
return spectrogram_transforms.build(config, samplerate)
|
||||
|
||||
@ -1,17 +1,6 @@
|
||||
"""BatDetect2 Target Definition system."""
|
||||
|
||||
from collections import Counter
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.conditions import build_sound_event_condition
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_CLASSES,
|
||||
DEFAULT_DETECTION_CLASS,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
TargetClassConfig,
|
||||
@ -19,23 +8,29 @@ from batdetect2.targets.classes import (
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
)
|
||||
from batdetect2.targets.config import TargetConfig, load_target_config
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
ROIMapperConfig,
|
||||
ROITargetMapper,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.targets.targets import (
|
||||
Targets,
|
||||
build_targets,
|
||||
iterate_encoded_sound_events,
|
||||
load_targets,
|
||||
)
|
||||
from batdetect2.targets.terms import (
|
||||
call_type,
|
||||
data_source,
|
||||
generic_class,
|
||||
individual,
|
||||
)
|
||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"AnchorBBoxMapperConfig",
|
||||
"DEFAULT_TARGET_CONFIG",
|
||||
"ROIMapperConfig",
|
||||
"ROITargetMapper",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
@ -45,365 +40,13 @@ __all__ = [
|
||||
"build_roi_mapper",
|
||||
"build_sound_event_decoder",
|
||||
"build_sound_event_encoder",
|
||||
"build_targets",
|
||||
"call_type",
|
||||
"data_source",
|
||||
"generic_class",
|
||||
"get_class_names_from_config",
|
||||
"individual",
|
||||
"iterate_encoded_sound_events",
|
||||
"load_target_config",
|
||||
"load_targets",
|
||||
]
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
detection_target: TargetClassConfig = Field(
|
||||
default=DEFAULT_DETECTION_CLASS
|
||||
)
|
||||
|
||||
classification_targets: List[TargetClassConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES
|
||||
)
|
||||
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
|
||||
@field_validator("classification_targets")
|
||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TargetConfig:
|
||||
"""Load the unified target configuration from a file.
|
||||
|
||||
Reads a configuration file (typically YAML) and validates it against the
|
||||
`TargetConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`TargetConfig` schema (including validation within nested configs
|
||||
like `ClassesConfig`).
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path=path, schema=TargetConfig, field=field)
|
||||
|
||||
|
||||
class Targets(TargetProtocol):
|
||||
"""Encapsulates the complete configured target definition pipeline.
|
||||
|
||||
This class implements the `TargetProtocol`, holding the configured
|
||||
functions for filtering, transforming, encoding (tags to class name),
|
||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||
and back). It provides a high-level interface to apply these steps and
|
||||
access relevant metadata like class names and dimension names.
|
||||
|
||||
Instances are typically created using the `build_targets` factory function
|
||||
or the `load_targets` convenience loader.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined in the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
detection_class_tags: List[data.Tag]
|
||||
dimension_names: List[str]
|
||||
detection_class_name: str
|
||||
|
||||
def __init__(self, config: TargetConfig):
|
||||
"""Initialize the Targets object."""
|
||||
self.config = config
|
||||
|
||||
self._filter_fn = build_sound_event_condition(
|
||||
config.detection_target.match_if
|
||||
)
|
||||
self._encode_fn = build_sound_event_encoder(
|
||||
config.classification_targets
|
||||
)
|
||||
self._decode_fn = build_sound_event_decoder(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self._roi_mapper = build_roi_mapper(config.roi)
|
||||
|
||||
self.dimension_names = self._roi_mapper.dimension_names
|
||||
|
||||
self.class_names = get_class_names_from_config(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self.detection_class_name = config.detection_target.name
|
||||
self.detection_class_tags = config.detection_target.assign_tags
|
||||
|
||||
self._roi_mapper_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classification_targets
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
if class_name not in self.class_names:
|
||||
# TODO: improve this warning
|
||||
logger.warning(
|
||||
"The ROI mapper overrides contains a class ({class_name}) "
|
||||
"not present in the class names.",
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation should be kept (passes the filter),
|
||||
False otherwise. If no filter was configured, always returns True.
|
||||
"""
|
||||
return self._filter_fn(sound_event)
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
to determine the specific class name for the annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to encode. Note: This should typically be called
|
||||
*after* applying any transformations via the `transform` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the matched target class, or None if the annotation
|
||||
does not match any specific class rule (i.e., it belongs to the
|
||||
generic category).
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
`TargetClass.tags`) to convert a class name string into a list of
|
||||
`soundevent.data.Tag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_label : str
|
||||
The class name to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags corresponding to the input class name.
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
class_name = self.encode_class(sound_event)
|
||||
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].encode(
|
||||
sound_event.sound_event
|
||||
)
|
||||
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
un-scales the dimensions and reconstructs the geometry (typically a
|
||||
`BoundingBox`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].decode(
|
||||
position,
|
||||
size,
|
||||
)
|
||||
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
classification_targets=DEFAULT_CLASSES,
|
||||
detection_target=DEFAULT_DETECTION_CLASS,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If term keys or derivation function keys specified in the `config`
|
||||
are not found in their respective registries.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
config = config or DEFAULT_TARGET_CONFIG
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building targets with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
return Targets(config=config)
|
||||
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
This convenience factory method loads the `TargetConfig` from the
|
||||
specified file path and then calls `Targets.from_config` to build
|
||||
the fully initialized `Targets` object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing
|
||||
the target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
Errors raised during file loading, validation, or extraction via
|
||||
`load_target_config`.
|
||||
KeyError, ImportError, AttributeError, TypeError
|
||||
Errors raised during the build process by `Targets.from_config`
|
||||
(e.g., missing keys in registries, failed imports).
|
||||
"""
|
||||
config = load_target_config(
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return build_targets(config)
|
||||
|
||||
|
||||
def iterate_encoded_sound_events(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
||||
for sound_event in sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
geometry = sound_event.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event)
|
||||
position, size = targets.encode_roi(sound_event)
|
||||
|
||||
yield class_name, position, size
|
||||
|
||||
84
src/batdetect2/targets/config.py
Normal file
84
src/batdetect2/targets/config.py
Normal file
@ -0,0 +1,84 @@
|
||||
from collections import Counter
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_CLASSES,
|
||||
DEFAULT_DETECTION_CLASS,
|
||||
TargetClassConfig,
|
||||
)
|
||||
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"load_target_config",
|
||||
]
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
detection_target: TargetClassConfig = Field(
|
||||
default=DEFAULT_DETECTION_CLASS
|
||||
)
|
||||
|
||||
classification_targets: List[TargetClassConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES
|
||||
)
|
||||
|
||||
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
|
||||
|
||||
@field_validator("classification_targets")
|
||||
def check_unique_class_names(cls, v: List[TargetClassConfig]):
|
||||
"""Ensure all defined class names are unique."""
|
||||
names = [c.name for c in v]
|
||||
|
||||
if len(names) != len(set(names)):
|
||||
name_counts = Counter(names)
|
||||
duplicates = [
|
||||
name for name, count in name_counts.items() if count > 1
|
||||
]
|
||||
raise ValueError(
|
||||
"Class names must be unique. Found duplicates: "
|
||||
f"{', '.join(duplicates)}"
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
def load_target_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TargetConfig:
|
||||
"""Load the unified target configuration from a file.
|
||||
|
||||
Reads a configuration file (typically YAML) and validates it against the
|
||||
`TargetConfig` schema, potentially extracting data from a nested field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded configuration data does not conform to the
|
||||
`TargetConfig` schema (including validation within nested configs
|
||||
like `ClassesConfig`).
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path within the loaded data.
|
||||
"""
|
||||
return load_config(path=path, schema=TargetConfig, field=field)
|
||||
@ -26,10 +26,10 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core.arrays import spec_to_xarray
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Position,
|
||||
@ -265,6 +265,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
|
||||
"""
|
||||
|
||||
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
@ -456,8 +457,11 @@ def build_roi_mapper(
|
||||
)
|
||||
|
||||
if config.name == "peak_energy_bbox":
|
||||
preprocessor = build_preprocessor(config.preprocessing)
|
||||
audio_loader = build_audio_loader(config.preprocessing.audio)
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return PeakEnergyBBoxMapper(
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
|
||||
308
src/batdetect2/targets/targets.py
Normal file
308
src/batdetect2/targets/targets.py
Normal file
@ -0,0 +1,308 @@
|
||||
from typing import Iterable, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.conditions import build_sound_event_condition
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_CLASSES,
|
||||
DEFAULT_DETECTION_CLASS,
|
||||
build_sound_event_decoder,
|
||||
build_sound_event_encoder,
|
||||
get_class_names_from_config,
|
||||
)
|
||||
from batdetect2.targets.config import TargetConfig, load_target_config
|
||||
from batdetect2.targets.rois import (
|
||||
AnchorBBoxMapperConfig,
|
||||
build_roi_mapper,
|
||||
)
|
||||
from batdetect2.typing.targets import Position, Size, TargetProtocol
|
||||
|
||||
|
||||
class Targets(TargetProtocol):
|
||||
"""Encapsulates the complete configured target definition pipeline.
|
||||
|
||||
This class implements the `TargetProtocol`, holding the configured
|
||||
functions for filtering, transforming, encoding (tags to class name),
|
||||
decoding (class name to tags), and mapping ROIs (geometry to position/size
|
||||
and back). It provides a high-level interface to apply these steps and
|
||||
access relevant metadata like class names and dimension names.
|
||||
|
||||
Instances are typically created using the `build_targets` factory function
|
||||
or the `load_targets` convenience loader.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
class_names : List[str]
|
||||
An ordered list of the unique names of the specific target classes
|
||||
defined in the configuration.
|
||||
generic_class_tags : List[data.Tag]
|
||||
A list of `soundevent.data.Tag` objects representing the configured
|
||||
generic class category (used when no specific class matches).
|
||||
dimension_names : List[str]
|
||||
The names of the size dimensions handled by the ROI mapper
|
||||
(e.g., ['width', 'height']).
|
||||
"""
|
||||
|
||||
class_names: List[str]
|
||||
detection_class_tags: List[data.Tag]
|
||||
dimension_names: List[str]
|
||||
detection_class_name: str
|
||||
|
||||
def __init__(self, config: TargetConfig):
|
||||
"""Initialize the Targets object."""
|
||||
self.config = config
|
||||
|
||||
self._filter_fn = build_sound_event_condition(
|
||||
config.detection_target.match_if
|
||||
)
|
||||
self._encode_fn = build_sound_event_encoder(
|
||||
config.classification_targets
|
||||
)
|
||||
self._decode_fn = build_sound_event_decoder(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self._roi_mapper = build_roi_mapper(config.roi)
|
||||
|
||||
self.dimension_names = self._roi_mapper.dimension_names
|
||||
|
||||
self.class_names = get_class_names_from_config(
|
||||
config.classification_targets
|
||||
)
|
||||
|
||||
self.detection_class_name = config.detection_target.name
|
||||
self.detection_class_tags = config.detection_target.assign_tags
|
||||
|
||||
self._roi_mapper_overrides = {
|
||||
class_config.name: build_roi_mapper(class_config.roi)
|
||||
for class_config in config.classification_targets
|
||||
if class_config.roi is not None
|
||||
}
|
||||
|
||||
for class_name in self._roi_mapper_overrides:
|
||||
if class_name not in self.class_names:
|
||||
# TODO: improve this warning
|
||||
logger.warning(
|
||||
"The ROI mapper overrides contains a class ({class_name}) "
|
||||
"not present in the class names.",
|
||||
class_name=class_name,
|
||||
)
|
||||
|
||||
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
|
||||
"""Apply the configured filter to a sound event annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to filter.
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool
|
||||
True if the annotation should be kept (passes the filter),
|
||||
False otherwise. If no filter was configured, always returns True.
|
||||
"""
|
||||
return self._filter_fn(sound_event)
|
||||
|
||||
def encode_class(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> Optional[str]:
|
||||
"""Encode a sound event annotation to its target class name.
|
||||
|
||||
Applies the configured class definition rules (including priority)
|
||||
to determine the specific class name for the annotation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation to encode. Note: This should typically be called
|
||||
*after* applying any transformations via the `transform` method.
|
||||
|
||||
Returns
|
||||
-------
|
||||
str or None
|
||||
The name of the matched target class, or None if the annotation
|
||||
does not match any specific class rule (i.e., it belongs to the
|
||||
generic category).
|
||||
"""
|
||||
return self._encode_fn(sound_event)
|
||||
|
||||
def decode_class(self, class_label: str) -> List[data.Tag]:
|
||||
"""Decode a predicted class name back into representative tags.
|
||||
|
||||
Uses the configured mapping (based on `TargetClass.output_tags` or
|
||||
`TargetClass.tags`) to convert a class name string into a list of
|
||||
`soundevent.data.Tag` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_label : str
|
||||
The class name to decode.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.Tag]
|
||||
The list of tags corresponding to the input class name.
|
||||
"""
|
||||
return self._decode_fn(class_label)
|
||||
|
||||
def encode_roi(
|
||||
self, sound_event: data.SoundEventAnnotation
|
||||
) -> tuple[Position, Size]:
|
||||
"""Extract the target reference position from the annotation's roi.
|
||||
|
||||
Delegates to the internal ROI mapper's `get_roi_position` method.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_event : data.SoundEventAnnotation
|
||||
The annotation containing the geometry (ROI).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the annotation lacks geometry.
|
||||
"""
|
||||
class_name = self.encode_class(sound_event)
|
||||
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].encode(
|
||||
sound_event.sound_event
|
||||
)
|
||||
|
||||
return self._roi_mapper.encode(sound_event.sound_event)
|
||||
|
||||
def decode_roi(
|
||||
self,
|
||||
position: Position,
|
||||
size: Size,
|
||||
class_name: Optional[str] = None,
|
||||
) -> data.Geometry:
|
||||
"""Recover an approximate geometric ROI from a position and dimensions.
|
||||
|
||||
Delegates to the internal ROI mapper's `recover_roi` method, which
|
||||
un-scales the dimensions and reconstructs the geometry (typically a
|
||||
`BoundingBox`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pos : Tuple[float, float]
|
||||
The reference position `(time, frequency)`.
|
||||
dims : np.ndarray
|
||||
NumPy array with size dimensions (e.g., from model prediction),
|
||||
matching the order in `self.dimension_names`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
data.Geometry
|
||||
The reconstructed geometry (typically `BoundingBox`).
|
||||
"""
|
||||
if class_name in self._roi_mapper_overrides:
|
||||
return self._roi_mapper_overrides[class_name].decode(
|
||||
position,
|
||||
size,
|
||||
)
|
||||
|
||||
return self._roi_mapper.decode(position, size)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
classification_targets=DEFAULT_CLASSES,
|
||||
detection_target=DEFAULT_DETECTION_CLASS,
|
||||
roi=AnchorBBoxMapperConfig(),
|
||||
)
|
||||
|
||||
|
||||
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
|
||||
"""Build a Targets object from a loaded TargetConfig.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : TargetConfig
|
||||
The loaded and validated unified target configuration object.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError
|
||||
If term keys or derivation function keys specified in the `config`
|
||||
are not found in their respective registries.
|
||||
ImportError, AttributeError, TypeError
|
||||
If dynamic import of a derivation function fails (when configured).
|
||||
"""
|
||||
config = config or DEFAULT_TARGET_CONFIG
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building targets with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
return Targets(config=config)
|
||||
|
||||
|
||||
def load_targets(
|
||||
config_path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> Targets:
|
||||
"""Load a Targets object directly from a configuration file.
|
||||
|
||||
This convenience factory method loads the `TargetConfig` from the
|
||||
specified file path and then calls `Targets.from_config` to build
|
||||
the fully initialized `Targets` object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config_path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing
|
||||
the target configuration. If None, the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Targets
|
||||
An initialized `Targets` object ready for use.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
|
||||
TypeError
|
||||
Errors raised during file loading, validation, or extraction via
|
||||
`load_target_config`.
|
||||
KeyError, ImportError, AttributeError, TypeError
|
||||
Errors raised during the build process by `Targets.from_config`
|
||||
(e.g., missing keys in registries, failed imports).
|
||||
"""
|
||||
config = load_target_config(
|
||||
config_path,
|
||||
field=field,
|
||||
)
|
||||
return build_targets(config)
|
||||
|
||||
|
||||
def iterate_encoded_sound_events(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> Iterable[Tuple[Optional[str], Position, Size]]:
|
||||
for sound_event in sound_events:
|
||||
if not targets.filter(sound_event):
|
||||
continue
|
||||
|
||||
geometry = sound_event.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event)
|
||||
position, size = targets.encode_roi(sound_event)
|
||||
|
||||
yield class_name, position, size
|
||||
@ -16,10 +16,8 @@ from batdetect2.train.augmentations import (
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
load_full_training_config,
|
||||
load_train_config,
|
||||
)
|
||||
from batdetect2.train.dataset import (
|
||||
@ -48,7 +46,6 @@ __all__ = [
|
||||
"DetectionLossConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"FullTrainingConfig",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"PLTrainerConfig",
|
||||
@ -71,7 +68,6 @@ __all__ = [
|
||||
"build_trainer",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"load_full_training_config",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"mask_frequency",
|
||||
|
||||
@ -4,8 +4,6 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
@ -22,8 +20,6 @@ from batdetect2.train.losses import LossConfig
|
||||
__all__ = [
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
"FullTrainingConfig",
|
||||
"load_full_training_config",
|
||||
]
|
||||
|
||||
|
||||
@ -93,18 +89,3 @@ def load_train_config(
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
||||
|
||||
|
||||
class FullTrainingConfig(ModelConfig):
|
||||
"""Full training configuration."""
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
|
||||
|
||||
def load_full_training_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> FullTrainingConfig:
|
||||
"""Load the full training configuration."""
|
||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||
|
||||
@ -5,8 +5,9 @@ from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.train.augmentations import (
|
||||
RandomAudioSource,
|
||||
build_augmentations,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
@ -6,11 +6,17 @@ from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.plotting.clips import build_preprocessor
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
__all__ = [
|
||||
"TrainingModule",
|
||||
]
|
||||
@ -21,7 +27,8 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: FullTrainingConfig,
|
||||
config: "BatDetect2Config",
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
model: Optional[Model] = None,
|
||||
@ -31,6 +38,7 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.config = config
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
@ -39,7 +47,23 @@ class TrainingModule(L.LightningModule):
|
||||
loss = build_loss(self.config.train.loss)
|
||||
|
||||
if model is None:
|
||||
model = build_model(self.config)
|
||||
targets = build_targets(self.config.targets)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
config=self.config.preprocess,
|
||||
input_samplerate=self.input_samplerate,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor, config=self.config.postprocess
|
||||
)
|
||||
|
||||
model = build_model(
|
||||
config=self.config.model,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
self.loss = loss
|
||||
self.model = model
|
||||
@ -74,16 +98,18 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> Tuple[Model, FullTrainingConfig]:
|
||||
) -> Tuple[Model, "BatDetect2Config"]:
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
return module.model, module.config
|
||||
|
||||
|
||||
def build_training_module(
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
t_max: int = 200,
|
||||
) -> TrainingModule:
|
||||
config = config or FullTrainingConfig()
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
return TrainingModule(
|
||||
config=config,
|
||||
learning_rate=config.train.optimizer.learning_rate,
|
||||
|
||||
@ -1,29 +1,31 @@
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from lightning import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
)
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import TrainingModule, build_training_module
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.typing import (
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipLabeller,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.typing.train import ClipLabeller
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"build_trainer",
|
||||
@ -36,12 +38,13 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
def train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
evaluator: Optional[Evaluator] = None,
|
||||
trainer: Optional[Trainer] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
@ -51,17 +54,20 @@ def train(
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
|
||||
config = config or FullTrainingConfig()
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
targets = targets or build_targets(config.targets)
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(config.preprocess)
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(
|
||||
config=config.preprocess.audio
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
labeller = labeller or build_clip_labeler(
|
||||
@ -95,7 +101,7 @@ def train(
|
||||
|
||||
if model_path is not None:
|
||||
logger.debug("Loading model from: {path}", path=model_path)
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
module = TrainingModule.load_from_checkpoint(Path(model_path))
|
||||
else:
|
||||
module = build_training_module(
|
||||
config,
|
||||
@ -103,8 +109,9 @@ def train(
|
||||
)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
config,
|
||||
config.train,
|
||||
targets=targets,
|
||||
evaluator=evaluator,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
@ -121,8 +128,8 @@ def train(
|
||||
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: TargetProtocol,
|
||||
config: FullTrainingConfig,
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional[Evaluator] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
@ -136,7 +143,7 @@ def build_trainer_callbacks(
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
evaluator = evaluator or build_evaluator(targets=targets)
|
||||
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
@ -150,20 +157,21 @@ def build_trainer_callbacks(
|
||||
|
||||
|
||||
def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
conf: TrainingConfig,
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional[Evaluator] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = conf.train.trainer
|
||||
trainer_conf = conf.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
train_logger = build_logger(
|
||||
conf.train.logger,
|
||||
conf.logger,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
@ -181,7 +189,7 @@ def build_trainer(
|
||||
logger=train_logger,
|
||||
callbacks=build_trainer_callbacks(
|
||||
targets,
|
||||
config=conf,
|
||||
evaluator=evaluator,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
|
||||
@ -13,8 +13,6 @@ from batdetect2.typing.postprocess import (
|
||||
from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
SpectrogramPipeline,
|
||||
)
|
||||
from batdetect2.typing.targets import (
|
||||
Position,
|
||||
@ -60,8 +58,6 @@ __all__ = [
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"SpectrogramBuilder",
|
||||
"SpectrogramPipeline",
|
||||
"TargetProtocol",
|
||||
"TrainExample",
|
||||
]
|
||||
|
||||
@ -32,6 +32,8 @@ class AudioLoader(Protocol):
|
||||
allows for different loading strategies or implementations.
|
||||
"""
|
||||
|
||||
samplerate: int
|
||||
|
||||
def load_file(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
@ -125,22 +127,6 @@ class SpectrogramBuilder(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class AudioPipeline(Protocol):
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class SpectrogramPipeline(Protocol):
|
||||
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
"""Defines a high-level interface for the complete preprocessing pipeline."""
|
||||
|
||||
@ -152,11 +138,13 @@ class PreprocessorProtocol(Protocol):
|
||||
|
||||
output_samplerate: float
|
||||
|
||||
audio_pipeline: AudioPipeline
|
||||
|
||||
spectrogram_pipeline: SpectrogramPipeline
|
||||
|
||||
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
|
||||
return self(torch.tensor(wav)).numpy()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user