Config restructuring

This commit is contained in:
mbsantiago 2025-09-16 18:57:56 +01:00
parent 7d6cba5465
commit bbb96b33a2
30 changed files with 1731 additions and 1619 deletions

View File

@ -36,31 +36,30 @@ targets:
name: anchor_bbox
anchor: top-left
preprocess:
audio:
samplerate: 256000
resample:
enabled: True
method: "poly"
audio:
samplerate: 256000
resample:
enabled: True
method: "poly"
spectrogram:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_substraction
preprocess:
stft:
window_duration: 0.002
window_overlap: 0.75
window_fn: hann
frequencies:
max_freq: 120000
min_freq: 10000
size:
height: 128
resize_factor: 0.5
spectrogram_transforms:
- name: pcen
time_constant: 0.1
gain: 0.98
bias: 2
power: 0.5
- name: spectral_mean_substraction
postprocess:
nms_kernel_size: 9
@ -113,12 +112,44 @@ train:
train_loader:
batch_size: 8
num_workers: 2
shuffle: True
clipping_strategy:
name: random_subclip
duration: 0.256
augmentations:
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- name: warp
probability: 0.2
delta: 0.04
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3
val_loader:
num_workers: 2
clipping_strategy:
@ -141,32 +172,3 @@ train:
logger:
name: csv
augmentations:
enabled: true
audio:
- name: mix_audio
probability: 0.2
min_weight: 0.3
max_weight: 0.7
- name: add_echo
probability: 0.2
max_delay: 0.005
min_weight: 0.0
max_weight: 1.0
spectrogram:
- name: scale_volume
probability: 0.2
min_scaling: 0.0
max_scaling: 2.0
- name: warp
probability: 0.2
delta: 0.04
- name: mask_time
probability: 0.2
max_perc: 0.05
max_masks: 3
- name: mask_freq
probability: 0.2
max_perc: 0.10
max_masks: 3

295
src/batdetect2/audio.py Normal file
View File

