Restructuring

This commit is contained in:
mbsantiago 2025-09-16 13:38:38 +01:00
parent 60e922d565
commit 7d6cba5465
46 changed files with 474 additions and 463 deletions

View File

@ -1,10 +1,15 @@
import os
import click import click
from batdetect2 import api
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ProcessingConfiguration DEFAULT_MODEL_PATH = os.path.join(
from batdetect2.utils.detector_utils import save_results_to_file os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)
@cli.command() @cli.command()
@ -74,6 +79,9 @@ def detect(
Input files should be short in duration e.g. < 30 seconds. Input files should be short in duration e.g. < 30 seconds.
""" """
from batdetect2 import api
from batdetect2.utils.detector_utils import save_results_to_file
click.echo(f"Loading model: {args['model_path']}") click.echo(f"Loading model: {args['model_path']}")
model, params = api.load_model(args["model_path"]) model, params = api.load_model(args["model_path"])
@ -123,7 +131,7 @@ def detect(
click.echo(f" {err}") click.echo(f" {err}")
def print_config(config: ProcessingConfiguration): def print_config(config):
"""Print the processing configuration.""" """Print the processing configuration."""
click.echo("\nProcessing Configuration:") click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")

View File

@ -4,7 +4,6 @@ from typing import Optional
import click import click
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
__all__ = ["data"] __all__ = ["data"]
@ -33,6 +32,8 @@ def summary(
field: Optional[str] = None, field: Optional[str] = None,
base_dir: Optional[Path] = None, base_dir: Optional[Path] = None,
): ):
from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd() base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config( dataset = load_dataset_from_config(
dataset_config, dataset_config,

View File

@ -6,9 +6,6 @@ import click
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
__all__ = ["evaluate_command"] __all__ = ["evaluate_command"]
@ -31,6 +28,10 @@ def evaluate_command(
workers: Optional[int] = None, workers: Optional[int] = None,
verbose: int = 0, verbose: int = 0,
): ):
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
logger.remove() logger.remove()
if verbose == 0: if verbose == 0:
log_level = "WARNING" log_level = "WARNING"

View File

@ -6,13 +6,6 @@ import click
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
__all__ = ["train_command"] __all__ = ["train_command"]
@ -53,6 +46,14 @@ def train_command(
run_name: Optional[str] = None, run_name: Optional[str] = None,
verbose: int = 0, verbose: int = 0,
): ):
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
logger.remove() logger.remove()
if verbose == 0: if verbose == 0:
log_level = "WARNING" log_level = "WARNING"

View File

@ -11,7 +11,6 @@ from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from soundevent.types import ClassMapper from soundevent.types import ClassMapper
from batdetect2.targets.terms import get_term_from_key
from batdetect2.types import ( from batdetect2.types import (
Annotation, Annotation,
AudioLoaderAnnotationGroup, AudioLoaderAnnotationGroup,
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event, sound_event=sound_event,
tags=[ tags=[
data.Tag( data.Tag(key=label_key, value=annotation["class"]),
term=get_term_from_key(label_key), data.Tag(key=event_key, value=annotation["event"]),
value=annotation["class"], data.Tag(key=individual_key, value=str(annotation["individual"])),
),
data.Tag(
term=get_term_from_key(event_key),
value=annotation["event"],
),
data.Tag(
term=get_term_from_key(individual_key),
value=str(annotation["individual"]),
),
], ],
) )
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
tags=[ tags=[
data.PredictedTag( data.PredictedTag(
score=annotation["class_prob"], score=annotation["class_prob"],
tag=data.Tag( tag=data.Tag(key=label_key, value=annotation["class"]),
term=get_term_from_key(label_key),
value=annotation["class"],
),
), ),
data.PredictedTag( data.PredictedTag(
score=annotation["det_prob"], score=annotation["det_prob"],
tag=data.Tag( tag=data.Tag(key=event_key, value=annotation["event"]),
term=get_term_from_key(event_key),
value=annotation["event"],
),
), ),
], ],
) )

16
src/batdetect2/config.py Normal file
View File

