mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Restructuring
This commit is contained in:
parent
60e922d565
commit
7d6cba5465
@ -1,10 +1,15 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"checkpoints",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -74,6 +79,9 @@ def detect(
|
||||
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
from batdetect2 import api
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
click.echo(f"Loading model: {args['model_path']}")
|
||||
model, params = api.load_model(args["model_path"])
|
||||
|
||||
@ -123,7 +131,7 @@ def detect(
|
||||
click.echo(f" {err}")
|
||||
|
||||
|
||||
def print_config(config: ProcessingConfiguration):
|
||||
def print_config(config):
|
||||
"""Print the processing configuration."""
|
||||
click.echo("\nProcessing Configuration:")
|
||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||
|
||||
@ -4,7 +4,6 @@ from typing import Optional
|
||||
import click
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
__all__ = ["data"]
|
||||
|
||||
@ -33,6 +32,8 @@ def summary(
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
|
||||
@ -6,9 +6,6 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
|
||||
__all__ = ["evaluate_command"]
|
||||
|
||||
@ -31,6 +28,10 @@ def evaluate_command(
|
||||
workers: Optional[int] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
|
||||
@ -6,13 +6,6 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
|
||||
__all__ = ["train_command"]
|
||||
|
||||
@ -53,6 +46,14 @@ def train_command(
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
|
||||
@ -11,7 +11,6 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2.targets.terms import get_term_from_key
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
AudioLoaderAnnotationGroup,
|
||||
@ -173,18 +172,9 @@ def annotation_to_sound_event_annotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
data.Tag(
|
||||
term=get_term_from_key(individual_key),
|
||||
value=str(annotation["individual"]),
|
||||
),
|
||||
data.Tag(key=label_key, value=annotation["class"]),
|
||||
data.Tag(key=event_key, value=annotation["event"]),
|
||||
data.Tag(key=individual_key, value=str(annotation["individual"])),
|
||||
],
|
||||
)
|
||||
|
||||
@ -219,17 +209,11 @@ def annotation_to_sound_event_prediction(
|
||||
tags=[
|
||||
data.PredictedTag(
|
||||
score=annotation["class_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(label_key),
|
||||
value=annotation["class"],
|
||||
),
|
||||
tag=data.Tag(key=label_key, value=annotation["class"]),
|
||||
),
|
||||
data.PredictedTag(
|
||||
score=annotation["det_prob"],
|
||||
tag=data.Tag(
|
||||
term=get_term_from_key(event_key),
|
||||
value=annotation["event"],
|
||||
),
|
||||
tag=data.Tag(key=event_key, value=annotation["event"]),
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
16
src/batdetect2/config.py
Normal file
16
src/batdetect2/config.py
Normal file
@ -0,0 +1,16 @@
|
||||
from typing import Literal
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.models.backbones import BackboneConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
|
||||
class BatDetect2Config(BaseConfig):
|
||||
config_version: Literal["v1"] = "v1"
|
||||
|
||||
train: TrainingConfig
|
||||
evaluation: EvaluationConfig
|
||||
model: BackboneConfig
|
||||
preprocess: PreprocessingConfig
|
||||
8
src/batdetect2/core/__init__.py
Normal file
8
src/batdetect2/core/__init__.py
Normal file
@ -0,0 +1,8 @@
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"Registry",
|
||||
]
|
||||
@ -1,6 +1,11 @@
|
||||
import sys
|
||||
from typing import Generic, Protocol, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import ParamSpec
|
||||
else:
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
__all__ = [
|
||||
@ -18,7 +18,7 @@ from uuid import uuid5
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.types import AnnotatedDataset
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -33,7 +33,7 @@ from loguru import logger
|
||||
from pydantic import Field, ValidationError
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.annotations.legacy import (
|
||||
FileAnnotation,
|
||||
file_annotation_to_clip,
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from pathlib import Path
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"AnnotatedDataset",
|
||||
|
||||
@ -5,8 +5,8 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
AnnotationFormats,
|
||||
|
||||
@ -4,8 +4,8 @@ from typing import Annotated, Dict, List, Literal, Optional, Union
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
|
||||
@ -4,8 +4,8 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing.evaluate import AffinityFunction
|
||||
|
||||
affinity_functions: Registry[AffinityFunction, []] = Registry(
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import List, Optional
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.match import MatchConfig, StartTimeMatchConfig
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAPConfig,
|
||||
|
||||
@ -8,8 +8,8 @@ from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import match_geometries as optimal_match
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
GeometricIOUConfig,
|
||||
|
||||
@ -15,8 +15,8 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing import MetricsProtocol
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
|
||||
@ -31,7 +31,7 @@ AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
|
||||
|
||||
class DetectionAPConfig(BaseConfig):
|
||||
name: Literal["detection_ap"] = "detection_ap"
|
||||
implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||
|
||||
|
||||
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||
@ -96,7 +96,7 @@ class DetectionAP(MetricsProtocol):
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
||||
return cls(implementation=config.implementation)
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||
@ -104,6 +104,7 @@ metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||
|
||||
class ClassificationAPConfig(BaseConfig):
|
||||
name: Literal["classification_ap"] = "classification_ap"
|
||||
ap_implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
@ -193,6 +194,7 @@ class ClassificationAP(MetricsProtocol):
|
||||
):
|
||||
return cls(
|
||||
class_names,
|
||||
implementation=config.ap_implementation,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
@ -7,8 +7,8 @@ import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
|
||||
@ -32,7 +32,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.backbones import (
|
||||
Backbone,
|
||||
BackboneConfig,
|
||||
|
||||
@ -25,7 +25,7 @@ import torch.nn.functional as F
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
|
||||
@ -34,7 +34,7 @@ import torch.nn.functional as F
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ConvBlock",
|
||||
|
||||
@ -20,7 +20,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
SelfAttentionConfig,
|
||||
VerticalConv,
|
||||
|
||||
@ -24,7 +24,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvUpConfig,
|
||||
|
||||
@ -26,7 +26,7 @@ import torch
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
|
||||
@ -8,8 +8,7 @@ from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.plotting.clip_predictions import plot_prediction
|
||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"plot_matches",
|
||||
|
||||
@ -7,7 +7,7 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
|
||||
@ -1,176 +1,21 @@
|
||||
"""Main entry point for the BatDetect2 Preprocessing subsystem.
|
||||
"""Main entry point for the BatDetect2 preprocessing subsystem."""
|
||||
|
||||
This package (`batdetect2.preprocessing`) defines and orchestrates the pipeline
|
||||
for converting raw audio input (from files or data objects) into processed
|
||||
spectrograms suitable for input to BatDetect2 models. This ensures consistent
|
||||
data handling between model training and inference.
|
||||
|
||||
The preprocessing pipeline consists of two main stages, configured via nested
|
||||
data structures:
|
||||
1. **Audio Processing (`.audio`)**: Loads audio waveforms and applies initial
|
||||
processing like resampling, duration adjustment, centering, and scaling.
|
||||
Configured via `AudioConfig`.
|
||||
2. **Spectrogram Generation (`.spectrogram`)**: Computes the spectrogram from
|
||||
the processed waveform using STFT, followed by frequency cropping, optional
|
||||
PCEN, amplitude scaling (dB, power, linear), optional denoising, optional
|
||||
resizing, and optional peak normalization. Configured via
|
||||
`SpectrogramConfig`.
|
||||
|
||||
This module provides the primary interface:
|
||||
|
||||
- `PreprocessingConfig`: A unified configuration object holding `AudioConfig`
|
||||
and `SpectrogramConfig`.
|
||||
- `load_preprocessing_config`: Function to load the unified configuration.
|
||||
- `Preprocessor`: A protocol defining the interface for the end-to-end pipeline.
|
||||
- `StandardPreprocessor`: The default implementation of the `Preprocessor`.
|
||||
- `build_preprocessor`: A factory function to create a `StandardPreprocessor`
|
||||
instance from a `PreprocessingConfig`.
|
||||
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess.audio import (
|
||||
DEFAULT_DURATION,
|
||||
SCALE_RAW_AUDIO,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
build_audio_loader,
|
||||
build_audio_pipeline,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.preprocess.config import (
|
||||
MAX_FREQ,
|
||||
MIN_FREQ,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
SpectrogramConfig,
|
||||
SpectrogramPipeline,
|
||||
STFTConfig,
|
||||
_spec_params_from_config,
|
||||
build_spectrogram_builder,
|
||||
build_spectrogram_pipeline,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
PreprocessingConfig,
|
||||
load_preprocessing_config,
|
||||
)
|
||||
from batdetect2.typing import PreprocessorProtocol
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
|
||||
__all__ = [
|
||||
"AudioConfig",
|
||||
"DEFAULT_DURATION",
|
||||
"FrequencyConfig",
|
||||
"MAX_FREQ",
|
||||
"MIN_FREQ",
|
||||
"PcenConfig",
|
||||
"PreprocessingConfig",
|
||||
"ResampleConfig",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"STFTConfig",
|
||||
"SpectrogramConfig",
|
||||
"MAX_FREQ",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"build_audio_loader",
|
||||
"build_spectrogram_builder",
|
||||
"PreprocessingConfig",
|
||||
"load_preprocessing_config",
|
||||
"build_preprocessor",
|
||||
"build_audio_loader",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
Aggregates the configuration for both the initial audio processing stage
|
||||
and the subsequent spectrogram generation stage.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio : AudioConfig
|
||||
Configuration settings for the audio loading and initial waveform
|
||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||
Defaults to default `AudioConfig` settings if omitted.
|
||||
spectrogram : SpectrogramConfig
|
||||
Configuration settings for the spectrogram generation process
|
||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
"""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
|
||||
|
||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
|
||||
input_samplerate: int
|
||||
output_samplerate: float
|
||||
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_pipeline: torch.nn.Module,
|
||||
spectrogram_pipeline: SpectrogramPipeline,
|
||||
input_samplerate: int,
|
||||
output_samplerate: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.audio_pipeline = audio_pipeline
|
||||
self.spectrogram_pipeline = spectrogram_pipeline
|
||||
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.output_samplerate = output_samplerate
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = self.audio_pipeline(wav)
|
||||
return self.spectrogram_pipeline(wav)
|
||||
|
||||
|
||||
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
||||
samplerate = config.audio.samplerate
|
||||
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
||||
factor = config.spectrogram.size.resize_factor
|
||||
return samplerate * factor / hop_size
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
config = config or PreprocessingConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building preprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
samplerate = config.audio.samplerate
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
output_samplerate = compute_output_samplerate(config)
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_pipeline=build_audio_pipeline(config.audio),
|
||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||
samplerate, config.spectrogram
|
||||
),
|
||||
input_samplerate=samplerate,
|
||||
output_samplerate=output_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
@ -1,64 +1,34 @@
|
||||
"""Handles loading and initial preprocessing of audio waveforms."""
|
||||
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import audio, data
|
||||
from soundfile import LibsndfileError
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.common import CenterTensor, PeakNormalize
|
||||
from batdetect2.preprocess.config import (
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
AudioConfig,
|
||||
AudioTransform,
|
||||
ResampleConfig,
|
||||
)
|
||||
from batdetect2.typing import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"ResampleConfig",
|
||||
"AudioConfig",
|
||||
"SoundEventAudioLoader",
|
||||
"build_audio_loader",
|
||||
"load_file_audio",
|
||||
"load_recording_audio",
|
||||
"load_clip_audio",
|
||||
"resample_audio",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"SCALE_RAW_AUDIO",
|
||||
"DEFAULT_DURATION",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
SCALE_RAW_AUDIO = False
|
||||
"""Default setting for whether to perform peak normalization."""
|
||||
|
||||
DEFAULT_DURATION = None
|
||||
"""Default setting for target audio duration in seconds."""
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
class SoundEventAudioLoader:
|
||||
class SoundEventAudioLoader(AudioLoader):
|
||||
"""Concrete implementation of the `AudioLoader`."""
|
||||
|
||||
def __init__(
|
||||
@ -294,19 +264,6 @@ def resample_audio_fourier(
|
||||
)
|
||||
|
||||
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
name: Literal["center_audio"] = "center_audio"
|
||||
|
||||
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
name: Literal["scale_audio"] = "scale_audio"
|
||||
|
||||
|
||||
class FixDurationConfig(BaseConfig):
|
||||
name: Literal["fix_duration"] = "fix_duration"
|
||||
duration: float = 0.5
|
||||
|
||||
|
||||
class FixDuration(torch.nn.Module):
|
||||
def __init__(self, samplerate: int, duration: float):
|
||||
super().__init__()
|
||||
@ -326,24 +283,6 @@ class FixDuration(torch.nn.Module):
|
||||
return torch.nn.functional.pad(wav, (0, self.length - length))
|
||||
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
FixDurationConfig,
|
||||
ScaleAudioConfig,
|
||||
CenterAudioConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
|
||||
def build_audio_loader(
|
||||
config: Optional[AudioConfig] = None,
|
||||
) -> AudioLoader:
|
||||
|
||||
212
src/batdetect2/preprocess/config.py
Normal file
212
src/batdetect2/preprocess/config.py
Normal file
@ -0,0 +1,212 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = [
|
||||
"load_preprocessing_config",
|
||||
"CenterAudioConfig",
|
||||
"ScaleAudioConfig",
|
||||
"FixDurationConfig",
|
||||
"ResampleConfig",
|
||||
"AudioTransform",
|
||||
"AudioConfig",
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"PcenConfig",
|
||||
"ScaleAmplitudeConfig",
|
||||
"SpectralMeanSubstractionConfig",
|
||||
"ResizeConfig",
|
||||
"PeakNormalizeConfig",
|
||||
"SpectrogramTransform",
|
||||
"SpectrogramConfig",
|
||||
"PreprocessingConfig",
|
||||
"TARGET_SAMPLERATE_HZ",
|
||||
"MIN_FREQ",
|
||||
"MAX_FREQ",
|
||||
]
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
"""Default target sample rate in Hz used if resampling is enabled."""
|
||||
|
||||
MIN_FREQ = 10_000
|
||||
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
MAX_FREQ = 120_000
|
||||
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
|
||||
class CenterAudioConfig(BaseConfig):
|
||||
name: Literal["center_audio"] = "center_audio"
|
||||
|
||||
|
||||
class ScaleAudioConfig(BaseConfig):
|
||||
name: Literal["scale_audio"] = "scale_audio"
|
||||
|
||||
|
||||
class FixDurationConfig(BaseConfig):
|
||||
name: Literal["fix_duration"] = "fix_duration"
|
||||
duration: float = 0.5
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
"""Configuration for audio resampling.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
samplerate : int, default=256000
|
||||
The target sample rate in Hz to resample the audio to. Must be > 0.
|
||||
method : str, default="poly"
|
||||
The resampling algorithm to use. Options:
|
||||
- "poly": Polyphase resampling using `scipy.signal.resample_poly`.
|
||||
Generally fast.
|
||||
- "fourier": Resampling via Fourier method using
|
||||
`scipy.signal.resample`. May handle non-integer
|
||||
resampling factors differently.
|
||||
"""
|
||||
|
||||
enabled: bool = True
|
||||
method: str = "poly"
|
||||
|
||||
|
||||
AudioTransform = Annotated[
|
||||
Union[
|
||||
FixDurationConfig,
|
||||
ScaleAudioConfig,
|
||||
CenterAudioConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
"""Configuration for loading and initial audio preprocessing."""
|
||||
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
transforms: List[AudioTransform] = Field(default_factory=list)
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
"""Configuration for the Short-Time Fourier Transform (STFT).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
window_duration : float, default=0.002
|
||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
||||
> 0. Determines frequency resolution (longer window = finer frequency
|
||||
resolution).
|
||||
window_overlap : float, default=0.75
|
||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
||||
(higher overlap = finer time resolution).
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function to apply before FFT calculation. Common
|
||||
options include "hann", "hamming", "blackman". See
|
||||
`scipy.signal.get_window`.
|
||||
"""
|
||||
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
"""Configuration for frequency axis parameters.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
max_freq : int, default=120000
|
||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies above this value will be cropped. Must be > 0.
|
||||
min_freq : int, default=10000
|
||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=120_000, ge=0)
|
||||
min_freq: int = Field(default=10_000, ge=0)
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class ScaleAmplitudeConfig(BaseConfig):
|
||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||
scale: Literal["power", "db"] = "db"
|
||||
|
||||
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||
|
||||
|
||||
class ResizeConfig(BaseConfig):
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
|
||||
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
name: Literal["peak_normalize"] = "peak_normalize"
|
||||
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
PeakNormalizeConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
transforms: Sequence[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseConfig):
|
||||
"""Unified configuration for the audio preprocessing pipeline.
|
||||
|
||||
Aggregates the configuration for both the initial audio processing stage
|
||||
and the subsequent spectrogram generation stage.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
audio : AudioConfig
|
||||
Configuration settings for the audio loading and initial waveform
|
||||
processing steps (e.g., resampling, duration adjustment, scaling).
|
||||
Defaults to default `AudioConfig` settings if omitted.
|
||||
spectrogram : SpectrogramConfig
|
||||
Configuration settings for the spectrogram generation process
|
||||
(e.g., STFT parameters, frequency cropping, scaling, denoising,
|
||||
resizing). Defaults to default `SpectrogramConfig` settings if omitted.
|
||||
"""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
def load_preprocessing_config(
|
||||
path: PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> PreprocessingConfig:
|
||||
return load_config(path, schema=PreprocessingConfig, field=field)
|
||||
86
src/batdetect2/preprocess/preprocessor.py
Normal file
86
src/batdetect2/preprocess/preprocessor.py
Normal file
@ -0,0 +1,86 @@
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.preprocess.audio import build_audio_pipeline
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
_spec_params_from_config,
|
||||
build_spectrogram_pipeline,
|
||||
)
|
||||
from batdetect2.typing import PreprocessorProtocol, SpectrogramPipeline
|
||||
|
||||
__all__ = [
|
||||
"StandardPreprocessor",
|
||||
"build_preprocessor",
|
||||
]
|
||||
|
||||
|
||||
class StandardPreprocessor(torch.nn.Module, PreprocessorProtocol):
|
||||
"""Standard implementation of the `Preprocessor` protocol."""
|
||||
|
||||
input_samplerate: int
|
||||
output_samplerate: float
|
||||
|
||||
max_freq: float
|
||||
min_freq: float
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_pipeline: torch.nn.Module,
|
||||
spectrogram_pipeline: SpectrogramPipeline,
|
||||
input_samplerate: int,
|
||||
output_samplerate: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.audio_pipeline = audio_pipeline
|
||||
self.spectrogram_pipeline = spectrogram_pipeline
|
||||
|
||||
self.max_freq = max_freq
|
||||
self.min_freq = min_freq
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.output_samplerate = output_samplerate
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> torch.Tensor:
|
||||
wav = self.audio_pipeline(wav)
|
||||
return self.spectrogram_pipeline(wav)
|
||||
|
||||
|
||||
def compute_output_samplerate(config: PreprocessingConfig) -> float:
|
||||
samplerate = config.audio.samplerate
|
||||
_, hop_size = _spec_params_from_config(samplerate, config.spectrogram.stft)
|
||||
factor = config.spectrogram.size.resize_factor
|
||||
return samplerate * factor / hop_size
|
||||
|
||||
|
||||
def build_preprocessor(
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> PreprocessorProtocol:
|
||||
"""Factory function to build the standard preprocessor from configuration."""
|
||||
config = config or PreprocessingConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building preprocessor with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
samplerate = config.audio.samplerate
|
||||
|
||||
min_freq = config.spectrogram.frequencies.min_freq
|
||||
max_freq = config.spectrogram.frequencies.max_freq
|
||||
|
||||
output_samplerate = compute_output_samplerate(config)
|
||||
|
||||
return StandardPreprocessor(
|
||||
audio_pipeline=build_audio_pipeline(config.audio),
|
||||
spectrogram_pipeline=build_spectrogram_pipeline(
|
||||
samplerate, config.spectrogram
|
||||
),
|
||||
input_samplerate=samplerate,
|
||||
output_samplerate=output_samplerate,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
@ -1,63 +1,37 @@
|
||||
"""Computes spectrograms from audio waveforms with configurable parameters."""
|
||||
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.common import PeakNormalize
|
||||
from batdetect2.preprocess.config import (
|
||||
ScaleAmplitudeConfig,
|
||||
SpectrogramConfig,
|
||||
SpectrogramTransform,
|
||||
STFTConfig,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"STFTConfig",
|
||||
"FrequencyConfig",
|
||||
"PcenConfig",
|
||||
"SpectrogramConfig",
|
||||
"build_spectrogram_builder",
|
||||
"MIN_FREQ",
|
||||
"MAX_FREQ",
|
||||
"build_spectrogram_pipeline",
|
||||
]
|
||||
|
||||
|
||||
MIN_FREQ = 10_000
|
||||
"""Default minimum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
MAX_FREQ = 120_000
|
||||
"""Default maximum frequency (Hz) for spectrogram frequency cropping."""
|
||||
|
||||
|
||||
class STFTConfig(BaseConfig):
|
||||
"""Configuration for the Short-Time Fourier Transform (STFT).
|
||||
|
||||
Attributes
|
||||
----------
|
||||
window_duration : float, default=0.002
|
||||
Duration of the STFT window in seconds (e.g., 0.002 for 2ms). Must be
|
||||
> 0. Determines frequency resolution (longer window = finer frequency
|
||||
resolution).
|
||||
window_overlap : float, default=0.75
|
||||
Fraction of overlap between consecutive STFT windows (e.g., 0.75
|
||||
for 75%). Must be >= 0 and < 1. Determines time resolution
|
||||
(higher overlap = finer time resolution).
|
||||
window_fn : str, default="hann"
|
||||
Name of the window function to apply before FFT calculation. Common
|
||||
options include "hann", "hamming", "blackman". See
|
||||
`scipy.signal.get_window`.
|
||||
"""
|
||||
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
def build_spectrogram_builder(
|
||||
samplerate: int,
|
||||
conf: STFTConfig,
|
||||
) -> torch.nn.Module:
|
||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
window_fn=get_spectrogram_window(conf.window_fn),
|
||||
center=True,
|
||||
power=1,
|
||||
)
|
||||
|
||||
|
||||
def get_spectrogram_window(name: str) -> Callable[..., torch.Tensor]:
|
||||
@ -87,37 +61,6 @@ def _spec_params_from_config(samplerate: int, conf: STFTConfig):
|
||||
return n_fft, hop_length
|
||||
|
||||
|
||||
def build_spectrogram_builder(
|
||||
samplerate: int,
|
||||
conf: STFTConfig,
|
||||
) -> torch.nn.Module:
|
||||
n_fft, hop_length = _spec_params_from_config(samplerate, conf)
|
||||
return torchaudio.transforms.Spectrogram(
|
||||
n_fft=n_fft,
|
||||
hop_length=hop_length,
|
||||
window_fn=get_spectrogram_window(conf.window_fn),
|
||||
center=True,
|
||||
power=1,
|
||||
)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
"""Configuration for frequency axis parameters.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
max_freq : int, default=120000
|
||||
Maximum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies above this value will be cropped. Must be > 0.
|
||||
min_freq : int, default=10000
|
||||
Minimum frequency in Hz to retain in the spectrogram after STFT.
|
||||
Frequencies below this value will be cropped. Must be >= 0.
|
||||
"""
|
||||
|
||||
max_freq: int = Field(default=120_000, ge=0)
|
||||
min_freq: int = Field(default=10_000, ge=0)
|
||||
|
||||
|
||||
def _frequency_to_index(
|
||||
freq: float,
|
||||
samplerate: int,
|
||||
@ -164,16 +107,6 @@ class FrequencyClip(torch.nn.Module):
|
||||
)
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
"""Configuration for Per-Channel Energy Normalization (PCEN)."""
|
||||
|
||||
name: Literal["pcen"] = "pcen"
|
||||
time_constant: float = 0.4
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class PCEN(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@ -231,11 +164,6 @@ def _compute_smoothing_constant(
|
||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
|
||||
|
||||
class ScaleAmplitudeConfig(BaseConfig):
|
||||
name: Literal["scale_amplitude"] = "scale_amplitude"
|
||||
scale: Literal["power", "db"] = "db"
|
||||
|
||||
|
||||
class ToPower(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
return spec**2
|
||||
@ -253,22 +181,12 @@ def _build_amplitude_scaler(conf: ScaleAmplitudeConfig) -> torch.nn.Module:
|
||||
)
|
||||
|
||||
|
||||
class SpectralMeanSubstractionConfig(BaseConfig):
|
||||
name: Literal["spectral_mean_substraction"] = "spectral_mean_substraction"
|
||||
|
||||
|
||||
class SpectralMeanSubstraction(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
mean = spec.mean(-1, keepdim=True)
|
||||
return (spec - mean).clamp(min=0)
|
||||
|
||||
|
||||
class ResizeConfig(BaseConfig):
|
||||
name: Literal["resize_spec"] = "resize_spec"
|
||||
height: int = 128
|
||||
resize_factor: float = 0.5
|
||||
|
||||
|
||||
class ResizeSpec(torch.nn.Module):
|
||||
def __init__(self, height: int, time_factor: float):
|
||||
super().__init__()
|
||||
@ -295,33 +213,6 @@ class ResizeSpec(torch.nn.Module):
|
||||
return resized
|
||||
|
||||
|
||||
class PeakNormalizeConfig(BaseConfig):
|
||||
name: Literal["peak_normalize"] = "peak_normalize"
|
||||
|
||||
|
||||
SpectrogramTransform = Annotated[
|
||||
Union[
|
||||
PcenConfig,
|
||||
ScaleAmplitudeConfig,
|
||||
SpectralMeanSubstractionConfig,
|
||||
PeakNormalizeConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
stft: STFTConfig = Field(default_factory=STFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
size: ResizeConfig = Field(default_factory=ResizeConfig)
|
||||
transforms: Sequence[SpectrogramTransform] = Field(
|
||||
default_factory=lambda: [
|
||||
PcenConfig(),
|
||||
SpectralMeanSubstractionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _build_spectrogram_transform_step(
|
||||
step: SpectrogramTransform,
|
||||
samplerate: int,
|
||||
|
||||
@ -7,7 +7,7 @@ from loguru import logger
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.conditions import build_sound_event_condition
|
||||
from batdetect2.targets.classes import (
|
||||
DEFAULT_CLASSES,
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Dict, List, Optional
|
||||
from pydantic import Field, PrivateAttr, computed_field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data.conditions import (
|
||||
AllOfConfig,
|
||||
HasAllTagsConfig,
|
||||
|
||||
@ -26,12 +26,17 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.arrays import spec_to_xarray
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.typing.targets import Position, ROITargetMapper, Size
|
||||
from batdetect2.utils.arrays import spec_to_xarray
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Position,
|
||||
PreprocessorProtocol,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Anchor",
|
||||
|
||||
@ -11,11 +11,10 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import scale_geometry, shift_geometry
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.train.clips import get_subclip_annotation
|
||||
from batdetect2.typing import Augmentation
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
from batdetect2.typing import AudioLoader, Augmentation
|
||||
|
||||
__all__ = [
|
||||
"AugmentationConfig",
|
||||
|
||||
@ -10,10 +10,12 @@ from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import get_image_plotter
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
from batdetect2.typing.models import ModelOutput
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.train import TrainExample
|
||||
from batdetect2.typing import (
|
||||
ClipEvaluation,
|
||||
ModelOutput,
|
||||
RawPrediction,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
|
||||
@ -6,8 +6,7 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
|
||||
@ -3,7 +3,7 @@ from typing import Optional, Union
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
@ -80,7 +80,6 @@ class OptimizerConfig(BaseConfig):
|
||||
class TrainingConfig(BaseConfig):
|
||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
||||
|
||||
@ -5,8 +5,8 @@ from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.plotting.clips import build_audio_loader
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_audio_loader, build_preprocessor
|
||||
from batdetect2.train.augmentations import (
|
||||
RandomAudioSource,
|
||||
build_augmentations,
|
||||
@ -14,10 +14,14 @@ from batdetect2.train.augmentations import (
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.typing.train import Augmentation, ClipLabeller
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
Augmentation,
|
||||
ClipLabeller,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TrainingDataset",
|
||||
|
||||
@ -13,14 +13,10 @@ import torch
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||
from batdetect2.typing import (
|
||||
ClipLabeller,
|
||||
Heatmaps,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing import ClipLabeller, Heatmaps, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"LabelConfig",
|
||||
|
||||
@ -18,7 +18,7 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
|
||||
|
||||
@ -27,7 +27,7 @@ from loguru import logger
|
||||
from pydantic import Field
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.typing import Losses, LossProtocol, ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -1,4 +1,8 @@
|
||||
from batdetect2.typing.evaluate import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
MatchEvaluation,
|
||||
MetricsProtocol,
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
@ -10,9 +14,11 @@ from batdetect2.typing.preprocess import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
SpectrogramBuilder,
|
||||
SpectrogramPipeline,
|
||||
)
|
||||
from batdetect2.typing.targets import (
|
||||
Position,
|
||||
ROITargetMapper,
|
||||
Size,
|
||||
SoundEventDecoder,
|
||||
SoundEventEncoder,
|
||||
@ -34,6 +40,7 @@ __all__ = [
|
||||
"Augmentation",
|
||||
"BackboneModel",
|
||||
"BatDetect2Prediction",
|
||||
"ClipEvaluation",
|
||||
"ClipLabeller",
|
||||
"ClipperProtocol",
|
||||
"DetectionModel",
|
||||
@ -47,12 +54,14 @@ __all__ = [
|
||||
"Position",
|
||||
"PostprocessorProtocol",
|
||||
"PreprocessorProtocol",
|
||||
"ROITargetMapper",
|
||||
"RawPrediction",
|
||||
"Size",
|
||||
"SoundEventDecoder",
|
||||
"SoundEventEncoder",
|
||||
"SoundEventFilter",
|
||||
"SpectrogramBuilder",
|
||||
"SpectrogramPipeline",
|
||||
"TargetProtocol",
|
||||
"TrainExample",
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user