@ -0,0 +1,295 @@
from typing import Optional
import numpy as np
from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.core import BaseConfig
from batdetect2.typing import AudioLoader
__all__ = [
"SoundEventAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"resample_audio",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: ResampleConfig = Field(default_factory=ResampleConfig)
def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
samplerate=config.samplerate,
config=config.resample,
)
class SoundEventAudioLoader(AudioLoader):
"""Concrete implementation of the `AudioLoader`."""
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
):
self.samplerate = samplerate
self.config = config or ResampleConfig()
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
path,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
recording,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
clip,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_file_audio(
path: data.PathLike,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
recording = data.Recording.from_file(path)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {path}. Error: {e}"
) from e
return load_recording_audio(
recording,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_recording_audio(
recording: data.Recording,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
return load_clip_audio(
clip,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
def load_clip_audio(
clip: data.Clip,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
.sel(channel=0)
.astype(dtype)
)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {clip.recording.path}. "
f"Error: {e}"
) from e
if not config or not config.enabled or samplerate is None:
return wav.data.astype(dtype)
sr = int(1 / wav.time.attrs["step"])
return resample_audio(
wav.data,
sr=sr,
samplerate=samplerate,
method=config.method,
)
def resample_audio(
wav: np.ndarray,
sr: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly",
) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate."""
if sr == samplerate:
return wav
if method == "poly":
return resample_audio_poly(
wav,
sr_orig=sr,
sr_new=samplerate,
)
elif method == "fourier":
return resample_audio_fourier(
wav,
sr_orig=sr,
sr_new=samplerate,
)
else:
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
def resample_audio_poly(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample_poly`.
This method is often preferred for signals when the ratio of new
to old sample rates can be expressed as a rational number. It uses
polyphase filtering.
Parameters
----------
array : np.ndarray
The input array to resample.
sr_orig : int
The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to the target sample rate.
Raises
------
ValueError
If sample rates are not positive.
"""
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array,
sr_new // gcd,
sr_orig // gcd,
axis=axis,
)
def resample_audio_fourier(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`.
This method uses FFTs to resample the signal.
Parameters
----------
array : np.ndarray
The input array to resample.
num : int
The desired number of samples in the output array along `axis`.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to have `num` samples along `axis`.
Raises
------
ValueError
If `num` is negative.
"""
ratio = sr_new / sr_orig
return resample( # type: ignore
array,
int(array.shape[axis] * ratio),
axis=axis,
)

View File

@ -14,7 +14,7 @@ __all__ = ["train_command"]
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--targets", type=click.Path(exists=True))
@click.option("--targets-config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@ -37,7 +37,7 @@ def train_command(
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
targets: Optional[Path] = None,
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
train_workers: int = 0,
@ -46,13 +46,13 @@ def train_command(
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.config import (
BatDetect2Config,
load_full_config,
)
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
from batdetect2.train import train
logger.remove()
if verbose == 0:
@ -68,15 +68,16 @@ def train_command(
logger.info("Loading training configuration...")
conf = (
load_full_training_config(config, field=config_field)
load_full_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
else BatDetect2Config()
)
if targets is not None:
if targets_config is not None:
logger.info("Loading targets configuration...")
targets_config = load_target_config(targets)
conf = conf.model_copy(update=dict(targets=targets_config))
conf = conf.model_copy(
update=dict(targets=load_target_config(targets_config))
)
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset)
@ -96,6 +97,7 @@ def train_command(
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train(
train_annotations=train_annotations,
val_annotations=val_annotations,

View File

@ -1,16 +1,40 @@
from typing import Literal
from typing import Literal, Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio import AudioConfig
from batdetect2.core import BaseConfig
from batdetect2.core.configs import load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models.backbones import BackboneConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.models.config import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.targets.config import TargetConfig
from batdetect2.train.config import TrainingConfig
__all__ = [
"BatDetect2Config",
"load_full_config",
]
class BatDetect2Config(BaseConfig):
config_version: Literal["v1"] = "v1"
train: TrainingConfig
evaluation: EvaluationConfig
model: BackboneConfig
preprocess: PreprocessingConfig
train: TrainingConfig = Field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
def load_full_config(
path: PathLike,
field: Optional[str] = None,
) -> BatDetect2Config:
return load_config(path, schema=BatDetect2Config, field=field)

View File

@ -3,39 +3,46 @@ from typing import List, Optional, Tuple
import pandas as pd
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
from batdetect2.models import Model
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol
from batdetect2.postprocess import get_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.train import build_val_loader
from batdetect2.typing import ClipLabeller, TargetProtocol
def evaluate(
model: Model,
test_annotations: List[data.ClipAnnotation],
config: Optional[FullTrainingConfig] = None,
targets: Optional[TargetProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
labeller: Optional[ClipLabeller] = None,
config: Optional[EvaluationConfig] = None,
num_workers: Optional[int] = None,
) -> Tuple[pd.DataFrame, dict]:
config = config or FullTrainingConfig()
config = config or EvaluationConfig()
audio_loader = build_audio_loader(config.preprocess.audio)
audio_loader = audio_loader or build_audio_loader()
preprocessor = build_preprocessor(config.preprocess)
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
)
targets = build_targets(config.targets)
targets = targets or build_targets()
labeller = build_clip_labeler(
labeller = labeller or build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.train.labels,
)
loader = build_val_loader(
@ -43,7 +50,6 @@ def evaluate(
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train.val_loader,
num_workers=num_workers,
)
@ -52,7 +58,7 @@ def evaluate(
clip_annotations = []
predictions = []
evaluator = build_evaluator(config=config.evaluation)
evaluator = build_evaluator(config=config)
for batch in loader:
outputs = model.detector(batch.spec)

View File

@ -89,7 +89,7 @@ class ExampleGallery(PlotterProtocol):
@classmethod
def from_config(cls, config: ExampleGalleryConfig):
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio)
audio_loader = build_audio_loader(config.preprocessing.audio_transforms)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,

View File

@ -29,15 +29,10 @@ provided here.
from typing import List, Optional
import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
build_backbone,
load_backbone_config,
)
from batdetect2.models.blocks import (
ConvConfig,
@ -51,6 +46,10 @@ from batdetect2.models.bottleneck import (
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.config import (
BackboneConfig,
load_backbone_config,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
@ -63,9 +62,9 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import (
DetectionsTensor,
@ -99,20 +98,10 @@ __all__ = [
"build_detector",
"load_backbone_config",
"Model",
"ModelConfig",
"build_model",
]
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
class Model(torch.nn.Module):
detector: DetectionModel
preprocessor: PreprocessorProtocol
@ -125,14 +114,12 @@ class Model(torch.nn.Module):
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
config: ModelConfig,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.config = config
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
spec = self.preprocessor(wav)
@ -140,32 +127,25 @@ class Model(torch.nn.Module):
return self.postprocessor(outputs)
def build_model(config: Optional[ModelConfig] = None):
config = config or ModelConfig()
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
def build_model(
config: Optional[BackboneConfig] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
):
config = config or BackboneConfig()
targets = targets or build_targets()
preprocessor = preprocessor or build_preprocessor()
postprocessor = postprocessor or build_postprocessor(
preprocessor=preprocessor,
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config.model,
config=config,
)
return Model(
config=config,
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,
targets=targets,
)
def load_model_config(
path: PathLike, field: Optional[str] = None
) -> ModelConfig:
return load_config(path, schema=ModelConfig, field=field)

View File

@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the
network's total downsampling factor.
"""
from typing import Optional, Tuple
from typing import Tuple
import torch
import torch.nn.functional as F
from soundevent import data
from torch import nn
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
build_bottleneck,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
Decoder,
DecoderConfig,
build_decoder,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
Encoder,
EncoderConfig,
build_encoder,
)
from batdetect2.models.bottleneck import build_bottleneck
from batdetect2.models.config import BackboneConfig
from batdetect2.models.decoder import Decoder, build_decoder
from batdetect2.models.encoder import Encoder, build_encoder
from batdetect2.typing.models import BackboneModel
__all__ = [
"Backbone",
"BackboneConfig",
"load_backbone_config",
"build_backbone",
]
@ -161,82 +144,6 @@ class Backbone(BackboneModel):
return x
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)
def build_backbone(config: BackboneConfig) -> BackboneModel:
"""Factory function to build a Backbone from configuration.

View File

@ -0,0 +1,98 @@
from typing import Optional
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
)
from batdetect2.models.decoder import (
DEFAULT_DECODER_CONFIG,
DecoderConfig,
)
from batdetect2.models.encoder import (
DEFAULT_ENCODER_CONFIG,
EncoderConfig,
)
__all__ = [
"BackboneConfig",
"load_backbone_config",
]
class BackboneConfig(BaseConfig):
"""Configuration for the Encoder-Decoder Backbone network.
Aggregates configurations for the encoder, bottleneck, and decoder
components, along with defining the input and final output dimensions
for the complete backbone.
Attributes
----------
input_height : int, default=128
Expected height (frequency bins) of the input spectrograms to the
backbone. Must be positive.
in_channels : int, default=1
Expected number of channels in the input spectrograms (e.g., 1 for
mono). Must be positive.
encoder : EncoderConfig, optional
Configuration for the encoder. If None or omitted,
the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the
encoder module) will be used.
bottleneck : BottleneckConfig, optional
Configuration for the bottleneck layer connecting encoder and decoder.
If None or omitted, the default bottleneck configuration will be used.
decoder : DecoderConfig, optional
Configuration for the decoder. If None or omitted,
the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the
decoder module) will be used.
out_channels : int, default=32
Desired number of channels in the final feature map output by the
backbone. Must be positive.
"""
input_height: int = 128
in_channels: int = 1
encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG
bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
out_channels: int = 32
def load_backbone_config(
path: data.PathLike,
field: Optional[str] = None,
) -> BackboneConfig:
"""Load the backbone configuration from a file.
Reads a configuration file (YAML) and validates it against the
`BackboneConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
backbone configuration (e.g., "model.backbone"). If None, the entire
file content is used.
Returns
-------
BackboneConfig
The loaded and validated backbone configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `BackboneConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=BackboneConfig, field=field)

View File

@ -5,8 +5,9 @@ import torch
from matplotlib.axes import Axes
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.plotting.common import plot_spectrogram
from batdetect2.preprocess import build_audio_loader, build_preprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [

View File

@ -1,307 +1,29 @@
"""Main entry point for the BatDetect2 Postprocessing pipeline."""
from typing import List, Optional
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.config import (
PostprocessConfig,
load_postprocess_config,
)
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.nms import (
NMS_KERNEL_SIZE,
non_max_suppression,
)
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
DetectionsTensor,
PostprocessorProtocol,
RawPrediction,
from batdetect2.postprocess.postprocessor import (
Postprocessor,
build_postprocessor,
get_raw_predictions,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
"DEFAULT_DETECTION_THRESHOLD",
"MAX_FREQ",
"MIN_FREQ",
"ModelOutput",
"NMS_KERNEL_SIZE",
"PostprocessConfig",
"Postprocessor",
"TOP_K_PER_SEC",
"build_postprocessor",
"convert_raw_predictions_to_clip_prediction",
"to_raw_predictions",
"load_postprocess_config",
"non_max_suppression",
"get_raw_predictions",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 100
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)
def build_postprocessor(
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor."""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
samplerate=preprocessor.output_samplerate,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
)
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline."""
def __init__(
self,
samplerate: float,
min_freq: float,
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
):
"""Initialize the Postprocessor."""
super().__init__()
self.samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]
def get_detections(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
return detections
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
return [
map_detection_to_clip(
detection,
start_time=start_time,
end_time=start_time + duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, start_time in zip(detections, start_times)
]
def get_raw_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
start_times: Optional[List[float]] = None,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch."""
detections = postprocessor.get_detections(output, start_times)
return [
to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections
]
def get_sound_event_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
clips: List[data.Clip],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -0,0 +1,94 @@
from typing import Optional
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
__all__ = [
"PostprocessConfig",
"load_postprocess_config",
]
DEFAULT_DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 100
class PostprocessConfig(BaseConfig):
"""Configuration settings for the postprocessing pipeline.
Defines tunable parameters that control how raw model outputs are
converted into final detections.
Attributes
----------
nms_kernel_size : int, default=NMS_KERNEL_SIZE
Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression.
Used to suppress weaker detections near stronger peaks. Must be
positive.
detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD
Minimum confidence score from the detection heatmap required to
consider a point as a potential detection. Must be >= 0.
classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD
Minimum confidence score for a specific class prediction to be included
in the decoded tags for a detection. Must be >= 0.
top_k_per_sec : int, default=TOP_K_PER_SEC
Desired maximum number of detections per second of audio. Used by
`get_max_detections` to calculate an absolute limit based on clip
duration before applying `extract_detections_from_array`. Must be
positive.
"""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(
default=DEFAULT_DETECTION_THRESHOLD,
ge=0,
)
classification_threshold: float = Field(
default=DEFAULT_CLASSIFICATION_THRESHOLD,
ge=0,
)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
def load_postprocess_config(
path: data.PathLike,
field: Optional[str] = None,
) -> PostprocessConfig:
"""Load the postprocessing configuration from a file.
Reads a configuration file (YAML) and validates it against the
`PostprocessConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
postprocessing configuration (e.g., "inference.postprocessing").
If None, the entire file content is used.
Returns
-------
PostprocessConfig
The loaded and validated postprocessing configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`PostprocessConfig` schema.
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path, schema=PostprocessConfig, field=field)

View File

@ -0,0 +1,208 @@
from typing import List, Optional
import torch
from loguru import logger
from soundevent import data
from batdetect2.postprocess.config import (
PostprocessConfig,
)
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
DetectionsTensor,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"build_postprocessor",
"Postprocessor",
]
def build_postprocessor(
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol:
"""Factory function to build the standard postprocessor."""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
samplerate=preprocessor.output_samplerate,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
)
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline."""
def __init__(
self,
samplerate: float,
min_freq: float,
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
):
"""Initialize the Postprocessor."""
super().__init__()
self.samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]
def get_detections(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
return detections
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
return [
map_detection_to_clip(
detection,
start_time=start_time,
end_time=start_time + duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, start_time in zip(detections, start_times)
]
def get_raw_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
start_times: Optional[List[float]] = None,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch."""
detections = postprocessor.get_detections(output, start_times)
return [
to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections
]
def get_sound_event_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
clips: List[data.Clip],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -1,21 +1,19 @@
"""Main entry point for the BatDetect2 preprocessing subsystem."""
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.preprocess.config import (
MAX_FREQ,
MIN_FREQ,
TARGET_SAMPLERATE_HZ,
PreprocessingConfig,
load_preprocessing_config,
)
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor
from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ
__all__ = [
"MIN_FREQ",
"MAX_FREQ",
"TARGET_SAMPLERATE_HZ",
"MIN_FREQ",
"PreprocessingConfig",
"load_preprocessing_config",
"Preprocessor",
"TARGET_SAMPLERATE_HZ",
"build_preprocessor",
"build_audio_loader",
"load_preprocessing_config",
]

