Restructuring

This commit is contained in:
mbsantiago 2025-09-16 13:38:38 +01:00
parent 60e922d565
commit 7d6cba5465
46 changed files with 474 additions and 463 deletions

View File

@ -1,10 +1,15 @@
import os
import click
from batdetect2 import api
from batdetect2.cli.base import cli
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ProcessingConfiguration
from batdetect2.utils.detector_utils import save_results_to_file
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)
@cli.command()
@ -74,6 +79,9 @@ def detect(
Input files should be short in duration e.g. < 30 seconds.
"""
from batdetect2 import api
from batdetect2.utils.detector_utils import save_results_to_file
click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"])
@ -123,7 +131,7 @@ def detect(
click.echo(f" {err}")
def print_config(config: ProcessingConfiguration):
def print_config(config):
"""Print the processing configuration."""
click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")

View File

@ -4,7 +4,6 @@ from typing import Optional
import click
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
__all__ = ["data"]
@ -33,6 +32,8 @@ def summary(
field: Optional[str] = None,
base_dir: Optional[Path] = None,
):
from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config(
dataset_config,

View File

@ -6,9 +6,6 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
__all__ = ["evaluate_command"]
@ -31,6 +28,10 @@ def evaluate_command(
workers: Optional[int] = None,
verbose: int = 0,
):
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
logger.remove()
if verbose == 0:
log_level = "WARNING"

View File

@ -6,13 +6,6 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
__all__ = ["train_command"]
@ -53,6 +46,14 @@ def train_command(
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
logger.remove()
if verbose == 0:
log_level = "WARNING"

View File

@ -11,7 +11,6 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper
from batdetect2.targets.terms import get_term_from_key
from batdetect2.types import (
Annotation,
AudioLoaderAnnotationGroup,
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
tags=[
data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
data.Tag(
term=get_term_from_key(individual_key),
value=str(annotation["individual"]),
),
data.Tag(key=label_key, value=annotation["class"]),
data.Tag(key=event_key, value=annotation["event"]),
data.Tag(key=individual_key, value=str(annotation["individual"])),
],
)
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
tags=[
data.PredictedTag(
score=annotation["class_prob"],
tag=data.Tag(
term=get_term_from_key(label_key),
value=annotation["class"],
),
tag=data.Tag(key=label_key, value=annotation["class"]),
),
data.PredictedTag(
score=annotation["det_prob"],
tag=data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
tag=data.Tag(key=event_key, value=annotation["event"]),
),
],
)

16
src/batdetect2/config.py Normal file
View File

@ -0,0 +1,16 @@
from typing import Literal
from batdetect2.core import BaseConfig
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models.backbones import BackboneConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.train.config import TrainingConfig
class BatDetect2Config(BaseConfig):
config_version: Literal["v1"] = "v1"
train: TrainingConfig
evaluation: EvaluationConfig
model: BackboneConfig
preprocess: PreprocessingConfig

View File

@ -0,0 +1,8 @@
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
__all__ = [
"BaseConfig",
"load_config",
"Registry",
]

View File

@ -1,7 +1,12 @@
import sys
from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel
from typing_extensions import ParamSpec
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
__all__ = [
"Registry",

View File

@ -18,7 +18,7 @@ from uuid import uuid5
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [

View File

@ -33,7 +33,7 @@ from loguru import logger
from pydantic import Field, ValidationError
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.legacy import (
FileAnnotation,
file_annotation_to_clip,

View File

@ -1,6 +1,6 @@
from pathlib import Path
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
__all__ = [
"AnnotatedDataset",

View File

@ -5,8 +5,8 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]

View File

@ -25,7 +25,7 @@ from loguru import logger
from pydantic import Field
from soundevent import data, io
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.annotations import (
AnnotatedDataset,
AnnotationFormats,

View File

@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.data.conditions import (
SoundEventCondition,
SoundEventConditionConfig,

View File

@ -4,8 +4,8 @@ from pydantic import Field
from soundevent import data
from soundevent.evaluation import compute_affinity
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.typing.evaluate import AffinityFunction
affinity_functions: Registry[AffinityFunction, []] = Registry(

View File

@ -3,7 +3,7 @@ from typing import List, Optional
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
from batdetect2.evaluate.metrics import (
ClassificationAPConfig,

View File

@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.evaluate.affinity import (
AffinityConfig,
GeometricIOUConfig,

View File

@ -15,8 +15,8 @@ from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.typing import MetricsProtocol
from batdetect2.typing.evaluate import ClipEvaluation
@ -31,7 +31,7 @@ AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
implementation: AveragePrecisionImplementation = "pascal_voc"
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
def pascal_voc_average_precision(y_true, y_score) -> float:
@ -96,7 +96,7 @@ class DetectionAP(MetricsProtocol):
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls(implementation=config.implementation)
return cls(implementation=config.ap_implementation)
metrics_registry.register(DetectionAPConfig, DetectionAP)
@ -104,6 +104,7 @@ metrics_registry.register(DetectionAPConfig, DetectionAP)
class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap"
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
@ -193,6 +194,7 @@ class ClassificationAP(MetricsProtocol):
):
return cls(
class_names,
implementation=config.ap_implementation,
include=config.include,
exclude=config.exclude,
)

View File

@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
import pandas as pd
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor

View File

@ -32,7 +32,7 @@ import torch
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,

View File

@ -25,7 +25,7 @@ import torch.nn.functional as F
from soundevent import data
from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,

View File

@ -34,7 +34,7 @@ import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
__all__ = [
"ConvBlock",

View File

@ -20,7 +20,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
SelfAttentionConfig,
VerticalConv,

View File

@ -24,7 +24,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvUpConfig,

View File

@ -26,7 +26,7 @@ import torch
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvDownConfig,

View File

@ -8,8 +8,7 @@ from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
__all__ = [
"plot_matches",

View File

@ -7,7 +7,7 @@ from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,

View File

@ -1,176 +1,21 @@
"""Main entry point for the BatDetect2 Preprocessing subsystem.
"""Main entry point for the BatDetect2 preprocessing subsystem."""
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
for converting raw audio input (from files or data objects) into processed
spectrograms suitable for input to BatDetect2 models. This ensures consistent
data handling between model training and inference.
The preprocessing pipeline consists of two main stages, configured via nested
data structures:
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
processing like resampling, duration adjustment, centering, and scaling.
Configured via `AudioConfig`.
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
the processed waveform using STFT, followed by frequency cropping, optional
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
resizing, and optional peak normalization. Configured via
`SpectrogramConfig`.
This module provides the primary interface:
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
and `SpectrogramConfig`.
- `load_preprocessing_config`: Function to load the unified configuration.
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
instance from a `PreprocessingConfig`.
"""
from typing import Optional
import torch
from loguru import logger
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
DEFAULT_DURATION,
SCALE_RAW_AUDIO,
TARGET_SAMPLERATE_HZ,
AudioConfig,
ResampleConfig,
build_audio_loader,
build_audio_pipeline,
)
from batdetect2.preprocess.spectrogram import (
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.preprocess.config import (
MAX_FREQ,
MIN_FREQ,
FrequencyConfig,
PcenConfig,
SpectrogramConfig,
SpectrogramPipeline,
STFTConfig,
_spec_params_from_config,
build_spectrogram_builder,
build_spectrogram_pipeline,
TARGET_SAMPLERATE_HZ,
PreprocessingConfig,
load_preprocessing_config,
)
from batdetect2.typing import PreprocessorProtocol
from batdetect2.preprocess.preprocessor import build_preprocessor
__all__ = [
"AudioConfig",
"DEFAULT_DURATION",
"FrequencyConfig",
"MAX_FREQ",
"MIN_FREQ",
"PcenConfig",
"PreprocessingConfig",
"ResampleConfig",
"SCALE_RAW_AUDIO",
"STFTConfig",
"SpectrogramConfig",
"MAX_FREQ",
"TARGET_SAMPLERATE_HZ",
"build_audio_loader",
"build_spectrogram_builder",
"PreprocessingConfig",
"load_preprocessing_config",
"build_preprocessor",
"build_audio_loader",
]
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
input_samplerate: int
output_samplerate: float
max_freq: float
min_freq: float
def __init__(
self,
audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline,
input_samplerate: int,
output_samplerate: float,
max_freq: float,
min_freq: float,
) -> None:
super().__init__()
self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline
self.max_freq = max_freq
self.min_freq = min_freq
self.input_samplerate = input_samplerate
self.output_samplerate = output_samplerate
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav)
def compute_output_samplerate(config: PreprocessingConfig) -> float:
samplerate = config.audio.samplerate
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
factor = config.spectrogram.size.resize_factor
return samplerate * factor / hop_size
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
samplerate = config.audio.samplerate
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
output_samplerate = compute_output_samplerate(config)
return StandardPreprocessor(
audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram
),
input_samplerate=samplerate,
output_samplerate=output_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -1,64 +1,34 @@
"""Handles loading and initial preprocessing of audio waveforms."""
from typing import Annotated, List, Literal, Optional, Union
from typing import Optional
import numpy as np
import torch
from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly
from soundevent import audio, data
from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
from batdetect2.preprocess.config import (
TARGET_SAMPLERATE_HZ,
AudioConfig,
AudioTransform,
ResampleConfig,
)
from batdetect2.typing import AudioLoader
__all__ = [
"ResampleConfig",
"AudioConfig",
"SoundEventAudioLoader",
"build_audio_loader",
"load_file_audio",
"load_recording_audio",
"load_clip_audio",
"resample_audio",
"TARGET_SAMPLERATE_HZ",
"SCALE_RAW_AUDIO",
"DEFAULT_DURATION",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
SCALE_RAW_AUDIO = False
"""Default setting for whether to perform peak normalization."""
DEFAULT_DURATION = None
"""Default setting for target audio duration in seconds."""
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
class SoundEventAudioLoader:
class SoundEventAudioLoader(AudioLoader):
"""Concrete implementation of the `AudioLoader`."""
def __init__(
@ -294,19 +264,6 @@ def resample_audio_fourier(
)
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
class FixDuration(torch.nn.Module):
def __init__(self, samplerate: int, duration: float):
super().__init__()
@ -326,24 +283,6 @@ class FixDuration(torch.nn.Module):
return torch.nn.functional.pad(wav, (0, self.length - length))
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"),
]
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
transforms: List[AudioTransform] = Field(default_factory=list)
def build_audio_loader(
config: Optional[AudioConfig] = None,
) -> AudioLoader:

View File

@ -0,0 +1,212 @@
from collections.abc import Sequence
from typing import Annotated, List, Literal, Optional, Union
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.core.configs import BaseConfig, load_config
__all__ = [
"load_preprocessing_config",
"CenterAudioConfig",
"ScaleAudioConfig",
"FixDurationConfig",
"ResampleConfig",
"AudioTransform",
"AudioConfig",
"STFTConfig",
"FrequencyConfig",
"PcenConfig",
"ScaleAmplitudeConfig",
"SpectralMeanSubstractionConfig",
"ResizeConfig",
"PeakNormalizeConfig",
"SpectrogramTransform",
"SpectrogramConfig",
"PreprocessingConfig",
"TARGET_SAMPLERATE_HZ",
"MIN_FREQ",
"MAX_FREQ",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
MIN_FREQ = 10_000
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
MAX_FREQ = 120_000
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"),
]
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
transforms: List[AudioTransform] = Field(default_factory=list)
class STFTConfig(BaseConfig):
"""Configuration for the Short-Time Fourier Transform (STFT).
Attributes
----------
window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
> 0. Determines frequency resolution (longer window = finer frequency
resolution).
window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75
for 75%). Must be >= 0 and < 1. Determines time resolution
(higher overlap = finer time resolution).
window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common
options include "hann", "hamming", "blackman". See
`scipy.signal.get_window`.
"""
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
class FrequencyConfig(BaseConfig):
"""Configuration for frequency axis parameters.
Attributes
----------
max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT.
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class ScaleAmplitudeConfig(BaseConfig):
name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db"
class SpectralMeanSubstractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"),
]
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -0,0 +1,86 @@
from typing import Optional
import torch
from loguru import logger
from batdetect2.preprocess.audio import build_audio_pipeline
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.spectrogram import (
_spec_params_from_config,
build_spectrogram_pipeline,
)
from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline
__all__ = [
"StandardPreprocessor",
"build_preprocessor",
]
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
input_samplerate: int
output_samplerate: float
max_freq: float
min_freq: float
def __init__(
self,
audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline,
input_samplerate: int,
output_samplerate: float,
max_freq: float,
min_freq: float,
) -> None:
super().__init__()
self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline
self.max_freq = max_freq
self.min_freq = min_freq
self.input_samplerate = input_samplerate
self.output_samplerate = output_samplerate
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav)
def compute_output_samplerate(config: PreprocessingConfig) -> float:
samplerate = config.audio.samplerate
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
factor = config.spectrogram.size.resize_factor
return samplerate * factor / hop_size
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
samplerate = config.audio.samplerate
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
output_samplerate = compute_output_samplerate(config)
return StandardPreprocessor(
audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram
),
input_samplerate=samplerate,
output_samplerate=output_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -1,63 +1,37 @@
"""Computes spectrograms from audio waveforms with configurable parameters."""
from typing import (
Annotated,
Callable,
List,
Literal,
Optional,
Sequence,
Union,
)
from typing import Callable, Optional
import numpy as np
import torch
import torchaudio
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.common import PeakNormalize
from batdetect2.preprocess.config import (
ScaleAmplitudeConfig,
SpectrogramConfig,
SpectrogramTransform,
STFTConfig,
)
__all__ = [
"STFTConfig",
"FrequencyConfig",
"PcenConfig",
"SpectrogramConfig",
"build_spectrogram_builder",
"MIN_FREQ",
"MAX_FREQ",
"build_spectrogram_pipeline",
]
MIN_FREQ = 10_000
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
MAX_FREQ = 120_000
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
class STFTConfig(BaseConfig):
"""Configuration for the Short-Time Fourier Transform (STFT).
Attributes
----------
window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
> 0. Determines frequency resolution (longer window = finer frequency
resolution).
window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75
for 75%). Must be >= 0 and < 1. Determines time resolution
(higher overlap = finer time resolution).
window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common
options include "hann", "hamming", "blackman". See
`scipy.signal.get_window`.
"""
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
def build_spectrogram_builder(
samplerate: int,
conf: STFTConfig,
) -> torch.nn.Module:
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
window_fn=get_spectrogram_window(conf.window_fn),
center=True,
power=1,
)
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
@ -87,37 +61,6 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig):
return n_fft, hop_length
def build_spectrogram_builder(
samplerate: int,
conf: STFTConfig,
) -> torch.nn.Module:
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
window_fn=get_spectrogram_window(conf.window_fn),
center=True,
power=1,
)
class FrequencyConfig(BaseConfig):
"""Configuration for frequency axis parameters.
Attributes
----------
max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT.
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
def _frequency_to_index(
freq: float,
samplerate: int,
@ -164,16 +107,6 @@ class FrequencyClip(torch.nn.Module):
)
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class PCEN(torch.nn.Module):
def __init__(
self,
@ -231,11 +164,6 @@ def _compute_smoothing_constant(
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
class ScaleAmplitudeConfig(BaseConfig):
name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db"
class ToPower(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec**2
@ -253,22 +181,12 @@ def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
)
class SpectralMeanSubstractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
class SpectralMeanSubstraction(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0)
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
class ResizeSpec(torch.nn.Module):
def __init__(self, height: int, time_factor: float):
super().__init__()
@ -295,33 +213,6 @@ class ResizeSpec(torch.nn.Module):
return resized
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"),
]
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
def _build_spectrogram_transform_step(
step: SpectrogramTransform,
samplerate: int,

View File

@ -7,7 +7,7 @@ from loguru import logger
from pydantic import Field, field_validator
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.conditions import build_sound_event_condition
from batdetect2.targets.classes import (
DEFAULT_CLASSES,

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional
from pydantic import Field, PrivateAttr, computed_field, model_validator
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data.conditions import (
AllOfConfig,
HasAllTagsConfig,

View File

@ -26,12 +26,17 @@ import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.targets import Position, ROITargetMapper, Size
from batdetect2.utils.arrays import spec_to_xarray
from batdetect2.typing import (
AudioLoader,
Position,
PreprocessorProtocol,
ROITargetMapper,
Size,
)
__all__ = [
"Anchor",

View File

@ -11,11 +11,10 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.train.clips import get_subclip_annotation
from batdetect2.typing import Augmentation
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.utils.arrays import adjust_width
from batdetect2.typing import AudioLoader, Augmentation
__all__ = [
"AugmentationConfig",

View File

@ -10,10 +10,12 @@ from batdetect2.postprocess import get_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter
from batdetect2.typing.evaluate import ClipEvaluation
from batdetect2.typing.models import ModelOutput
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.train import TrainExample
from batdetect2.typing import (
ClipEvaluation,
ModelOutput,
RawPrediction,
TrainExample,
)
class ValidationMetrics(Callback):

View File

@ -6,8 +6,7 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256

View File

@ -3,7 +3,7 @@ from typing import Optional, Union
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig
from batdetect2.train.augmentations import (
@ -80,7 +80,6 @@ class OptimizerConfig(BaseConfig):
class TrainingConfig(BaseConfig):
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)

View File

@ -5,8 +5,8 @@ from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_audio_loader, build_preprocessor
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,
@ -14,10 +14,14 @@ from batdetect2.train.augmentations import (
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import Augmentation, ClipLabeller
from batdetect2.utils.arrays import adjust_width
from batdetect2.typing import (
AudioLoader,
Augmentation,
ClipLabeller,
ClipperProtocol,
PreprocessorProtocol,
TrainExample,
)
__all__ = [
"TrainingDataset",

View File

@ -13,14 +13,10 @@ import torch
from loguru import logger
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.typing import (
ClipLabeller,
Heatmaps,
TargetProtocol,
)
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
__all__ = [
"LabelConfig",

View File

@ -18,7 +18,7 @@ from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"

View File

@ -27,7 +27,7 @@ from loguru import logger
from pydantic import Field
from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
__all__ = [

View File

@ -1,4 +1,8 @@
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
from batdetect2.typing.evaluate import (
ClipEvaluation,
MatchEvaluation,
MetricsProtocol,
)
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
@ -10,9 +14,11 @@ from batdetect2.typing.preprocess import (
AudioLoader,
PreprocessorProtocol,
SpectrogramBuilder,
SpectrogramPipeline,
)
from batdetect2.typing.targets import (
Position,
ROITargetMapper,
Size,
SoundEventDecoder,
SoundEventEncoder,
@ -34,6 +40,7 @@ __all__ = [
"Augmentation",
"BackboneModel",
"BatDetect2Prediction",
"ClipEvaluation",
"ClipLabeller",
"ClipperProtocol",
"DetectionModel",
@ -47,12 +54,14 @@ __all__ = [
"Position",
"PostprocessorProtocol",
"PreprocessorProtocol",
"ROITargetMapper",
"RawPrediction",
"Size",
"SoundEventDecoder",
"SoundEventEncoder",
"SoundEventFilter",
"SpectrogramBuilder",
"SpectrogramPipeline",
"TargetProtocol",
"TrainExample",
]