@ -0,0 +1,16 @@
from typing import Literal
from batdetect2.core import BaseConfig
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.models.backbones import BackboneConfig
from batdetect2.preprocess import PreprocessingConfig
from batdetect2.train.config import TrainingConfig
class BatDetect2Config(BaseConfig):
config_version: Literal["v1"] = "v1"
train: TrainingConfig
evaluation: EvaluationConfig
model: BackboneConfig
preprocess: PreprocessingConfig

View File

@ -0,0 +1,8 @@
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
__all__ = [
"BaseConfig",
"load_config",
"Registry",
]

View File

@ -1,7 +1,12 @@
import sys
from typing import Generic, Protocol, Type, TypeVar from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import ParamSpec
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
__all__ = [ __all__ = [
"Registry", "Registry",

View File

@ -18,7 +18,7 @@ from uuid import uuid5
from pydantic import Field from pydantic import Field
from soundevent import data, io from soundevent import data, io
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.types import AnnotatedDataset from batdetect2.data.annotations.types import AnnotatedDataset
__all__ = [ __all__ = [

View File

@ -33,7 +33,7 @@ from loguru import logger
from pydantic import Field, ValidationError from pydantic import Field, ValidationError
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data.annotations.legacy import ( from batdetect2.data.annotations.legacy import (
FileAnnotation, FileAnnotation,
file_annotation_to_clip, file_annotation_to_clip,

View File

@ -1,6 +1,6 @@
from pathlib import Path from pathlib import Path
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
__all__ = [ __all__ = [
"AnnotatedDataset", "AnnotatedDataset",

View File

@ -5,8 +5,8 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data._core import Registry from batdetect2.core.registries import Registry
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]

View File

@ -25,7 +25,7 @@ from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data, io from soundevent import data, io
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.annotations import ( from batdetect2.data.annotations import (
AnnotatedDataset, AnnotatedDataset,
AnnotationFormats, AnnotationFormats,

View File

@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data._core import Registry from batdetect2.core.registries import Registry
from batdetect2.data.conditions import ( from batdetect2.data.conditions import (
SoundEventCondition, SoundEventCondition,
SoundEventConditionConfig, SoundEventConditionConfig,

View File

@ -4,8 +4,8 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data._core import Registry from batdetect2.core.registries import Registry
from batdetect2.typing.evaluate import AffinityFunction from batdetect2.typing.evaluate import AffinityFunction
affinity_functions: Registry[AffinityFunction, []] = Registry( affinity_functions: Registry[AffinityFunction, []] = Registry(

View File

@ -3,7 +3,7 @@ from typing import List, Optional
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
from batdetect2.evaluate.metrics import ( from batdetect2.evaluate.metrics import (
ClassificationAPConfig, ClassificationAPConfig,

View File

@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data._core import Registry from batdetect2.core.registries import Registry
from batdetect2.evaluate.affinity import ( from batdetect2.evaluate.affinity import (
AffinityConfig, AffinityConfig,
GeometricIOUConfig, GeometricIOUConfig,

View File

@ -15,8 +15,8 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data._core import Registry from batdetect2.core.registries import Registry
from batdetect2.typing import MetricsProtocol from batdetect2.typing import MetricsProtocol
from batdetect2.typing.evaluate import ClipEvaluation from batdetect2.typing.evaluate import ClipEvaluation
@ -31,7 +31,7 @@ AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig): class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap" name: Literal["detection_ap"] = "detection_ap"
implementation: AveragePrecisionImplementation = "pascal_voc" ap_implementation: AveragePrecisionImplementation = "pascal_voc"
def pascal_voc_average_precision(y_true, y_score) -> float: def pascal_voc_average_precision(y_true, y_score) -> float:
@ -96,7 +96,7 @@ class DetectionAP(MetricsProtocol):
@classmethod @classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]): def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls(implementation=config.implementation) return cls(implementation=config.ap_implementation)
metrics_registry.register(DetectionAPConfig, DetectionAP) metrics_registry.register(DetectionAPConfig, DetectionAP)
@ -104,6 +104,7 @@ metrics_registry.register(DetectionAPConfig, DetectionAP)
class ClassificationAPConfig(BaseConfig): class ClassificationAPConfig(BaseConfig):
name: Literal["classification_ap"] = "classification_ap" name: Literal["classification_ap"] = "classification_ap"
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
include: Optional[List[str]] = None include: Optional[List[str]] = None
exclude: Optional[List[str]] = None exclude: Optional[List[str]] = None
@ -193,6 +194,7 @@ class ClassificationAP(MetricsProtocol):
): ):
return cls( return cls(
class_names, class_names,
implementation=config.ap_implementation,
include=config.include, include=config.include,
exclude=config.exclude, exclude=config.exclude,
) )

View File

@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
from pydantic import Field from pydantic import Field
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data._core import Registry from batdetect2.core.registries import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor

View File

@ -32,7 +32,7 @@ import torch
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
Backbone, Backbone,
BackboneConfig, BackboneConfig,

View File

@ -25,7 +25,7 @@ import torch.nn.functional as F
from soundevent import data from soundevent import data
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.models.bottleneck import ( from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG, DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig, BottleneckConfig,

View File

@ -34,7 +34,7 @@ import torch.nn.functional as F
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
__all__ = [ __all__ = [
"ConvBlock", "ConvBlock",

View File

@ -20,7 +20,7 @@ import torch
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
SelfAttentionConfig, SelfAttentionConfig,
VerticalConv, VerticalConv,

View File

@ -24,7 +24,7 @@ import torch
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
ConvConfig, ConvConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,

View File

@ -26,7 +26,7 @@ import torch
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,

View File

@ -8,8 +8,7 @@ from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import AudioLoader, plot_clip from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.preprocess import PreprocessorProtocol from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
from batdetect2.typing.evaluate import MatchEvaluation
__all__ = [ __all__ = [
"plot_matches", "plot_matches",

View File

@ -7,7 +7,7 @@ from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import ( from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction, convert_raw_prediction_to_sound_event_prediction,

View File

@ -1,176 +1,21 @@
"""Main entry point for the BatDetect2 Preprocessing subsystem. """Main entry point for the BatDetect2 preprocessing subsystem."""
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline from batdetect2.preprocess.audio import build_audio_loader
for converting raw audio input (from files or data objects) into processed from batdetect2.preprocess.config import (
spectrograms suitable for input to BatDetect2 models. This ensures consistent
data handling between model training and inference.
The preprocessing pipeline consists of two main stages, configured via nested
data structures:
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
processing like resampling, duration adjustment, centering, and scaling.
Configured via `AudioConfig`.
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
the processed waveform using STFT, followed by frequency cropping, optional
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
resizing, and optional peak normalization. Configured via
`SpectrogramConfig`.
This module provides the primary interface:
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
and `SpectrogramConfig`.
- `load_preprocessing_config`: Function to load the unified configuration.
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
instance from a `PreprocessingConfig`.
"""
from typing import Optional
import torch
from loguru import logger
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess.audio import (
DEFAULT_DURATION,
SCALE_RAW_AUDIO,
TARGET_SAMPLERATE_HZ,
AudioConfig,
ResampleConfig,
build_audio_loader,
build_audio_pipeline,
)
from batdetect2.preprocess.spectrogram import (
MAX_FREQ, MAX_FREQ,
MIN_FREQ, MIN_FREQ,
FrequencyConfig, TARGET_SAMPLERATE_HZ,
PcenConfig, PreprocessingConfig,
SpectrogramConfig, load_preprocessing_config,
SpectrogramPipeline,
STFTConfig,
_spec_params_from_config,
build_spectrogram_builder,
build_spectrogram_pipeline,
) )
from batdetect2.typing import PreprocessorProtocol from batdetect2.preprocess.preprocessor import build_preprocessor
__all__ = [ __all__ = [
"AudioConfig",
"DEFAULT_DURATION",
"FrequencyConfig",
"MAX_FREQ",
"MIN_FREQ", "MIN_FREQ",
"PcenConfig", "MAX_FREQ",
"PreprocessingConfig",
"ResampleConfig",
"SCALE_RAW_AUDIO",
"STFTConfig",
"SpectrogramConfig",
"TARGET_SAMPLERATE_HZ", "TARGET_SAMPLERATE_HZ",
"build_audio_loader", "PreprocessingConfig",
"build_spectrogram_builder",
"load_preprocessing_config", "load_preprocessing_config",
"build_preprocessor",
"build_audio_loader",
] ]
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
input_samplerate: int
output_samplerate: float
max_freq: float
min_freq: float
def __init__(
self,
audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline,
input_samplerate: int,
output_samplerate: float,
max_freq: float,
min_freq: float,
) -> None:
super().__init__()
self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline
self.max_freq = max_freq
self.min_freq = min_freq
self.input_samplerate = input_samplerate
self.output_samplerate = output_samplerate
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav)
def compute_output_samplerate(config: PreprocessingConfig) -> float:
samplerate = config.audio.samplerate
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
factor = config.spectrogram.size.resize_factor
return samplerate * factor / hop_size
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
samplerate = config.audio.samplerate
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
output_samplerate = compute_output_samplerate(config)
return StandardPreprocessor(
audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram
),
input_samplerate=samplerate,
output_samplerate=output_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -1,64 +1,34 @@
"""Handles loading and initial preprocessing of audio waveforms.""" """Handles loading and initial preprocessing of audio waveforms."""
from typing import Annotated, List, Literal, Optional, Union from typing import Optional
import numpy as np import numpy as np
import torch import torch
from numpy.typing import DTypeLike from numpy.typing import DTypeLike
from pydantic import Field
from scipy.signal import resample, resample_poly from scipy.signal import resample, resample_poly
from soundevent import audio, data from soundevent import audio, data
from soundfile import LibsndfileError from soundfile import LibsndfileError
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.common import CenterTensor, PeakNormalize from batdetect2.preprocess.common import CenterTensor, PeakNormalize
from batdetect2.preprocess.config import (
TARGET_SAMPLERATE_HZ,
AudioConfig,
AudioTransform,
ResampleConfig,
)
from batdetect2.typing import AudioLoader from batdetect2.typing import AudioLoader
__all__ = [ __all__ = [
"ResampleConfig",
"AudioConfig",
"SoundEventAudioLoader", "SoundEventAudioLoader",
"build_audio_loader", "build_audio_loader",
"load_file_audio", "load_file_audio",
"load_recording_audio", "load_recording_audio",
"load_clip_audio", "load_clip_audio",
"resample_audio", "resample_audio",
"TARGET_SAMPLERATE_HZ",
"SCALE_RAW_AUDIO",
"DEFAULT_DURATION",
] ]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
SCALE_RAW_AUDIO = False class SoundEventAudioLoader(AudioLoader):
"""Default setting for whether to perform peak normalization."""
DEFAULT_DURATION = None
"""Default setting for target audio duration in seconds."""
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
class SoundEventAudioLoader:
"""Concrete implementation of the `AudioLoader`.""" """Concrete implementation of the `AudioLoader`."""
def __init__( def __init__(
@ -294,19 +264,6 @@ def resample_audio_fourier(
) )
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
class FixDuration(torch.nn.Module): class FixDuration(torch.nn.Module):
def __init__(self, samplerate: int, duration: float): def __init__(self, samplerate: int, duration: float):
super().__init__() super().__init__()
@ -326,24 +283,6 @@ class FixDuration(torch.nn.Module):
return torch.nn.functional.pad(wav, (0, self.length - length)) return torch.nn.functional.pad(wav, (0, self.length - length))
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"),
]
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
transforms: List[AudioTransform] = Field(default_factory=list)
def build_audio_loader( def build_audio_loader(
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,
) -> AudioLoader: ) -> AudioLoader:

View File

@ -0,0 +1,212 @@
from collections.abc import Sequence
from typing import Annotated, List, Literal, Optional, Union
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.core.configs import BaseConfig, load_config
__all__ = [
"load_preprocessing_config",
"CenterAudioConfig",
"ScaleAudioConfig",
"FixDurationConfig",
"ResampleConfig",
"AudioTransform",
"AudioConfig",
"STFTConfig",
"FrequencyConfig",
"PcenConfig",
"ScaleAmplitudeConfig",
"SpectralMeanSubstractionConfig",
"ResizeConfig",
"PeakNormalizeConfig",
"SpectrogramTransform",
"SpectrogramConfig",
"PreprocessingConfig",
"TARGET_SAMPLERATE_HZ",
"MIN_FREQ",
"MAX_FREQ",
]
TARGET_SAMPLERATE_HZ = 256_000
"""Default target sample rate in Hz used if resampling is enabled."""
MIN_FREQ = 10_000
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
MAX_FREQ = 120_000
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
class CenterAudioConfig(BaseConfig):
name: Literal["center_audio"] = "center_audio"
class ScaleAudioConfig(BaseConfig):
name: Literal["scale_audio"] = "scale_audio"
class FixDurationConfig(BaseConfig):
name: Literal["fix_duration"] = "fix_duration"
duration: float = 0.5
class ResampleConfig(BaseConfig):
"""Configuration for audio resampling.
Attributes
----------
samplerate : int, default=256000
The target sample rate in Hz to resample the audio to. Must be > 0.
method : str, default="poly"
The resampling algorithm to use. Options:
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
Generally fast.
- "fourier": Resampling via Fourier method using
`scipy.signal.resample`. May handle non-integer
resampling factors differently.
"""
enabled: bool = True
method: str = "poly"
AudioTransform = Annotated[
Union[
FixDurationConfig,
ScaleAudioConfig,
CenterAudioConfig,
],
Field(discriminator="name"),
]
class AudioConfig(BaseConfig):
"""Configuration for loading and initial audio preprocessing."""
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
transforms: List[AudioTransform] = Field(default_factory=list)
class STFTConfig(BaseConfig):
"""Configuration for the Short-Time Fourier Transform (STFT).
Attributes
----------
window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
> 0. Determines frequency resolution (longer window = finer frequency
resolution).
window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75
for 75%). Must be >= 0 and < 1. Determines time resolution
(higher overlap = finer time resolution).
window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common
options include "hann", "hamming", "blackman". See
`scipy.signal.get_window`.
"""
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
class FrequencyConfig(BaseConfig):
"""Configuration for frequency axis parameters.
Attributes
----------
max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT.
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class ScaleAmplitudeConfig(BaseConfig):
name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db"
class SpectralMeanSubstractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"),
]
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
class PreprocessingConfig(BaseConfig):
"""Unified configuration for the audio preprocessing pipeline.
Aggregates the configuration for both the initial audio processing stage
and the subsequent spectrogram generation stage.
Attributes
----------
audio : AudioConfig
Configuration settings for the audio loading and initial waveform
processing steps (e.g., resampling, duration adjustment, scaling).
Defaults to default `AudioConfig` settings if omitted.
spectrogram : SpectrogramConfig
Configuration settings for the spectrogram generation process
(e.g., STFT parameters, frequency cropping, scaling, denoising,
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
"""
audio: AudioConfig = Field(default_factory=AudioConfig)
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
def load_preprocessing_config(
path: PathLike,
field: Optional[str] = None,
) -> PreprocessingConfig:
return load_config(path, schema=PreprocessingConfig, field=field)

View File

@ -0,0 +1,86 @@
from typing import Optional
import torch
from loguru import logger
from batdetect2.preprocess.audio import build_audio_pipeline
from batdetect2.preprocess.config import PreprocessingConfig
from batdetect2.preprocess.spectrogram import (
_spec_params_from_config,
build_spectrogram_pipeline,
)
from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline
__all__ = [
"StandardPreprocessor",
"build_preprocessor",
]
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
"""Standard implementation of the `Preprocessor` protocol."""
input_samplerate: int
output_samplerate: float
max_freq: float
min_freq: float
def __init__(
self,
audio_pipeline: torch.nn.Module,
spectrogram_pipeline: SpectrogramPipeline,
input_samplerate: int,
output_samplerate: float,
max_freq: float,
min_freq: float,
) -> None:
super().__init__()
self.audio_pipeline = audio_pipeline
self.spectrogram_pipeline = spectrogram_pipeline
self.max_freq = max_freq
self.min_freq = min_freq
self.input_samplerate = input_samplerate
self.output_samplerate = output_samplerate
def forward(self, wav: torch.Tensor) -> torch.Tensor:
wav = self.audio_pipeline(wav)
return self.spectrogram_pipeline(wav)
def compute_output_samplerate(config: PreprocessingConfig) -> float:
samplerate = config.audio.samplerate
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
factor = config.spectrogram.size.resize_factor
return samplerate * factor / hop_size
def build_preprocessor(
config: Optional[PreprocessingConfig] = None,
) -> PreprocessorProtocol:
"""Factory function to build the standard preprocessor from configuration."""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
samplerate = config.audio.samplerate
min_freq = config.spectrogram.frequencies.min_freq
max_freq = config.spectrogram.frequencies.max_freq
output_samplerate = compute_output_samplerate(config)
return StandardPreprocessor(
audio_pipeline=build_audio_pipeline(config.audio),
spectrogram_pipeline=build_spectrogram_pipeline(
samplerate, config.spectrogram
),
input_samplerate=samplerate,
output_samplerate=output_samplerate,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -1,63 +1,37 @@
"""Computes spectrograms from audio waveforms with configurable parameters.""" """Computes spectrograms from audio waveforms with configurable parameters."""
from typing import ( from typing import Callable, Optional
Annotated,
Callable,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np import numpy as np
import torch import torch
import torchaudio import torchaudio
from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.preprocess.common import PeakNormalize from batdetect2.preprocess.common import PeakNormalize
from batdetect2.preprocess.config import (
ScaleAmplitudeConfig,
SpectrogramConfig,
SpectrogramTransform,
STFTConfig,
)
__all__ = [ __all__ = [
"STFTConfig",
"FrequencyConfig",
"PcenConfig",
"SpectrogramConfig",
"build_spectrogram_builder", "build_spectrogram_builder",
"MIN_FREQ", "build_spectrogram_pipeline",
"MAX_FREQ",
] ]
MIN_FREQ = 10_000 def build_spectrogram_builder(
"""Default minimum frequency (Hz) for spectrogram frequency cropping.""" samplerate: int,
conf: STFTConfig,
MAX_FREQ = 120_000 ) -> torch.nn.Module:
"""Default maximum frequency (Hz) for spectrogram frequency cropping.""" n_fft, hop_length = _spec_params_from_config(samplerate, conf)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
class STFTConfig(BaseConfig): hop_length=hop_length,
"""Configuration for the Short-Time Fourier Transform (STFT). window_fn=get_spectrogram_window(conf.window_fn),
center=True,
Attributes power=1,
---------- )
window_duration : float, default=0.002
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
> 0. Determines frequency resolution (longer window = finer frequency
resolution).
window_overlap : float, default=0.75
Fraction of overlap between consecutive STFT windows (e.g., 0.75
for 75%). Must be >= 0 and < 1. Determines time resolution
(higher overlap = finer time resolution).
window_fn : str, default="hann"
Name of the window function to apply before FFT calculation. Common
options include "hann", "hamming", "blackman". See
`scipy.signal.get_window`.
"""
window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann"
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]: def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
@ -87,37 +61,6 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig):
return n_fft, hop_length return n_fft, hop_length
def build_spectrogram_builder(
samplerate: int,
conf: STFTConfig,
) -> torch.nn.Module:
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
return torchaudio.transforms.Spectrogram(
n_fft=n_fft,
hop_length=hop_length,
window_fn=get_spectrogram_window(conf.window_fn),
center=True,
power=1,
)
class FrequencyConfig(BaseConfig):
"""Configuration for frequency axis parameters.
Attributes
----------
max_freq : int, default=120000
Maximum frequency in Hz to retain in the spectrogram after STFT.
Frequencies above this value will be cropped. Must be > 0.
min_freq : int, default=10000
Minimum frequency in Hz to retain in the spectrogram after STFT.
Frequencies below this value will be cropped. Must be >= 0.
"""
max_freq: int = Field(default=120_000, ge=0)
min_freq: int = Field(default=10_000, ge=0)
def _frequency_to_index( def _frequency_to_index(
freq: float, freq: float,
samplerate: int, samplerate: int,
@ -164,16 +107,6 @@ class FrequencyClip(torch.nn.Module):
) )
class PcenConfig(BaseConfig):
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
name: Literal["pcen"] = "pcen"
time_constant: float = 0.4
gain: float = 0.98
bias: float = 2
power: float = 0.5
class PCEN(torch.nn.Module): class PCEN(torch.nn.Module):
def __init__( def __init__(
self, self,
@ -231,11 +164,6 @@ def _compute_smoothing_constant(
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
class ScaleAmplitudeConfig(BaseConfig):
name: Literal["scale_amplitude"] = "scale_amplitude"
scale: Literal["power", "db"] = "db"
class ToPower(torch.nn.Module): class ToPower(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec**2 return spec**2
@ -253,22 +181,12 @@ def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
) )
class SpectralMeanSubstractionConfig(BaseConfig):
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
class SpectralMeanSubstraction(torch.nn.Module): class SpectralMeanSubstraction(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
mean = spec.mean(-1, keepdim=True) mean = spec.mean(-1, keepdim=True)
return (spec - mean).clamp(min=0) return (spec - mean).clamp(min=0)
class ResizeConfig(BaseConfig):
name: Literal["resize_spec"] = "resize_spec"
height: int = 128
resize_factor: float = 0.5
class ResizeSpec(torch.nn.Module): class ResizeSpec(torch.nn.Module):
def __init__(self, height: int, time_factor: float): def __init__(self, height: int, time_factor: float):
super().__init__() super().__init__()
@ -295,33 +213,6 @@ class ResizeSpec(torch.nn.Module):
return resized return resized
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"
SpectrogramTransform = Annotated[
Union[
PcenConfig,
ScaleAmplitudeConfig,
SpectralMeanSubstractionConfig,
PeakNormalizeConfig,
],
Field(discriminator="name"),
]
class SpectrogramConfig(BaseConfig):
stft: STFTConfig = Field(default_factory=STFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
size: ResizeConfig = Field(default_factory=ResizeConfig)
transforms: Sequence[SpectrogramTransform] = Field(
default_factory=lambda: [
PcenConfig(),
SpectralMeanSubstractionConfig(),
]
)
def _build_spectrogram_transform_step( def _build_spectrogram_transform_step(
step: SpectrogramTransform, step: SpectrogramTransform,
samplerate: int, samplerate: int,

View File

@ -7,7 +7,7 @@ from loguru import logger
from pydantic import Field, field_validator from pydantic import Field, field_validator
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.conditions import build_sound_event_condition from batdetect2.data.conditions import build_sound_event_condition
from batdetect2.targets.classes import ( from batdetect2.targets.classes import (
DEFAULT_CLASSES, DEFAULT_CLASSES,

View File

@ -3,7 +3,7 @@ from typing import Dict, List, Optional
from pydantic import Field, PrivateAttr, computed_field, model_validator from pydantic import Field, PrivateAttr, computed_field, model_validator
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.data.conditions import ( from batdetect2.data.conditions import (
AllOfConfig, AllOfConfig,
HasAllTagsConfig, HasAllTagsConfig,

View File

@ -26,12 +26,17 @@ import numpy as np
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.core.arrays import spec_to_xarray
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol from batdetect2.typing import (
from batdetect2.typing.targets import Position, ROITargetMapper, Size AudioLoader,
from batdetect2.utils.arrays import spec_to_xarray Position,
PreprocessorProtocol,
ROITargetMapper,
Size,
)
__all__ = [ __all__ = [
"Anchor", "Anchor",

View File

@ -11,11 +11,10 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.train.clips import get_subclip_annotation from batdetect2.train.clips import get_subclip_annotation
from batdetect2.typing import Augmentation from batdetect2.typing import AudioLoader, Augmentation
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.utils.arrays import adjust_width
__all__ = [ __all__ = [
"AugmentationConfig", "AugmentationConfig",

View File

@ -10,10 +10,12 @@ from batdetect2.postprocess import get_raw_predictions
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter from batdetect2.train.logging import get_image_plotter
from batdetect2.typing.evaluate import ClipEvaluation from batdetect2.typing import (
from batdetect2.typing.models import ModelOutput ClipEvaluation,
from batdetect2.typing.postprocess import RawPrediction ModelOutput,
from batdetect2.typing.train import TrainExample RawPrediction,
TrainExample,
)
class ValidationMetrics(Callback): class ValidationMetrics(Callback):

View File

@ -6,8 +6,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig from batdetect2.core import BaseConfig, Registry
from batdetect2.data._core import Registry
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_TRAIN_CLIP_DURATION = 0.256

View File

@ -3,7 +3,7 @@ from typing import Optional, Union
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
@ -80,7 +80,6 @@ class OptimizerConfig(BaseConfig):
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig) train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig) cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)

View File

@ -5,8 +5,8 @@ from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from batdetect2.plotting.clips import build_audio_loader from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_audio_loader, build_preprocessor
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
RandomAudioSource, RandomAudioSource,
build_augmentations, build_augmentations,
@ -14,10 +14,14 @@ from batdetect2.train.augmentations import (
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing import (
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol AudioLoader,
from batdetect2.typing.train import Augmentation, ClipLabeller Augmentation,
from batdetect2.utils.arrays import adjust_width ClipLabeller,
ClipperProtocol,
PreprocessorProtocol,
TrainExample,
)
__all__ = [ __all__ = [
"TrainingDataset", "TrainingDataset",

View File

@ -13,14 +13,10 @@ import torch
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.typing import ( from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
ClipLabeller,
Heatmaps,
TargetProtocol,
)
__all__ = [ __all__ = [
"LabelConfig", "LabelConfig",

View File

@ -18,7 +18,7 @@ from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"

View File

@ -27,7 +27,7 @@ from loguru import logger
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.core.configs import BaseConfig
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
__all__ = [ __all__ = [

View File

@ -1,4 +1,8 @@
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol from batdetect2.typing.evaluate import (
ClipEvaluation,
MatchEvaluation,
MetricsProtocol,
)
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
BatDetect2Prediction, BatDetect2Prediction,
@ -10,9 +14,11 @@ from batdetect2.typing.preprocess import (
AudioLoader, AudioLoader,
PreprocessorProtocol, PreprocessorProtocol,
SpectrogramBuilder, SpectrogramBuilder,
SpectrogramPipeline,
) )
from batdetect2.typing.targets import ( from batdetect2.typing.targets import (
Position, Position,
ROITargetMapper,
Size, Size,
SoundEventDecoder, SoundEventDecoder,
SoundEventEncoder, SoundEventEncoder,
@ -34,6 +40,7 @@ __all__ = [
"Augmentation", "Augmentation",
"BackboneModel", "BackboneModel",
"BatDetect2Prediction", "BatDetect2Prediction",
"ClipEvaluation",
"ClipLabeller", "ClipLabeller",
"ClipperProtocol", "ClipperProtocol",
"DetectionModel", "DetectionModel",
@ -47,12 +54,14 @@ __all__ = [
"Position", "Position",
"PostprocessorProtocol", "PostprocessorProtocol",
"PreprocessorProtocol", "PreprocessorProtocol",
"ROITargetMapper",
"RawPrediction", "RawPrediction",
"Size", "Size",
"SoundEventDecoder", "SoundEventDecoder",
"SoundEventEncoder", "SoundEventEncoder",
"SoundEventFilter", "SoundEventFilter",
"SpectrogramBuilder", "SpectrogramBuilder",
"SpectrogramPipeline",
"TargetProtocol", "TargetProtocol",
"TrainExample", "TrainExample",
] ]