View File

@ -1,267 +1,60 @@
"""Handles loading and initial preprocessing of audio waveforms."""
from typing import Annotated, Literal, Union
from typing import Optional
import numpy as np
import torch
from numpy.typing import DTypeLike
from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from pydantic import Field
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
from batdetect2.preprocess.config import (
TARGET_SAMPLERATE_HZ,
AudioConfig,
AudioTransform,
ResampleConfig,
)
from batdetect2.typing import AudioLoader
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.core import BaseConfig, Registry
from batdetect2.preprocess.common import center_tensor, peak_normalize
__all__ = [
"SoundEventAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"resample_audio",
"CenterAudioConfig",
"ScaleAudioConfig",
"FixDurationConfig",
"build_audio_transform",
]
class SoundEventAudioLoader(AudioLoader):
"""Concrete implementation of the `AudioLoader`."""
def __init__(
self,
samplerate: int = TARGET_SAMPLERATE_HZ,
config: Optional[ResampleConfig] = None,
):
self.samplerate = samplerate
self.config = config or ResampleConfig()
def load_file(
self,
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess audio directly from a file path."""
return load_file_audio(
path,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_recording(
self,
recording: data.Recording,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the entire audio for a Recording object."""
return load_recording_audio(
recording,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
def load_clip(
self,
clip: data.Clip,
audio_dir: Optional[data.PathLike] = None,
) -> np.ndarray:
"""Load and preprocess the audio segment defined by a Clip object."""
return load_clip_audio(
clip,
samplerate=self.samplerate,
config=self.config,
audio_dir=audio_dir,
)
audio_transforms: Registry[torch.nn.Module, [int]] = Registry(
"audio_transform"
)
def load_file_audio(
path: data.PathLike,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess audio from a file path using specified config."""
try:
recording = data.Recording.from_file(path)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {path}. Error: {e}"
) from e
return load_recording_audio(
recording,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
def load_recording_audio(
recording: data.Recording,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess the entire audio content of a recording using config."""
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
return load_clip_audio(
clip,
samplerate=samplerate,
config=config,
dtype=dtype,
audio_dir=audio_dir,
)
class CenterAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor:
return center_tensor(wav)
@classmethod
def from_config(cls, config: CenterAudioConfig, samplerate: int):
return cls()
def load_clip_audio(
clip: data.Clip,
samplerate: Optional[int] = None,
config: Optional[ResampleConfig] = None,
audio_dir: Optional[data.PathLike] = None,
dtype: DTypeLike = np.float32, # type: ignore
) -> np.ndarray:
"""Load and preprocess a specific audio clip segment based on config."""
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
.sel(channel=0)
.astype(dtype)
)
except LibsndfileError as e:
raise FileNotFoundError(
f"Could not load the recording at path: {clip.recording.path}. "
f"Error: {e}"
) from e
if not config or not config.enabled or samplerate is None:
return wav.data.astype(dtype)
sr = int(1 / wav.time.attrs["step"])
return resample_audio(
wav.data,
sr=sr,
samplerate=samplerate,
method=config.method,
)
audio_transforms.register(CenterAudioConfig, CenterAudio)
def resample_audio(
wav: np.ndarray,
sr: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
method: str = "poly",
) -> np.ndarray:
"""Resample an audio waveform DataArray to a target sample rate."""
if sr == samplerate:
return wav
if method == "poly":
return resample_audio_poly(
wav,
sr_orig=sr,
sr_new=samplerate,
)
elif method == "fourier":
return resample_audio_fourier(
wav,
sr_orig=sr,
sr_new=samplerate,
)
else:
raise NotImplementedError(
f"Resampling method '{method}' not implemented"
)
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
def resample_audio_poly(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample_poly`.
class ScaleAudio(torch.nn.Module):
def forward(self, wav: torch.Tensor) -> torch.Tensor:
return peak_normalize(wav)
This method is often preferred for signals when the ratio of new
to old sample rates can be expressed as a rational number. It uses
polyphase filtering.
Parameters
----------
array : np.ndarray
The input array to resample.
sr_orig : int
The original sample rate in Hz.
sr_new : int
The target sample rate in Hz.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to the target sample rate.
Raises
------
ValueError
If sample rates are not positive.
"""
gcd = np.gcd(sr_orig, sr_new)
return resample_poly(
array,
sr_new // gcd,
sr_orig // gcd,
axis=axis,
)
@classmethod
def from_config(cls, config: ScaleAudioConfig, samplerate: int):
return cls()
def resample_audio_fourier(
array: np.ndarray,
sr_orig: int,
sr_new: int,
axis: int = -1,
) -> np.ndarray:
"""Resample a numpy array using `scipy.signal.resample`.
audio_transforms.register(ScaleAudioConfig, ScaleAudio)
This method uses FFTs to resample the signal.
Parameters
----------
array : np.ndarray
The input array to resample.
num : int
The desired number of samples in the output array along `axis`.
axis : int, default=-1
The axis of `array` along which to resample.
Returns
-------
np.ndarray
The array resampled to have `num` samples along `axis`.
Raises
------
ValueError
If `num` is negative.
"""
ratio = sr_new / sr_orig
return resample( # type: ignore
array,
int(array.shape[axis] * ratio),
axis=axis,
)
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
class FixDuration(torch.nn.Module):
@ -282,40 +75,25 @@ class FixDuration(torch.nn.Module):
return torch.nn.functional.pad(wav, (0, self.length - length))
def build_audio_loader(
config: Optional[AudioConfig] = None,
) -> AudioLoader:
"""Factory function to create an AudioLoader based on configuration."""
config = config or AudioConfig()
return SoundEventAudioLoader(
samplerate=config.samplerate,
config=config.resample,
)
@classmethod
def from_config(cls, config: FixDurationConfig, samplerate: int):
return cls(samplerate=samplerate, duration=config.duration)
def build_audio_transform_step(
audio_transforms.register(FixDurationConfig, FixDuration)
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"),
]
def build_audio_transform(
config: AudioTransform,
samplerate: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module:
if config.name == "fix_duration":
return FixDuration(samplerate=samplerate, duration=config.duration)
if config.name == "scale_audio":
return PeakNormalize()
if config.name == "center_audio":
return CenterTensor()
raise NotImplementedError(
f"Audio preprocessing step {config.name} not implemented"
)
def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module:
return torch.nn.Sequential(
*[
build_audio_transform_step(step, samplerate=config.samplerate)
for step in config.transforms
]
)
return audio_transforms.build(config, samplerate)

View File

@ -1,24 +1,22 @@
import torch
__all__ = [
"CenterTensor",
"PeakNormalize",
"center_tensor",
"peak_normalize",
]
class CenterTensor(torch.nn.Module):
def forward(self, wav: torch.Tensor):
return wav - wav.mean()
def center_tensor(tensor: torch.Tensor) -> torch.Tensor:
return tensor - tensor.mean()
class PeakNormalize(torch.nn.Module):
def forward(self, wav: torch.Tensor):
max_value = wav.abs().min()
def peak_normalize(tensor: torch.Tensor) -> torch.Tensor:
max_value = tensor.abs().min()
denominator = torch.where(
max_value == 0,
torch.tensor(1.0, device=wav.device, dtype=wav.dtype),
max_value,
)
denominator = torch.where(
max_value == 0,
torch.tensor(1.0, device=tensor.device, dtype=tensor.dtype),
max_value,
)
return wav / denominator
return tensor / denominator

View File

@ -1,187 +1,25 @@
from collections.abc import Sequence
from typing import Annotated, List, Literal, Optional, Union
from typing import List, Optional
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import AudioTransform
from batdetect2.preprocess.spectrogram import (
FrequencyConfig,
PcenConfig,
ResizeConfig,
SpectralMeanSubstractionConfig,
SpectrogramTransform,
STFTConfig,
)
__all__ = [
"load_preprocessing_config",
"CenterAudioConfig",
"ScaleAudioConfig",
"FixDurationConfig",
"ResampleConfig",
"AudioTransform",
"AudioConfig",
"STFTConfig",
"FrequencyConfig",
"PcenConfig",
"ScaleAmplitudeConfig",
"SpectralMeanSubstractionConfig",
"ResizeConfig",
"PeakNormalizeConfig",
"SpectrogramTransform",
"SpectrogramConfig",
"PreprocessingConfig",
"TARGET_SAMPLERATE_HZ",
"MIN_FREQ",
"MAX_FREQ",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
MIN_FREQ = 10_000
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
MAX_FREQ = 120_000
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"),
]
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
transforms: List[AudioTransform] = Field(default_factory=list)
class STFTConfig(BaseConfig):
"""Configuration for the Short-Time Fourier Transform (STFT).
Attributes
----------
window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
> 0. Determines frequency resolution (longer window = finer frequency
resolution).
window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75
for 75%). Must be >= 0 and < 1. Determines time resolution
(higher overlap = finer time resolution).
window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common
options include "hann", "hamming", "blackman". See
`scipy.signal.get_window`.
"""
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
class FrequencyConfig(BaseConfig):
"""Configuration for frequency axis parameters.
Attributes
----------
max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT.
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class ScaleAmplitudeConfig(BaseConfig):
name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db"
class SpectralMeanSubstractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"),
]
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
@ -201,8 +39,20 @@ class PreprocessingConfig(BaseConfig):
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
audio_transforms: List[AudioTransform] = Field(default_factory=list)
spectrogram_transforms: List[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
def load_preprocessing_config(

View File

@ -3,21 +3,25 @@ from typing import Optional
import torch
from loguru import logger
from batdetect2.preprocess.audio import build_audio_pipeline
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.preprocess.audio import build_audio_transform
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.spectrogram import (
_spec_params_from_config,
build_spectrogram_pipeline,
build_spectrogram_builder,
build_spectrogram_crop,
build_spectrogram_resizer,
build_spectrogram_transform,
)
from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline
from batdetect2.typing import PreprocessorProtocol
__all__ = [
"StandardPreprocessor",
"Preprocessor",
"build_preprocessor",
]
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
class Preprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
input_samplerate: int
@ -28,37 +32,78 @@ class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
def __init__(
self,
audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline,
config: PreprocessingConfig,
input_samplerate: int,
output_samplerate: float,
max_freq: float,
min_freq: float,
) -> None:
super().__init__()
self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline
self.max_freq = max_freq
self.min_freq = min_freq
self.audio_transforms = torch.nn.Sequential(
*[
build_audio_transform(step, samplerate=input_samplerate)
for step in config.audio_transforms
]
)
self.spectrogram_transforms = torch.nn.Sequential(
*[
build_spectrogram_transform(step, samplerate=input_samplerate)
for step in config.spectrogram_transforms
]
)
self.spectrogram_builder = build_spectrogram_builder(
config.stft,
samplerate=input_samplerate,
)
self.spectrogram_crop = build_spectrogram_crop(
config.frequencies,
stft=config.stft,
samplerate=input_samplerate,
)
self.spectrogram_resizer = build_spectrogram_resizer(config.size)
self.min_freq = config.frequencies.min_freq
self.max_freq = config.frequencies.max_freq
self.input_samplerate = input_samplerate
self.output_samplerate = output_samplerate
self.output_samplerate = compute_output_samplerate(
config,
input_samplerate=input_samplerate,
)
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav)
wav = self.audio_transforms(wav)
spec = self.spectrogram_builder(wav)
return self.process_spectrogram(spec)
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
return self.spectrogram_builder(wav)
def process_audio(self, wav: torch.Tensor) -> torch.Tensor:
return self(wav)
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
spec = self.spectrogram_crop(spec)
spec = self.spectrogram_transforms(spec)
return self.spectrogram_resizer(spec)
def compute_output_samplerate(config: PreprocessingConfig) -> float:
samplerate = config.audio.samplerate
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
factor = config.spectrogram.size.resize_factor
return samplerate * factor / hop_size
def compute_output_samplerate(
config: PreprocessingConfig,
input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> float:
_, hop_size = _spec_params_from_config(
config.stft, samplerate=input_samplerate
)
factor = config.size.resize_factor
return input_samplerate * factor / hop_size
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
input_samplerate: int = TARGET_SAMPLERATE_HZ,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
@ -66,21 +111,4 @@ def build_preprocessor(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
samplerate = config.audio.samplerate
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
output_samplerate = compute_output_samplerate(config)
return StandardPreprocessor(
audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram
),
input_samplerate=samplerate,
output_samplerate=output_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)
return Preprocessor(config=config, input_samplerate=input_samplerate)

View File

@ -1,34 +1,64 @@
"""Computes spectrograms from audio waveforms with configurable parameters."""
from typing import Callable, Optional
from typing import Annotated, Callable, Literal, Optional, Union
import numpy as np
import torch
import torchaudio
from pydantic import Field
from batdetect2.preprocess.common import PeakNormalize
from batdetect2.preprocess.config import (
ScaleAmplitudeConfig,
SpectrogramConfig,
SpectrogramTransform,
STFTConfig,
)
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.preprocess.common import peak_normalize
__all__ = [
"STFTConfig",
"build_spectrogram_transform",
"build_spectrogram_builder",
"build_spectrogram_pipeline",
]
MIN_FREQ = 10_000
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
MAX_FREQ = 120_000
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
class STFTConfig(BaseConfig):
"""Configuration for the Short-Time Fourier Transform (STFT).
Attributes
----------
window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
> 0. Determines frequency resolution (longer window = finer frequency
resolution).
window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75
for 75%). Must be >= 0 and < 1. Determines time resolution
(higher overlap = finer time resolution).
window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common
options include "hann", "hamming", "blackman". See
`scipy.signal.get_window`.
"""
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
def build_spectrogram_builder(
samplerate: int,
conf: STFTConfig,
config: STFTConfig,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module:
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
window_fn=get_spectrogram_window(conf.window_fn),
window_fn=get_spectrogram_window(config.window_fn),
center=True,
power=1,
)
@ -55,16 +85,19 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
)
def _spec_params_from_config(samplerate: int, conf: STFTConfig):
n_fft = int(samplerate * conf.window_duration)
hop_length = int(n_fft * (1 - conf.window_overlap))
def _spec_params_from_config(
config: STFTConfig,
samplerate: int = TARGET_SAMPLERATE_HZ,
):
n_fft = int(samplerate * config.window_duration)
hop_length = int(n_fft * (1 - config.window_overlap))
return n_fft, hop_length
def _frequency_to_index(
freq: float,
samplerate: int,
n_fft: int,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> Optional[int]:
alpha = freq * 2 / samplerate
height = np.floor(n_fft / 2) + 1
@ -79,14 +112,49 @@ def _frequency_to_index(
return index
class FrequencyClip(torch.nn.Module):
class FrequencyConfig(BaseConfig):
"""Configuration for frequency axis parameters.
Attributes
----------
max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT.
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=MAX_FREQ, ge=0)
min_freq: int = Field(default=MIN_FREQ, ge=0)
class FrequencyCrop(torch.nn.Module):
def __init__(
self,
low_index: Optional[int] = None,
high_index: Optional[int] = None,
samplerate: int,
n_fft: int,
min_freq: Optional[int] = None,
max_freq: Optional[int] = None,
):
super().__init__()
self.n_fft = n_fft
self.samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
low_index = None
if min_freq is not None:
low_index = _frequency_to_index(
min_freq, self.samplerate, self.n_fft
)
self.low_index = low_index
high_index = None
if max_freq is not None:
high_index = _frequency_to_index(
max_freq, self.samplerate, self.n_fft
)
self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor:
@ -107,6 +175,72 @@ class FrequencyClip(torch.nn.Module):
)
def build_spectrogram_crop(
config: FrequencyConfig,
stft: Optional[STFTConfig] = None,
samplerate: int = TARGET_SAMPLERATE_HZ,
) -> torch.nn.Module:
stft = stft or STFTConfig()
n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate)
return FrequencyCrop(
samplerate=samplerate,
n_fft=n_fft,
min_freq=config.min_freq,
max_freq=config.max_freq,
)
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
class ResizeSpec(torch.nn.Module):
def __init__(self, height: int, time_factor: float):
super().__init__()
self.height = height
self.time_factor = time_factor
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
original_ndim = spec.ndim
while spec.ndim < 4:
spec = spec.unsqueeze(0)
resized = torch.nn.functional.interpolate(
spec,
size=(self.height, target_length),
mode="bilinear",
)
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module:
return ResizeSpec(height=config.height, time_factor=config.resize_factor)
spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry(
"spectrogram_transform"
)
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class PCEN(torch.nn.Module):
def __init__(
self,
@ -115,7 +249,7 @@ class PCEN(torch.nn.Module):
bias: float = 2.0,
power: float = 0.5,
eps: float = 1e-6,
dtype=torch.float64,
dtype=torch.float32,
):
super().__init__()
self.smoothing_constant = smoothing_constant
@ -151,6 +285,19 @@ class PCEN(torch.nn.Module):
* torch.expm1(self.power * torch.log1p(S * smooth / self.bias))
).to(spec.dtype)
@classmethod
def from_config(cls, config: PcenConfig, samplerate: int):
smooth = _compute_smoothing_constant(samplerate, config.time_constant)
return cls(
smoothing_constant=smooth,
gain=config.gain,
bias=config.bias,
power=config.power,
)
spectrogram_transforms.register(PcenConfig, PCEN)
def _compute_smoothing_constant(
samplerate: int,
@ -164,21 +311,40 @@ def _compute_smoothing_constant(
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
class ScaleAmplitudeConfig(BaseConfig):
name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db"
class ToPower(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec**2
def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
if conf.scale == "db":
return torchaudio.transforms.AmplitudeToDB()
_scalers = {
"db": torchaudio.transforms.AmplitudeToDB,
"power": ToPower,
}
if conf.scale == "power":
return ToPower()
raise NotImplementedError(
f"Amplitude scaling {conf.scale} not implemented"
)
class ScaleAmplitude(torch.nn.Module):
def __init__(self, scale: Literal["power", "db"]):
self.scale = scale
self.scaler = _scalers[scale]()
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return self.scaler(spec)
@classmethod
def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int):
return cls(scale=config.scale)
spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude)
class SpectralMeanSubstractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
class SpectralMeanSubstraction(torch.nn.Module):
@ -186,129 +352,49 @@ class SpectralMeanSubstraction(torch.nn.Module):
mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0)
@classmethod
def from_config(
cls,
config: SpectralMeanSubstractionConfig,
samplerate: int,
):
return cls()
class ResizeSpec(torch.nn.Module):
def __init__(self, height: int, time_factor: float):
super().__init__()
self.height = height
self.time_factor = time_factor
spectrogram_transforms.register(
SpectralMeanSubstractionConfig,
SpectralMeanSubstraction,
)
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
class PeakNormalize(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
return peak_normalize(spec)
original_ndim = spec.ndim
while spec.ndim < 4:
spec = spec.unsqueeze(0)
resized = torch.nn.functional.interpolate(
spec,
size=(self.height, target_length),
mode="bilinear",
)
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
@classmethod
def from_config(cls, config: PeakNormalizeConfig, samplerate: int):
return cls()
def _build_spectrogram_transform_step(
step: SpectrogramTransform,
samplerate: int,
) -> torch.nn.Module:
if step.name == "pcen":
return PCEN(
smoothing_constant=_compute_smoothing_constant(
samplerate=samplerate,
time_constant=step.time_constant,
),
gain=step.gain,
bias=step.bias,
power=step.power,
)
spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize)
if step.name == "scale_amplitude":
return _build_amplitude_scaler(step)
if step.name == "spectral_mean_substraction":
return SpectralMeanSubstraction()
if step.name == "peak_normalize":
return PeakNormalize()
raise NotImplementedError(
f"Spectrogram preprocessing step {step.name} not implemented"
)
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"),
]
def build_spectrogram_transform(
config: SpectrogramTransform,
samplerate: int,
conf: SpectrogramConfig,
) -> torch.nn.Module:
return torch.nn.Sequential(
*[
_build_spectrogram_transform_step(step, samplerate=samplerate)
for step in conf.transforms
]
)
class SpectrogramPipeline(torch.nn.Module):
def __init__(
self,
spec_builder: torch.nn.Module,
freq_cutter: torch.nn.Module,
transforms: torch.nn.Module,
resizer: torch.nn.Module,
):
super().__init__()
self.spec_builder = spec_builder
self.freq_cutter = freq_cutter
self.transforms = transforms
self.resizer = resizer
def forward(self, wav: torch.Tensor) -> torch.Tensor:
spec = self.spec_builder(wav)
spec = self.freq_cutter(spec)
spec = self.transforms(spec)
return self.resizer(spec)
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor:
return self.spec_builder(wav)
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor:
return self.freq_cutter(spec)
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
return self.transforms(spec)
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor:
return self.resizer(spec)
def build_spectrogram_pipeline(
samplerate: int,
conf: SpectrogramConfig,
) -> SpectrogramPipeline:
spec_builder = build_spectrogram_builder(samplerate, conf.stft)
n_fft, _ = _spec_params_from_config(samplerate, conf.stft)
cutter = FrequencyClip(
low_index=_frequency_to_index(
conf.frequencies.min_freq, samplerate, n_fft
),
high_index=_frequency_to_index(
conf.frequencies.max_freq, samplerate, n_fft
),
)
transforms = build_spectrogram_transform(samplerate, conf)
resizer = ResizeSpec(
height=conf.size.height,
time_factor=conf.size.resize_factor,
)
return SpectrogramPipeline(
spec_builder=spec_builder,
freq_cutter=cutter,
transforms=transforms,
resizer=resizer,
)
return spectrogram_transforms.build(config, samplerate)

View File

@ -1,17 +1,6 @@
"""BatDetect2 Target Definition system."""
from collections import Counter
from typing import Iterable, List, Optional, Tuple
from loguru import logger
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.conditions import build_sound_event_condition
from batdetect2.targets.classes import (
DEFAULT_CLASSES,
DEFAULT_DETECTION_CLASS,
SoundEventDecoder,
SoundEventEncoder,
TargetClassConfig,
@ -19,23 +8,29 @@ from batdetect2.targets.classes import (
build_sound_event_encoder,
get_class_names_from_config,
)
from batdetect2.targets.config import TargetConfig, load_target_config
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
ROIMapperConfig,
ROITargetMapper,
build_roi_mapper,
)
from batdetect2.targets.targets import (
Targets,
build_targets,
iterate_encoded_sound_events,
load_targets,
)
from batdetect2.targets.terms import (
call_type,
data_source,
generic_class,
individual,
)
from batdetect2.typing.targets import Position, Size, TargetProtocol
__all__ = [
"AnchorBBoxMapperConfig",
"DEFAULT_TARGET_CONFIG",
"ROIMapperConfig",
"ROITargetMapper",
"SoundEventDecoder",
"SoundEventEncoder",
@ -45,365 +40,13 @@ __all__ = [
"build_roi_mapper",
"build_sound_event_decoder",
"build_sound_event_encoder",
"build_targets",
"call_type",
"data_source",
"generic_class",
"get_class_names_from_config",
"individual",
"iterate_encoded_sound_events",
"load_target_config",
"load_targets",
]
class TargetConfig(BaseConfig):
detection_target: TargetClassConfig = Field(
default=DEFAULT_DETECTION_CLASS
)
classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
@field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def load_target_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TargetConfig:
"""Load the unified target configuration from a file.
Reads a configuration file (typically YAML) and validates it against the
`TargetConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : data.PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
target configuration. If None, the entire file content is used.
Returns
-------
TargetConfig
The loaded and validated unified target configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`TargetConfig` schema (including validation within nested configs
like `ClassesConfig`).
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path=path, schema=TargetConfig, field=field)
class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline.
This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name),
decoding (class name to tags), and mapping ROIs (geometry to position/size
and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined in the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches).
dimension_names : List[str]
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: List[str]
detection_class_tags: List[data.Tag]
dimension_names: List[str]
detection_class_name: str
def __init__(self, config: TargetConfig):
"""Initialize the Targets object."""
self.config = config
self._filter_fn = build_sound_event_condition(
config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
)
self._roi_mapper = build_roi_mapper(config.roi)
self.dimension_names = self._roi_mapper.dimension_names
self.class_names = get_class_names_from_config(
config.classification_targets
)
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to filter.
Returns
-------
bool
True if the annotation should be kept (passes the filter),
False otherwise. If no filter was configured, always returns True.
"""
return self._filter_fn(sound_event)
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
to determine the specific class name for the annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to encode. Note: This should typically be called
*after* applying any transformations via the `transform` method.
Returns
-------
str or None
The name of the matched target class, or None if the annotation
does not match any specific class rule (i.e., it belongs to the
generic category).
"""
return self._encode_fn(sound_event)
def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
`TargetClass.tags`) to convert a class name string into a list of
`soundevent.data.Tag` objects.
Parameters
----------
class_label : str
The class name to decode.
Returns
-------
List[data.Tag]
The list of tags corresponding to the input class name.
"""
return self._decode_fn(class_label)
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
Tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry.
"""
class_name = self.encode_class(sound_event)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
)
return self._roi_mapper.encode(sound_event.sound_event)
def decode_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(),
)
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
return Targets(config=config)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
) -> Targets:
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
"""
config = load_target_config(
config_path,
field=field,
)
return build_targets(config)
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)
yield class_name, position, size

View File

@ -0,0 +1,84 @@
from collections import Counter
from typing import List, Optional
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.targets.classes import (
DEFAULT_CLASSES,
DEFAULT_DETECTION_CLASS,
TargetClassConfig,
)
from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig
__all__ = [
"TargetConfig",
"load_target_config",
]
class TargetConfig(BaseConfig):
detection_target: TargetClassConfig = Field(
default=DEFAULT_DETECTION_CLASS
)
classification_targets: List[TargetClassConfig] = Field(
default_factory=lambda: DEFAULT_CLASSES
)
roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig)
@field_validator("classification_targets")
def check_unique_class_names(cls, v: List[TargetClassConfig]):
"""Ensure all defined class names are unique."""
names = [c.name for c in v]
if len(names) != len(set(names)):
name_counts = Counter(names)
duplicates = [
name for name, count in name_counts.items() if count > 1
]
raise ValueError(
"Class names must be unique. Found duplicates: "
f"{', '.join(duplicates)}"
)
return v
def load_target_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TargetConfig:
"""Load the unified target configuration from a file.
Reads a configuration file (typically YAML) and validates it against the
`TargetConfig` schema, potentially extracting data from a nested field.
Parameters
----------
path : data.PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
target configuration. If None, the entire file content is used.
Returns
-------
TargetConfig
The loaded and validated unified target configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded configuration data does not conform to the
`TargetConfig` schema (including validation within nested configs
like `ClassesConfig`).
KeyError, TypeError
If `field` specifies an invalid path within the loaded data.
"""
return load_config(path=path, schema=TargetConfig, field=field)

View File

@ -26,10 +26,10 @@ import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.typing import (
AudioLoader,
Position,
@ -265,6 +265,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig):
"""
name: Literal["peak_energy_bbox"] = "peak_energy_bbox"
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
@ -456,8 +457,11 @@ def build_roi_mapper(
)
if config.name == "peak_energy_bbox":
preprocessor = build_preprocessor(config.preprocessing)
audio_loader = build_audio_loader(config.preprocessing.audio)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return PeakEnergyBBoxMapper(
preprocessor=preprocessor,
audio_loader=audio_loader,

View File

@ -0,0 +1,308 @@
from typing import Iterable, List, Optional, Tuple
from loguru import logger
from soundevent import data
from batdetect2.data.conditions import build_sound_event_condition
from batdetect2.targets.classes import (
DEFAULT_CLASSES,
DEFAULT_DETECTION_CLASS,
build_sound_event_decoder,
build_sound_event_encoder,
get_class_names_from_config,
)
from batdetect2.targets.config import TargetConfig, load_target_config
from batdetect2.targets.rois import (
AnchorBBoxMapperConfig,
build_roi_mapper,
)
from batdetect2.typing.targets import Position, Size, TargetProtocol
class Targets(TargetProtocol):
"""Encapsulates the complete configured target definition pipeline.
This class implements the `TargetProtocol`, holding the configured
functions for filtering, transforming, encoding (tags to class name),
decoding (class name to tags), and mapping ROIs (geometry to position/size
and back). It provides a high-level interface to apply these steps and
access relevant metadata like class names and dimension names.
Instances are typically created using the `build_targets` factory function
or the `load_targets` convenience loader.
Attributes
----------
class_names : List[str]
An ordered list of the unique names of the specific target classes
defined in the configuration.
generic_class_tags : List[data.Tag]
A list of `soundevent.data.Tag` objects representing the configured
generic class category (used when no specific class matches).
dimension_names : List[str]
The names of the size dimensions handled by the ROI mapper
(e.g., ['width', 'height']).
"""
class_names: List[str]
detection_class_tags: List[data.Tag]
dimension_names: List[str]
detection_class_name: str
def __init__(self, config: TargetConfig):
"""Initialize the Targets object."""
self.config = config
self._filter_fn = build_sound_event_condition(
config.detection_target.match_if
)
self._encode_fn = build_sound_event_encoder(
config.classification_targets
)
self._decode_fn = build_sound_event_decoder(
config.classification_targets
)
self._roi_mapper = build_roi_mapper(config.roi)
self.dimension_names = self._roi_mapper.dimension_names
self.class_names = get_class_names_from_config(
config.classification_targets
)
self.detection_class_name = config.detection_target.name
self.detection_class_tags = config.detection_target.assign_tags
self._roi_mapper_overrides = {
class_config.name: build_roi_mapper(class_config.roi)
for class_config in config.classification_targets
if class_config.roi is not None
}
for class_name in self._roi_mapper_overrides:
if class_name not in self.class_names:
# TODO: improve this warning
logger.warning(
"The ROI mapper overrides contains a class ({class_name}) "
"not present in the class names.",
class_name=class_name,
)
def filter(self, sound_event: data.SoundEventAnnotation) -> bool:
"""Apply the configured filter to a sound event annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to filter.
Returns
-------
bool
True if the annotation should be kept (passes the filter),
False otherwise. If no filter was configured, always returns True.
"""
return self._filter_fn(sound_event)
def encode_class(
self, sound_event: data.SoundEventAnnotation
) -> Optional[str]:
"""Encode a sound event annotation to its target class name.
Applies the configured class definition rules (including priority)
to determine the specific class name for the annotation.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation to encode. Note: This should typically be called
*after* applying any transformations via the `transform` method.
Returns
-------
str or None
The name of the matched target class, or None if the annotation
does not match any specific class rule (i.e., it belongs to the
generic category).
"""
return self._encode_fn(sound_event)
def decode_class(self, class_label: str) -> List[data.Tag]:
"""Decode a predicted class name back into representative tags.
Uses the configured mapping (based on `TargetClass.output_tags` or
`TargetClass.tags`) to convert a class name string into a list of
`soundevent.data.Tag` objects.
Parameters
----------
class_label : str
The class name to decode.
Returns
-------
List[data.Tag]
The list of tags corresponding to the input class name.
"""
return self._decode_fn(class_label)
def encode_roi(
self, sound_event: data.SoundEventAnnotation
) -> tuple[Position, Size]:
"""Extract the target reference position from the annotation's roi.
Delegates to the internal ROI mapper's `get_roi_position` method.
Parameters
----------
sound_event : data.SoundEventAnnotation
The annotation containing the geometry (ROI).
Returns
-------
Tuple[float, float]
The reference position `(time, frequency)`.
Raises
------
ValueError
If the annotation lacks geometry.
"""
class_name = self.encode_class(sound_event)
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].encode(
sound_event.sound_event
)
return self._roi_mapper.encode(sound_event.sound_event)
def decode_roi(
self,
position: Position,
size: Size,
class_name: Optional[str] = None,
) -> data.Geometry:
"""Recover an approximate geometric ROI from a position and dimensions.
Delegates to the internal ROI mapper's `recover_roi` method, which
un-scales the dimensions and reconstructs the geometry (typically a
`BoundingBox`).
Parameters
----------
pos : Tuple[float, float]
The reference position `(time, frequency)`.
dims : np.ndarray
NumPy array with size dimensions (e.g., from model prediction),
matching the order in `self.dimension_names`.
Returns
-------
data.Geometry
The reconstructed geometry (typically `BoundingBox`).
"""
if class_name in self._roi_mapper_overrides:
return self._roi_mapper_overrides[class_name].decode(
position,
size,
)
return self._roi_mapper.decode(position, size)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
classification_targets=DEFAULT_CLASSES,
detection_target=DEFAULT_DETECTION_CLASS,
roi=AnchorBBoxMapperConfig(),
)
def build_targets(config: Optional[TargetConfig] = None) -> Targets:
"""Build a Targets object from a loaded TargetConfig.
Parameters
----------
config : TargetConfig
The loaded and validated unified target configuration object.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
KeyError
If term keys or derivation function keys specified in the `config`
are not found in their respective registries.
ImportError, AttributeError, TypeError
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
return Targets(config=config)
def load_targets(
config_path: data.PathLike,
field: Optional[str] = None,
) -> Targets:
"""Load a Targets object directly from a configuration file.
This convenience factory method loads the `TargetConfig` from the
specified file path and then calls `Targets.from_config` to build
the fully initialized `Targets` object.
Parameters
----------
config_path : data.PathLike
Path to the configuration file (e.g., YAML).
field : str, optional
Dot-separated path to a nested section within the file containing
the target configuration. If None, the entire file content is used.
Returns
-------
Targets
An initialized `Targets` object ready for use.
Raises
------
FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError,
TypeError
Errors raised during file loading, validation, or extraction via
`load_target_config`.
KeyError, ImportError, AttributeError, TypeError
Errors raised during the build process by `Targets.from_config`
(e.g., missing keys in registries, failed imports).
"""
config = load_target_config(
config_path,
field=field,
)
return build_targets(config)
def iterate_encoded_sound_events(
sound_events: Iterable[data.SoundEventAnnotation],
targets: TargetProtocol,
) -> Iterable[Tuple[Optional[str], Position, Size]]:
for sound_event in sound_events:
if not targets.filter(sound_event):
continue
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
class_name = targets.encode_class(sound_event)
position, size = targets.encode_roi(sound_event)
yield class_name, position, size

View File

@ -16,10 +16,8 @@ from batdetect2.train.augmentations import (
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
load_full_training_config,
load_train_config,
)
from batdetect2.train.dataset import (
@ -48,7 +46,6 @@ __all__ = [
"DetectionLossConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
@ -71,7 +68,6 @@ __all__ = [
"build_trainer",
"build_val_dataset",
"build_val_loader",
"load_full_training_config",
"load_label_config",
"load_train_config",
"mask_frequency",

View File

@ -4,8 +4,6 @@ from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
@ -22,8 +20,6 @@ from batdetect2.train.losses import LossConfig
__all__ = [
"TrainingConfig",
"load_train_config",
"FullTrainingConfig",
"load_full_training_config",
]
@ -93,18 +89,3 @@ def load_train_config(
field: Optional[str] = None,
) -> TrainingConfig:
return load_config(path, schema=TrainingConfig, field=field)
class FullTrainingConfig(ModelConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)

View File

@ -5,8 +5,9 @@ from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_audio_loader, build_preprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
import lightning as L
import torch
@ -6,11 +6,17 @@ from soundevent.data import PathLike
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.models import Model, build_model
from batdetect2.train.config import FullTrainingConfig
from batdetect2.plotting.clips import build_preprocessor
from batdetect2.postprocess import build_postprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
__all__ = [
"TrainingModule",
]
@ -21,7 +27,8 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
config: FullTrainingConfig,
config: "BatDetect2Config",
input_samplerate: int = TARGET_SAMPLERATE_HZ,
learning_rate: float = 0.001,
t_max: int = 100,
model: Optional[Model] = None,
@ -31,6 +38,7 @@ class TrainingModule(L.LightningModule):
self.save_hyperparameters(logger=False)
self.input_samplerate = input_samplerate
self.config = config
self.learning_rate = learning_rate
self.t_max = t_max
@ -39,7 +47,23 @@ class TrainingModule(L.LightningModule):
loss = build_loss(self.config.train.loss)
if model is None:
model = build_model(self.config)
targets = build_targets(self.config.targets)
preprocessor = build_preprocessor(
config=self.config.preprocess,
input_samplerate=self.input_samplerate,
)
postprocessor = build_postprocessor(
preprocessor, config=self.config.postprocess
)
model = build_model(
config=self.config.model,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
self.loss = loss
self.model = model
@ -74,16 +98,18 @@ class TrainingModule(L.LightningModule):
def load_model_from_checkpoint(
path: PathLike,
) -> Tuple[Model, FullTrainingConfig]:
) -> Tuple[Model, "BatDetect2Config"]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config
def build_training_module(
config: Optional[FullTrainingConfig] = None,
config: Optional["BatDetect2Config"] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
return TrainingModule(
config=config,
learning_rate=config.train.optimizer.learning_rate,

View File

@ -1,29 +1,31 @@
from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional
from typing import TYPE_CHECKING, List, Optional
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from soundevent import data
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import (
FullTrainingConfig,
)
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule, build_training_module
from batdetect2.train.logging import build_logger
from batdetect2.typing import (
TargetProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.typing.train import ClipLabeller
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
ClipLabeller,
PreprocessorProtocol,
TargetProtocol,
)
__all__ = [
"build_trainer",
@ -36,12 +38,13 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
evaluator: Optional[Evaluator] = None,
trainer: Optional[Trainer] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
config: Optional[FullTrainingConfig] = None,
targets: Optional["TargetProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None,
config: Optional["BatDetect2Config"] = None,
model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
@ -51,17 +54,20 @@ def train(
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
from batdetect2.config import BatDetect2Config
if seed is not None:
seed_everything(seed)
config = config or FullTrainingConfig()
config = config or BatDetect2Config()
targets = targets or build_targets(config.targets)
targets = targets or build_targets(config=config.targets)
preprocessor = preprocessor or build_preprocessor(config.preprocess)
audio_loader = audio_loader or build_audio_loader(config=config.audio)
audio_loader = audio_loader or build_audio_loader(
config=config.preprocess.audio
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
labeller = labeller or build_clip_labeler(
@ -95,7 +101,7 @@ def train(
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
module = TrainingModule.load_from_checkpoint(Path(model_path))
else:
module = build_training_module(
config,
@ -103,8 +109,9 @@ def train(
)
trainer = trainer or build_trainer(
config,
config.train,
targets=targets,
evaluator=evaluator,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
@ -121,8 +128,8 @@ def train(
def build_trainer_callbacks(
targets: TargetProtocol,
config: FullTrainingConfig,
targets: "TargetProtocol",
evaluator: Optional[Evaluator] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
@ -136,7 +143,7 @@ def build_trainer_callbacks(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = build_evaluator(config=config.evaluation, targets=targets)
evaluator = evaluator or build_evaluator(targets=targets)
return [
ModelCheckpoint(
@ -150,20 +157,21 @@ def build_trainer_callbacks(
def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
conf: TrainingConfig,
targets: "TargetProtocol",
evaluator: Optional[Evaluator] = None,
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Trainer:
trainer_conf = conf.train.trainer
trainer_conf = conf.trainer
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(
conf.train.logger,
conf.logger,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
@ -181,7 +189,7 @@ def build_trainer(
logger=train_logger,
callbacks=build_trainer_callbacks(
targets,
config=conf,
evaluator=evaluator,
checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name,
run_name=run_name,

View File

@ -13,8 +13,6 @@ from batdetect2.typing.postprocess import (
from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
SpectrogramPipeline,
)
from batdetect2.typing.targets import (
Position,
@ -60,8 +58,6 @@ __all__ = [
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SpectrogramBuilder",
"SpectrogramPipeline",
"TargetProtocol",
"TrainExample",
]

View File

@ -32,6 +32,8 @@ class AudioLoader(Protocol):
allows for different loading strategies or implementations.
"""
samplerate: int
def load_file(
self,
path: data.PathLike,
@ -125,22 +127,6 @@ class SpectrogramBuilder(Protocol):
...
class AudioPipeline(Protocol):
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
class SpectrogramPipeline(Protocol):
def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ...
def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
class PreprocessorProtocol(Protocol):
"""Defines a high-level interface for the complete preprocessing pipeline."""
@ -152,11 +138,13 @@ class PreprocessorProtocol(Protocol):
output_samplerate: float
audio_pipeline: AudioPipeline
spectrogram_pipeline: SpectrogramPipeline
def __call__(self, wav: torch.Tensor) -> torch.Tensor: ...
def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ...
def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ...
def process_numpy(self, wav: np.ndarray) -> np.ndarray:
return self(torch.tensor(wav)).numpy()