diff --git a/example_data/config.yaml b/example_data/config.yaml index 27a5ecd..8305ff2 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -36,31 +36,30 @@ targets: name: anchor_bbox anchor: top-left -preprocess: - audio: - samplerate: 256000 - resample: - enabled: True - method: "poly" +audio: + samplerate: 256000 + resample: + enabled: True + method: "poly" - spectrogram: - stft: - window_duration: 0.002 - window_overlap: 0.75 - window_fn: hann - frequencies: - max_freq: 120000 - min_freq: 10000 - size: - height: 128 - resize_factor: 0.5 - transforms: - - name: pcen - time_constant: 0.1 - gain: 0.98 - bias: 2 - power: 0.5 - - name: spectral_mean_substraction +preprocess: + stft: + window_duration: 0.002 + window_overlap: 0.75 + window_fn: hann + frequencies: + max_freq: 120000 + min_freq: 10000 + size: + height: 128 + resize_factor: 0.5 + spectrogram_transforms: + - name: pcen + time_constant: 0.1 + gain: 0.98 + bias: 2 + power: 0.5 + - name: spectral_mean_substraction postprocess: nms_kernel_size: 9 @@ -113,12 +112,44 @@ train: train_loader: batch_size: 8 + num_workers: 2 + shuffle: True + clipping_strategy: name: random_subclip duration: 0.256 + augmentations: + enabled: true + audio: + - name: mix_audio + probability: 0.2 + min_weight: 0.3 + max_weight: 0.7 + - name: add_echo + probability: 0.2 + max_delay: 0.005 + min_weight: 0.0 + max_weight: 1.0 + spectrogram: + - name: scale_volume + probability: 0.2 + min_scaling: 0.0 + max_scaling: 2.0 + - name: warp + probability: 0.2 + delta: 0.04 + - name: mask_time + probability: 0.2 + max_perc: 0.05 + max_masks: 3 + - name: mask_freq + probability: 0.2 + max_perc: 0.10 + max_masks: 3 + val_loader: num_workers: 2 clipping_strategy: @@ -141,32 +172,3 @@ train: logger: name: csv - - augmentations: - enabled: true - audio: - - name: mix_audio - probability: 0.2 - min_weight: 0.3 - max_weight: 0.7 - - name: add_echo - probability: 0.2 - max_delay: 0.005 - min_weight: 0.0 - max_weight: 1.0 - spectrogram: - - name: scale_volume - probability: 0.2 - min_scaling: 0.0 - max_scaling: 2.0 - - name: warp - probability: 0.2 - delta: 0.04 - - name: mask_time - probability: 0.2 - max_perc: 0.05 - max_masks: 3 - - name: mask_freq - probability: 0.2 - max_perc: 0.10 - max_masks: 3 diff --git a/src/batdetect2/audio.py b/src/batdetect2/audio.py new file mode 100644 index 0000000..76d60a5 --- /dev/null +++ b/src/batdetect2/audio.py @@ -0,0 +1,295 @@ +from typing import Optional + +import numpy as np +from numpy.typing import DTypeLike +from pydantic import Field +from scipy.signal import resample, resample_poly +from soundevent import audio, data +from soundfile import LibsndfileError + +from batdetect2.core import BaseConfig +from batdetect2.typing import AudioLoader + +__all__ = [ + "SoundEventAudioLoader", + "build_audio_loader", + "load_file_audio", + "load_recording_audio", + "load_clip_audio", + "resample_audio", +] + +TARGET_SAMPLERATE_HZ = 256_000 +"""Default target sample rate in Hz used if resampling is enabled.""" + + +class ResampleConfig(BaseConfig): + """Configuration for audio resampling. + + Attributes + ---------- + samplerate : int, default=256000 + The target sample rate in Hz to resample the audio to. Must be > 0. + method : str, default="poly" + The resampling algorithm to use. Options: + - "poly": Polyphase resampling using `scipy.signal.resample_poly`. + Generally fast. + - "fourier": Resampling via Fourier method using + `scipy.signal.resample`. May handle non-integer + resampling factors differently. + """ + + enabled: bool = True + method: str = "poly" + + +class AudioConfig(BaseConfig): + """Configuration for loading and initial audio preprocessing.""" + + samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) + resample: ResampleConfig = Field(default_factory=ResampleConfig) + + +def build_audio_loader(config: Optional[AudioConfig] = None) -> AudioLoader: + """Factory function to create an AudioLoader based on configuration.""" + config = config or AudioConfig() + return SoundEventAudioLoader( + samplerate=config.samplerate, + config=config.resample, + ) + + +class SoundEventAudioLoader(AudioLoader): + """Concrete implementation of the `AudioLoader`.""" + + def __init__( + self, + samplerate: int = TARGET_SAMPLERATE_HZ, + config: Optional[ResampleConfig] = None, + ): + self.samplerate = samplerate + self.config = config or ResampleConfig() + + def load_file( + self, + path: data.PathLike, + audio_dir: Optional[data.PathLike] = None, + ) -> np.ndarray: + """Load and preprocess audio directly from a file path.""" + return load_file_audio( + path, + samplerate=self.samplerate, + config=self.config, + audio_dir=audio_dir, + ) + + def load_recording( + self, + recording: data.Recording, + audio_dir: Optional[data.PathLike] = None, + ) -> np.ndarray: + """Load and preprocess the entire audio for a Recording object.""" + return load_recording_audio( + recording, + samplerate=self.samplerate, + config=self.config, + audio_dir=audio_dir, + ) + + def load_clip( + self, + clip: data.Clip, + audio_dir: Optional[data.PathLike] = None, + ) -> np.ndarray: + """Load and preprocess the audio segment defined by a Clip object.""" + return load_clip_audio( + clip, + samplerate=self.samplerate, + config=self.config, + audio_dir=audio_dir, + ) + + +def load_file_audio( + path: data.PathLike, + samplerate: Optional[int] = None, + config: Optional[ResampleConfig] = None, + audio_dir: Optional[data.PathLike] = None, + dtype: DTypeLike = np.float32, # type: ignore +) -> np.ndarray: + """Load and preprocess audio from a file path using specified config.""" + try: + recording = data.Recording.from_file(path) + except LibsndfileError as e: + raise FileNotFoundError( + f"Could not load the recording at path: {path}. Error: {e}" + ) from e + + return load_recording_audio( + recording, + samplerate=samplerate, + config=config, + dtype=dtype, + audio_dir=audio_dir, + ) + + +def load_recording_audio( + recording: data.Recording, + samplerate: Optional[int] = None, + config: Optional[ResampleConfig] = None, + audio_dir: Optional[data.PathLike] = None, + dtype: DTypeLike = np.float32, # type: ignore +) -> np.ndarray: + """Load and preprocess the entire audio content of a recording using config.""" + clip = data.Clip( + recording=recording, + start_time=0, + end_time=recording.duration, + ) + return load_clip_audio( + clip, + samplerate=samplerate, + config=config, + dtype=dtype, + audio_dir=audio_dir, + ) + + +def load_clip_audio( + clip: data.Clip, + samplerate: Optional[int] = None, + config: Optional[ResampleConfig] = None, + audio_dir: Optional[data.PathLike] = None, + dtype: DTypeLike = np.float32, # type: ignore +) -> np.ndarray: + """Load and preprocess a specific audio clip segment based on config.""" + try: + wav = ( + audio.load_clip(clip, audio_dir=audio_dir) + .sel(channel=0) + .astype(dtype) + ) + except LibsndfileError as e: + raise FileNotFoundError( + f"Could not load the recording at path: {clip.recording.path}. " + f"Error: {e}" + ) from e + + if not config or not config.enabled or samplerate is None: + return wav.data.astype(dtype) + + sr = int(1 / wav.time.attrs["step"]) + return resample_audio( + wav.data, + sr=sr, + samplerate=samplerate, + method=config.method, + ) + + +def resample_audio( + wav: np.ndarray, + sr: int, + samplerate: int = TARGET_SAMPLERATE_HZ, + method: str = "poly", +) -> np.ndarray: + """Resample an audio waveform DataArray to a target sample rate.""" + if sr == samplerate: + return wav + + if method == "poly": + return resample_audio_poly( + wav, + sr_orig=sr, + sr_new=samplerate, + ) + elif method == "fourier": + return resample_audio_fourier( + wav, + sr_orig=sr, + sr_new=samplerate, + ) + else: + raise NotImplementedError( + f"Resampling method '{method}' not implemented" + ) + + +def resample_audio_poly( + array: np.ndarray, + sr_orig: int, + sr_new: int, + axis: int = -1, +) -> np.ndarray: + """Resample a numpy array using `scipy.signal.resample_poly`. + + This method is often preferred for signals when the ratio of new + to old sample rates can be expressed as a rational number. It uses + polyphase filtering. + + Parameters + ---------- + array : np.ndarray + The input array to resample. + sr_orig : int + The original sample rate in Hz. + sr_new : int + The target sample rate in Hz. + axis : int, default=-1 + The axis of `array` along which to resample. + + Returns + ------- + np.ndarray + The array resampled to the target sample rate. + + Raises + ------ + ValueError + If sample rates are not positive. + """ + gcd = np.gcd(sr_orig, sr_new) + return resample_poly( + array, + sr_new // gcd, + sr_orig // gcd, + axis=axis, + ) + + +def resample_audio_fourier( + array: np.ndarray, + sr_orig: int, + sr_new: int, + axis: int = -1, +) -> np.ndarray: + """Resample a numpy array using `scipy.signal.resample`. + + This method uses FFTs to resample the signal. + + Parameters + ---------- + array : np.ndarray + The input array to resample. + num : int + The desired number of samples in the output array along `axis`. + axis : int, default=-1 + The axis of `array` along which to resample. + + Returns + ------- + np.ndarray + The array resampled to have `num` samples along `axis`. + + Raises + ------ + ValueError + If `num` is negative. + """ + ratio = sr_new / sr_orig + return resample( # type: ignore + array, + int(array.shape[axis] * ratio), + axis=axis, + ) diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 2562a48..77e17df 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -14,7 +14,7 @@ __all__ = ["train_command"] @click.argument("train_dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True)) -@click.option("--targets", type=click.Path(exists=True)) +@click.option("--targets-config", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True)) @@ -37,7 +37,7 @@ def train_command( ckpt_dir: Optional[Path] = None, log_dir: Optional[Path] = None, config: Optional[Path] = None, - targets: Optional[Path] = None, + targets_config: Optional[Path] = None, config_field: Optional[str] = None, seed: Optional[int] = None, train_workers: int = 0, @@ -46,13 +46,13 @@ def train_command( run_name: Optional[str] = None, verbose: int = 0, ): + from batdetect2.config import ( + BatDetect2Config, + load_full_config, + ) from batdetect2.data import load_dataset_from_config from batdetect2.targets import load_target_config - from batdetect2.train import ( - FullTrainingConfig, - load_full_training_config, - train, - ) + from batdetect2.train import train logger.remove() if verbose == 0: @@ -68,15 +68,16 @@ def train_command( logger.info("Loading training configuration...") conf = ( - load_full_training_config(config, field=config_field) + load_full_config(config, field=config_field) if config is not None - else FullTrainingConfig() + else BatDetect2Config() ) - if targets is not None: + if targets_config is not None: logger.info("Loading targets configuration...") - targets_config = load_target_config(targets) - conf = conf.model_copy(update=dict(targets=targets_config)) + conf = conf.model_copy( + update=dict(targets=load_target_config(targets_config)) + ) logger.info("Loading training dataset...") train_annotations = load_dataset_from_config(train_dataset) @@ -96,6 +97,7 @@ def train_command( logger.debug("No validation directory provided.") logger.info("Configuration and data loaded. Starting training...") + train( train_annotations=train_annotations, val_annotations=val_annotations, diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index 0c49c55..bffd563 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -1,16 +1,40 @@ -from typing import Literal +from typing import Literal, Optional +from pydantic import Field +from soundevent.data import PathLike + +from batdetect2.audio import AudioConfig from batdetect2.core import BaseConfig +from batdetect2.core.configs import load_config from batdetect2.evaluate.config import EvaluationConfig -from batdetect2.models.backbones import BackboneConfig -from batdetect2.preprocess import PreprocessingConfig +from batdetect2.models.config import BackboneConfig +from batdetect2.postprocess.config import PostprocessConfig +from batdetect2.preprocess.config import PreprocessingConfig +from batdetect2.targets.config import TargetConfig from batdetect2.train.config import TrainingConfig +__all__ = [ + "BatDetect2Config", + "load_full_config", +] + class BatDetect2Config(BaseConfig): config_version: Literal["v1"] = "v1" - train: TrainingConfig - evaluation: EvaluationConfig - model: BackboneConfig - preprocess: PreprocessingConfig + train: TrainingConfig = Field(default_factory=TrainingConfig) + evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) + model: BackboneConfig = Field(default_factory=BackboneConfig) + preprocess: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) + audio: AudioConfig = Field(default_factory=AudioConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + + +def load_full_config( + path: PathLike, + field: Optional[str] = None, +) -> BatDetect2Config: + return load_config(path, schema=BatDetect2Config, field=field) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 7bcc71b..768673a 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -3,39 +3,46 @@ from typing import List, Optional, Tuple import pandas as pd from soundevent import data +from batdetect2.audio import build_audio_loader +from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.dataframe import extract_matches_dataframe from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP from batdetect2.models import Model -from batdetect2.plotting.clips import build_audio_loader +from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol from batdetect2.postprocess import get_raw_predictions from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets -from batdetect2.train.config import FullTrainingConfig from batdetect2.train.dataset import ValidationDataset from batdetect2.train.labels import build_clip_labeler from batdetect2.train.train import build_val_loader +from batdetect2.typing import ClipLabeller, TargetProtocol def evaluate( model: Model, test_annotations: List[data.ClipAnnotation], - config: Optional[FullTrainingConfig] = None, + targets: Optional[TargetProtocol] = None, + audio_loader: Optional[AudioLoader] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + labeller: Optional[ClipLabeller] = None, + config: Optional[EvaluationConfig] = None, num_workers: Optional[int] = None, ) -> Tuple[pd.DataFrame, dict]: - config = config or FullTrainingConfig() + config = config or EvaluationConfig() - audio_loader = build_audio_loader(config.preprocess.audio) + audio_loader = audio_loader or build_audio_loader() - preprocessor = build_preprocessor(config.preprocess) + preprocessor = preprocessor or build_preprocessor( + input_samplerate=audio_loader.samplerate, + ) - targets = build_targets(config.targets) + targets = targets or build_targets() - labeller = build_clip_labeler( + labeller = labeller or build_clip_labeler( targets, min_freq=preprocessor.min_freq, max_freq=preprocessor.max_freq, - config=config.train.labels, ) loader = build_val_loader( @@ -43,7 +50,6 @@ def evaluate( audio_loader=audio_loader, labeller=labeller, preprocessor=preprocessor, - config=config.train.val_loader, num_workers=num_workers, ) @@ -52,7 +58,7 @@ def evaluate( clip_annotations = [] predictions = [] - evaluator = build_evaluator(config=config.evaluation) + evaluator = build_evaluator(config=config) for batch in loader: outputs = model.detector(batch.spec) diff --git a/src/batdetect2/evaluate/plots.py b/src/batdetect2/evaluate/plots.py index ae921ec..436e094 100644 --- a/src/batdetect2/evaluate/plots.py +++ b/src/batdetect2/evaluate/plots.py @@ -89,7 +89,7 @@ class ExampleGallery(PlotterProtocol): @classmethod def from_config(cls, config: ExampleGalleryConfig): preprocessor = build_preprocessor(config.preprocessing) - audio_loader = build_audio_loader(config.preprocessing.audio) + audio_loader = build_audio_loader(config.preprocessing.audio_transforms) return cls( examples_per_class=config.examples_per_class, preprocessor=preprocessor, diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 1bc7c14..b28d3a8 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -29,15 +29,10 @@ provided here. from typing import List, Optional import torch -from pydantic import Field -from soundevent.data import PathLike -from batdetect2.core.configs import BaseConfig, load_config from batdetect2.models.backbones import ( Backbone, - BackboneConfig, build_backbone, - load_backbone_config, ) from batdetect2.models.blocks import ( ConvConfig, @@ -51,6 +46,10 @@ from batdetect2.models.bottleneck import ( BottleneckConfig, build_bottleneck, ) +from batdetect2.models.config import ( + BackboneConfig, + load_backbone_config, +) from batdetect2.models.decoder import ( DEFAULT_DECODER_CONFIG, DecoderConfig, @@ -63,9 +62,9 @@ from batdetect2.models.encoder import ( build_encoder, ) from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead -from batdetect2.postprocess import PostprocessConfig, build_postprocessor -from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.targets import TargetConfig, build_targets +from batdetect2.postprocess import build_postprocessor +from batdetect2.preprocess import build_preprocessor +from batdetect2.targets import build_targets from batdetect2.typing.models import DetectionModel from batdetect2.typing.postprocess import ( DetectionsTensor, @@ -99,20 +98,10 @@ __all__ = [ "build_detector", "load_backbone_config", "Model", - "ModelConfig", "build_model", ] -class ModelConfig(BaseConfig): - model: BackboneConfig = Field(default_factory=BackboneConfig) - preprocess: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) - targets: TargetConfig = Field(default_factory=TargetConfig) - - class Model(torch.nn.Module): detector: DetectionModel preprocessor: PreprocessorProtocol @@ -125,14 +114,12 @@ class Model(torch.nn.Module): preprocessor: PreprocessorProtocol, postprocessor: PostprocessorProtocol, targets: TargetProtocol, - config: ModelConfig, ): super().__init__() self.detector = detector self.preprocessor = preprocessor self.postprocessor = postprocessor self.targets = targets - self.config = config def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]: spec = self.preprocessor(wav) @@ -140,32 +127,25 @@ class Model(torch.nn.Module): return self.postprocessor(outputs) -def build_model(config: Optional[ModelConfig] = None): - config = config or ModelConfig() - - targets = build_targets(config=config.targets) - - preprocessor = build_preprocessor(config=config.preprocess) - - postprocessor = build_postprocessor( +def build_model( + config: Optional[BackboneConfig] = None, + targets: Optional[TargetProtocol] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + postprocessor: Optional[PostprocessorProtocol] = None, +): + config = config or BackboneConfig() + targets = targets or build_targets() + preprocessor = preprocessor or build_preprocessor() + postprocessor = postprocessor or build_postprocessor( preprocessor=preprocessor, - config=config.postprocess, ) - detector = build_detector( num_classes=len(targets.class_names), - config=config.model, + config=config, ) return Model( - config=config, detector=detector, postprocessor=postprocessor, preprocessor=preprocessor, targets=targets, ) - - -def load_model_config( - path: PathLike, field: Optional[str] = None -) -> ModelConfig: - return load_config(path, schema=ModelConfig, field=field) diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index cf5f3b8..fd55c80 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -18,37 +18,20 @@ automatic padding to handle input sizes not perfectly divisible by the network's total downsampling factor. """ -from typing import Optional, Tuple +from typing import Tuple import torch import torch.nn.functional as F -from soundevent import data from torch import nn -from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.models.bottleneck import ( - DEFAULT_BOTTLENECK_CONFIG, - BottleneckConfig, - build_bottleneck, -) -from batdetect2.models.decoder import ( - DEFAULT_DECODER_CONFIG, - Decoder, - DecoderConfig, - build_decoder, -) -from batdetect2.models.encoder import ( - DEFAULT_ENCODER_CONFIG, - Encoder, - EncoderConfig, - build_encoder, -) +from batdetect2.models.bottleneck import build_bottleneck +from batdetect2.models.config import BackboneConfig +from batdetect2.models.decoder import Decoder, build_decoder +from batdetect2.models.encoder import Encoder, build_encoder from batdetect2.typing.models import BackboneModel __all__ = [ "Backbone", - "BackboneConfig", - "load_backbone_config", "build_backbone", ] @@ -161,82 +144,6 @@ class Backbone(BackboneModel): return x -class BackboneConfig(BaseConfig): - """Configuration for the Encoder-Decoder Backbone network. - - Aggregates configurations for the encoder, bottleneck, and decoder - components, along with defining the input and final output dimensions - for the complete backbone. - - Attributes - ---------- - input_height : int, default=128 - Expected height (frequency bins) of the input spectrograms to the - backbone. Must be positive. - in_channels : int, default=1 - Expected number of channels in the input spectrograms (e.g., 1 for - mono). Must be positive. - encoder : EncoderConfig, optional - Configuration for the encoder. If None or omitted, - the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the - encoder module) will be used. - bottleneck : BottleneckConfig, optional - Configuration for the bottleneck layer connecting encoder and decoder. - If None or omitted, the default bottleneck configuration will be used. - decoder : DecoderConfig, optional - Configuration for the decoder. If None or omitted, - the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the - decoder module) will be used. - out_channels : int, default=32 - Desired number of channels in the final feature map output by the - backbone. Must be positive. - """ - - input_height: int = 128 - in_channels: int = 1 - encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG - bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG - decoder: DecoderConfig = DEFAULT_DECODER_CONFIG - out_channels: int = 32 - - -def load_backbone_config( - path: data.PathLike, - field: Optional[str] = None, -) -> BackboneConfig: - """Load the backbone configuration from a file. - - Reads a configuration file (YAML) and validates it against the - `BackboneConfig` schema, potentially extracting data from a nested field. - - Parameters - ---------- - path : PathLike - Path to the configuration file. - field : str, optional - Dot-separated path to a nested section within the file containing the - backbone configuration (e.g., "model.backbone"). If None, the entire - file content is used. - - Returns - ------- - BackboneConfig - The loaded and validated backbone configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - yaml.YAMLError - If the file content is not valid YAML. - pydantic.ValidationError - If the loaded config data does not conform to `BackboneConfig`. - KeyError, TypeError - If `field` specifies an invalid path. - """ - return load_config(path, schema=BackboneConfig, field=field) - - def build_backbone(config: BackboneConfig) -> BackboneModel: """Factory function to build a Backbone from configuration. diff --git a/src/batdetect2/models/config.py b/src/batdetect2/models/config.py new file mode 100644 index 0000000..0f34b34 --- /dev/null +++ b/src/batdetect2/models/config.py @@ -0,0 +1,98 @@ +from typing import Optional + +from soundevent import data + +from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.models.bottleneck import ( + DEFAULT_BOTTLENECK_CONFIG, + BottleneckConfig, +) +from batdetect2.models.decoder import ( + DEFAULT_DECODER_CONFIG, + DecoderConfig, +) +from batdetect2.models.encoder import ( + DEFAULT_ENCODER_CONFIG, + EncoderConfig, +) + +__all__ = [ + "BackboneConfig", + "load_backbone_config", +] + + +class BackboneConfig(BaseConfig): + """Configuration for the Encoder-Decoder Backbone network. + + Aggregates configurations for the encoder, bottleneck, and decoder + components, along with defining the input and final output dimensions + for the complete backbone. + + Attributes + ---------- + input_height : int, default=128 + Expected height (frequency bins) of the input spectrograms to the + backbone. Must be positive. + in_channels : int, default=1 + Expected number of channels in the input spectrograms (e.g., 1 for + mono). Must be positive. + encoder : EncoderConfig, optional + Configuration for the encoder. If None or omitted, + the default encoder configuration (`DEFAULT_ENCODER_CONFIG` from the + encoder module) will be used. + bottleneck : BottleneckConfig, optional + Configuration for the bottleneck layer connecting encoder and decoder. + If None or omitted, the default bottleneck configuration will be used. + decoder : DecoderConfig, optional + Configuration for the decoder. If None or omitted, + the default decoder configuration (`DEFAULT_DECODER_CONFIG` from the + decoder module) will be used. + out_channels : int, default=32 + Desired number of channels in the final feature map output by the + backbone. Must be positive. + """ + + input_height: int = 128 + in_channels: int = 1 + encoder: EncoderConfig = DEFAULT_ENCODER_CONFIG + bottleneck: BottleneckConfig = DEFAULT_BOTTLENECK_CONFIG + decoder: DecoderConfig = DEFAULT_DECODER_CONFIG + out_channels: int = 32 + + +def load_backbone_config( + path: data.PathLike, + field: Optional[str] = None, +) -> BackboneConfig: + """Load the backbone configuration from a file. + + Reads a configuration file (YAML) and validates it against the + `BackboneConfig` schema, potentially extracting data from a nested field. + + Parameters + ---------- + path : PathLike + Path to the configuration file. + field : str, optional + Dot-separated path to a nested section within the file containing the + backbone configuration (e.g., "model.backbone"). If None, the entire + file content is used. + + Returns + ------- + BackboneConfig + The loaded and validated backbone configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + yaml.YAMLError + If the file content is not valid YAML. + pydantic.ValidationError + If the loaded config data does not conform to `BackboneConfig`. + KeyError, TypeError + If `field` specifies an invalid path. + """ + return load_config(path, schema=BackboneConfig, field=field) diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py index 051bf6c..63e978f 100644 --- a/src/batdetect2/plotting/clips.py +++ b/src/batdetect2/plotting/clips.py @@ -5,8 +5,9 @@ import torch from matplotlib.axes import Axes from soundevent import data +from batdetect2.audio import build_audio_loader from batdetect2.plotting.common import plot_spectrogram -from batdetect2.preprocess import build_audio_loader, build_preprocessor +from batdetect2.preprocess import build_preprocessor from batdetect2.typing import AudioLoader, PreprocessorProtocol __all__ = [ diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index 58ceafc..c4ad923 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -1,307 +1,29 @@ """Main entry point for the BatDetect2 Postprocessing pipeline.""" -from typing import List, Optional - -import torch -from loguru import logger -from pydantic import Field -from soundevent import data - -from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.postprocess.config import ( + PostprocessConfig, + load_postprocess_config, +) from batdetect2.postprocess.decoding import ( - DEFAULT_CLASSIFICATION_THRESHOLD, - convert_raw_prediction_to_sound_event_prediction, convert_raw_predictions_to_clip_prediction, to_raw_predictions, ) -from batdetect2.postprocess.extraction import extract_prediction_tensor from batdetect2.postprocess.nms import ( - NMS_KERNEL_SIZE, non_max_suppression, ) -from batdetect2.postprocess.remapping import map_detection_to_clip -from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.typing import ModelOutput -from batdetect2.typing.postprocess import ( - BatDetect2Prediction, - DetectionsTensor, - PostprocessorProtocol, - RawPrediction, +from batdetect2.postprocess.postprocessor import ( + Postprocessor, + build_postprocessor, + get_raw_predictions, ) -from batdetect2.typing.preprocess import PreprocessorProtocol -from batdetect2.typing.targets import TargetProtocol __all__ = [ - "DEFAULT_CLASSIFICATION_THRESHOLD", - "DEFAULT_DETECTION_THRESHOLD", - "MAX_FREQ", - "MIN_FREQ", - "ModelOutput", - "NMS_KERNEL_SIZE", "PostprocessConfig", "Postprocessor", - "TOP_K_PER_SEC", "build_postprocessor", "convert_raw_predictions_to_clip_prediction", "to_raw_predictions", "load_postprocess_config", "non_max_suppression", + "get_raw_predictions", ] - -DEFAULT_DETECTION_THRESHOLD = 0.01 - - -TOP_K_PER_SEC = 100 - - -class PostprocessConfig(BaseConfig): - """Configuration settings for the postprocessing pipeline. - - Defines tunable parameters that control how raw model outputs are - converted into final detections. - - Attributes - ---------- - nms_kernel_size : int, default=NMS_KERNEL_SIZE - Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression. - Used to suppress weaker detections near stronger peaks. Must be - positive. - detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD - Minimum confidence score from the detection heatmap required to - consider a point as a potential detection. Must be >= 0. - classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD - Minimum confidence score for a specific class prediction to be included - in the decoded tags for a detection. Must be >= 0. - top_k_per_sec : int, default=TOP_K_PER_SEC - Desired maximum number of detections per second of audio. Used by - `get_max_detections` to calculate an absolute limit based on clip - duration before applying `extract_detections_from_array`. Must be - positive. - """ - - nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) - detection_threshold: float = Field( - default=DEFAULT_DETECTION_THRESHOLD, - ge=0, - ) - classification_threshold: float = Field( - default=DEFAULT_CLASSIFICATION_THRESHOLD, - ge=0, - ) - top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) - - -def load_postprocess_config( - path: data.PathLike, - field: Optional[str] = None, -) -> PostprocessConfig: - """Load the postprocessing configuration from a file. - - Reads a configuration file (YAML) and validates it against the - `PostprocessConfig` schema, potentially extracting data from a nested - field. - - Parameters - ---------- - path : PathLike - Path to the configuration file. - field : str, optional - Dot-separated path to a nested section within the file containing the - postprocessing configuration (e.g., "inference.postprocessing"). - If None, the entire file content is used. - - Returns - ------- - PostprocessConfig - The loaded and validated postprocessing configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - yaml.YAMLError - If the file content is not valid YAML. - pydantic.ValidationError - If the loaded configuration data does not conform to the - `PostprocessConfig` schema. - KeyError, TypeError - If `field` specifies an invalid path within the loaded data. - """ - return load_config(path, schema=PostprocessConfig, field=field) - - -def build_postprocessor( - preprocessor: PreprocessorProtocol, - config: Optional[PostprocessConfig] = None, -) -> PostprocessorProtocol: - """Factory function to build the standard postprocessor.""" - config = config or PostprocessConfig() - logger.opt(lazy=True).debug( - "Building postprocessor with config: \n{}", - lambda: config.to_yaml_string(), - ) - return Postprocessor( - samplerate=preprocessor.output_samplerate, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - top_k_per_sec=config.top_k_per_sec, - detection_threshold=config.detection_threshold, - ) - - -class Postprocessor(torch.nn.Module, PostprocessorProtocol): - """Standard implementation of the postprocessing pipeline.""" - - def __init__( - self, - samplerate: float, - min_freq: float, - max_freq: float, - top_k_per_sec: int = 200, - detection_threshold: float = 0.01, - ): - """Initialize the Postprocessor.""" - super().__init__() - self.samplerate = samplerate - self.min_freq = min_freq - self.max_freq = max_freq - self.top_k_per_sec = top_k_per_sec - self.detection_threshold = detection_threshold - - def forward(self, output: ModelOutput) -> List[DetectionsTensor]: - width = output.detection_probs.shape[-1] - duration = width / self.samplerate - max_detections = int(self.top_k_per_sec * duration) - detections = extract_prediction_tensor( - output, - max_detections=max_detections, - threshold=self.detection_threshold, - ) - return [ - map_detection_to_clip( - detection, - start_time=0, - end_time=duration, - min_freq=self.min_freq, - max_freq=self.max_freq, - ) - for detection in detections - ] - - def get_detections( - self, - output: ModelOutput, - start_times: Optional[List[float]] = None, - ) -> List[DetectionsTensor]: - width = output.detection_probs.shape[-1] - duration = width / self.samplerate - max_detections = int(self.top_k_per_sec * duration) - - detections = extract_prediction_tensor( - output, - max_detections=max_detections, - threshold=self.detection_threshold, - ) - - if start_times is None: - return detections - - width = output.detection_probs.shape[-1] - duration = width / self.samplerate - return [ - map_detection_to_clip( - detection, - start_time=start_time, - end_time=start_time + duration, - min_freq=self.min_freq, - max_freq=self.max_freq, - ) - for detection, start_time in zip(detections, start_times) - ] - - -def get_raw_predictions( - output: ModelOutput, - targets: TargetProtocol, - postprocessor: PostprocessorProtocol, - start_times: Optional[List[float]] = None, -) -> List[List[RawPrediction]]: - """Extract intermediate RawPrediction objects for a batch.""" - detections = postprocessor.get_detections(output, start_times) - return [ - to_raw_predictions(detection.numpy(), targets=targets) - for detection in detections - ] - - -def get_sound_event_predictions( - output: ModelOutput, - targets: TargetProtocol, - postprocessor: PostprocessorProtocol, - clips: List[data.Clip], - classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, -) -> List[List[BatDetect2Prediction]]: - raw_predictions = get_raw_predictions( - output, - targets=targets, - postprocessor=postprocessor, - start_times=[clip.start_time for clip in clips], - ) - return [ - [ - BatDetect2Prediction( - raw=raw, - sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( - raw, - recording=clip.recording, - targets=targets, - classification_threshold=classification_threshold, - ), - ) - for raw in predictions - ] - for predictions, clip in zip(raw_predictions, clips) - ] - - -def get_predictions( - output: ModelOutput, - clips: List[data.Clip], - targets: TargetProtocol, - postprocessor: PostprocessorProtocol, - classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, -) -> List[data.ClipPrediction]: - """Perform the full postprocessing pipeline for a batch. - - Takes raw model output and corresponding clips, applies the entire - configured chain (NMS, remapping, extraction, geometry recovery, class - decoding), producing final `soundevent.data.ClipPrediction` objects. - - Parameters - ---------- - output : ModelOutput - Raw output from the neural network model for a batch. - clips : List[data.Clip] - List of `soundevent.data.Clip` objects corresponding to the batch. - - Returns - ------- - List[data.ClipPrediction] - List containing one `ClipPrediction` object for each input clip, - populated with `SoundEventPrediction` objects. - """ - raw_predictions = get_raw_predictions( - output, - targets=targets, - postprocessor=postprocessor, - start_times=[clip.start_time for clip in clips], - ) - return [ - convert_raw_predictions_to_clip_prediction( - prediction, - clip, - targets=targets, - classification_threshold=classification_threshold, - ) - for prediction, clip in zip(raw_predictions, clips) - ] diff --git a/src/batdetect2/postprocess/config.py b/src/batdetect2/postprocess/config.py new file mode 100644 index 0000000..8299d5c --- /dev/null +++ b/src/batdetect2/postprocess/config.py @@ -0,0 +1,94 @@ +from typing import Optional + +from pydantic import Field +from soundevent import data + +from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD +from batdetect2.postprocess.nms import NMS_KERNEL_SIZE + +__all__ = [ + "PostprocessConfig", + "load_postprocess_config", +] + +DEFAULT_DETECTION_THRESHOLD = 0.01 + + +TOP_K_PER_SEC = 100 + + +class PostprocessConfig(BaseConfig): + """Configuration settings for the postprocessing pipeline. + + Defines tunable parameters that control how raw model outputs are + converted into final detections. + + Attributes + ---------- + nms_kernel_size : int, default=NMS_KERNEL_SIZE + Size (pixels) of the kernel/neighborhood for Non-Maximum Suppression. + Used to suppress weaker detections near stronger peaks. Must be + positive. + detection_threshold : float, default=DEFAULT_DETECTION_THRESHOLD + Minimum confidence score from the detection heatmap required to + consider a point as a potential detection. Must be >= 0. + classification_threshold : float, default=DEFAULT_CLASSIFICATION_THRESHOLD + Minimum confidence score for a specific class prediction to be included + in the decoded tags for a detection. Must be >= 0. + top_k_per_sec : int, default=TOP_K_PER_SEC + Desired maximum number of detections per second of audio. Used by + `get_max_detections` to calculate an absolute limit based on clip + duration before applying `extract_detections_from_array`. Must be + positive. + """ + + nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0) + detection_threshold: float = Field( + default=DEFAULT_DETECTION_THRESHOLD, + ge=0, + ) + classification_threshold: float = Field( + default=DEFAULT_CLASSIFICATION_THRESHOLD, + ge=0, + ) + top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0) + + +def load_postprocess_config( + path: data.PathLike, + field: Optional[str] = None, +) -> PostprocessConfig: + """Load the postprocessing configuration from a file. + + Reads a configuration file (YAML) and validates it against the + `PostprocessConfig` schema, potentially extracting data from a nested + field. + + Parameters + ---------- + path : PathLike + Path to the configuration file. + field : str, optional + Dot-separated path to a nested section within the file containing the + postprocessing configuration (e.g., "inference.postprocessing"). + If None, the entire file content is used. + + Returns + ------- + PostprocessConfig + The loaded and validated postprocessing configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + yaml.YAMLError + If the file content is not valid YAML. + pydantic.ValidationError + If the loaded configuration data does not conform to the + `PostprocessConfig` schema. + KeyError, TypeError + If `field` specifies an invalid path within the loaded data. + """ + return load_config(path, schema=PostprocessConfig, field=field) diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py new file mode 100644 index 0000000..5c13bd3 --- /dev/null +++ b/src/batdetect2/postprocess/postprocessor.py @@ -0,0 +1,208 @@ +from typing import List, Optional + +import torch +from loguru import logger +from soundevent import data + +from batdetect2.postprocess.config import ( + PostprocessConfig, +) +from batdetect2.postprocess.decoding import ( + DEFAULT_CLASSIFICATION_THRESHOLD, + convert_raw_prediction_to_sound_event_prediction, + convert_raw_predictions_to_clip_prediction, + to_raw_predictions, +) +from batdetect2.postprocess.extraction import extract_prediction_tensor +from batdetect2.postprocess.remapping import map_detection_to_clip +from batdetect2.typing import ModelOutput +from batdetect2.typing.postprocess import ( + BatDetect2Prediction, + DetectionsTensor, + PostprocessorProtocol, + RawPrediction, +) +from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.typing.targets import TargetProtocol + +__all__ = [ + "build_postprocessor", + "Postprocessor", +] + + +def build_postprocessor( + preprocessor: PreprocessorProtocol, + config: Optional[PostprocessConfig] = None, +) -> PostprocessorProtocol: + """Factory function to build the standard postprocessor.""" + config = config or PostprocessConfig() + logger.opt(lazy=True).debug( + "Building postprocessor with config: \n{}", + lambda: config.to_yaml_string(), + ) + return Postprocessor( + samplerate=preprocessor.output_samplerate, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + top_k_per_sec=config.top_k_per_sec, + detection_threshold=config.detection_threshold, + ) + + +class Postprocessor(torch.nn.Module, PostprocessorProtocol): + """Standard implementation of the postprocessing pipeline.""" + + def __init__( + self, + samplerate: float, + min_freq: float, + max_freq: float, + top_k_per_sec: int = 200, + detection_threshold: float = 0.01, + ): + """Initialize the Postprocessor.""" + super().__init__() + self.samplerate = samplerate + self.min_freq = min_freq + self.max_freq = max_freq + self.top_k_per_sec = top_k_per_sec + self.detection_threshold = detection_threshold + + def forward(self, output: ModelOutput) -> List[DetectionsTensor]: + width = output.detection_probs.shape[-1] + duration = width / self.samplerate + max_detections = int(self.top_k_per_sec * duration) + detections = extract_prediction_tensor( + output, + max_detections=max_detections, + threshold=self.detection_threshold, + ) + return [ + map_detection_to_clip( + detection, + start_time=0, + end_time=duration, + min_freq=self.min_freq, + max_freq=self.max_freq, + ) + for detection in detections + ] + + def get_detections( + self, + output: ModelOutput, + start_times: Optional[List[float]] = None, + ) -> List[DetectionsTensor]: + width = output.detection_probs.shape[-1] + duration = width / self.samplerate + max_detections = int(self.top_k_per_sec * duration) + + detections = extract_prediction_tensor( + output, + max_detections=max_detections, + threshold=self.detection_threshold, + ) + + if start_times is None: + return detections + + width = output.detection_probs.shape[-1] + duration = width / self.samplerate + return [ + map_detection_to_clip( + detection, + start_time=start_time, + end_time=start_time + duration, + min_freq=self.min_freq, + max_freq=self.max_freq, + ) + for detection, start_time in zip(detections, start_times) + ] + + +def get_raw_predictions( + output: ModelOutput, + targets: TargetProtocol, + postprocessor: PostprocessorProtocol, + start_times: Optional[List[float]] = None, +) -> List[List[RawPrediction]]: + """Extract intermediate RawPrediction objects for a batch.""" + detections = postprocessor.get_detections(output, start_times) + return [ + to_raw_predictions(detection.numpy(), targets=targets) + for detection in detections + ] + + +def get_sound_event_predictions( + output: ModelOutput, + targets: TargetProtocol, + postprocessor: PostprocessorProtocol, + clips: List[data.Clip], + classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, +) -> List[List[BatDetect2Prediction]]: + raw_predictions = get_raw_predictions( + output, + targets=targets, + postprocessor=postprocessor, + start_times=[clip.start_time for clip in clips], + ) + return [ + [ + BatDetect2Prediction( + raw=raw, + sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( + raw, + recording=clip.recording, + targets=targets, + classification_threshold=classification_threshold, + ), + ) + for raw in predictions + ] + for predictions, clip in zip(raw_predictions, clips) + ] + + +def get_predictions( + output: ModelOutput, + clips: List[data.Clip], + targets: TargetProtocol, + postprocessor: PostprocessorProtocol, + classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, +) -> List[data.ClipPrediction]: + """Perform the full postprocessing pipeline for a batch. + + Takes raw model output and corresponding clips, applies the entire + configured chain (NMS, remapping, extraction, geometry recovery, class + decoding), producing final `soundevent.data.ClipPrediction` objects. + + Parameters + ---------- + output : ModelOutput + Raw output from the neural network model for a batch. + clips : List[data.Clip] + List of `soundevent.data.Clip` objects corresponding to the batch. + + Returns + ------- + List[data.ClipPrediction] + List containing one `ClipPrediction` object for each input clip, + populated with `SoundEventPrediction` objects. + """ + raw_predictions = get_raw_predictions( + output, + targets=targets, + postprocessor=postprocessor, + start_times=[clip.start_time for clip in clips], + ) + return [ + convert_raw_predictions_to_clip_prediction( + prediction, + clip, + targets=targets, + classification_threshold=classification_threshold, + ) + for prediction, clip in zip(raw_predictions, clips) + ] diff --git a/src/batdetect2/preprocess/__init__.py b/src/batdetect2/preprocess/__init__.py index 4118a5f..38c34ea 100644 --- a/src/batdetect2/preprocess/__init__.py +++ b/src/batdetect2/preprocess/__init__.py @@ -1,21 +1,19 @@ """Main entry point for the BatDetect2 preprocessing subsystem.""" -from batdetect2.preprocess.audio import build_audio_loader +from batdetect2.audio import TARGET_SAMPLERATE_HZ from batdetect2.preprocess.config import ( - MAX_FREQ, - MIN_FREQ, - TARGET_SAMPLERATE_HZ, PreprocessingConfig, load_preprocessing_config, ) -from batdetect2.preprocess.preprocessor import build_preprocessor +from batdetect2.preprocess.preprocessor import Preprocessor, build_preprocessor +from batdetect2.preprocess.spectrogram import MAX_FREQ, MIN_FREQ __all__ = [ - "MIN_FREQ", "MAX_FREQ", - "TARGET_SAMPLERATE_HZ", + "MIN_FREQ", "PreprocessingConfig", - "load_preprocessing_config", + "Preprocessor", + "TARGET_SAMPLERATE_HZ", "build_preprocessor", - "build_audio_loader", + "load_preprocessing_config", ] diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index 9ca5984..a6debb0 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -1,267 +1,60 @@ -"""Handles loading and initial preprocessing of audio waveforms.""" +from typing import Annotated, Literal, Union -from typing import Optional - -import numpy as np import torch -from numpy.typing import DTypeLike -from scipy.signal import resample, resample_poly -from soundevent import audio, data -from soundfile import LibsndfileError +from pydantic import Field -from batdetect2.preprocess.common import CenterTensor, PeakNormalize -from batdetect2.preprocess.config import ( - TARGET_SAMPLERATE_HZ, - AudioConfig, - AudioTransform, - ResampleConfig, -) -from batdetect2.typing import AudioLoader +from batdetect2.audio import TARGET_SAMPLERATE_HZ +from batdetect2.core import BaseConfig, Registry +from batdetect2.preprocess.common import center_tensor, peak_normalize __all__ = [ - "SoundEventAudioLoader", - "build_audio_loader", - "load_file_audio", - "load_recording_audio", - "load_clip_audio", - "resample_audio", + "CenterAudioConfig", + "ScaleAudioConfig", + "FixDurationConfig", + "build_audio_transform", ] -class SoundEventAudioLoader(AudioLoader): - """Concrete implementation of the `AudioLoader`.""" - - def __init__( - self, - samplerate: int = TARGET_SAMPLERATE_HZ, - config: Optional[ResampleConfig] = None, - ): - self.samplerate = samplerate - self.config = config or ResampleConfig() - - def load_file( - self, - path: data.PathLike, - audio_dir: Optional[data.PathLike] = None, - ) -> np.ndarray: - """Load and preprocess audio directly from a file path.""" - return load_file_audio( - path, - samplerate=self.samplerate, - config=self.config, - audio_dir=audio_dir, - ) - - def load_recording( - self, - recording: data.Recording, - audio_dir: Optional[data.PathLike] = None, - ) -> np.ndarray: - """Load and preprocess the entire audio for a Recording object.""" - return load_recording_audio( - recording, - samplerate=self.samplerate, - config=self.config, - audio_dir=audio_dir, - ) - - def load_clip( - self, - clip: data.Clip, - audio_dir: Optional[data.PathLike] = None, - ) -> np.ndarray: - """Load and preprocess the audio segment defined by a Clip object.""" - return load_clip_audio( - clip, - samplerate=self.samplerate, - config=self.config, - audio_dir=audio_dir, - ) +audio_transforms: Registry[torch.nn.Module, [int]] = Registry( + "audio_transform" +) -def load_file_audio( - path: data.PathLike, - samplerate: Optional[int] = None, - config: Optional[ResampleConfig] = None, - audio_dir: Optional[data.PathLike] = None, - dtype: DTypeLike = np.float32, # type: ignore -) -> np.ndarray: - """Load and preprocess audio from a file path using specified config.""" - try: - recording = data.Recording.from_file(path) - except LibsndfileError as e: - raise FileNotFoundError( - f"Could not load the recording at path: {path}. Error: {e}" - ) from e - - return load_recording_audio( - recording, - samplerate=samplerate, - config=config, - dtype=dtype, - audio_dir=audio_dir, - ) +class CenterAudioConfig(BaseConfig): + name: Literal["center_audio"] = "center_audio" -def load_recording_audio( - recording: data.Recording, - samplerate: Optional[int] = None, - config: Optional[ResampleConfig] = None, - audio_dir: Optional[data.PathLike] = None, - dtype: DTypeLike = np.float32, # type: ignore -) -> np.ndarray: - """Load and preprocess the entire audio content of a recording using config.""" - clip = data.Clip( - recording=recording, - start_time=0, - end_time=recording.duration, - ) - return load_clip_audio( - clip, - samplerate=samplerate, - config=config, - dtype=dtype, - audio_dir=audio_dir, - ) +class CenterAudio(torch.nn.Module): + def forward(self, wav: torch.Tensor) -> torch.Tensor: + return center_tensor(wav) + + @classmethod + def from_config(cls, config: CenterAudioConfig, samplerate: int): + return cls() -def load_clip_audio( - clip: data.Clip, - samplerate: Optional[int] = None, - config: Optional[ResampleConfig] = None, - audio_dir: Optional[data.PathLike] = None, - dtype: DTypeLike = np.float32, # type: ignore -) -> np.ndarray: - """Load and preprocess a specific audio clip segment based on config.""" - try: - wav = ( - audio.load_clip(clip, audio_dir=audio_dir) - .sel(channel=0) - .astype(dtype) - ) - except LibsndfileError as e: - raise FileNotFoundError( - f"Could not load the recording at path: {clip.recording.path}. " - f"Error: {e}" - ) from e - - if not config or not config.enabled or samplerate is None: - return wav.data.astype(dtype) - - sr = int(1 / wav.time.attrs["step"]) - return resample_audio( - wav.data, - sr=sr, - samplerate=samplerate, - method=config.method, - ) +audio_transforms.register(CenterAudioConfig, CenterAudio) -def resample_audio( - wav: np.ndarray, - sr: int, - samplerate: int = TARGET_SAMPLERATE_HZ, - method: str = "poly", -) -> np.ndarray: - """Resample an audio waveform DataArray to a target sample rate.""" - if sr == samplerate: - return wav - - if method == "poly": - return resample_audio_poly( - wav, - sr_orig=sr, - sr_new=samplerate, - ) - elif method == "fourier": - return resample_audio_fourier( - wav, - sr_orig=sr, - sr_new=samplerate, - ) - else: - raise NotImplementedError( - f"Resampling method '{method}' not implemented" - ) +class ScaleAudioConfig(BaseConfig): + name: Literal["scale_audio"] = "scale_audio" -def resample_audio_poly( - array: np.ndarray, - sr_orig: int, - sr_new: int, - axis: int = -1, -) -> np.ndarray: - """Resample a numpy array using `scipy.signal.resample_poly`. +class ScaleAudio(torch.nn.Module): + def forward(self, wav: torch.Tensor) -> torch.Tensor: + return peak_normalize(wav) - This method is often preferred for signals when the ratio of new - to old sample rates can be expressed as a rational number. It uses - polyphase filtering. - - Parameters - ---------- - array : np.ndarray - The input array to resample. - sr_orig : int - The original sample rate in Hz. - sr_new : int - The target sample rate in Hz. - axis : int, default=-1 - The axis of `array` along which to resample. - - Returns - ------- - np.ndarray - The array resampled to the target sample rate. - - Raises - ------ - ValueError - If sample rates are not positive. - """ - gcd = np.gcd(sr_orig, sr_new) - return resample_poly( - array, - sr_new // gcd, - sr_orig // gcd, - axis=axis, - ) + @classmethod + def from_config(cls, config: ScaleAudioConfig, samplerate: int): + return cls() -def resample_audio_fourier( - array: np.ndarray, - sr_orig: int, - sr_new: int, - axis: int = -1, -) -> np.ndarray: - """Resample a numpy array using `scipy.signal.resample`. +audio_transforms.register(ScaleAudioConfig, ScaleAudio) - This method uses FFTs to resample the signal. - Parameters - ---------- - array : np.ndarray - The input array to resample. - num : int - The desired number of samples in the output array along `axis`. - axis : int, default=-1 - The axis of `array` along which to resample. - - Returns - ------- - np.ndarray - The array resampled to have `num` samples along `axis`. - - Raises - ------ - ValueError - If `num` is negative. - """ - ratio = sr_new / sr_orig - return resample( # type: ignore - array, - int(array.shape[axis] * ratio), - axis=axis, - ) +class FixDurationConfig(BaseConfig): + name: Literal["fix_duration"] = "fix_duration" + duration: float = 0.5 class FixDuration(torch.nn.Module): @@ -282,40 +75,25 @@ class FixDuration(torch.nn.Module): return torch.nn.functional.pad(wav, (0, self.length - length)) - -def build_audio_loader( - config: Optional[AudioConfig] = None, -) -> AudioLoader: - """Factory function to create an AudioLoader based on configuration.""" - config = config or AudioConfig() - return SoundEventAudioLoader( - samplerate=config.samplerate, - config=config.resample, - ) + @classmethod + def from_config(cls, config: FixDurationConfig, samplerate: int): + return cls(samplerate=samplerate, duration=config.duration) -def build_audio_transform_step( +audio_transforms.register(FixDurationConfig, FixDuration) + +AudioTransform = Annotated[ + Union[ + FixDurationConfig, + ScaleAudioConfig, + CenterAudioConfig, + ], + Field(discriminator="name"), +] + + +def build_audio_transform( config: AudioTransform, - samplerate: int, + samplerate: int = TARGET_SAMPLERATE_HZ, ) -> torch.nn.Module: - if config.name == "fix_duration": - return FixDuration(samplerate=samplerate, duration=config.duration) - - if config.name == "scale_audio": - return PeakNormalize() - - if config.name == "center_audio": - return CenterTensor() - - raise NotImplementedError( - f"Audio preprocessing step {config.name} not implemented" - ) - - -def build_audio_pipeline(config: AudioConfig) -> torch.nn.Module: - return torch.nn.Sequential( - *[ - build_audio_transform_step(step, samplerate=config.samplerate) - for step in config.transforms - ] - ) + return audio_transforms.build(config, samplerate) diff --git a/src/batdetect2/preprocess/common.py b/src/batdetect2/preprocess/common.py index 37bf2a2..c498063 100644 --- a/src/batdetect2/preprocess/common.py +++ b/src/batdetect2/preprocess/common.py @@ -1,24 +1,22 @@ import torch __all__ = [ - "CenterTensor", - "PeakNormalize", + "center_tensor", + "peak_normalize", ] -class CenterTensor(torch.nn.Module): - def forward(self, wav: torch.Tensor): - return wav - wav.mean() +def center_tensor(tensor: torch.Tensor) -> torch.Tensor: + return tensor - tensor.mean() -class PeakNormalize(torch.nn.Module): - def forward(self, wav: torch.Tensor): - max_value = wav.abs().min() +def peak_normalize(tensor: torch.Tensor) -> torch.Tensor: + max_value = tensor.abs().min() - denominator = torch.where( - max_value == 0, - torch.tensor(1.0, device=wav.device, dtype=wav.dtype), - max_value, - ) + denominator = torch.where( + max_value == 0, + torch.tensor(1.0, device=tensor.device, dtype=tensor.dtype), + max_value, + ) - return wav / denominator + return tensor / denominator diff --git a/src/batdetect2/preprocess/config.py b/src/batdetect2/preprocess/config.py index b60c067..2ac8150 100644 --- a/src/batdetect2/preprocess/config.py +++ b/src/batdetect2/preprocess/config.py @@ -1,187 +1,25 @@ -from collections.abc import Sequence -from typing import Annotated, List, Literal, Optional, Union +from typing import List, Optional from pydantic import Field from soundevent.data import PathLike from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.preprocess.audio import AudioTransform +from batdetect2.preprocess.spectrogram import ( + FrequencyConfig, + PcenConfig, + ResizeConfig, + SpectralMeanSubstractionConfig, + SpectrogramTransform, + STFTConfig, +) __all__ = [ "load_preprocessing_config", - "CenterAudioConfig", - "ScaleAudioConfig", - "FixDurationConfig", - "ResampleConfig", "AudioTransform", - "AudioConfig", - "STFTConfig", - "FrequencyConfig", - "PcenConfig", - "ScaleAmplitudeConfig", - "SpectralMeanSubstractionConfig", - "ResizeConfig", - "PeakNormalizeConfig", - "SpectrogramTransform", - "SpectrogramConfig", "PreprocessingConfig", - "TARGET_SAMPLERATE_HZ", - "MIN_FREQ", - "MAX_FREQ", ] -TARGET_SAMPLERATE_HZ = 256_000 -"""Default target sample rate in Hz used if resampling is enabled.""" - -MIN_FREQ = 10_000 -"""Default minimum frequency (Hz) for spectrogram frequency cropping.""" - -MAX_FREQ = 120_000 -"""Default maximum frequency (Hz) for spectrogram frequency cropping.""" - - -class CenterAudioConfig(BaseConfig): - name: Literal["center_audio"] = "center_audio" - - -class ScaleAudioConfig(BaseConfig): - name: Literal["scale_audio"] = "scale_audio" - - -class FixDurationConfig(BaseConfig): - name: Literal["fix_duration"] = "fix_duration" - duration: float = 0.5 - - -class ResampleConfig(BaseConfig): - """Configuration for audio resampling. - - Attributes - ---------- - samplerate : int, default=256000 - The target sample rate in Hz to resample the audio to. Must be > 0. - method : str, default="poly" - The resampling algorithm to use. Options: - - "poly": Polyphase resampling using `scipy.signal.resample_poly`. - Generally fast. - - "fourier": Resampling via Fourier method using - `scipy.signal.resample`. May handle non-integer - resampling factors differently. - """ - - enabled: bool = True - method: str = "poly" - - -AudioTransform = Annotated[ - Union[ - FixDurationConfig, - ScaleAudioConfig, - CenterAudioConfig, - ], - Field(discriminator="name"), -] - - -class AudioConfig(BaseConfig): - """Configuration for loading and initial audio preprocessing.""" - - samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) - resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) - transforms: List[AudioTransform] = Field(default_factory=list) - - -class STFTConfig(BaseConfig): - """Configuration for the Short-Time Fourier Transform (STFT). - - Attributes - ---------- - window_duration : float, default=0.002 - Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be - > 0. Determines frequency resolution (longer window = finer frequency - resolution). - window_overlap : float, default=0.75 - Fraction of overlap between consecutive STFT windows (e.g., 0.75 - for 75%). Must be >= 0 and < 1. Determines time resolution - (higher overlap = finer time resolution). - window_fn : str, default="hann" - Name of the window function to apply before FFT calculation. Common - options include "hann", "hamming", "blackman". See - `scipy.signal.get_window`. - """ - - window_duration: float = Field(default=0.002, gt=0) - window_overlap: float = Field(default=0.75, ge=0, lt=1) - window_fn: str = "hann" - - -class FrequencyConfig(BaseConfig): - """Configuration for frequency axis parameters. - - Attributes - ---------- - max_freq : int, default=120000 - Maximum frequency in Hz to retain in the spectrogram after STFT. - Frequencies above this value will be cropped. Must be > 0. - min_freq : int, default=10000 - Minimum frequency in Hz to retain in the spectrogram after STFT. - Frequencies below this value will be cropped. Must be >= 0. - """ - - max_freq: int = Field(default=120_000, ge=0) - min_freq: int = Field(default=10_000, ge=0) - - -class PcenConfig(BaseConfig): - """Configuration for Per-Channel Energy Normalization (PCEN).""" - - name: Literal["pcen"] = "pcen" - time_constant: float = 0.4 - gain: float = 0.98 - bias: float = 2 - power: float = 0.5 - - -class ScaleAmplitudeConfig(BaseConfig): - name: Literal["scale_amplitude"] = "scale_amplitude" - scale: Literal["power", "db"] = "db" - - -class SpectralMeanSubstractionConfig(BaseConfig): - name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction" - - -class ResizeConfig(BaseConfig): - name: Literal["resize_spec"] = "resize_spec" - height: int = 128 - resize_factor: float = 0.5 - - -class PeakNormalizeConfig(BaseConfig): - name: Literal["peak_normalize"] = "peak_normalize" - - -SpectrogramTransform = Annotated[ - Union[ - PcenConfig, - ScaleAmplitudeConfig, - SpectralMeanSubstractionConfig, - PeakNormalizeConfig, - ], - Field(discriminator="name"), -] - - -class SpectrogramConfig(BaseConfig): - stft: STFTConfig = Field(default_factory=STFTConfig) - frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) - size: ResizeConfig = Field(default_factory=ResizeConfig) - transforms: Sequence[SpectrogramTransform] = Field( - default_factory=lambda: [ - PcenConfig(), - SpectralMeanSubstractionConfig(), - ] - ) - class PreprocessingConfig(BaseConfig): """Unified configuration for the audio preprocessing pipeline. @@ -201,8 +39,20 @@ class PreprocessingConfig(BaseConfig): resizing). Defaults to default `SpectrogramConfig` settings if omitted. """ - audio: AudioConfig = Field(default_factory=AudioConfig) - spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) + audio_transforms: List[AudioTransform] = Field(default_factory=list) + + spectrogram_transforms: List[SpectrogramTransform] = Field( + default_factory=lambda: [ + PcenConfig(), + SpectralMeanSubstractionConfig(), + ] + ) + + stft: STFTConfig = Field(default_factory=STFTConfig) + + frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) + + size: ResizeConfig = Field(default_factory=ResizeConfig) def load_preprocessing_config( diff --git a/src/batdetect2/preprocess/preprocessor.py b/src/batdetect2/preprocess/preprocessor.py index e2ef27d..ccd0f46 100644 --- a/src/batdetect2/preprocess/preprocessor.py +++ b/src/batdetect2/preprocess/preprocessor.py @@ -3,21 +3,25 @@ from typing import Optional import torch from loguru import logger -from batdetect2.preprocess.audio import build_audio_pipeline +from batdetect2.audio import TARGET_SAMPLERATE_HZ +from batdetect2.preprocess.audio import build_audio_transform from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.preprocess.spectrogram import ( _spec_params_from_config, - build_spectrogram_pipeline, + build_spectrogram_builder, + build_spectrogram_crop, + build_spectrogram_resizer, + build_spectrogram_transform, ) -from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline +from batdetect2.typing import PreprocessorProtocol __all__ = [ - "StandardPreprocessor", + "Preprocessor", "build_preprocessor", ] -class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): +class Preprocessor(torch.nn.Module, PreprocessorProtocol): """Standard implementation of the `Preprocessor` protocol.""" input_samplerate: int @@ -28,37 +32,78 @@ class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): def __init__( self, - audio_pipeline: torch.nn.Module, - spectrogram_pipeline: SpectrogramPipeline, + config: PreprocessingConfig, input_samplerate: int, - output_samplerate: float, - max_freq: float, - min_freq: float, ) -> None: super().__init__() - self.audio_pipeline = audio_pipeline - self.spectrogram_pipeline = spectrogram_pipeline - self.max_freq = max_freq - self.min_freq = min_freq + self.audio_transforms = torch.nn.Sequential( + *[ + build_audio_transform(step, samplerate=input_samplerate) + for step in config.audio_transforms + ] + ) + + self.spectrogram_transforms = torch.nn.Sequential( + *[ + build_spectrogram_transform(step, samplerate=input_samplerate) + for step in config.spectrogram_transforms + ] + ) + + self.spectrogram_builder = build_spectrogram_builder( + config.stft, + samplerate=input_samplerate, + ) + + self.spectrogram_crop = build_spectrogram_crop( + config.frequencies, + stft=config.stft, + samplerate=input_samplerate, + ) + + self.spectrogram_resizer = build_spectrogram_resizer(config.size) + + self.min_freq = config.frequencies.min_freq + self.max_freq = config.frequencies.max_freq self.input_samplerate = input_samplerate - self.output_samplerate = output_samplerate + self.output_samplerate = compute_output_samplerate( + config, + input_samplerate=input_samplerate, + ) def forward(self, wav: torch.Tensor) -> torch.Tensor: - wav = self.audio_pipeline(wav) - return self.spectrogram_pipeline(wav) + wav = self.audio_transforms(wav) + spec = self.spectrogram_builder(wav) + return self.process_spectrogram(spec) + + def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: + return self.spectrogram_builder(wav) + + def process_audio(self, wav: torch.Tensor) -> torch.Tensor: + return self(wav) + + def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: + spec = self.spectrogram_crop(spec) + spec = self.spectrogram_transforms(spec) + return self.spectrogram_resizer(spec) -def compute_output_samplerate(config: PreprocessingConfig) -> float: - samplerate = config.audio.samplerate - _, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft) - factor = config.spectrogram.size.resize_factor - return samplerate * factor / hop_size +def compute_output_samplerate( + config: PreprocessingConfig, + input_samplerate: int = TARGET_SAMPLERATE_HZ, +) -> float: + _, hop_size = _spec_params_from_config( + config.stft, samplerate=input_samplerate + ) + factor = config.size.resize_factor + return input_samplerate * factor / hop_size def build_preprocessor( config: Optional[PreprocessingConfig] = None, + input_samplerate: int = TARGET_SAMPLERATE_HZ, ) -> PreprocessorProtocol: """Factory function to build the standard preprocessor from configuration.""" config = config or PreprocessingConfig() @@ -66,21 +111,4 @@ def build_preprocessor( "Building preprocessor with config: \n{}", lambda: config.to_yaml_string(), ) - - samplerate = config.audio.samplerate - - min_freq = config.spectrogram.frequencies.min_freq - max_freq = config.spectrogram.frequencies.max_freq - - output_samplerate = compute_output_samplerate(config) - - return StandardPreprocessor( - audio_pipeline=build_audio_pipeline(config.audio), - spectrogram_pipeline=build_spectrogram_pipeline( - samplerate, config.spectrogram - ), - input_samplerate=samplerate, - output_samplerate=output_samplerate, - min_freq=min_freq, - max_freq=max_freq, - ) + return Preprocessor(config=config, input_samplerate=input_samplerate) diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 2fa3938..9b8fa7a 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -1,34 +1,64 @@ """Computes spectrograms from audio waveforms with configurable parameters.""" -from typing import Callable, Optional +from typing import Annotated, Callable, Literal, Optional, Union import numpy as np import torch import torchaudio +from pydantic import Field -from batdetect2.preprocess.common import PeakNormalize -from batdetect2.preprocess.config import ( - ScaleAmplitudeConfig, - SpectrogramConfig, - SpectrogramTransform, - STFTConfig, -) +from batdetect2.audio import TARGET_SAMPLERATE_HZ +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry +from batdetect2.preprocess.common import peak_normalize __all__ = [ + "STFTConfig", + "build_spectrogram_transform", "build_spectrogram_builder", - "build_spectrogram_pipeline", ] +MIN_FREQ = 10_000 +"""Default minimum frequency (Hz) for spectrogram frequency cropping.""" + +MAX_FREQ = 120_000 +"""Default maximum frequency (Hz) for spectrogram frequency cropping.""" + + +class STFTConfig(BaseConfig): + """Configuration for the Short-Time Fourier Transform (STFT). + + Attributes + ---------- + window_duration : float, default=0.002 + Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be + > 0. Determines frequency resolution (longer window = finer frequency + resolution). + window_overlap : float, default=0.75 + Fraction of overlap between consecutive STFT windows (e.g., 0.75 + for 75%). Must be >= 0 and < 1. Determines time resolution + (higher overlap = finer time resolution). + window_fn : str, default="hann" + Name of the window function to apply before FFT calculation. Common + options include "hann", "hamming", "blackman". See + `scipy.signal.get_window`. + """ + + window_duration: float = Field(default=0.002, gt=0) + window_overlap: float = Field(default=0.75, ge=0, lt=1) + window_fn: str = "hann" + + def build_spectrogram_builder( - samplerate: int, - conf: STFTConfig, + config: STFTConfig, + samplerate: int = TARGET_SAMPLERATE_HZ, ) -> torch.nn.Module: - n_fft, hop_length = _spec_params_from_config(samplerate, conf) + n_fft, hop_length = _spec_params_from_config(config, samplerate=samplerate) return torchaudio.transforms.Spectrogram( n_fft=n_fft, hop_length=hop_length, - window_fn=get_spectrogram_window(conf.window_fn), + window_fn=get_spectrogram_window(config.window_fn), center=True, power=1, ) @@ -55,16 +85,19 @@ def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]: ) -def _spec_params_from_config(samplerate: int, conf: STFTConfig): - n_fft = int(samplerate * conf.window_duration) - hop_length = int(n_fft * (1 - conf.window_overlap)) +def _spec_params_from_config( + config: STFTConfig, + samplerate: int = TARGET_SAMPLERATE_HZ, +): + n_fft = int(samplerate * config.window_duration) + hop_length = int(n_fft * (1 - config.window_overlap)) return n_fft, hop_length def _frequency_to_index( freq: float, - samplerate: int, n_fft: int, + samplerate: int = TARGET_SAMPLERATE_HZ, ) -> Optional[int]: alpha = freq * 2 / samplerate height = np.floor(n_fft / 2) + 1 @@ -79,14 +112,49 @@ def _frequency_to_index( return index -class FrequencyClip(torch.nn.Module): +class FrequencyConfig(BaseConfig): + """Configuration for frequency axis parameters. + + Attributes + ---------- + max_freq : int, default=120000 + Maximum frequency in Hz to retain in the spectrogram after STFT. + Frequencies above this value will be cropped. Must be > 0. + min_freq : int, default=10000 + Minimum frequency in Hz to retain in the spectrogram after STFT. + Frequencies below this value will be cropped. Must be >= 0. + """ + + max_freq: int = Field(default=MAX_FREQ, ge=0) + min_freq: int = Field(default=MIN_FREQ, ge=0) + + +class FrequencyCrop(torch.nn.Module): def __init__( self, - low_index: Optional[int] = None, - high_index: Optional[int] = None, + samplerate: int, + n_fft: int, + min_freq: Optional[int] = None, + max_freq: Optional[int] = None, ): super().__init__() + self.n_fft = n_fft + self.samplerate = samplerate + self.min_freq = min_freq + self.max_freq = max_freq + + low_index = None + if min_freq is not None: + low_index = _frequency_to_index( + min_freq, self.samplerate, self.n_fft + ) self.low_index = low_index + + high_index = None + if max_freq is not None: + high_index = _frequency_to_index( + max_freq, self.samplerate, self.n_fft + ) self.high_index = high_index def forward(self, spec: torch.Tensor) -> torch.Tensor: @@ -107,6 +175,72 @@ class FrequencyClip(torch.nn.Module): ) +def build_spectrogram_crop( + config: FrequencyConfig, + stft: Optional[STFTConfig] = None, + samplerate: int = TARGET_SAMPLERATE_HZ, +) -> torch.nn.Module: + stft = stft or STFTConfig() + n_fft, _ = _spec_params_from_config(stft, samplerate=samplerate) + return FrequencyCrop( + samplerate=samplerate, + n_fft=n_fft, + min_freq=config.min_freq, + max_freq=config.max_freq, + ) + + +class ResizeConfig(BaseConfig): + name: Literal["resize_spec"] = "resize_spec" + height: int = 128 + resize_factor: float = 0.5 + + +class ResizeSpec(torch.nn.Module): + def __init__(self, height: int, time_factor: float): + super().__init__() + self.height = height + self.time_factor = time_factor + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + current_length = spec.shape[-1] + target_length = int(self.time_factor * current_length) + + original_ndim = spec.ndim + while spec.ndim < 4: + spec = spec.unsqueeze(0) + + resized = torch.nn.functional.interpolate( + spec, + size=(self.height, target_length), + mode="bilinear", + ) + + while resized.ndim != original_ndim: + resized = resized.squeeze(0) + + return resized + + +def build_spectrogram_resizer(config: ResizeConfig) -> torch.nn.Module: + return ResizeSpec(height=config.height, time_factor=config.resize_factor) + + +spectrogram_transforms: Registry[torch.nn.Module, [int]] = Registry( + "spectrogram_transform" +) + + +class PcenConfig(BaseConfig): + """Configuration for Per-Channel Energy Normalization (PCEN).""" + + name: Literal["pcen"] = "pcen" + time_constant: float = 0.4 + gain: float = 0.98 + bias: float = 2 + power: float = 0.5 + + class PCEN(torch.nn.Module): def __init__( self, @@ -115,7 +249,7 @@ class PCEN(torch.nn.Module): bias: float = 2.0, power: float = 0.5, eps: float = 1e-6, - dtype=torch.float64, + dtype=torch.float32, ): super().__init__() self.smoothing_constant = smoothing_constant @@ -151,6 +285,19 @@ class PCEN(torch.nn.Module): * torch.expm1(self.power * torch.log1p(S * smooth / self.bias)) ).to(spec.dtype) + @classmethod + def from_config(cls, config: PcenConfig, samplerate: int): + smooth = _compute_smoothing_constant(samplerate, config.time_constant) + return cls( + smoothing_constant=smooth, + gain=config.gain, + bias=config.bias, + power=config.power, + ) + + +spectrogram_transforms.register(PcenConfig, PCEN) + def _compute_smoothing_constant( samplerate: int, @@ -164,21 +311,40 @@ def _compute_smoothing_constant( return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) +class ScaleAmplitudeConfig(BaseConfig): + name: Literal["scale_amplitude"] = "scale_amplitude" + scale: Literal["power", "db"] = "db" + + class ToPower(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: return spec**2 -def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module: - if conf.scale == "db": - return torchaudio.transforms.AmplitudeToDB() +_scalers = { + "db": torchaudio.transforms.AmplitudeToDB, + "power": ToPower, +} - if conf.scale == "power": - return ToPower() - raise NotImplementedError( - f"Amplitude scaling {conf.scale} not implemented" - ) +class ScaleAmplitude(torch.nn.Module): + def __init__(self, scale: Literal["power", "db"]): + self.scale = scale + self.scaler = _scalers[scale]() + + def forward(self, spec: torch.Tensor) -> torch.Tensor: + return self.scaler(spec) + + @classmethod + def from_config(cls, config: ScaleAmplitudeConfig, samplerate: int): + return cls(scale=config.scale) + + +spectrogram_transforms.register(ScaleAmplitudeConfig, ScaleAmplitude) + + +class SpectralMeanSubstractionConfig(BaseConfig): + name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction" class SpectralMeanSubstraction(torch.nn.Module): @@ -186,129 +352,49 @@ class SpectralMeanSubstraction(torch.nn.Module): mean = spec.mean(-1, keepdim=True) return (spec - mean).clamp(min=0) + @classmethod + def from_config( + cls, + config: SpectralMeanSubstractionConfig, + samplerate: int, + ): + return cls() -class ResizeSpec(torch.nn.Module): - def __init__(self, height: int, time_factor: float): - super().__init__() - self.height = height - self.time_factor = time_factor +spectrogram_transforms.register( + SpectralMeanSubstractionConfig, + SpectralMeanSubstraction, +) + + +class PeakNormalizeConfig(BaseConfig): + name: Literal["peak_normalize"] = "peak_normalize" + + +class PeakNormalize(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: - current_length = spec.shape[-1] - target_length = int(self.time_factor * current_length) + return peak_normalize(spec) - original_ndim = spec.ndim - while spec.ndim < 4: - spec = spec.unsqueeze(0) - - resized = torch.nn.functional.interpolate( - spec, - size=(self.height, target_length), - mode="bilinear", - ) - - while resized.ndim != original_ndim: - resized = resized.squeeze(0) - - return resized + @classmethod + def from_config(cls, config: PeakNormalizeConfig, samplerate: int): + return cls() -def _build_spectrogram_transform_step( - step: SpectrogramTransform, - samplerate: int, -) -> torch.nn.Module: - if step.name == "pcen": - return PCEN( - smoothing_constant=_compute_smoothing_constant( - samplerate=samplerate, - time_constant=step.time_constant, - ), - gain=step.gain, - bias=step.bias, - power=step.power, - ) +spectrogram_transforms.register(PeakNormalizeConfig, PeakNormalize) - if step.name == "scale_amplitude": - return _build_amplitude_scaler(step) - - if step.name == "spectral_mean_substraction": - return SpectralMeanSubstraction() - - if step.name == "peak_normalize": - return PeakNormalize() - - raise NotImplementedError( - f"Spectrogram preprocessing step {step.name} not implemented" - ) +SpectrogramTransform = Annotated[ + Union[ + PcenConfig, + ScaleAmplitudeConfig, + SpectralMeanSubstractionConfig, + PeakNormalizeConfig, + ], + Field(discriminator="name"), +] def build_spectrogram_transform( + config: SpectrogramTransform, samplerate: int, - conf: SpectrogramConfig, ) -> torch.nn.Module: - return torch.nn.Sequential( - *[ - _build_spectrogram_transform_step(step, samplerate=samplerate) - for step in conf.transforms - ] - ) - - -class SpectrogramPipeline(torch.nn.Module): - def __init__( - self, - spec_builder: torch.nn.Module, - freq_cutter: torch.nn.Module, - transforms: torch.nn.Module, - resizer: torch.nn.Module, - ): - super().__init__() - self.spec_builder = spec_builder - self.freq_cutter = freq_cutter - self.transforms = transforms - self.resizer = resizer - - def forward(self, wav: torch.Tensor) -> torch.Tensor: - spec = self.spec_builder(wav) - spec = self.freq_cutter(spec) - spec = self.transforms(spec) - return self.resizer(spec) - - def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: - return self.spec_builder(wav) - - def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: - return self.freq_cutter(spec) - - def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: - return self.transforms(spec) - - def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: - return self.resizer(spec) - - -def build_spectrogram_pipeline( - samplerate: int, - conf: SpectrogramConfig, -) -> SpectrogramPipeline: - spec_builder = build_spectrogram_builder(samplerate, conf.stft) - n_fft, _ = _spec_params_from_config(samplerate, conf.stft) - cutter = FrequencyClip( - low_index=_frequency_to_index( - conf.frequencies.min_freq, samplerate, n_fft - ), - high_index=_frequency_to_index( - conf.frequencies.max_freq, samplerate, n_fft - ), - ) - transforms = build_spectrogram_transform(samplerate, conf) - resizer = ResizeSpec( - height=conf.size.height, - time_factor=conf.size.resize_factor, - ) - return SpectrogramPipeline( - spec_builder=spec_builder, - freq_cutter=cutter, - transforms=transforms, - resizer=resizer, - ) + return spectrogram_transforms.build(config, samplerate) diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index fae0507..8a09b66 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -1,17 +1,6 @@ """BatDetect2 Target Definition system.""" -from collections import Counter -from typing import Iterable, List, Optional, Tuple - -from loguru import logger -from pydantic import Field, field_validator -from soundevent import data - -from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.data.conditions import build_sound_event_condition from batdetect2.targets.classes import ( - DEFAULT_CLASSES, - DEFAULT_DETECTION_CLASS, SoundEventDecoder, SoundEventEncoder, TargetClassConfig, @@ -19,23 +8,29 @@ from batdetect2.targets.classes import ( build_sound_event_encoder, get_class_names_from_config, ) +from batdetect2.targets.config import TargetConfig, load_target_config from batdetect2.targets.rois import ( AnchorBBoxMapperConfig, ROIMapperConfig, ROITargetMapper, build_roi_mapper, ) +from batdetect2.targets.targets import ( + Targets, + build_targets, + iterate_encoded_sound_events, + load_targets, +) from batdetect2.targets.terms import ( call_type, data_source, generic_class, individual, ) -from batdetect2.typing.targets import Position, Size, TargetProtocol __all__ = [ "AnchorBBoxMapperConfig", - "DEFAULT_TARGET_CONFIG", + "ROIMapperConfig", "ROITargetMapper", "SoundEventDecoder", "SoundEventEncoder", @@ -45,365 +40,13 @@ __all__ = [ "build_roi_mapper", "build_sound_event_decoder", "build_sound_event_encoder", + "build_targets", "call_type", "data_source", "generic_class", "get_class_names_from_config", "individual", + "iterate_encoded_sound_events", "load_target_config", + "load_targets", ] - - -class TargetConfig(BaseConfig): - detection_target: TargetClassConfig = Field( - default=DEFAULT_DETECTION_CLASS - ) - - classification_targets: List[TargetClassConfig] = Field( - default_factory=lambda: DEFAULT_CLASSES - ) - - roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig) - - @field_validator("classification_targets") - def check_unique_class_names(cls, v: List[TargetClassConfig]): - """Ensure all defined class names are unique.""" - names = [c.name for c in v] - - if len(names) != len(set(names)): - name_counts = Counter(names) - duplicates = [ - name for name, count in name_counts.items() if count > 1 - ] - raise ValueError( - "Class names must be unique. Found duplicates: " - f"{', '.join(duplicates)}" - ) - return v - - -def load_target_config( - path: data.PathLike, - field: Optional[str] = None, -) -> TargetConfig: - """Load the unified target configuration from a file. - - Reads a configuration file (typically YAML) and validates it against the - `TargetConfig` schema, potentially extracting data from a nested field. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file. - field : str, optional - Dot-separated path to a nested section within the file containing the - target configuration. If None, the entire file content is used. - - Returns - ------- - TargetConfig - The loaded and validated unified target configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - yaml.YAMLError - If the file content is not valid YAML. - pydantic.ValidationError - If the loaded configuration data does not conform to the - `TargetConfig` schema (including validation within nested configs - like `ClassesConfig`). - KeyError, TypeError - If `field` specifies an invalid path within the loaded data. - """ - return load_config(path=path, schema=TargetConfig, field=field) - - -class Targets(TargetProtocol): - """Encapsulates the complete configured target definition pipeline. - - This class implements the `TargetProtocol`, holding the configured - functions for filtering, transforming, encoding (tags to class name), - decoding (class name to tags), and mapping ROIs (geometry to position/size - and back). It provides a high-level interface to apply these steps and - access relevant metadata like class names and dimension names. - - Instances are typically created using the `build_targets` factory function - or the `load_targets` convenience loader. - - Attributes - ---------- - class_names : List[str] - An ordered list of the unique names of the specific target classes - defined in the configuration. - generic_class_tags : List[data.Tag] - A list of `soundevent.data.Tag` objects representing the configured - generic class category (used when no specific class matches). - dimension_names : List[str] - The names of the size dimensions handled by the ROI mapper - (e.g., ['width', 'height']). - """ - - class_names: List[str] - detection_class_tags: List[data.Tag] - dimension_names: List[str] - detection_class_name: str - - def __init__(self, config: TargetConfig): - """Initialize the Targets object.""" - self.config = config - - self._filter_fn = build_sound_event_condition( - config.detection_target.match_if - ) - self._encode_fn = build_sound_event_encoder( - config.classification_targets - ) - self._decode_fn = build_sound_event_decoder( - config.classification_targets - ) - - self._roi_mapper = build_roi_mapper(config.roi) - - self.dimension_names = self._roi_mapper.dimension_names - - self.class_names = get_class_names_from_config( - config.classification_targets - ) - - self.detection_class_name = config.detection_target.name - self.detection_class_tags = config.detection_target.assign_tags - - self._roi_mapper_overrides = { - class_config.name: build_roi_mapper(class_config.roi) - for class_config in config.classification_targets - if class_config.roi is not None - } - - for class_name in self._roi_mapper_overrides: - if class_name not in self.class_names: - # TODO: improve this warning - logger.warning( - "The ROI mapper overrides contains a class ({class_name}) " - "not present in the class names.", - class_name=class_name, - ) - - def filter(self, sound_event: data.SoundEventAnnotation) -> bool: - """Apply the configured filter to a sound event annotation. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation to filter. - - Returns - ------- - bool - True if the annotation should be kept (passes the filter), - False otherwise. If no filter was configured, always returns True. - """ - return self._filter_fn(sound_event) - - def encode_class( - self, sound_event: data.SoundEventAnnotation - ) -> Optional[str]: - """Encode a sound event annotation to its target class name. - - Applies the configured class definition rules (including priority) - to determine the specific class name for the annotation. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation to encode. Note: This should typically be called - *after* applying any transformations via the `transform` method. - - Returns - ------- - str or None - The name of the matched target class, or None if the annotation - does not match any specific class rule (i.e., it belongs to the - generic category). - """ - return self._encode_fn(sound_event) - - def decode_class(self, class_label: str) -> List[data.Tag]: - """Decode a predicted class name back into representative tags. - - Uses the configured mapping (based on `TargetClass.output_tags` or - `TargetClass.tags`) to convert a class name string into a list of - `soundevent.data.Tag` objects. - - Parameters - ---------- - class_label : str - The class name to decode. - - Returns - ------- - List[data.Tag] - The list of tags corresponding to the input class name. - """ - return self._decode_fn(class_label) - - def encode_roi( - self, sound_event: data.SoundEventAnnotation - ) -> tuple[Position, Size]: - """Extract the target reference position from the annotation's roi. - - Delegates to the internal ROI mapper's `get_roi_position` method. - - Parameters - ---------- - sound_event : data.SoundEventAnnotation - The annotation containing the geometry (ROI). - - Returns - ------- - Tuple[float, float] - The reference position `(time, frequency)`. - - Raises - ------ - ValueError - If the annotation lacks geometry. - """ - class_name = self.encode_class(sound_event) - - if class_name in self._roi_mapper_overrides: - return self._roi_mapper_overrides[class_name].encode( - sound_event.sound_event - ) - - return self._roi_mapper.encode(sound_event.sound_event) - - def decode_roi( - self, - position: Position, - size: Size, - class_name: Optional[str] = None, - ) -> data.Geometry: - """Recover an approximate geometric ROI from a position and dimensions. - - Delegates to the internal ROI mapper's `recover_roi` method, which - un-scales the dimensions and reconstructs the geometry (typically a - `BoundingBox`). - - Parameters - ---------- - pos : Tuple[float, float] - The reference position `(time, frequency)`. - dims : np.ndarray - NumPy array with size dimensions (e.g., from model prediction), - matching the order in `self.dimension_names`. - - Returns - ------- - data.Geometry - The reconstructed geometry (typically `BoundingBox`). - """ - if class_name in self._roi_mapper_overrides: - return self._roi_mapper_overrides[class_name].decode( - position, - size, - ) - - return self._roi_mapper.decode(position, size) - - -DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( - classification_targets=DEFAULT_CLASSES, - detection_target=DEFAULT_DETECTION_CLASS, - roi=AnchorBBoxMapperConfig(), -) - - -def build_targets(config: Optional[TargetConfig] = None) -> Targets: - """Build a Targets object from a loaded TargetConfig. - - Parameters - ---------- - config : TargetConfig - The loaded and validated unified target configuration object. - - Returns - ------- - Targets - An initialized `Targets` object ready for use. - - Raises - ------ - KeyError - If term keys or derivation function keys specified in the `config` - are not found in their respective registries. - ImportError, AttributeError, TypeError - If dynamic import of a derivation function fails (when configured). - """ - config = config or DEFAULT_TARGET_CONFIG - logger.opt(lazy=True).debug( - "Building targets with config: \n{}", - lambda: config.to_yaml_string(), - ) - - return Targets(config=config) - - -def load_targets( - config_path: data.PathLike, - field: Optional[str] = None, -) -> Targets: - """Load a Targets object directly from a configuration file. - - This convenience factory method loads the `TargetConfig` from the - specified file path and then calls `Targets.from_config` to build - the fully initialized `Targets` object. - - Parameters - ---------- - config_path : data.PathLike - Path to the configuration file (e.g., YAML). - field : str, optional - Dot-separated path to a nested section within the file containing - the target configuration. If None, the entire file content is used. - - Returns - ------- - Targets - An initialized `Targets` object ready for use. - - Raises - ------ - FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError, - TypeError - Errors raised during file loading, validation, or extraction via - `load_target_config`. - KeyError, ImportError, AttributeError, TypeError - Errors raised during the build process by `Targets.from_config` - (e.g., missing keys in registries, failed imports). - """ - config = load_target_config( - config_path, - field=field, - ) - return build_targets(config) - - -def iterate_encoded_sound_events( - sound_events: Iterable[data.SoundEventAnnotation], - targets: TargetProtocol, -) -> Iterable[Tuple[Optional[str], Position, Size]]: - for sound_event in sound_events: - if not targets.filter(sound_event): - continue - - geometry = sound_event.sound_event.geometry - - if geometry is None: - continue - - class_name = targets.encode_class(sound_event) - position, size = targets.encode_roi(sound_event) - - yield class_name, position, size diff --git a/src/batdetect2/targets/config.py b/src/batdetect2/targets/config.py new file mode 100644 index 0000000..73207d3 --- /dev/null +++ b/src/batdetect2/targets/config.py @@ -0,0 +1,84 @@ +from collections import Counter +from typing import List, Optional + +from pydantic import Field, field_validator +from soundevent import data + +from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.targets.classes import ( + DEFAULT_CLASSES, + DEFAULT_DETECTION_CLASS, + TargetClassConfig, +) +from batdetect2.targets.rois import AnchorBBoxMapperConfig, ROIMapperConfig + +__all__ = [ + "TargetConfig", + "load_target_config", +] + + +class TargetConfig(BaseConfig): + detection_target: TargetClassConfig = Field( + default=DEFAULT_DETECTION_CLASS + ) + + classification_targets: List[TargetClassConfig] = Field( + default_factory=lambda: DEFAULT_CLASSES + ) + + roi: ROIMapperConfig = Field(default_factory=AnchorBBoxMapperConfig) + + @field_validator("classification_targets") + def check_unique_class_names(cls, v: List[TargetClassConfig]): + """Ensure all defined class names are unique.""" + names = [c.name for c in v] + + if len(names) != len(set(names)): + name_counts = Counter(names) + duplicates = [ + name for name, count in name_counts.items() if count > 1 + ] + raise ValueError( + "Class names must be unique. Found duplicates: " + f"{', '.join(duplicates)}" + ) + return v + + +def load_target_config( + path: data.PathLike, + field: Optional[str] = None, +) -> TargetConfig: + """Load the unified target configuration from a file. + + Reads a configuration file (typically YAML) and validates it against the + `TargetConfig` schema, potentially extracting data from a nested field. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file. + field : str, optional + Dot-separated path to a nested section within the file containing the + target configuration. If None, the entire file content is used. + + Returns + ------- + TargetConfig + The loaded and validated unified target configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + yaml.YAMLError + If the file content is not valid YAML. + pydantic.ValidationError + If the loaded configuration data does not conform to the + `TargetConfig` schema (including validation within nested configs + like `ClassesConfig`). + KeyError, TypeError + If `field` specifies an invalid path within the loaded data. + """ + return load_config(path=path, schema=TargetConfig, field=field) diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 495c981..ee81e73 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -26,10 +26,10 @@ import numpy as np from pydantic import Field from soundevent import data +from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.core.arrays import spec_to_xarray from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.preprocess.audio import build_audio_loader from batdetect2.typing import ( AudioLoader, Position, @@ -265,6 +265,7 @@ class PeakEnergyBBoxMapperConfig(BaseConfig): """ name: Literal["peak_energy_bbox"] = "peak_energy_bbox" + audio: AudioConfig = Field(default_factory=AudioConfig) preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) @@ -456,8 +457,11 @@ def build_roi_mapper( ) if config.name == "peak_energy_bbox": - preprocessor = build_preprocessor(config.preprocessing) - audio_loader = build_audio_loader(config.preprocessing.audio) + audio_loader = build_audio_loader(config=config.audio) + preprocessor = build_preprocessor( + config.preprocessing, + input_samplerate=audio_loader.samplerate, + ) return PeakEnergyBBoxMapper( preprocessor=preprocessor, audio_loader=audio_loader, diff --git a/src/batdetect2/targets/targets.py b/src/batdetect2/targets/targets.py new file mode 100644 index 0000000..a692025 --- /dev/null +++ b/src/batdetect2/targets/targets.py @@ -0,0 +1,308 @@ +from typing import Iterable, List, Optional, Tuple + +from loguru import logger +from soundevent import data + +from batdetect2.data.conditions import build_sound_event_condition +from batdetect2.targets.classes import ( + DEFAULT_CLASSES, + DEFAULT_DETECTION_CLASS, + build_sound_event_decoder, + build_sound_event_encoder, + get_class_names_from_config, +) +from batdetect2.targets.config import TargetConfig, load_target_config +from batdetect2.targets.rois import ( + AnchorBBoxMapperConfig, + build_roi_mapper, +) +from batdetect2.typing.targets import Position, Size, TargetProtocol + + +class Targets(TargetProtocol): + """Encapsulates the complete configured target definition pipeline. + + This class implements the `TargetProtocol`, holding the configured + functions for filtering, transforming, encoding (tags to class name), + decoding (class name to tags), and mapping ROIs (geometry to position/size + and back). It provides a high-level interface to apply these steps and + access relevant metadata like class names and dimension names. + + Instances are typically created using the `build_targets` factory function + or the `load_targets` convenience loader. + + Attributes + ---------- + class_names : List[str] + An ordered list of the unique names of the specific target classes + defined in the configuration. + generic_class_tags : List[data.Tag] + A list of `soundevent.data.Tag` objects representing the configured + generic class category (used when no specific class matches). + dimension_names : List[str] + The names of the size dimensions handled by the ROI mapper + (e.g., ['width', 'height']). + """ + + class_names: List[str] + detection_class_tags: List[data.Tag] + dimension_names: List[str] + detection_class_name: str + + def __init__(self, config: TargetConfig): + """Initialize the Targets object.""" + self.config = config + + self._filter_fn = build_sound_event_condition( + config.detection_target.match_if + ) + self._encode_fn = build_sound_event_encoder( + config.classification_targets + ) + self._decode_fn = build_sound_event_decoder( + config.classification_targets + ) + + self._roi_mapper = build_roi_mapper(config.roi) + + self.dimension_names = self._roi_mapper.dimension_names + + self.class_names = get_class_names_from_config( + config.classification_targets + ) + + self.detection_class_name = config.detection_target.name + self.detection_class_tags = config.detection_target.assign_tags + + self._roi_mapper_overrides = { + class_config.name: build_roi_mapper(class_config.roi) + for class_config in config.classification_targets + if class_config.roi is not None + } + + for class_name in self._roi_mapper_overrides: + if class_name not in self.class_names: + # TODO: improve this warning + logger.warning( + "The ROI mapper overrides contains a class ({class_name}) " + "not present in the class names.", + class_name=class_name, + ) + + def filter(self, sound_event: data.SoundEventAnnotation) -> bool: + """Apply the configured filter to a sound event annotation. + + Parameters + ---------- + sound_event : data.SoundEventAnnotation + The annotation to filter. + + Returns + ------- + bool + True if the annotation should be kept (passes the filter), + False otherwise. If no filter was configured, always returns True. + """ + return self._filter_fn(sound_event) + + def encode_class( + self, sound_event: data.SoundEventAnnotation + ) -> Optional[str]: + """Encode a sound event annotation to its target class name. + + Applies the configured class definition rules (including priority) + to determine the specific class name for the annotation. + + Parameters + ---------- + sound_event : data.SoundEventAnnotation + The annotation to encode. Note: This should typically be called + *after* applying any transformations via the `transform` method. + + Returns + ------- + str or None + The name of the matched target class, or None if the annotation + does not match any specific class rule (i.e., it belongs to the + generic category). + """ + return self._encode_fn(sound_event) + + def decode_class(self, class_label: str) -> List[data.Tag]: + """Decode a predicted class name back into representative tags. + + Uses the configured mapping (based on `TargetClass.output_tags` or + `TargetClass.tags`) to convert a class name string into a list of + `soundevent.data.Tag` objects. + + Parameters + ---------- + class_label : str + The class name to decode. + + Returns + ------- + List[data.Tag] + The list of tags corresponding to the input class name. + """ + return self._decode_fn(class_label) + + def encode_roi( + self, sound_event: data.SoundEventAnnotation + ) -> tuple[Position, Size]: + """Extract the target reference position from the annotation's roi. + + Delegates to the internal ROI mapper's `get_roi_position` method. + + Parameters + ---------- + sound_event : data.SoundEventAnnotation + The annotation containing the geometry (ROI). + + Returns + ------- + Tuple[float, float] + The reference position `(time, frequency)`. + + Raises + ------ + ValueError + If the annotation lacks geometry. + """ + class_name = self.encode_class(sound_event) + + if class_name in self._roi_mapper_overrides: + return self._roi_mapper_overrides[class_name].encode( + sound_event.sound_event + ) + + return self._roi_mapper.encode(sound_event.sound_event) + + def decode_roi( + self, + position: Position, + size: Size, + class_name: Optional[str] = None, + ) -> data.Geometry: + """Recover an approximate geometric ROI from a position and dimensions. + + Delegates to the internal ROI mapper's `recover_roi` method, which + un-scales the dimensions and reconstructs the geometry (typically a + `BoundingBox`). + + Parameters + ---------- + pos : Tuple[float, float] + The reference position `(time, frequency)`. + dims : np.ndarray + NumPy array with size dimensions (e.g., from model prediction), + matching the order in `self.dimension_names`. + + Returns + ------- + data.Geometry + The reconstructed geometry (typically `BoundingBox`). + """ + if class_name in self._roi_mapper_overrides: + return self._roi_mapper_overrides[class_name].decode( + position, + size, + ) + + return self._roi_mapper.decode(position, size) + + +DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( + classification_targets=DEFAULT_CLASSES, + detection_target=DEFAULT_DETECTION_CLASS, + roi=AnchorBBoxMapperConfig(), +) + + +def build_targets(config: Optional[TargetConfig] = None) -> Targets: + """Build a Targets object from a loaded TargetConfig. + + Parameters + ---------- + config : TargetConfig + The loaded and validated unified target configuration object. + + Returns + ------- + Targets + An initialized `Targets` object ready for use. + + Raises + ------ + KeyError + If term keys or derivation function keys specified in the `config` + are not found in their respective registries. + ImportError, AttributeError, TypeError + If dynamic import of a derivation function fails (when configured). + """ + config = config or DEFAULT_TARGET_CONFIG + logger.opt(lazy=True).debug( + "Building targets with config: \n{}", + lambda: config.to_yaml_string(), + ) + + return Targets(config=config) + + +def load_targets( + config_path: data.PathLike, + field: Optional[str] = None, +) -> Targets: + """Load a Targets object directly from a configuration file. + + This convenience factory method loads the `TargetConfig` from the + specified file path and then calls `Targets.from_config` to build + the fully initialized `Targets` object. + + Parameters + ---------- + config_path : data.PathLike + Path to the configuration file (e.g., YAML). + field : str, optional + Dot-separated path to a nested section within the file containing + the target configuration. If None, the entire file content is used. + + Returns + ------- + Targets + An initialized `Targets` object ready for use. + + Raises + ------ + FileNotFoundError, yaml.YAMLError, pydantic.ValidationError, KeyError, + TypeError + Errors raised during file loading, validation, or extraction via + `load_target_config`. + KeyError, ImportError, AttributeError, TypeError + Errors raised during the build process by `Targets.from_config` + (e.g., missing keys in registries, failed imports). + """ + config = load_target_config( + config_path, + field=field, + ) + return build_targets(config) + + +def iterate_encoded_sound_events( + sound_events: Iterable[data.SoundEventAnnotation], + targets: TargetProtocol, +) -> Iterable[Tuple[Optional[str], Position, Size]]: + for sound_event in sound_events: + if not targets.filter(sound_event): + continue + + geometry = sound_event.sound_event.geometry + + if geometry is None: + continue + + class_name = targets.encode_class(sound_event) + position, size = targets.encode_roi(sound_event) + + yield class_name, position, size diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index e30bc25..7a7207a 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -16,10 +16,8 @@ from batdetect2.train.augmentations import ( ) from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.config import ( - FullTrainingConfig, PLTrainerConfig, TrainingConfig, - load_full_training_config, load_train_config, ) from batdetect2.train.dataset import ( @@ -48,7 +46,6 @@ __all__ = [ "DetectionLossConfig", "EchoAugmentationConfig", "FrequencyMaskAugmentationConfig", - "FullTrainingConfig", "LossConfig", "LossFunction", "PLTrainerConfig", @@ -71,7 +68,6 @@ __all__ = [ "build_trainer", "build_val_dataset", "build_val_loader", - "load_full_training_config", "load_label_config", "load_train_config", "mask_frequency", diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 66aa2b5..a5ec359 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -4,8 +4,6 @@ from pydantic import Field from soundevent import data from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.evaluate import EvaluationConfig -from batdetect2.models import ModelConfig from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, @@ -22,8 +20,6 @@ from batdetect2.train.losses import LossConfig __all__ = [ "TrainingConfig", "load_train_config", - "FullTrainingConfig", - "load_full_training_config", ] @@ -93,18 +89,3 @@ def load_train_config( field: Optional[str] = None, ) -> TrainingConfig: return load_config(path, schema=TrainingConfig, field=field) - - -class FullTrainingConfig(ModelConfig): - """Full training configuration.""" - - train: TrainingConfig = Field(default_factory=TrainingConfig) - evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig) - - -def load_full_training_config( - path: data.PathLike, - field: Optional[str] = None, -) -> FullTrainingConfig: - """Load the full training configuration.""" - return load_config(path, schema=FullTrainingConfig, field=field) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 03fb8b3..d1eeb3c 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -5,8 +5,9 @@ from loguru import logger from soundevent import data from torch.utils.data import DataLoader, Dataset +from batdetect2.audio import build_audio_loader from batdetect2.core.arrays import adjust_width -from batdetect2.preprocess import build_audio_loader, build_preprocessor +from batdetect2.preprocess import build_preprocessor from batdetect2.train.augmentations import ( RandomAudioSource, build_augmentations, diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 8970c0e..2527212 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -1,4 +1,4 @@ -from typing import Optional, Tuple +from typing import TYPE_CHECKING, Optional, Tuple import lightning as L import torch @@ -6,11 +6,17 @@ from soundevent.data import PathLike from torch.optim.adam import Adam from torch.optim.lr_scheduler import CosineAnnealingLR +from batdetect2.audio import TARGET_SAMPLERATE_HZ from batdetect2.models import Model, build_model -from batdetect2.train.config import FullTrainingConfig +from batdetect2.plotting.clips import build_preprocessor +from batdetect2.postprocess import build_postprocessor +from batdetect2.targets.targets import build_targets from batdetect2.train.losses import build_loss from batdetect2.typing import ModelOutput, TrainExample +if TYPE_CHECKING: + from batdetect2.config import BatDetect2Config + __all__ = [ "TrainingModule", ] @@ -21,7 +27,8 @@ class TrainingModule(L.LightningModule): def __init__( self, - config: FullTrainingConfig, + config: "BatDetect2Config", + input_samplerate: int = TARGET_SAMPLERATE_HZ, learning_rate: float = 0.001, t_max: int = 100, model: Optional[Model] = None, @@ -31,6 +38,7 @@ class TrainingModule(L.LightningModule): self.save_hyperparameters(logger=False) + self.input_samplerate = input_samplerate self.config = config self.learning_rate = learning_rate self.t_max = t_max @@ -39,7 +47,23 @@ class TrainingModule(L.LightningModule): loss = build_loss(self.config.train.loss) if model is None: - model = build_model(self.config) + targets = build_targets(self.config.targets) + + preprocessor = build_preprocessor( + config=self.config.preprocess, + input_samplerate=self.input_samplerate, + ) + + postprocessor = build_postprocessor( + preprocessor, config=self.config.postprocess + ) + + model = build_model( + config=self.config.model, + targets=targets, + preprocessor=preprocessor, + postprocessor=postprocessor, + ) self.loss = loss self.model = model @@ -74,16 +98,18 @@ class TrainingModule(L.LightningModule): def load_model_from_checkpoint( path: PathLike, -) -> Tuple[Model, FullTrainingConfig]: +) -> Tuple[Model, "BatDetect2Config"]: module = TrainingModule.load_from_checkpoint(path) # type: ignore return module.model, module.config def build_training_module( - config: Optional[FullTrainingConfig] = None, + config: Optional["BatDetect2Config"] = None, t_max: int = 200, ) -> TrainingModule: - config = config or FullTrainingConfig() + from batdetect2.config import BatDetect2Config + + config = config or BatDetect2Config() return TrainingModule( config=config, learning_rate=config.train.optimizer.learning_rate, diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index a071db2..dcdfff0 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,29 +1,31 @@ from collections.abc import Sequence from pathlib import Path -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional from lightning import Trainer, seed_everything from lightning.pytorch.callbacks import Callback, ModelCheckpoint from loguru import logger from soundevent import data -from batdetect2.evaluate.evaluator import build_evaluator -from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader +from batdetect2.audio import build_audio_loader +from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets from batdetect2.train.callbacks import ValidationMetrics -from batdetect2.train.config import ( - FullTrainingConfig, -) +from batdetect2.train.config import TrainingConfig from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import TrainingModule, build_training_module from batdetect2.train.logging import build_logger -from batdetect2.typing import ( - TargetProtocol, -) -from batdetect2.typing.preprocess import AudioLoader -from batdetect2.typing.train import ClipLabeller + +if TYPE_CHECKING: + from batdetect2.config import BatDetect2Config + from batdetect2.typing import ( + AudioLoader, + ClipLabeller, + PreprocessorProtocol, + TargetProtocol, + ) __all__ = [ "build_trainer", @@ -36,12 +38,13 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" def train( train_annotations: Sequence[data.ClipAnnotation], val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, + evaluator: Optional[Evaluator] = None, trainer: Optional[Trainer] = None, - targets: Optional[TargetProtocol] = None, - preprocessor: Optional[PreprocessorProtocol] = None, - audio_loader: Optional[AudioLoader] = None, - labeller: Optional[ClipLabeller] = None, - config: Optional[FullTrainingConfig] = None, + targets: Optional["TargetProtocol"] = None, + preprocessor: Optional["PreprocessorProtocol"] = None, + audio_loader: Optional["AudioLoader"] = None, + labeller: Optional["ClipLabeller"] = None, + config: Optional["BatDetect2Config"] = None, model_path: Optional[data.PathLike] = None, train_workers: Optional[int] = None, val_workers: Optional[int] = None, @@ -51,17 +54,20 @@ def train( run_name: Optional[str] = None, seed: Optional[int] = None, ): + from batdetect2.config import BatDetect2Config + if seed is not None: seed_everything(seed) - config = config or FullTrainingConfig() + config = config or BatDetect2Config() - targets = targets or build_targets(config.targets) + targets = targets or build_targets(config=config.targets) - preprocessor = preprocessor or build_preprocessor(config.preprocess) + audio_loader = audio_loader or build_audio_loader(config=config.audio) - audio_loader = audio_loader or build_audio_loader( - config=config.preprocess.audio + preprocessor = preprocessor or build_preprocessor( + input_samplerate=audio_loader.samplerate, + config=config.preprocess, ) labeller = labeller or build_clip_labeler( @@ -95,7 +101,7 @@ def train( if model_path is not None: logger.debug("Loading model from: {path}", path=model_path) - module = TrainingModule.load_from_checkpoint(model_path) # type: ignore + module = TrainingModule.load_from_checkpoint(Path(model_path)) else: module = build_training_module( config, @@ -103,8 +109,9 @@ def train( ) trainer = trainer or build_trainer( - config, + config.train, targets=targets, + evaluator=evaluator, checkpoint_dir=checkpoint_dir, log_dir=log_dir, experiment_name=experiment_name, @@ -121,8 +128,8 @@ def train( def build_trainer_callbacks( - targets: TargetProtocol, - config: FullTrainingConfig, + targets: "TargetProtocol", + evaluator: Optional[Evaluator] = None, checkpoint_dir: Optional[Path] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, @@ -136,7 +143,7 @@ def build_trainer_callbacks( if run_name is not None: checkpoint_dir = checkpoint_dir / run_name - evaluator = build_evaluator(config=config.evaluation, targets=targets) + evaluator = evaluator or build_evaluator(targets=targets) return [ ModelCheckpoint( @@ -150,20 +157,21 @@ def build_trainer_callbacks( def build_trainer( - conf: FullTrainingConfig, - targets: TargetProtocol, + conf: TrainingConfig, + targets: "TargetProtocol", + evaluator: Optional[Evaluator] = None, checkpoint_dir: Optional[Path] = None, log_dir: Optional[Path] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, ) -> Trainer: - trainer_conf = conf.train.trainer + trainer_conf = conf.trainer logger.opt(lazy=True).debug( "Building trainer with config: \n{config}", config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) train_logger = build_logger( - conf.train.logger, + conf.logger, log_dir=log_dir, experiment_name=experiment_name, run_name=run_name, @@ -181,7 +189,7 @@ def build_trainer( logger=train_logger, callbacks=build_trainer_callbacks( targets, - config=conf, + evaluator=evaluator, checkpoint_dir=checkpoint_dir, experiment_name=experiment_name, run_name=run_name, diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py index c51b4e7..5395a38 100644 --- a/src/batdetect2/typing/__init__.py +++ b/src/batdetect2/typing/__init__.py @@ -13,8 +13,6 @@ from batdetect2.typing.postprocess import ( from batdetect2.typing.preprocess import ( AudioLoader, PreprocessorProtocol, - SpectrogramBuilder, - SpectrogramPipeline, ) from batdetect2.typing.targets import ( Position, @@ -60,8 +58,6 @@ __all__ = [ "SoundEventDecoder", "SoundEventEncoder", "SoundEventFilter", - "SpectrogramBuilder", - "SpectrogramPipeline", "TargetProtocol", "TrainExample", ] diff --git a/src/batdetect2/typing/preprocess.py b/src/batdetect2/typing/preprocess.py index 31e7603..1e660f3 100644 --- a/src/batdetect2/typing/preprocess.py +++ b/src/batdetect2/typing/preprocess.py @@ -32,6 +32,8 @@ class AudioLoader(Protocol): allows for different loading strategies or implementations. """ + samplerate: int + def load_file( self, path: data.PathLike, @@ -125,22 +127,6 @@ class SpectrogramBuilder(Protocol): ... -class AudioPipeline(Protocol): - def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... - - -class SpectrogramPipeline(Protocol): - def compute_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ... - - def select_frequencies(self, spec: torch.Tensor) -> torch.Tensor: ... - - def transform_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ... - - def resize_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ... - - def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... - - class PreprocessorProtocol(Protocol): """Defines a high-level interface for the complete preprocessing pipeline.""" @@ -152,11 +138,13 @@ class PreprocessorProtocol(Protocol): output_samplerate: float - audio_pipeline: AudioPipeline - - spectrogram_pipeline: SpectrogramPipeline - def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... + def generate_spectrogram(self, wav: torch.Tensor) -> torch.Tensor: ... + + def process_audio(self, wav: torch.Tensor) -> torch.Tensor: ... + + def process_spectrogram(self, spec: torch.Tensor) -> torch.Tensor: ... + def process_numpy(self, wav: np.ndarray) -> np.ndarray: return self(torch.tensor(wav)).numpy()