diff --git a/src/batdetect2/cli/compat.py b/src/batdetect2/cli/compat.py index b02c283..856be3a 100644 --- a/src/batdetect2/cli/compat.py +++ b/src/batdetect2/cli/compat.py @@ -1,10 +1,15 @@ +import os + import click -from batdetect2 import api from batdetect2.cli.base import cli -from batdetect2.detector.parameters import DEFAULT_MODEL_PATH -from batdetect2.types import ProcessingConfiguration -from batdetect2.utils.detector_utils import save_results_to_file + +DEFAULT_MODEL_PATH = os.path.join( + os.path.dirname(os.path.dirname(__file__)), + "models", + "checkpoints", + "Net2DFast_UK_same.pth.tar", +) @cli.command() @@ -74,6 +79,9 @@ def detect( Input files should be short in duration e.g. < 30 seconds. """ + from batdetect2 import api + from batdetect2.utils.detector_utils import save_results_to_file + click.echo(f"Loading model: {args['model_path']}") model, params = api.load_model(args["model_path"]) @@ -123,7 +131,7 @@ def detect( click.echo(f" {err}") -def print_config(config: ProcessingConfiguration): +def print_config(config): """Print the processing configuration.""" click.echo("\nProcessing Configuration:") click.echo(f"Time Expansion Factor: {config.get('time_expansion')}") diff --git a/src/batdetect2/cli/data.py b/src/batdetect2/cli/data.py index 94b7dac..f824211 100644 --- a/src/batdetect2/cli/data.py +++ b/src/batdetect2/cli/data.py @@ -4,7 +4,6 @@ from typing import Optional import click from batdetect2.cli.base import cli -from batdetect2.data import load_dataset_from_config __all__ = ["data"] @@ -33,6 +32,8 @@ def summary( field: Optional[str] = None, base_dir: Optional[Path] = None, ): + from batdetect2.data import load_dataset_from_config + base_dir = base_dir or Path.cwd() dataset = load_dataset_from_config( dataset_config, diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index 172ef3f..20f30b6 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -6,9 +6,6 @@ import click from loguru import logger from batdetect2.cli.base import cli -from batdetect2.data import load_dataset_from_config -from batdetect2.evaluate.evaluate import evaluate -from batdetect2.train.lightning import load_model_from_checkpoint __all__ = ["evaluate_command"] @@ -31,6 +28,10 @@ def evaluate_command( workers: Optional[int] = None, verbose: int = 0, ): + from batdetect2.data import load_dataset_from_config + from batdetect2.evaluate.evaluate import evaluate + from batdetect2.train.lightning import load_model_from_checkpoint + logger.remove() if verbose == 0: log_level = "WARNING" diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index af2c1b6..2562a48 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -6,13 +6,6 @@ import click from loguru import logger from batdetect2.cli.base import cli -from batdetect2.data import load_dataset_from_config -from batdetect2.targets import load_target_config -from batdetect2.train import ( - FullTrainingConfig, - load_full_training_config, - train, -) __all__ = ["train_command"] @@ -53,6 +46,14 @@ def train_command( run_name: Optional[str] = None, verbose: int = 0, ): + from batdetect2.data import load_dataset_from_config + from batdetect2.targets import load_target_config + from batdetect2.train import ( + FullTrainingConfig, + load_full_training_config, + train, + ) + logger.remove() if verbose == 0: log_level = "WARNING" diff --git a/src/batdetect2/compat/data.py b/src/batdetect2/compat/data.py index 598be50..6473eb6 100644 --- a/src/batdetect2/compat/data.py +++ b/src/batdetect2/compat/data.py @@ -11,7 +11,6 @@ from soundevent import data from soundevent.geometry import compute_bounds from soundevent.types import ClassMapper -from batdetect2.targets.terms import get_term_from_key from batdetect2.types import ( Annotation, AudioLoaderAnnotationGroup, @@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation( uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), sound_event=sound_event, tags=[ - data.Tag( - term=get_term_from_key(label_key), - value=annotation["class"], - ), - data.Tag( - term=get_term_from_key(event_key), - value=annotation["event"], - ), - data.Tag( - term=get_term_from_key(individual_key), - value=str(annotation["individual"]), - ), + data.Tag(key=label_key, value=annotation["class"]), + data.Tag(key=event_key, value=annotation["event"]), + data.Tag(key=individual_key, value=str(annotation["individual"])), ], ) @@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction( tags=[ data.PredictedTag( score=annotation["class_prob"], - tag=data.Tag( - term=get_term_from_key(label_key), - value=annotation["class"], - ), + tag=data.Tag(key=label_key, value=annotation["class"]), ), data.PredictedTag( score=annotation["det_prob"], - tag=data.Tag( - term=get_term_from_key(event_key), - value=annotation["event"], - ), + tag=data.Tag(key=event_key, value=annotation["event"]), ), ], ) diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py new file mode 100644 index 0000000..0c49c55 --- /dev/null +++ b/src/batdetect2/config.py @@ -0,0 +1,16 @@ +from typing import Literal + +from batdetect2.core import BaseConfig +from batdetect2.evaluate.config import EvaluationConfig +from batdetect2.models.backbones import BackboneConfig +from batdetect2.preprocess import PreprocessingConfig +from batdetect2.train.config import TrainingConfig + + +class BatDetect2Config(BaseConfig): + config_version: Literal["v1"] = "v1" + + train: TrainingConfig + evaluation: EvaluationConfig + model: BackboneConfig + preprocess: PreprocessingConfig diff --git a/src/batdetect2/core/__init__.py b/src/batdetect2/core/__init__.py new file mode 100644 index 0000000..62730e8 --- /dev/null +++ b/src/batdetect2/core/__init__.py @@ -0,0 +1,8 @@ +from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.core.registries import Registry + +__all__ = [ + "BaseConfig", + "load_config", + "Registry", +] diff --git a/src/batdetect2/utils/arrays.py b/src/batdetect2/core/arrays.py similarity index 100% rename from src/batdetect2/utils/arrays.py rename to src/batdetect2/core/arrays.py diff --git a/src/batdetect2/configs.py b/src/batdetect2/core/configs.py similarity index 100% rename from src/batdetect2/configs.py rename to src/batdetect2/core/configs.py diff --git a/src/batdetect2/data/_core.py b/src/batdetect2/core/registries.py similarity index 93% rename from src/batdetect2/data/_core.py rename to src/batdetect2/core/registries.py index 6b16dd8..2535b82 100644 --- a/src/batdetect2/data/_core.py +++ b/src/batdetect2/core/registries.py @@ -1,7 +1,12 @@ +import sys from typing import Generic, Protocol, Type, TypeVar from pydantic import BaseModel -from typing_extensions import ParamSpec + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec __all__ = [ "Registry", diff --git a/src/batdetect2/data/annotations/aoef.py b/src/batdetect2/data/annotations/aoef.py index f57393b..748dac8 100644 --- a/src/batdetect2/data/annotations/aoef.py +++ b/src/batdetect2/data/annotations/aoef.py @@ -18,7 +18,7 @@ from uuid import uuid5 from pydantic import Field from soundevent import data, io -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig from batdetect2.data.annotations.types import AnnotatedDataset __all__ = [ diff --git a/src/batdetect2/data/annotations/batdetect2.py b/src/batdetect2/data/annotations/batdetect2.py index 1d17d81..5982f42 100644 --- a/src/batdetect2/data/annotations/batdetect2.py +++ b/src/batdetect2/data/annotations/batdetect2.py @@ -33,7 +33,7 @@ from loguru import logger from pydantic import Field, ValidationError from soundevent import data -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig from batdetect2.data.annotations.legacy import ( FileAnnotation, file_annotation_to_clip, diff --git a/src/batdetect2/data/annotations/types.py b/src/batdetect2/data/annotations/types.py index 73adf83..74e769b 100644 --- a/src/batdetect2/data/annotations/types.py +++ b/src/batdetect2/data/annotations/types.py @@ -1,6 +1,6 @@ from pathlib import Path -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig __all__ = [ "AnnotatedDataset", diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py index b3d8fea..54c082b 100644 --- a/src/batdetect2/data/conditions.py +++ b/src/batdetect2/data/conditions.py @@ -5,8 +5,8 @@ from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds -from batdetect2.configs import BaseConfig -from batdetect2.data._core import Registry +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry SoundEventCondition = Callable[[data.SoundEventAnnotation], bool] diff --git a/src/batdetect2/data/datasets.py b/src/batdetect2/data/datasets.py index f1b5117..db7f728 100644 --- a/src/batdetect2/data/datasets.py +++ b/src/batdetect2/data/datasets.py @@ -25,7 +25,7 @@ from loguru import logger from pydantic import Field from soundevent import data, io -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.data.annotations import ( AnnotatedDataset, AnnotationFormats, diff --git a/src/batdetect2/data/transforms.py b/src/batdetect2/data/transforms.py index d57f567..62dd1c4 100644 --- a/src/batdetect2/data/transforms.py +++ b/src/batdetect2/data/transforms.py @@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig -from batdetect2.data._core import Registry +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry from batdetect2.data.conditions import ( SoundEventCondition, SoundEventConditionConfig, diff --git a/src/batdetect2/evaluate/affinity.py b/src/batdetect2/evaluate/affinity.py index fe753bc..5a2ab91 100644 --- a/src/batdetect2/evaluate/affinity.py +++ b/src/batdetect2/evaluate/affinity.py @@ -4,8 +4,8 @@ from pydantic import Field from soundevent import data from soundevent.evaluation import compute_affinity -from batdetect2.configs import BaseConfig -from batdetect2.data._core import Registry +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry from batdetect2.typing.evaluate import AffinityFunction affinity_functions: Registry[AffinityFunction, []] = Registry( diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index 324c948..3a02265 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -3,7 +3,7 @@ from typing import List, Optional from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig from batdetect2.evaluate.metrics import ( ClassificationAPConfig, diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 6df3a36..2d67c13 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity from soundevent.evaluation import match_geometries as optimal_match from soundevent.geometry import compute_bounds -from batdetect2.configs import BaseConfig -from batdetect2.data._core import Registry +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry from batdetect2.evaluate.affinity import ( AffinityConfig, GeometricIOUConfig, diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index 9c48f37..6b52f7a 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -15,8 +15,8 @@ from pydantic import Field from sklearn import metrics from sklearn.preprocessing import label_binarize -from batdetect2.configs import BaseConfig -from batdetect2.data._core import Registry +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry from batdetect2.typing import MetricsProtocol from batdetect2.typing.evaluate import ClipEvaluation @@ -31,7 +31,7 @@ AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"] class DetectionAPConfig(BaseConfig): name: Literal["detection_ap"] = "detection_ap" - implementation: AveragePrecisionImplementation = "pascal_voc" + ap_implementation: AveragePrecisionImplementation = "pascal_voc" def pascal_voc_average_precision(y_true, y_score) -> float: @@ -96,7 +96,7 @@ class DetectionAP(MetricsProtocol): @classmethod def from_config(cls, config: DetectionAPConfig, class_names: List[str]): - return cls(implementation=config.implementation) + return cls(implementation=config.ap_implementation) metrics_registry.register(DetectionAPConfig, DetectionAP) @@ -104,6 +104,7 @@ metrics_registry.register(DetectionAPConfig, DetectionAP) class ClassificationAPConfig(BaseConfig): name: Literal["classification_ap"] = "classification_ap" + ap_implementation: AveragePrecisionImplementation = "pascal_voc" include: Optional[List[str]] = None exclude: Optional[List[str]] = None @@ -193,6 +194,7 @@ class ClassificationAP(MetricsProtocol): ): return cls( class_names, + implementation=config.ap_implementation, include=config.include, exclude=config.exclude, ) diff --git a/src/batdetect2/evaluate/plots.py b/src/batdetect2/evaluate/plots.py index 0a398c7..ae921ec 100644 --- a/src/batdetect2/evaluate/plots.py +++ b/src/batdetect2/evaluate/plots.py @@ -7,8 +7,8 @@ import matplotlib.pyplot as plt import pandas as pd from pydantic import Field -from batdetect2.configs import BaseConfig -from batdetect2.data._core import Registry +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.preprocess import PreprocessingConfig, build_preprocessor diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index 404d5ca..1bc7c14 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -32,7 +32,7 @@ import torch from pydantic import Field from soundevent.data import PathLike -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.models.backbones import ( Backbone, BackboneConfig, diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index 7fc377a..cf5f3b8 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -25,7 +25,7 @@ import torch.nn.functional as F from soundevent import data from torch import nn -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.models.bottleneck import ( DEFAULT_BOTTLENECK_CONFIG, BottleneckConfig, diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index ebc380e..9acd36d 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -34,7 +34,7 @@ import torch.nn.functional as F from pydantic import Field from torch import nn -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig __all__ = [ "ConvBlock", diff --git a/src/batdetect2/models/bottleneck.py b/src/batdetect2/models/bottleneck.py index 22d1647..253e702 100644 --- a/src/batdetect2/models/bottleneck.py +++ b/src/batdetect2/models/bottleneck.py @@ -20,7 +20,7 @@ import torch from pydantic import Field from torch import nn -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig from batdetect2.models.blocks import ( SelfAttentionConfig, VerticalConv, diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index 18133ac..dd74270 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -24,7 +24,7 @@ import torch from pydantic import Field from torch import nn -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig from batdetect2.models.blocks import ( ConvConfig, FreqCoordConvUpConfig, diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index 27b8853..e7da745 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -26,7 +26,7 @@ import torch from pydantic import Field from torch import nn -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig from batdetect2.models.blocks import ( ConvConfig, FreqCoordConvDownConfig, diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py index ccbe718..1fc7c73 100644 --- a/src/batdetect2/plotting/matches.py +++ b/src/batdetect2/plotting/matches.py @@ -8,8 +8,7 @@ from soundevent.plot.tags import TagColorMapper from batdetect2.plotting.clip_predictions import plot_prediction from batdetect2.plotting.clips import AudioLoader, plot_clip -from batdetect2.preprocess import PreprocessorProtocol -from batdetect2.typing.evaluate import MatchEvaluation +from batdetect2.typing import MatchEvaluation, PreprocessorProtocol __all__ = [ "plot_matches", diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index e1be16b..58ceafc 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -7,7 +7,7 @@ from loguru import logger from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.postprocess.decoding import ( DEFAULT_CLASSIFICATION_THRESHOLD, convert_raw_prediction_to_sound_event_prediction, diff --git a/src/batdetect2/preprocess/__init__.py b/src/batdetect2/preprocess/__init__.py index 7da0725..4118a5f 100644 --- a/src/batdetect2/preprocess/__init__.py +++ b/src/batdetect2/preprocess/__init__.py @@ -1,176 +1,21 @@ -"""Main entry point for the BatDetect2 Preprocessing subsystem. +"""Main entry point for the BatDetect2 preprocessing subsystem.""" -This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline -for converting raw audio input (from files or data objects) into processed -spectrograms suitable for input to BatDetect2 models. This ensures consistent -data handling between model training and inference. - -The preprocessing pipeline consists of two main stages, configured via nested -data structures: -1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial - processing like resampling, duration adjustment, centering, and scaling. - Configured via `AudioConfig`. -2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from - the processed waveform using STFT, followed by frequency cropping, optional - PCEN, amplitude scaling (dB, power, linear), optional denoising, optional - resizing, and optional peak normalization. Configured via - `SpectrogramConfig`. - -This module provides the primary interface: - -- `PreprocessingConfig`: A unified configuration object holding `AudioConfig` - and `SpectrogramConfig`. -- `load_preprocessing_config`: Function to load the unified configuration. -- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline. -- `StandardPreprocessor`: The default implementation of the `Preprocessor`. -- `build_preprocessor`: A factory function to create a `StandardPreprocessor` - instance from a `PreprocessingConfig`. - -""" - -from typing import Optional - -import torch -from loguru import logger -from pydantic import Field -from soundevent.data import PathLike - -from batdetect2.configs import BaseConfig, load_config -from batdetect2.preprocess.audio import ( - DEFAULT_DURATION, - SCALE_RAW_AUDIO, - TARGET_SAMPLERATE_HZ, - AudioConfig, - ResampleConfig, - build_audio_loader, - build_audio_pipeline, -) -from batdetect2.preprocess.spectrogram import ( +from batdetect2.preprocess.audio import build_audio_loader +from batdetect2.preprocess.config import ( MAX_FREQ, MIN_FREQ, - FrequencyConfig, - PcenConfig, - SpectrogramConfig, - SpectrogramPipeline, - STFTConfig, - _spec_params_from_config, - build_spectrogram_builder, - build_spectrogram_pipeline, + TARGET_SAMPLERATE_HZ, + PreprocessingConfig, + load_preprocessing_config, ) -from batdetect2.typing import PreprocessorProtocol +from batdetect2.preprocess.preprocessor import build_preprocessor __all__ = [ - "AudioConfig", - "DEFAULT_DURATION", - "FrequencyConfig", - "MAX_FREQ", "MIN_FREQ", - "PcenConfig", - "PreprocessingConfig", - "ResampleConfig", - "SCALE_RAW_AUDIO", - "STFTConfig", - "SpectrogramConfig", + "MAX_FREQ", "TARGET_SAMPLERATE_HZ", - "build_audio_loader", - "build_spectrogram_builder", + "PreprocessingConfig", "load_preprocessing_config", + "build_preprocessor", + "build_audio_loader", ] - - -class PreprocessingConfig(BaseConfig): - """Unified configuration for the audio preprocessing pipeline. - - Aggregates the configuration for both the initial audio processing stage - and the subsequent spectrogram generation stage. - - Attributes - ---------- - audio : AudioConfig - Configuration settings for the audio loading and initial waveform - processing steps (e.g., resampling, duration adjustment, scaling). - Defaults to default `AudioConfig` settings if omitted. - spectrogram : SpectrogramConfig - Configuration settings for the spectrogram generation process - (e.g., STFT parameters, frequency cropping, scaling, denoising, - resizing). Defaults to default `SpectrogramConfig` settings if omitted. - """ - - audio: AudioConfig = Field(default_factory=AudioConfig) - spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) - - -def load_preprocessing_config( - path: PathLike, - field: Optional[str] = None, -) -> PreprocessingConfig: - return load_config(path, schema=PreprocessingConfig, field=field) - - -class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): - """Standard implementation of the `Preprocessor` protocol.""" - - input_samplerate: int - output_samplerate: float - - max_freq: float - min_freq: float - - def __init__( - self, - audio_pipeline: torch.nn.Module, - spectrogram_pipeline: SpectrogramPipeline, - input_samplerate: int, - output_samplerate: float, - max_freq: float, - min_freq: float, - ) -> None: - super().__init__() - self.audio_pipeline = audio_pipeline - self.spectrogram_pipeline = spectrogram_pipeline - - self.max_freq = max_freq - self.min_freq = min_freq - - self.input_samplerate = input_samplerate - self.output_samplerate = output_samplerate - - def forward(self, wav: torch.Tensor) -> torch.Tensor: - wav = self.audio_pipeline(wav) - return self.spectrogram_pipeline(wav) - - -def compute_output_samplerate(config: PreprocessingConfig) -> float: - samplerate = config.audio.samplerate - _, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft) - factor = config.spectrogram.size.resize_factor - return samplerate * factor / hop_size - - -def build_preprocessor( - config: Optional[PreprocessingConfig] = None, -) -> PreprocessorProtocol: - """Factory function to build the standard preprocessor from configuration.""" - config = config or PreprocessingConfig() - logger.opt(lazy=True).debug( - "Building preprocessor with config: \n{}", - lambda: config.to_yaml_string(), - ) - - samplerate = config.audio.samplerate - - min_freq = config.spectrogram.frequencies.min_freq - max_freq = config.spectrogram.frequencies.max_freq - - output_samplerate = compute_output_samplerate(config) - - return StandardPreprocessor( - audio_pipeline=build_audio_pipeline(config.audio), - spectrogram_pipeline=build_spectrogram_pipeline( - samplerate, config.spectrogram - ), - input_samplerate=samplerate, - output_samplerate=output_samplerate, - min_freq=min_freq, - max_freq=max_freq, - ) diff --git a/src/batdetect2/preprocess/audio.py b/src/batdetect2/preprocess/audio.py index c5c72e2..9ca5984 100644 --- a/src/batdetect2/preprocess/audio.py +++ b/src/batdetect2/preprocess/audio.py @@ -1,64 +1,34 @@ """Handles loading and initial preprocessing of audio waveforms.""" -from typing import Annotated, List, Literal, Optional, Union +from typing import Optional import numpy as np import torch from numpy.typing import DTypeLike -from pydantic import Field from scipy.signal import resample, resample_poly from soundevent import audio, data from soundfile import LibsndfileError -from batdetect2.configs import BaseConfig from batdetect2.preprocess.common import CenterTensor, PeakNormalize +from batdetect2.preprocess.config import ( + TARGET_SAMPLERATE_HZ, + AudioConfig, + AudioTransform, + ResampleConfig, +) from batdetect2.typing import AudioLoader __all__ = [ - "ResampleConfig", - "AudioConfig", "SoundEventAudioLoader", "build_audio_loader", "load_file_audio", "load_recording_audio", "load_clip_audio", "resample_audio", - "TARGET_SAMPLERATE_HZ", - "SCALE_RAW_AUDIO", - "DEFAULT_DURATION", ] -TARGET_SAMPLERATE_HZ = 256_000 -"""Default target sample rate in Hz used if resampling is enabled.""" -SCALE_RAW_AUDIO = False -"""Default setting for whether to perform peak normalization.""" - -DEFAULT_DURATION = None -"""Default setting for target audio duration in seconds.""" - - -class ResampleConfig(BaseConfig): - """Configuration for audio resampling. - - Attributes - ---------- - samplerate : int, default=256000 - The target sample rate in Hz to resample the audio to. Must be > 0. - method : str, default="poly" - The resampling algorithm to use. Options: - - "poly": Polyphase resampling using `scipy.signal.resample_poly`. - Generally fast. - - "fourier": Resampling via Fourier method using - `scipy.signal.resample`. May handle non-integer - resampling factors differently. - """ - - enabled: bool = True - method: str = "poly" - - -class SoundEventAudioLoader: +class SoundEventAudioLoader(AudioLoader): """Concrete implementation of the `AudioLoader`.""" def __init__( @@ -294,19 +264,6 @@ def resample_audio_fourier( ) -class CenterAudioConfig(BaseConfig): - name: Literal["center_audio"] = "center_audio" - - -class ScaleAudioConfig(BaseConfig): - name: Literal["scale_audio"] = "scale_audio" - - -class FixDurationConfig(BaseConfig): - name: Literal["fix_duration"] = "fix_duration" - duration: float = 0.5 - - class FixDuration(torch.nn.Module): def __init__(self, samplerate: int, duration: float): super().__init__() @@ -326,24 +283,6 @@ class FixDuration(torch.nn.Module): return torch.nn.functional.pad(wav, (0, self.length - length)) -AudioTransform = Annotated[ - Union[ - FixDurationConfig, - ScaleAudioConfig, - CenterAudioConfig, - ], - Field(discriminator="name"), -] - - -class AudioConfig(BaseConfig): - """Configuration for loading and initial audio preprocessing.""" - - samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) - resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) - transforms: List[AudioTransform] = Field(default_factory=list) - - def build_audio_loader( config: Optional[AudioConfig] = None, ) -> AudioLoader: diff --git a/src/batdetect2/preprocess/config.py b/src/batdetect2/preprocess/config.py new file mode 100644 index 0000000..b60c067 --- /dev/null +++ b/src/batdetect2/preprocess/config.py @@ -0,0 +1,212 @@ +from collections.abc import Sequence +from typing import Annotated, List, Literal, Optional, Union + +from pydantic import Field +from soundevent.data import PathLike + +from batdetect2.core.configs import BaseConfig, load_config + +__all__ = [ + "load_preprocessing_config", + "CenterAudioConfig", + "ScaleAudioConfig", + "FixDurationConfig", + "ResampleConfig", + "AudioTransform", + "AudioConfig", + "STFTConfig", + "FrequencyConfig", + "PcenConfig", + "ScaleAmplitudeConfig", + "SpectralMeanSubstractionConfig", + "ResizeConfig", + "PeakNormalizeConfig", + "SpectrogramTransform", + "SpectrogramConfig", + "PreprocessingConfig", + "TARGET_SAMPLERATE_HZ", + "MIN_FREQ", + "MAX_FREQ", +] + +TARGET_SAMPLERATE_HZ = 256_000 +"""Default target sample rate in Hz used if resampling is enabled.""" + +MIN_FREQ = 10_000 +"""Default minimum frequency (Hz) for spectrogram frequency cropping.""" + +MAX_FREQ = 120_000 +"""Default maximum frequency (Hz) for spectrogram frequency cropping.""" + + +class CenterAudioConfig(BaseConfig): + name: Literal["center_audio"] = "center_audio" + + +class ScaleAudioConfig(BaseConfig): + name: Literal["scale_audio"] = "scale_audio" + + +class FixDurationConfig(BaseConfig): + name: Literal["fix_duration"] = "fix_duration" + duration: float = 0.5 + + +class ResampleConfig(BaseConfig): + """Configuration for audio resampling. + + Attributes + ---------- + samplerate : int, default=256000 + The target sample rate in Hz to resample the audio to. Must be > 0. + method : str, default="poly" + The resampling algorithm to use. Options: + - "poly": Polyphase resampling using `scipy.signal.resample_poly`. + Generally fast. + - "fourier": Resampling via Fourier method using + `scipy.signal.resample`. May handle non-integer + resampling factors differently. + """ + + enabled: bool = True + method: str = "poly" + + +AudioTransform = Annotated[ + Union[ + FixDurationConfig, + ScaleAudioConfig, + CenterAudioConfig, + ], + Field(discriminator="name"), +] + + +class AudioConfig(BaseConfig): + """Configuration for loading and initial audio preprocessing.""" + + samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) + resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) + transforms: List[AudioTransform] = Field(default_factory=list) + + +class STFTConfig(BaseConfig): + """Configuration for the Short-Time Fourier Transform (STFT). + + Attributes + ---------- + window_duration : float, default=0.002 + Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be + > 0. Determines frequency resolution (longer window = finer frequency + resolution). + window_overlap : float, default=0.75 + Fraction of overlap between consecutive STFT windows (e.g., 0.75 + for 75%). Must be >= 0 and < 1. Determines time resolution + (higher overlap = finer time resolution). + window_fn : str, default="hann" + Name of the window function to apply before FFT calculation. Common + options include "hann", "hamming", "blackman". See + `scipy.signal.get_window`. + """ + + window_duration: float = Field(default=0.002, gt=0) + window_overlap: float = Field(default=0.75, ge=0, lt=1) + window_fn: str = "hann" + + +class FrequencyConfig(BaseConfig): + """Configuration for frequency axis parameters. + + Attributes + ---------- + max_freq : int, default=120000 + Maximum frequency in Hz to retain in the spectrogram after STFT. + Frequencies above this value will be cropped. Must be > 0. + min_freq : int, default=10000 + Minimum frequency in Hz to retain in the spectrogram after STFT. + Frequencies below this value will be cropped. Must be >= 0. + """ + + max_freq: int = Field(default=120_000, ge=0) + min_freq: int = Field(default=10_000, ge=0) + + +class PcenConfig(BaseConfig): + """Configuration for Per-Channel Energy Normalization (PCEN).""" + + name: Literal["pcen"] = "pcen" + time_constant: float = 0.4 + gain: float = 0.98 + bias: float = 2 + power: float = 0.5 + + +class ScaleAmplitudeConfig(BaseConfig): + name: Literal["scale_amplitude"] = "scale_amplitude" + scale: Literal["power", "db"] = "db" + + +class SpectralMeanSubstractionConfig(BaseConfig): + name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction" + + +class ResizeConfig(BaseConfig): + name: Literal["resize_spec"] = "resize_spec" + height: int = 128 + resize_factor: float = 0.5 + + +class PeakNormalizeConfig(BaseConfig): + name: Literal["peak_normalize"] = "peak_normalize" + + +SpectrogramTransform = Annotated[ + Union[ + PcenConfig, + ScaleAmplitudeConfig, + SpectralMeanSubstractionConfig, + PeakNormalizeConfig, + ], + Field(discriminator="name"), +] + + +class SpectrogramConfig(BaseConfig): + stft: STFTConfig = Field(default_factory=STFTConfig) + frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) + size: ResizeConfig = Field(default_factory=ResizeConfig) + transforms: Sequence[SpectrogramTransform] = Field( + default_factory=lambda: [ + PcenConfig(), + SpectralMeanSubstractionConfig(), + ] + ) + + +class PreprocessingConfig(BaseConfig): + """Unified configuration for the audio preprocessing pipeline. + + Aggregates the configuration for both the initial audio processing stage + and the subsequent spectrogram generation stage. + + Attributes + ---------- + audio : AudioConfig + Configuration settings for the audio loading and initial waveform + processing steps (e.g., resampling, duration adjustment, scaling). + Defaults to default `AudioConfig` settings if omitted. + spectrogram : SpectrogramConfig + Configuration settings for the spectrogram generation process + (e.g., STFT parameters, frequency cropping, scaling, denoising, + resizing). Defaults to default `SpectrogramConfig` settings if omitted. + """ + + audio: AudioConfig = Field(default_factory=AudioConfig) + spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) + + +def load_preprocessing_config( + path: PathLike, + field: Optional[str] = None, +) -> PreprocessingConfig: + return load_config(path, schema=PreprocessingConfig, field=field) diff --git a/src/batdetect2/preprocess/preprocessor.py b/src/batdetect2/preprocess/preprocessor.py new file mode 100644 index 0000000..e2ef27d --- /dev/null +++ b/src/batdetect2/preprocess/preprocessor.py @@ -0,0 +1,86 @@ +from typing import Optional + +import torch +from loguru import logger + +from batdetect2.preprocess.audio import build_audio_pipeline +from batdetect2.preprocess.config import PreprocessingConfig +from batdetect2.preprocess.spectrogram import ( + _spec_params_from_config, + build_spectrogram_pipeline, +) +from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline + +__all__ = [ + "StandardPreprocessor", + "build_preprocessor", +] + + +class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol): + """Standard implementation of the `Preprocessor` protocol.""" + + input_samplerate: int + output_samplerate: float + + max_freq: float + min_freq: float + + def __init__( + self, + audio_pipeline: torch.nn.Module, + spectrogram_pipeline: SpectrogramPipeline, + input_samplerate: int, + output_samplerate: float, + max_freq: float, + min_freq: float, + ) -> None: + super().__init__() + self.audio_pipeline = audio_pipeline + self.spectrogram_pipeline = spectrogram_pipeline + + self.max_freq = max_freq + self.min_freq = min_freq + + self.input_samplerate = input_samplerate + self.output_samplerate = output_samplerate + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + wav = self.audio_pipeline(wav) + return self.spectrogram_pipeline(wav) + + +def compute_output_samplerate(config: PreprocessingConfig) -> float: + samplerate = config.audio.samplerate + _, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft) + factor = config.spectrogram.size.resize_factor + return samplerate * factor / hop_size + + +def build_preprocessor( + config: Optional[PreprocessingConfig] = None, +) -> PreprocessorProtocol: + """Factory function to build the standard preprocessor from configuration.""" + config = config or PreprocessingConfig() + logger.opt(lazy=True).debug( + "Building preprocessor with config: \n{}", + lambda: config.to_yaml_string(), + ) + + samplerate = config.audio.samplerate + + min_freq = config.spectrogram.frequencies.min_freq + max_freq = config.spectrogram.frequencies.max_freq + + output_samplerate = compute_output_samplerate(config) + + return StandardPreprocessor( + audio_pipeline=build_audio_pipeline(config.audio), + spectrogram_pipeline=build_spectrogram_pipeline( + samplerate, config.spectrogram + ), + input_samplerate=samplerate, + output_samplerate=output_samplerate, + min_freq=min_freq, + max_freq=max_freq, + ) diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index 02bad20..2fa3938 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -1,63 +1,37 @@ """Computes spectrograms from audio waveforms with configurable parameters.""" -from typing import ( - Annotated, - Callable, - List, - Literal, - Optional, - Sequence, - Union, -) +from typing import Callable, Optional import numpy as np import torch import torchaudio -from pydantic import Field -from batdetect2.configs import BaseConfig from batdetect2.preprocess.common import PeakNormalize +from batdetect2.preprocess.config import ( + ScaleAmplitudeConfig, + SpectrogramConfig, + SpectrogramTransform, + STFTConfig, +) __all__ = [ - "STFTConfig", - "FrequencyConfig", - "PcenConfig", - "SpectrogramConfig", "build_spectrogram_builder", - "MIN_FREQ", - "MAX_FREQ", + "build_spectrogram_pipeline", ] -MIN_FREQ = 10_000 -"""Default minimum frequency (Hz) for spectrogram frequency cropping.""" - -MAX_FREQ = 120_000 -"""Default maximum frequency (Hz) for spectrogram frequency cropping.""" - - -class STFTConfig(BaseConfig): - """Configuration for the Short-Time Fourier Transform (STFT). - - Attributes - ---------- - window_duration : float, default=0.002 - Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be - > 0. Determines frequency resolution (longer window = finer frequency - resolution). - window_overlap : float, default=0.75 - Fraction of overlap between consecutive STFT windows (e.g., 0.75 - for 75%). Must be >= 0 and < 1. Determines time resolution - (higher overlap = finer time resolution). - window_fn : str, default="hann" - Name of the window function to apply before FFT calculation. Common - options include "hann", "hamming", "blackman". See - `scipy.signal.get_window`. - """ - - window_duration: float = Field(default=0.002, gt=0) - window_overlap: float = Field(default=0.75, ge=0, lt=1) - window_fn: str = "hann" +def build_spectrogram_builder( + samplerate: int, + conf: STFTConfig, +) -> torch.nn.Module: + n_fft, hop_length = _spec_params_from_config(samplerate, conf) + return torchaudio.transforms.Spectrogram( + n_fft=n_fft, + hop_length=hop_length, + window_fn=get_spectrogram_window(conf.window_fn), + center=True, + power=1, + ) def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]: @@ -87,37 +61,6 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig): return n_fft, hop_length -def build_spectrogram_builder( - samplerate: int, - conf: STFTConfig, -) -> torch.nn.Module: - n_fft, hop_length = _spec_params_from_config(samplerate, conf) - return torchaudio.transforms.Spectrogram( - n_fft=n_fft, - hop_length=hop_length, - window_fn=get_spectrogram_window(conf.window_fn), - center=True, - power=1, - ) - - -class FrequencyConfig(BaseConfig): - """Configuration for frequency axis parameters. - - Attributes - ---------- - max_freq : int, default=120000 - Maximum frequency in Hz to retain in the spectrogram after STFT. - Frequencies above this value will be cropped. Must be > 0. - min_freq : int, default=10000 - Minimum frequency in Hz to retain in the spectrogram after STFT. - Frequencies below this value will be cropped. Must be >= 0. - """ - - max_freq: int = Field(default=120_000, ge=0) - min_freq: int = Field(default=10_000, ge=0) - - def _frequency_to_index( freq: float, samplerate: int, @@ -164,16 +107,6 @@ class FrequencyClip(torch.nn.Module): ) -class PcenConfig(BaseConfig): - """Configuration for Per-Channel Energy Normalization (PCEN).""" - - name: Literal["pcen"] = "pcen" - time_constant: float = 0.4 - gain: float = 0.98 - bias: float = 2 - power: float = 0.5 - - class PCEN(torch.nn.Module): def __init__( self, @@ -231,11 +164,6 @@ def _compute_smoothing_constant( return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) -class ScaleAmplitudeConfig(BaseConfig): - name: Literal["scale_amplitude"] = "scale_amplitude" - scale: Literal["power", "db"] = "db" - - class ToPower(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: return spec**2 @@ -253,22 +181,12 @@ def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module: ) -class SpectralMeanSubstractionConfig(BaseConfig): - name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction" - - class SpectralMeanSubstraction(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: mean = spec.mean(-1, keepdim=True) return (spec - mean).clamp(min=0) -class ResizeConfig(BaseConfig): - name: Literal["resize_spec"] = "resize_spec" - height: int = 128 - resize_factor: float = 0.5 - - class ResizeSpec(torch.nn.Module): def __init__(self, height: int, time_factor: float): super().__init__() @@ -295,33 +213,6 @@ class ResizeSpec(torch.nn.Module): return resized -class PeakNormalizeConfig(BaseConfig): - name: Literal["peak_normalize"] = "peak_normalize" - - -SpectrogramTransform = Annotated[ - Union[ - PcenConfig, - ScaleAmplitudeConfig, - SpectralMeanSubstractionConfig, - PeakNormalizeConfig, - ], - Field(discriminator="name"), -] - - -class SpectrogramConfig(BaseConfig): - stft: STFTConfig = Field(default_factory=STFTConfig) - frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) - size: ResizeConfig = Field(default_factory=ResizeConfig) - transforms: Sequence[SpectrogramTransform] = Field( - default_factory=lambda: [ - PcenConfig(), - SpectralMeanSubstractionConfig(), - ] - ) - - def _build_spectrogram_transform_step( step: SpectrogramTransform, samplerate: int, diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 0384b7a..fae0507 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -7,7 +7,7 @@ from loguru import logger from pydantic import Field, field_validator from soundevent import data -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.data.conditions import build_sound_event_condition from batdetect2.targets.classes import ( DEFAULT_CLASSES, diff --git a/src/batdetect2/targets/classes.py b/src/batdetect2/targets/classes.py index b277d5e..bfb4eeb 100644 --- a/src/batdetect2/targets/classes.py +++ b/src/batdetect2/targets/classes.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from pydantic import Field, PrivateAttr, computed_field, model_validator from soundevent import data -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig from batdetect2.data.conditions import ( AllOfConfig, HasAllTagsConfig, diff --git a/src/batdetect2/targets/rois.py b/src/batdetect2/targets/rois.py index 5a81089..495c981 100644 --- a/src/batdetect2/targets/rois.py +++ b/src/batdetect2/targets/rois.py @@ -26,12 +26,17 @@ import numpy as np from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig +from batdetect2.core.arrays import spec_to_xarray +from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess.audio import build_audio_loader -from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol -from batdetect2.typing.targets import Position, ROITargetMapper, Size -from batdetect2.utils.arrays import spec_to_xarray +from batdetect2.typing import ( + AudioLoader, + Position, + PreprocessorProtocol, + ROITargetMapper, + Size, +) __all__ = [ "Anchor", diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 5ec9437..c958c9d 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -11,11 +11,10 @@ from pydantic import Field from soundevent import data from soundevent.geometry import scale_geometry, shift_geometry -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.arrays import adjust_width +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.train.clips import get_subclip_annotation -from batdetect2.typing import Augmentation -from batdetect2.typing.preprocess import AudioLoader -from batdetect2.utils.arrays import adjust_width +from batdetect2.typing import AudioLoader, Augmentation __all__ = [ "AugmentationConfig", diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 7607ce0..3a0a9a0 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -10,10 +10,12 @@ from batdetect2.postprocess import get_raw_predictions from batdetect2.train.dataset import ValidationDataset from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import get_image_plotter -from batdetect2.typing.evaluate import ClipEvaluation -from batdetect2.typing.models import ModelOutput -from batdetect2.typing.postprocess import RawPrediction -from batdetect2.typing.train import TrainExample +from batdetect2.typing import ( + ClipEvaluation, + ModelOutput, + RawPrediction, + TrainExample, +) class ValidationMetrics(Callback): diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index a91fc49..667b038 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -6,8 +6,7 @@ from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds, intervals_overlap -from batdetect2.configs import BaseConfig -from batdetect2.data._core import Registry +from batdetect2.core import BaseConfig, Registry from batdetect2.typing import ClipperProtocol DEFAULT_TRAIN_CLIP_DURATION = 0.256 diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 7d4fe9f..66aa2b5 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -3,7 +3,7 @@ from typing import Optional, Union from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.evaluate import EvaluationConfig from batdetect2.models import ModelConfig from batdetect2.train.augmentations import ( @@ -80,7 +80,6 @@ class OptimizerConfig(BaseConfig): class TrainingConfig(BaseConfig): train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig) val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) - optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) loss: LossConfig = Field(default_factory=LossConfig) cliping: RandomClipConfig = Field(default_factory=RandomClipConfig) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index ced2028..03fb8b3 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -5,8 +5,8 @@ from loguru import logger from soundevent import data from torch.utils.data import DataLoader, Dataset -from batdetect2.plotting.clips import build_audio_loader -from batdetect2.preprocess import build_preprocessor +from batdetect2.core.arrays import adjust_width +from batdetect2.preprocess import build_audio_loader, build_preprocessor from batdetect2.train.augmentations import ( RandomAudioSource, build_augmentations, @@ -14,10 +14,14 @@ from batdetect2.train.augmentations import ( from batdetect2.train.clips import build_clipper from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.labels import build_clip_labeler -from batdetect2.typing import ClipperProtocol, TrainExample -from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol -from batdetect2.typing.train import Augmentation, ClipLabeller -from batdetect2.utils.arrays import adjust_width +from batdetect2.typing import ( + AudioLoader, + Augmentation, + ClipLabeller, + ClipperProtocol, + PreprocessorProtocol, + TrainExample, +) __all__ = [ "TrainingDataset", diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 868738b..58c6a37 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -13,14 +13,10 @@ import torch from loguru import logger from soundevent import data -from batdetect2.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.targets import build_targets, iterate_encoded_sound_events -from batdetect2.typing import ( - ClipLabeller, - Heatmaps, - TargetProtocol, -) +from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol __all__ = [ "LabelConfig", diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index fb6a36f..66344d1 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -18,7 +18,7 @@ from loguru import logger from pydantic import Field from soundevent import data -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index ea5b91a..e4ecd27 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -27,7 +27,7 @@ from loguru import logger from pydantic import Field from torch import nn -from batdetect2.configs import BaseConfig +from batdetect2.core.configs import BaseConfig from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample __all__ = [ diff --git a/src/batdetect2/typing/__init__.py b/src/batdetect2/typing/__init__.py index a9ef09e..c51b4e7 100644 --- a/src/batdetect2/typing/__init__.py +++ b/src/batdetect2/typing/__init__.py @@ -1,4 +1,8 @@ -from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol +from batdetect2.typing.evaluate import ( + ClipEvaluation, + MatchEvaluation, + MetricsProtocol, +) from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.postprocess import ( BatDetect2Prediction, @@ -10,9 +14,11 @@ from batdetect2.typing.preprocess import ( AudioLoader, PreprocessorProtocol, SpectrogramBuilder, + SpectrogramPipeline, ) from batdetect2.typing.targets import ( Position, + ROITargetMapper, Size, SoundEventDecoder, SoundEventEncoder, @@ -34,6 +40,7 @@ __all__ = [ "Augmentation", "BackboneModel", "BatDetect2Prediction", + "ClipEvaluation", "ClipLabeller", "ClipperProtocol", "DetectionModel", @@ -47,12 +54,14 @@ __all__ = [ "Position", "PostprocessorProtocol", "PreprocessorProtocol", + "ROITargetMapper", "RawPrediction", "Size", "SoundEventDecoder", "SoundEventEncoder", "SoundEventFilter", "SpectrogramBuilder", + "SpectrogramPipeline", "TargetProtocol", "TrainExample", ]