From 1f0fb14d895ba3871fceb0a0d8af0fa646cff97b Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 16 Nov 2024 21:26:18 +0000 Subject: [PATCH] Minor restructuring --- batdetect2/cli/compat.py | 3 +- batdetect2/configs.py | 5 + batdetect2/data/compat.py | 2 +- batdetect2/data/preprocessing.py | 436 --------------------------- batdetect2/models/detectors.py | 4 +- batdetect2/plot.py | 9 +- batdetect2/preprocess/__init__.py | 64 ++++ batdetect2/preprocess/audio.py | 162 ++++++++++ batdetect2/preprocess/spectrogram.py | 242 +++++++++++++++ batdetect2/terms.py | 88 ++++++ batdetect2/train/augmentations.py | 9 +- batdetect2/{data => train}/labels.py | 7 +- batdetect2/train/preprocess.py | 134 ++++---- batdetect2/train/targets.py | 99 ++++++ 14 files changed, 737 insertions(+), 527 deletions(-) create mode 100644 batdetect2/configs.py create mode 100644 batdetect2/preprocess/__init__.py create mode 100644 batdetect2/preprocess/audio.py create mode 100644 batdetect2/preprocess/spectrogram.py create mode 100644 batdetect2/terms.py rename batdetect2/{data => train}/labels.py (96%) create mode 100644 batdetect2/train/targets.py diff --git a/batdetect2/cli/compat.py b/batdetect2/cli/compat.py index 1c503b7..b02c283 100644 --- a/batdetect2/cli/compat.py +++ b/batdetect2/cli/compat.py @@ -1,12 +1,11 @@ import click from batdetect2 import api +from batdetect2.cli.base import cli from batdetect2.detector.parameters import DEFAULT_MODEL_PATH from batdetect2.types import ProcessingConfiguration from batdetect2.utils.detector_utils import save_results_to_file -from batdetect2.cli.base import cli - @cli.command() @click.argument( diff --git a/batdetect2/configs.py b/batdetect2/configs.py new file mode 100644 index 0000000..ab94d82 --- /dev/null +++ b/batdetect2/configs.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel, ConfigDict + + +class BaseConfig(BaseModel): + model_config = ConfigDict(extra="forbid") diff --git a/batdetect2/data/compat.py b/batdetect2/data/compat.py index 2aaf3d3..9415686 100644 --- a/batdetect2/data/compat.py +++ b/batdetect2/data/compat.py @@ -9,9 +9,9 @@ import numpy as np from pydantic import BaseModel, Field from soundevent import data from soundevent.geometry import compute_bounds +from soundevent.types import ClassMapper from batdetect2 import types -from batdetect2.data.labels import ClassMapper PathLike = Union[Path, str, os.PathLike] diff --git a/batdetect2/data/preprocessing.py b/batdetect2/data/preprocessing.py index b46cfbb..e69de29 100644 --- a/batdetect2/data/preprocessing.py +++ b/batdetect2/data/preprocessing.py @@ -1,436 +0,0 @@ -"""Module containing functions for preprocessing audio clips.""" - -from pathlib import Path -from typing import Literal, Optional, Union - -import librosa -import librosa.core.spectrum -import numpy as np -import xarray as xr -from numpy.typing import DTypeLike -from pydantic import BaseModel, Field -from scipy.signal import resample_poly -from soundevent import arrays, audio, data -from soundevent.arrays import operations as ops - -__all__ = [ - "PreprocessingConfig", - "preprocess_audio_clip", -] - - -TARGET_SAMPLERATE_HZ = 256000 -SCALE_RAW_AUDIO = False -FFT_WIN_LENGTH_S = 512 / 256000.0 -FFT_OVERLAP = 0.75 -MAX_FREQ_HZ = 120000 -MIN_FREQ_HZ = 10000 -DEFAULT_DURATION = 1 -SPEC_HEIGHT = 128 -SPEC_WIDTH = 256 -SPEC_SCALE = "pcen" -SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH -DENOISE_SPEC_AVG = True -MAX_SCALE_SPEC = False - - -class ResampleConfig(BaseModel): - samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) - mode: str = "poly" - - -class AudioConfig(BaseModel): - resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) - scale: bool = Field(default=SCALE_RAW_AUDIO) - center: bool = True - duration: Optional[float] = DEFAULT_DURATION - - -class FFTConfig(BaseModel): - window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0) - window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1) - window_fn: str = "hann" - - -class FrequencyConfig(BaseModel): - max_freq: int = Field(default=MAX_FREQ_HZ, gt=0) - min_freq: int = Field(default=MIN_FREQ_HZ, gt=0) - - -class PcenConfig(BaseModel): - time_constant: float = 0.4 - hop_length: int = 512 - gain: float = 0.98 - bias: float = 2 - power: float = 0.5 - - -class SpecSizeConfig(BaseModel): - height: int = SPEC_HEIGHT - time_period: float = SPEC_TIME_PERIOD - - -class SpectrogramConfig(BaseModel): - fft: FFTConfig = Field(default_factory=FFTConfig) - frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) - scale: Union[Literal["log"], None, PcenConfig] = "log" - denoise: bool = True - resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) - max_scale: bool = MAX_SCALE_SPEC - - -class PreprocessingConfig(BaseModel): - """Configuration for preprocessing data.""" - - audio: AudioConfig = Field(default_factory=AudioConfig) - spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) - - @classmethod - def from_file( - cls, - path: Union[str, Path], - ) -> "PreprocessingConfig": - """Load configuration from a file. - - Parameters - ---------- - path - Path to the configuration file. - - Returns - ------- - PreprocessingConfig - The configuration loaded from the file. - - Raises - ------ - FileNotFoundError - If the configuration file does not exist. - pydantic.ValidationError - If the configuration file is invalid. - """ - path = Path(path) - - if not path.is_file(): - raise FileNotFoundError(f"Config file not found: {path}") - - return cls.model_validate_json(path.read_text()) - - def to_file(self, path: Union[str, Path]) -> None: - """Save configuration to a file.""" - path = Path(path) - - if not path.parent.exists(): - path.parent.mkdir(parents=True) - - path.write_text(self.model_dump_json()) - - -def preprocess_audio_clip( - clip: data.Clip, - config: Optional[PreprocessingConfig] = None, -) -> xr.DataArray: - """Preprocesses audio clip to generate spectrogram. - - Parameters - ---------- - clip - The audio clip to preprocess. - config - Configuration for preprocessing. - - Returns - ------- - xr.DataArray - Preprocessed spectrogram. - - """ - config = config or PreprocessingConfig() - wav = load_clip_audio(clip, config=config.audio) - spec = compute_spectrogram(wav, config=config.spectrogram) - return spec - - -def load_clip_audio( - clip: data.Clip, - config: Optional[AudioConfig] = None, - dtype: DTypeLike = np.float32, -) -> xr.DataArray: - config = config or AudioConfig() - - wav = audio.load_clip(clip).sel(channel=0).astype(dtype) - - if config.duration is not None: - wav = adjust_audio_duration(wav, duration=config.duration) - - if config.resample: - wav = resample_audio( - wav, - samplerate=config.resample.samplerate, - dtype=dtype, - ) - - if config.center: - wav = ops.center(wav) - - if config.scale: - wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav)))) - - return wav.astype(dtype) - - -def compute_spectrogram( - wav: xr.DataArray, - config: Optional[SpectrogramConfig] = None, - dtype: DTypeLike = np.float32, -) -> xr.DataArray: - config = config or SpectrogramConfig() - - spec = stft( - wav, - window_duration=config.fft.window_duration, - window_overlap=config.fft.window_overlap, - window_fn=config.fft.window_fn, - dtype=dtype, - ) - - spec = crop_spectrogram_frequencies( - spec, - min_freq=config.frequencies.min_freq, - max_freq=config.frequencies.max_freq, - ) - - spec = scale_spectrogram(spec, scale=config.scale) - - if config.denoise: - spec = denoise_spectrogram(spec) - - if config.resize: - spec = resize_spectrogram(spec, config=config.resize) - - if config.max_scale: - spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) - - return spec.astype(dtype) - - -def crop_spectrogram_frequencies( - spec: xr.DataArray, - min_freq: int = MIN_FREQ_HZ, - max_freq: int = MAX_FREQ_HZ, -) -> xr.DataArray: - return arrays.crop_dim( - spec, - dim="frequency", - start=min_freq, - stop=max_freq, - ).astype(spec.dtype) - - -def adjust_audio_duration( - wave: xr.DataArray, - duration: float, -) -> xr.DataArray: - start_time, end_time = arrays.get_dim_range(wave, dim="time") - current_duration = end_time - start_time - - if current_duration == duration: - return wave - - if current_duration > duration: - return arrays.crop_dim( - wave, - dim="time", - start=start_time, - stop=start_time + duration, - ) - - return arrays.extend_dim( - wave, - dim="time", - start=start_time, - stop=start_time + duration, - ) - - -def resample_audio( - wav: xr.DataArray, - samplerate: int = TARGET_SAMPLERATE_HZ, - dtype: DTypeLike = np.float32, -) -> xr.DataArray: - if "time" not in wav.dims: - raise ValueError("Audio must have a time dimension") - - time_axis: int = wav.get_axis_num("time") # type: ignore - - start, stop = arrays.get_dim_range(wav, dim="time") - step = arrays.get_dim_step(wav, dim="time") - original_samplerate = int(1 / step) - - if original_samplerate == samplerate: - return wav.astype(dtype) - - gcd = np.gcd(original_samplerate, samplerate) - resampled = resample_poly( - wav.values, - samplerate // gcd, - original_samplerate // gcd, - axis=time_axis, - ) - - resampled_times = np.linspace( - start, - stop + step, - len(resampled), - endpoint=False, - dtype=dtype, - ) - return xr.DataArray( - data=resampled.astype(dtype), - dims=wav.dims, - coords={ - **wav.coords, - "time": arrays.create_time_dim_from_array( - resampled_times, - samplerate=samplerate, - ), - }, - attrs=wav.attrs, - ) - - -def stft( - wave: xr.DataArray, - window_duration: float, - window_overlap: float, - window_fn: str = "hann", - dtype: DTypeLike = np.float32, -) -> xr.DataArray: - start_time, end_time = arrays.get_dim_range(wave, dim="time") - step = arrays.get_dim_step(wave, dim="time") - sampling_rate = 1 / step - - hop_len = window_duration * (1 - window_overlap) - nfft = int(window_duration * sampling_rate) - noverlap = int(window_overlap * nfft) - - spec, _ = librosa.core.spectrum._spectrogram( - y=wave.data.astype(dtype), - power=1, - n_fft=nfft, - hop_length=nfft - noverlap, - center=False, - window=window_fn, - ) - - return xr.DataArray( - data=spec.astype(dtype), - dims=["frequency", "time"], - coords={ - "frequency": arrays.create_frequency_dim_from_array( - np.linspace( - 0, - sampling_rate / 2, - spec.shape[0], - endpoint=False, - dtype=dtype, - ), - step=sampling_rate / nfft, - ), - "time": arrays.create_time_dim_from_array( - np.linspace( - start_time, - end_time - (window_duration - hop_len), - spec.shape[1], - endpoint=False, - dtype=dtype, - ), - step=hop_len, - ), - }, - attrs={ - **wave.attrs, - "original_samplerate": sampling_rate, - "nfft": nfft, - "noverlap": noverlap, - }, - ) - - -def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray: - return xr.DataArray( - data=(spec - spec.mean("time")).clip(0), - dims=spec.dims, - coords=spec.coords, - attrs=spec.attrs, - ) - - -def scale_spectrogram( - spec: xr.DataArray, - scale: Union[Literal["log"], None, PcenConfig], - dtype: DTypeLike = np.float32, -) -> xr.DataArray: - if scale == "log": - return scale_log(spec, dtype=dtype) - - if isinstance(scale, PcenConfig): - return scale_pcen( - spec, - time_constant=scale.time_constant, - hop_length=scale.hop_length, - gain=scale.gain, - power=scale.power, - bias=scale.bias, - ) - - return spec - - -def scale_pcen( - spec: xr.DataArray, - time_constant: float = 0.4, - hop_length: int = 512, - gain: float = 0.98, - bias: float = 2, - power: float = 0.5, -) -> xr.DataArray: - samplerate = spec.attrs["original_samplerate"] - # NOTE: Not sure why the 10 is there - t_frames = time_constant * samplerate / (float(hop_length) * 10) - smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) - return audio.pcen( - spec * (2**31), - smooth=smoothing_constant, - gain=gain, - bias=bias, - power=power, - ).astype(spec.dtype) - - -def scale_log( - spec: xr.DataArray, - dtype: DTypeLike = np.float32, -) -> xr.DataArray: - samplerate = spec.attrs["original_samplerate"] - nfft = spec.attrs["nfft"] - log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum()) - return xr.DataArray( - data=np.log1p(log_scaling * spec).astype(dtype), - dims=spec.dims, - coords=spec.coords, - attrs=spec.attrs, - ) - - -def resize_spectrogram( - spec: xr.DataArray, - config: SpecSizeConfig, -) -> xr.DataArray: - duration = arrays.get_dim_width(spec, dim="time") - return ops.resize( - spec, - time=int(np.ceil(duration / config.time_period)), - frequency=config.height, - dtype=np.float32, - ) diff --git a/batdetect2/models/detectors.py b/batdetect2/models/detectors.py index 716fc9a..0d5595d 100644 --- a/batdetect2/models/detectors.py +++ b/batdetect2/models/detectors.py @@ -9,7 +9,7 @@ from torch import nn, optim from batdetect2.data.labels import ClassMapper from batdetect2.data.preprocessing import ( PreprocessingConfig, - preprocess_audio_clip, + preprocess, ) from batdetect2.models.feature_extractors import Net2DFast from batdetect2.models.post_process import ( @@ -79,7 +79,7 @@ class DetectorModel(L.LightningModule): ) def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray: - return preprocess_audio_clip( + return preprocess( clip, config=self.preprocessing_config, ) diff --git a/batdetect2/plot.py b/batdetect2/plot.py index 0da9a36..b9e5d4e 100644 --- a/batdetect2/plot.py +++ b/batdetect2/plot.py @@ -2,10 +2,10 @@ from typing import List, Optional, Tuple, Union, cast +import matplotlib.ticker as tick import numpy as np import torch from matplotlib import axes, patches -import matplotlib.ticker as tick from matplotlib import pyplot as plt from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS @@ -102,7 +102,6 @@ def spectrogram( return ax - def spectrogram_with_detections( spec: Union[torch.Tensor, np.ndarray], dets: List[Annotation], @@ -231,11 +230,11 @@ def detection( figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults to None. If `ax` is None, this will be used to create a new figure of the given size. - linewidth (float, optional): Line width of the detection. + linewidth (float, optional): Line width of the detection. Defaults to 1. - edgecolor (str, optional): Edge color of the detection. + edgecolor (str, optional): Edge color of the detection. Defaults to "w", i.e. white. - facecolor (str, optional): Face color of the detection. + facecolor (str, optional): Face color of the detection. Defaults to "none", i.e. transparent. with_name (bool, optional): Whether to plot the name of the predicted class next to the detection. Defaults to True. diff --git a/batdetect2/preprocess/__init__.py b/batdetect2/preprocess/__init__.py new file mode 100644 index 0000000..4c9560e --- /dev/null +++ b/batdetect2/preprocess/__init__.py @@ -0,0 +1,64 @@ +"""Module containing functions for preprocessing audio clips.""" + +from typing import Optional + +import xarray as xr +from pydantic import BaseModel, Field +from soundevent import data + +from batdetect2.preprocess.audio import ( + AudioConfig, + ResampleConfig, + load_clip_audio, +) +from batdetect2.preprocess.spectrogram import ( + FFTConfig, + FrequencyConfig, + PcenConfig, + SpecSizeConfig, + SpectrogramConfig, + compute_spectrogram, +) + +__all__ = [ + "AudioConfig", + "ResampleConfig", + "SpectrogramConfig", + "FFTConfig", + "FrequencyConfig", + "PcenConfig", + "SpecSizeConfig", + "PreprocessingConfig", + "preprocess_audio_clip", +] + + +class PreprocessingConfig(BaseModel): + """Configuration for preprocessing data.""" + + audio: AudioConfig = Field(default_factory=AudioConfig) + spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) + + +def preprocess_audio_clip( + clip: data.Clip, + config: Optional[PreprocessingConfig] = None, +) -> xr.DataArray: + """Preprocesses audio clip to generate spectrogram. + + Parameters + ---------- + clip + The audio clip to preprocess. + config + Configuration for preprocessing. + + Returns + ------- + xr.DataArray + Preprocessed spectrogram. + + """ + config = config or PreprocessingConfig() + wav = load_clip_audio(clip, config=config.audio) + return compute_spectrogram(wav, config=config.spectrogram) diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py new file mode 100644 index 0000000..9c538d2 --- /dev/null +++ b/batdetect2/preprocess/audio.py @@ -0,0 +1,162 @@ +from typing import Optional + +import numpy as np +import xarray as xr +from numpy.typing import DTypeLike +from pydantic import Field +from scipy.signal import resample, resample_poly +from soundevent import arrays, audio, data +from soundevent.arrays import operations as ops + +from batdetect2.configs import BaseConfig + +TARGET_SAMPLERATE_HZ = 256_000 +SCALE_RAW_AUDIO = False +DEFAULT_DURATION = 1 + + +class ResampleConfig(BaseConfig): + samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) + mode: str = "poly" + + +class AudioConfig(BaseConfig): + resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) + scale: bool = SCALE_RAW_AUDIO + center: bool = True + duration: Optional[float] = DEFAULT_DURATION + + +def load_clip_audio( + clip: data.Clip, + config: Optional[AudioConfig] = None, + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + config = config or AudioConfig() + + wav = audio.load_clip(clip).sel(channel=0).astype(dtype) + + if config.duration is not None: + wav = adjust_audio_duration(wav, duration=config.duration) + + if config.resample: + wav = resample_audio( + wav, + samplerate=config.resample.samplerate, + dtype=dtype, + ) + + if config.center: + wav = ops.center(wav) + + if config.scale: + wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav)))) + + return wav.astype(dtype) + + +def adjust_audio_duration( + wave: xr.DataArray, + duration: float, +) -> xr.DataArray: + start_time, end_time = arrays.get_dim_range(wave, dim="time") + current_duration = end_time - start_time + + if current_duration == duration: + return wave + + if current_duration > duration: + return arrays.crop_dim( + wave, + dim="time", + start=start_time, + stop=start_time + duration, + ) + + return arrays.extend_dim( + wave, + dim="time", + start=start_time, + stop=start_time + duration, + ) + + +def resample_audio( + wav: xr.DataArray, + samplerate: int = TARGET_SAMPLERATE_HZ, + mode: str = "poly", + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + if "time" not in wav.dims: + raise ValueError("Audio must have a time dimension") + + time_axis: int = wav.get_axis_num("time") # type: ignore + step = arrays.get_dim_step(wav, dim="time") + original_samplerate = int(1 / step) + + if original_samplerate == samplerate: + return wav.astype(dtype) + + if mode == "poly": + resampled = resample_audio_poly( + wav, + sr_orig=original_samplerate, + sr_new=samplerate, + axis=time_axis, + ) + elif mode == "fourier": + resampled = resample_audio_fourier( + wav, + sr_orig=original_samplerate, + sr_new=samplerate, + axis=time_axis, + ) + else: + raise NotImplementedError(f"Resampling mode '{mode}' not implemented") + + start, stop = arrays.get_dim_range(wav, dim="time") + times = np.linspace( + start, + stop + step, + len(resampled), + endpoint=False, + dtype=dtype, + ) + + return xr.DataArray( + data=resampled.astype(dtype), + dims=wav.dims, + coords={ + **wav.coords, + "time": arrays.create_time_dim_from_array( + times, + samplerate=samplerate, + ), + }, + attrs=wav.attrs, + ) + + +def resample_audio_poly( + array: xr.DataArray, + sr_orig: int, + sr_new: int, + axis: int = -1, +) -> np.ndarray: + gcd = np.gcd(sr_orig, sr_new) + return resample_poly( + array.values, + sr_new // gcd, + sr_orig // gcd, + axis=axis, + ) + + +def resample_audio_fourier( + array: xr.DataArray, + sr_orig: int, + sr_new: int, + axis: int = -1, +) -> np.ndarray: + ratio = sr_new / sr_orig + return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py new file mode 100644 index 0000000..c0a8e45 --- /dev/null +++ b/batdetect2/preprocess/spectrogram.py @@ -0,0 +1,242 @@ +from typing import Literal, Optional, Union + +import librosa +import librosa.core.spectrum +import numpy as np +import xarray as xr +from numpy.typing import DTypeLike +from pydantic import Field +from soundevent import arrays, audio +from soundevent.arrays import operations as ops + +from batdetect2.configs import BaseConfig +from batdetect2.preprocess.audio import DEFAULT_DURATION + +FFT_WIN_LENGTH_S = 512 / 256000.0 +FFT_OVERLAP = 0.75 +MAX_FREQ_HZ = 120000 +MIN_FREQ_HZ = 10000 +SPEC_HEIGHT = 128 +SPEC_WIDTH = 256 +SPEC_SCALE = "pcen" +SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH +DENOISE_SPEC_AVG = True +MAX_SCALE_SPEC = False + + +class FFTConfig(BaseConfig): + window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0) + window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1) + window_fn: str = "hann" + + +class FrequencyConfig(BaseConfig): + max_freq: int = Field(default=MAX_FREQ_HZ, gt=0) + min_freq: int = Field(default=MIN_FREQ_HZ, gt=0) + + +class PcenConfig(BaseConfig): + time_constant: float = 0.4 + hop_length: int = 512 + gain: float = 0.98 + bias: float = 2 + power: float = 0.5 + + +class SpecSizeConfig(BaseConfig): + height: int = SPEC_HEIGHT + time_period: float = SPEC_TIME_PERIOD + + +class SpectrogramConfig(BaseConfig): + fft: FFTConfig = Field(default_factory=FFTConfig) + frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) + scale: Union[Literal["log"], None, PcenConfig] = "log" + denoise: bool = True + resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) + max_scale: bool = MAX_SCALE_SPEC + + +def compute_spectrogram( + wav: xr.DataArray, + config: Optional[SpectrogramConfig] = None, + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + config = config or SpectrogramConfig() + + spec = stft( + wav, + window_duration=config.fft.window_duration, + window_overlap=config.fft.window_overlap, + window_fn=config.fft.window_fn, + dtype=dtype, + ) + + spec = crop_spectrogram_frequencies( + spec, + min_freq=config.frequencies.min_freq, + max_freq=config.frequencies.max_freq, + ) + + spec = scale_spectrogram(spec, scale=config.scale) + + if config.denoise: + spec = denoise_spectrogram(spec) + + if config.resize: + spec = resize_spectrogram(spec, config=config.resize) + + if config.max_scale: + spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) + + return spec.astype(dtype) + + +def crop_spectrogram_frequencies( + spec: xr.DataArray, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, +) -> xr.DataArray: + return arrays.crop_dim( + spec, + dim="frequency", + start=min_freq, + stop=max_freq, + ).astype(spec.dtype) + + +def stft( + wave: xr.DataArray, + window_duration: float, + window_overlap: float, + window_fn: str = "hann", + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + start_time, end_time = arrays.get_dim_range(wave, dim="time") + step = arrays.get_dim_step(wave, dim="time") + sampling_rate = 1 / step + + hop_len = window_duration * (1 - window_overlap) + nfft = int(window_duration * sampling_rate) + noverlap = int(window_overlap * nfft) + + spec, _ = librosa.core.spectrum._spectrogram( + y=wave.data.astype(dtype), + power=1, + n_fft=nfft, + hop_length=nfft - noverlap, + center=False, + window=window_fn, + ) + + return xr.DataArray( + data=spec.astype(dtype), + dims=["frequency", "time"], + coords={ + "frequency": arrays.create_frequency_dim_from_array( + np.linspace( + 0, + sampling_rate / 2, + spec.shape[0], + endpoint=False, + dtype=dtype, + ), + step=sampling_rate / nfft, + ), + "time": arrays.create_time_dim_from_array( + np.linspace( + start_time, + end_time - (window_duration - hop_len), + spec.shape[1], + endpoint=False, + dtype=dtype, + ), + step=hop_len, + ), + }, + attrs={ + **wave.attrs, + "original_samplerate": sampling_rate, + "nfft": nfft, + "noverlap": noverlap, + }, + ) + + +def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray: + return xr.DataArray( + data=(spec - spec.mean("time")).clip(0), + dims=spec.dims, + coords=spec.coords, + attrs=spec.attrs, + ) + + +def scale_spectrogram( + spec: xr.DataArray, + scale: Union[Literal["log"], None, PcenConfig], + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + if scale == "log": + return scale_log(spec, dtype=dtype) + + if isinstance(scale, PcenConfig): + return scale_pcen( + spec, + time_constant=scale.time_constant, + hop_length=scale.hop_length, + gain=scale.gain, + power=scale.power, + bias=scale.bias, + ) + + return spec + + +def scale_pcen( + spec: xr.DataArray, + time_constant: float = 0.4, + hop_length: int = 512, + gain: float = 0.98, + bias: float = 2, + power: float = 0.5, +) -> xr.DataArray: + samplerate = spec.attrs["original_samplerate"] + # NOTE: Not sure why the 10 is there + t_frames = time_constant * samplerate / (float(hop_length) * 10) + smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) + return audio.pcen( + spec * (2**31), + smooth=smoothing_constant, + gain=gain, + bias=bias, + power=power, + ).astype(spec.dtype) + + +def scale_log( + spec: xr.DataArray, + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + samplerate = spec.attrs["original_samplerate"] + nfft = spec.attrs["nfft"] + log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum()) + return xr.DataArray( + data=np.log1p(log_scaling * spec).astype(dtype), + dims=spec.dims, + coords=spec.coords, + attrs=spec.attrs, + ) + + +def resize_spectrogram( + spec: xr.DataArray, + config: SpecSizeConfig, +) -> xr.DataArray: + duration = arrays.get_dim_width(spec, dim="time") + return ops.resize( + spec, + time=int(np.ceil(duration / config.time_period)), + frequency=config.height, + dtype=np.float32, + ) diff --git a/batdetect2/terms.py b/batdetect2/terms.py new file mode 100644 index 0000000..8d162fd --- /dev/null +++ b/batdetect2/terms.py @@ -0,0 +1,88 @@ +from inspect import getmembers +from typing import Optional + +from pydantic import BaseModel +from soundevent import data, terms + +__all__ = [ + "call_type", + "individual", + "get_term_from_info", + "get_tag_from_info", + "TermInfo", + "TagInfo", +] + + +class TermInfo(BaseModel): + label: Optional[str] + name: Optional[str] + uri: Optional[str] + + +class TagInfo(BaseModel): + value: str + label: Optional[str] = None + term: Optional[TermInfo] = None + key: Optional[str] = None + + +call_type = data.Term( + name="soundevent:call_type", + label="Call Type", + definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).", +) + +individual = data.Term( + name="soundevent:individual", + label="Individual", + definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.", +) + + +ALL_TERMS = [ + *getmembers(terms, lambda x: isinstance(x, data.Term)), + call_type, + individual, +] + + +def get_term_from_info(term_info: TermInfo) -> data.Term: + for term in ALL_TERMS: + if term_info.name and term_info.name == term.name: + return term + + if term_info.label and term_info.label == term.label: + return term + + if term_info.uri and term_info.uri == term.uri: + return term + + if term_info.name is None: + if term_info.label is None: + raise ValueError("At least one of name or label must be provided.") + + term_info.name = ( + f"soundevent:{term_info.label.lower().replace(' ', '_')}" + ) + + if term_info.label is None: + term_info.label = term_info.name + + return data.Term( + name=term_info.name, + label=term_info.label, + uri=term_info.uri, + definition="Unknown", + ) + + +def get_tag_from_info(tag_info: TagInfo) -> data.Tag: + if tag_info.term: + term = get_term_from_info(tag_info.term) + elif tag_info.key: + term = data.term_from_key(tag_info.key) + else: + raise ValueError("Either term or key must be provided in tag info.") + + return data.Tag(term=term, value=tag_info.value) diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index f0b0130..159d1a9 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -1,11 +1,8 @@ from functools import wraps -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional import numpy as np import xarray as xr -from soundevent import data -from soundevent.geometry import compute_bounds - Augmentation = Callable[[xr.Dataset], xr.Dataset] @@ -223,8 +220,8 @@ def mask_frequency( num_masks = np.random.randint(1, max_num_masks + 1) freq_coord = train_example.coords["frequency"] - min_freq = freq_coord.min() - max_freq = freq_coord.max() + min_freq = float(freq_coord.min()) + max_freq = float(freq_coord.max()) for _ in range(num_masks): mask_size = np.random.uniform(0, max_freq_mask) diff --git a/batdetect2/data/labels.py b/batdetect2/train/labels.py similarity index 96% rename from batdetect2/data/labels.py rename to batdetect2/train/labels.py index 8e33e96..db45d3a 100644 --- a/batdetect2/data/labels.py +++ b/batdetect2/train/labels.py @@ -1,4 +1,4 @@ -from typing import Tuple +from typing import Sequence, Tuple import numpy as np import xarray as xr @@ -17,7 +17,7 @@ TARGET_SIGMA = 3.0 def generate_heatmaps( - clip_annotation: data.ClipAnnotation, + sound_events: Sequence[data.SoundEventAnnotation], spec: xr.DataArray, class_mapper: ClassMapper, target_sigma: float = TARGET_SIGMA, @@ -52,9 +52,8 @@ def generate_heatmaps( }, ) - for sound_event_annotation in clip_annotation.sound_events: + for sound_event_annotation in sound_events: geom = sound_event_annotation.sound_event.geometry - if geom is None: continue diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 004c414..9dfe755 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -1,21 +1,30 @@ """Module for preprocessing data for training.""" import os -import warnings from functools import partial from multiprocessing import Pool from pathlib import Path from typing import Callable, Optional, Sequence, Union import xarray as xr +from pydantic import Field from soundevent import data from tqdm.auto import tqdm -from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps -from batdetect2.data.preprocessing import ( +from batdetect2.configs import BaseConfig +from batdetect2.preprocess import ( PreprocessingConfig, preprocess_audio_clip, ) +from batdetect2.train.labels import ( + TARGET_SIGMA, + generate_heatmaps, +) +from batdetect2.train.targets import ( + TargetConfig, + build_class_mapper, + build_sound_event_filter, +) PathLike = Union[Path, str, os.PathLike] FilenameFn = Callable[[data.ClipAnnotation], str] @@ -25,25 +34,44 @@ __all__ = [ ] +class MasksConfig(BaseConfig): + sigma: float = TARGET_SIGMA + + +class TrainPreprocessingConfig(BaseConfig): + preprocessing: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + target: TargetConfig = Field(default_factory=TargetConfig) + masks: MasksConfig = Field(default_factory=MasksConfig) + + def generate_train_example( clip_annotation: data.ClipAnnotation, - class_mapper: ClassMapper, - preprocessing_config: Optional[PreprocessingConfig] = None, - target_sigma: float = TARGET_SIGMA, + config: Optional[TrainPreprocessingConfig] = None, ) -> xr.Dataset: """Generate a training example.""" - preprocessing_config = preprocessing_config or PreprocessingConfig() + config = config or TrainPreprocessingConfig() spectrogram = preprocess_audio_clip( clip_annotation.clip, - config=preprocessing_config, + config=config.preprocessing, ) + filter_fn = build_sound_event_filter( + include=config.target.include, + exclude=config.target.exclude, + ) + selected_events = [ + event for event in clip_annotation.sound_events if filter_fn(event) + ] + + class_mapper = build_class_mapper(config.target.classes) detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( - clip_annotation, + selected_events, spectrogram, class_mapper, - target_sigma=target_sigma, + target_sigma=config.masks.sigma, ) dataset = xr.Dataset( @@ -57,8 +85,7 @@ def generate_train_example( return dataset.assign_attrs( title=f"Training example for {clip_annotation.uuid}", - preprocessing_configuration=preprocessing_config.model_dump_json(), - target_sigma=target_sigma, + config=config.model_dump_json(), clip_annotation=clip_annotation.model_dump_json(), ) @@ -78,77 +105,22 @@ def save_to_file( ) -def load_config(path: PathLike, **kwargs) -> PreprocessingConfig: - """Load configuration from file.""" - - path = Path(path) - - if not path.is_file(): - warnings.warn( - f"Config file not found: {path}. Using default config.", - stacklevel=1, - ) - return PreprocessingConfig(**kwargs) - - try: - return PreprocessingConfig.model_validate_json(path.read_text()) - except ValueError as e: - warnings.warn( - f"Failed to load config file: {e}. Using default config.", - stacklevel=1, - ) - return PreprocessingConfig(**kwargs) - - def _get_filename(clip_annotation: data.ClipAnnotation) -> str: return f"{clip_annotation.uuid}.nc" -def preprocess_single_annotation( - clip_annotation: data.ClipAnnotation, - output_dir: PathLike, - config: PreprocessingConfig, - class_mapper: ClassMapper, - filename_fn: FilenameFn = _get_filename, - replace: bool = False, - target_sigma: float = TARGET_SIGMA, -) -> None: - output_dir = Path(output_dir) - - filename = filename_fn(clip_annotation) - path = output_dir / filename - - if path.is_file() and not replace: - return - - if path.is_file() and replace: - path.unlink() - - sample = generate_train_example( - clip_annotation, - class_mapper, - preprocessing_config=config, - target_sigma=target_sigma, - ) - - save_to_file(sample, path) - - def preprocess_annotations( clip_annotations: Sequence[data.ClipAnnotation], output_dir: PathLike, - class_mapper: ClassMapper, - target_sigma: float = TARGET_SIGMA, filename_fn: FilenameFn = _get_filename, replace: bool = False, - config: Optional[PreprocessingConfig] = None, + config: Optional[TrainPreprocessingConfig] = None, max_workers: Optional[int] = None, ) -> None: """Preprocess annotations and save to disk.""" output_dir = Path(output_dir) - if config is None: - config = PreprocessingConfig() + config = config or TrainPreprocessingConfig() if not output_dir.is_dir(): output_dir.mkdir(parents=True) @@ -161,13 +133,33 @@ def preprocess_annotations( preprocess_single_annotation, output_dir=output_dir, config=config, - class_mapper=class_mapper, filename_fn=filename_fn, replace=replace, - target_sigma=target_sigma, ), clip_annotations, ), total=len(clip_annotations), ) ) + + +def preprocess_single_annotation( + clip_annotation: data.ClipAnnotation, + output_dir: PathLike, + config: TrainPreprocessingConfig, + filename_fn: FilenameFn = _get_filename, + replace: bool = False, +) -> None: + output_dir = Path(output_dir) + + filename = filename_fn(clip_annotation) + path = output_dir / filename + + if path.is_file() and not replace: + return + + if path.is_file() and replace: + path.unlink() + + sample = generate_train_example(clip_annotation, config=config) + save_to_file(sample, path) diff --git a/batdetect2/train/targets.py b/batdetect2/train/targets.py new file mode 100644 index 0000000..e88aa04 --- /dev/null +++ b/batdetect2/train/targets.py @@ -0,0 +1,99 @@ +from functools import partial +from typing import Callable, List, Optional, Set + +from pydantic import Field +from soundevent import data +from soundevent.types import ClassMapper + +from batdetect2.configs import BaseConfig +from batdetect2.terms import TagInfo, get_tag_from_info + + +class TargetConfig(BaseConfig): + """Configuration for target generation.""" + + classes: List[TagInfo] = Field(default_factory=list) + + include: Optional[List[TagInfo]] = None + + exclude: Optional[List[TagInfo]] = None + + +def build_sound_event_filter( + include: Optional[List[TagInfo]] = None, + exclude: Optional[List[TagInfo]] = None, +) -> Callable[[data.SoundEventAnnotation], bool]: + include_tags = ( + {get_tag_from_info(tag) for tag in include} if include else None + ) + exclude_tags = ( + {get_tag_from_info(tag) for tag in exclude} if exclude else None + ) + return partial( + filter_sound_event, + include=include_tags, + exclude=exclude_tags, + ) + + +def build_class_mapper(classes: List[TagInfo]) -> ClassMapper: + target_tags = [get_tag_from_info(tag) for tag in classes] + labels = [tag.label if tag.label else tag.value for tag in classes] + return GenericMapper( + classes=target_tags, + labels=labels, + ) + + +def filter_sound_event( + sound_event_annotation: data.SoundEventAnnotation, + include: Optional[Set[data.Tag]] = None, + exclude: Optional[Set[data.Tag]] = None, +) -> bool: + tags = set(sound_event_annotation.tags) + + if include is not None and not tags & include: + return False + + if exclude is not None and tags & exclude: + return False + + return True + + +class GenericMapper(ClassMapper): + """Generic class mapper configuration.""" + + def __init__( + self, + classes: List[data.Tag], + labels: List[str], + ): + if not len(classes) == len(labels): + raise ValueError("Number of targets and class labels must match.") + + self.targets = set(classes) + self.class_labels = labels + + self._mapping = {tag: label for tag, label in zip(classes, labels)} + self._inverse_mapping = { + label: tag for tag, label in zip(classes, labels) + } + + def encode( + self, + sound_event_annotation: data.SoundEventAnnotation, + ) -> Optional[str]: + tags = set(sound_event_annotation.tags) + + intersection = tags & self.targets + if not intersection: + return None + + tag = intersection.pop() + return self._mapping[tag] + + def decode(self, label: str) -> List[data.Tag]: + if label not in self._inverse_mapping: + return [] + return [self._inverse_mapping[label]]