mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Minor restructuring
This commit is contained in:
parent
ee884da8b0
commit
1f0fb14d89
@ -1,12 +1,11 @@
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||||
from batdetect2.types import ProcessingConfiguration
|
from batdetect2.types import ProcessingConfiguration
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.argument(
|
@click.argument(
|
||||||
|
5
batdetect2/configs.py
Normal file
5
batdetect2/configs.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class BaseConfig(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
@ -9,9 +9,9 @@ import numpy as np
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
from batdetect2 import types
|
from batdetect2 import types
|
||||||
from batdetect2.data.labels import ClassMapper
|
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
|
@ -1,436 +0,0 @@
|
|||||||
"""Module containing functions for preprocessing audio clips."""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal, Optional, Union
|
|
||||||
|
|
||||||
import librosa
|
|
||||||
import librosa.core.spectrum
|
|
||||||
import numpy as np
|
|
||||||
import xarray as xr
|
|
||||||
from numpy.typing import DTypeLike
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from scipy.signal import resample_poly
|
|
||||||
from soundevent import arrays, audio, data
|
|
||||||
from soundevent.arrays import operations as ops
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"PreprocessingConfig",
|
|
||||||
"preprocess_audio_clip",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
TARGET_SAMPLERATE_HZ = 256000
|
|
||||||
SCALE_RAW_AUDIO = False
|
|
||||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
|
||||||
FFT_OVERLAP = 0.75
|
|
||||||
MAX_FREQ_HZ = 120000
|
|
||||||
MIN_FREQ_HZ = 10000
|
|
||||||
DEFAULT_DURATION = 1
|
|
||||||
SPEC_HEIGHT = 128
|
|
||||||
SPEC_WIDTH = 256
|
|
||||||
SPEC_SCALE = "pcen"
|
|
||||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
|
||||||
DENOISE_SPEC_AVG = True
|
|
||||||
MAX_SCALE_SPEC = False
|
|
||||||
|
|
||||||
|
|
||||||
class ResampleConfig(BaseModel):
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
|
||||||
mode: str = "poly"
|
|
||||||
|
|
||||||
|
|
||||||
class AudioConfig(BaseModel):
|
|
||||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
|
||||||
scale: bool = Field(default=SCALE_RAW_AUDIO)
|
|
||||||
center: bool = True
|
|
||||||
duration: Optional[float] = DEFAULT_DURATION
|
|
||||||
|
|
||||||
|
|
||||||
class FFTConfig(BaseModel):
|
|
||||||
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
|
||||||
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
|
||||||
window_fn: str = "hann"
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseModel):
|
|
||||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
|
||||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
|
||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseModel):
|
|
||||||
time_constant: float = 0.4
|
|
||||||
hop_length: int = 512
|
|
||||||
gain: float = 0.98
|
|
||||||
bias: float = 2
|
|
||||||
power: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class SpecSizeConfig(BaseModel):
|
|
||||||
height: int = SPEC_HEIGHT
|
|
||||||
time_period: float = SPEC_TIME_PERIOD
|
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramConfig(BaseModel):
|
|
||||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
|
||||||
scale: Union[Literal["log"], None, PcenConfig] = "log"
|
|
||||||
denoise: bool = True
|
|
||||||
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
|
||||||
max_scale: bool = MAX_SCALE_SPEC
|
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseModel):
|
|
||||||
"""Configuration for preprocessing data."""
|
|
||||||
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_file(
|
|
||||||
cls,
|
|
||||||
path: Union[str, Path],
|
|
||||||
) -> "PreprocessingConfig":
|
|
||||||
"""Load configuration from a file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path
|
|
||||||
Path to the configuration file.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
PreprocessingConfig
|
|
||||||
The configuration loaded from the file.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the configuration file does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the configuration file is invalid.
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.is_file():
|
|
||||||
raise FileNotFoundError(f"Config file not found: {path}")
|
|
||||||
|
|
||||||
return cls.model_validate_json(path.read_text())
|
|
||||||
|
|
||||||
def to_file(self, path: Union[str, Path]) -> None:
|
|
||||||
"""Save configuration to a file."""
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.parent.exists():
|
|
||||||
path.parent.mkdir(parents=True)
|
|
||||||
|
|
||||||
path.write_text(self.model_dump_json())
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_audio_clip(
|
|
||||||
clip: data.Clip,
|
|
||||||
config: Optional[PreprocessingConfig] = None,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
"""Preprocesses audio clip to generate spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
clip
|
|
||||||
The audio clip to preprocess.
|
|
||||||
config
|
|
||||||
Configuration for preprocessing.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.DataArray
|
|
||||||
Preprocessed spectrogram.
|
|
||||||
|
|
||||||
"""
|
|
||||||
config = config or PreprocessingConfig()
|
|
||||||
wav = load_clip_audio(clip, config=config.audio)
|
|
||||||
spec = compute_spectrogram(wav, config=config.spectrogram)
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_audio(
|
|
||||||
clip: data.Clip,
|
|
||||||
config: Optional[AudioConfig] = None,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
config = config or AudioConfig()
|
|
||||||
|
|
||||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
|
||||||
|
|
||||||
if config.duration is not None:
|
|
||||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
|
||||||
|
|
||||||
if config.resample:
|
|
||||||
wav = resample_audio(
|
|
||||||
wav,
|
|
||||||
samplerate=config.resample.samplerate,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.center:
|
|
||||||
wav = ops.center(wav)
|
|
||||||
|
|
||||||
if config.scale:
|
|
||||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
|
||||||
|
|
||||||
return wav.astype(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
|
||||||
wav: xr.DataArray,
|
|
||||||
config: Optional[SpectrogramConfig] = None,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
config = config or SpectrogramConfig()
|
|
||||||
|
|
||||||
spec = stft(
|
|
||||||
wav,
|
|
||||||
window_duration=config.fft.window_duration,
|
|
||||||
window_overlap=config.fft.window_overlap,
|
|
||||||
window_fn=config.fft.window_fn,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = crop_spectrogram_frequencies(
|
|
||||||
spec,
|
|
||||||
min_freq=config.frequencies.min_freq,
|
|
||||||
max_freq=config.frequencies.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = scale_spectrogram(spec, scale=config.scale)
|
|
||||||
|
|
||||||
if config.denoise:
|
|
||||||
spec = denoise_spectrogram(spec)
|
|
||||||
|
|
||||||
if config.resize:
|
|
||||||
spec = resize_spectrogram(spec, config=config.resize)
|
|
||||||
|
|
||||||
if config.max_scale:
|
|
||||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
|
||||||
|
|
||||||
return spec.astype(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def crop_spectrogram_frequencies(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
return arrays.crop_dim(
|
|
||||||
spec,
|
|
||||||
dim="frequency",
|
|
||||||
start=min_freq,
|
|
||||||
stop=max_freq,
|
|
||||||
).astype(spec.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_audio_duration(
|
|
||||||
wave: xr.DataArray,
|
|
||||||
duration: float,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
|
||||||
current_duration = end_time - start_time
|
|
||||||
|
|
||||||
if current_duration == duration:
|
|
||||||
return wave
|
|
||||||
|
|
||||||
if current_duration > duration:
|
|
||||||
return arrays.crop_dim(
|
|
||||||
wave,
|
|
||||||
dim="time",
|
|
||||||
start=start_time,
|
|
||||||
stop=start_time + duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
return arrays.extend_dim(
|
|
||||||
wave,
|
|
||||||
dim="time",
|
|
||||||
start=start_time,
|
|
||||||
stop=start_time + duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
|
||||||
wav: xr.DataArray,
|
|
||||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
if "time" not in wav.dims:
|
|
||||||
raise ValueError("Audio must have a time dimension")
|
|
||||||
|
|
||||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
|
||||||
|
|
||||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
|
||||||
step = arrays.get_dim_step(wav, dim="time")
|
|
||||||
original_samplerate = int(1 / step)
|
|
||||||
|
|
||||||
if original_samplerate == samplerate:
|
|
||||||
return wav.astype(dtype)
|
|
||||||
|
|
||||||
gcd = np.gcd(original_samplerate, samplerate)
|
|
||||||
resampled = resample_poly(
|
|
||||||
wav.values,
|
|
||||||
samplerate // gcd,
|
|
||||||
original_samplerate // gcd,
|
|
||||||
axis=time_axis,
|
|
||||||
)
|
|
||||||
|
|
||||||
resampled_times = np.linspace(
|
|
||||||
start,
|
|
||||||
stop + step,
|
|
||||||
len(resampled),
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
return xr.DataArray(
|
|
||||||
data=resampled.astype(dtype),
|
|
||||||
dims=wav.dims,
|
|
||||||
coords={
|
|
||||||
**wav.coords,
|
|
||||||
"time": arrays.create_time_dim_from_array(
|
|
||||||
resampled_times,
|
|
||||||
samplerate=samplerate,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
attrs=wav.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def stft(
|
|
||||||
wave: xr.DataArray,
|
|
||||||
window_duration: float,
|
|
||||||
window_overlap: float,
|
|
||||||
window_fn: str = "hann",
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
|
||||||
step = arrays.get_dim_step(wave, dim="time")
|
|
||||||
sampling_rate = 1 / step
|
|
||||||
|
|
||||||
hop_len = window_duration * (1 - window_overlap)
|
|
||||||
nfft = int(window_duration * sampling_rate)
|
|
||||||
noverlap = int(window_overlap * nfft)
|
|
||||||
|
|
||||||
spec, _ = librosa.core.spectrum._spectrogram(
|
|
||||||
y=wave.data.astype(dtype),
|
|
||||||
power=1,
|
|
||||||
n_fft=nfft,
|
|
||||||
hop_length=nfft - noverlap,
|
|
||||||
center=False,
|
|
||||||
window=window_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
return xr.DataArray(
|
|
||||||
data=spec.astype(dtype),
|
|
||||||
dims=["frequency", "time"],
|
|
||||||
coords={
|
|
||||||
"frequency": arrays.create_frequency_dim_from_array(
|
|
||||||
np.linspace(
|
|
||||||
0,
|
|
||||||
sampling_rate / 2,
|
|
||||||
spec.shape[0],
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
step=sampling_rate / nfft,
|
|
||||||
),
|
|
||||||
"time": arrays.create_time_dim_from_array(
|
|
||||||
np.linspace(
|
|
||||||
start_time,
|
|
||||||
end_time - (window_duration - hop_len),
|
|
||||||
spec.shape[1],
|
|
||||||
endpoint=False,
|
|
||||||
dtype=dtype,
|
|
||||||
),
|
|
||||||
step=hop_len,
|
|
||||||
),
|
|
||||||
},
|
|
||||||
attrs={
|
|
||||||
**wave.attrs,
|
|
||||||
"original_samplerate": sampling_rate,
|
|
||||||
"nfft": nfft,
|
|
||||||
"noverlap": noverlap,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
|
||||||
return xr.DataArray(
|
|
||||||
data=(spec - spec.mean("time")).clip(0),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs=spec.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def scale_spectrogram(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
scale: Union[Literal["log"], None, PcenConfig],
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
if scale == "log":
|
|
||||||
return scale_log(spec, dtype=dtype)
|
|
||||||
|
|
||||||
if isinstance(scale, PcenConfig):
|
|
||||||
return scale_pcen(
|
|
||||||
spec,
|
|
||||||
time_constant=scale.time_constant,
|
|
||||||
hop_length=scale.hop_length,
|
|
||||||
gain=scale.gain,
|
|
||||||
power=scale.power,
|
|
||||||
bias=scale.bias,
|
|
||||||
)
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def scale_pcen(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
time_constant: float = 0.4,
|
|
||||||
hop_length: int = 512,
|
|
||||||
gain: float = 0.98,
|
|
||||||
bias: float = 2,
|
|
||||||
power: float = 0.5,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
|
||||||
# NOTE: Not sure why the 10 is there
|
|
||||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
|
||||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
|
||||||
return audio.pcen(
|
|
||||||
spec * (2**31),
|
|
||||||
smooth=smoothing_constant,
|
|
||||||
gain=gain,
|
|
||||||
bias=bias,
|
|
||||||
power=power,
|
|
||||||
).astype(spec.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def scale_log(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
|
||||||
nfft = spec.attrs["nfft"]
|
|
||||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
|
||||||
return xr.DataArray(
|
|
||||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
|
||||||
dims=spec.dims,
|
|
||||||
coords=spec.coords,
|
|
||||||
attrs=spec.attrs,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def resize_spectrogram(
|
|
||||||
spec: xr.DataArray,
|
|
||||||
config: SpecSizeConfig,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
duration = arrays.get_dim_width(spec, dim="time")
|
|
||||||
return ops.resize(
|
|
||||||
spec,
|
|
||||||
time=int(np.ceil(duration / config.time_period)),
|
|
||||||
frequency=config.height,
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
@ -9,7 +9,7 @@ from torch import nn, optim
|
|||||||
from batdetect2.data.labels import ClassMapper
|
from batdetect2.data.labels import ClassMapper
|
||||||
from batdetect2.data.preprocessing import (
|
from batdetect2.data.preprocessing import (
|
||||||
PreprocessingConfig,
|
PreprocessingConfig,
|
||||||
preprocess_audio_clip,
|
preprocess,
|
||||||
)
|
)
|
||||||
from batdetect2.models.feature_extractors import Net2DFast
|
from batdetect2.models.feature_extractors import Net2DFast
|
||||||
from batdetect2.models.post_process import (
|
from batdetect2.models.post_process import (
|
||||||
@ -79,7 +79,7 @@ class DetectorModel(L.LightningModule):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
||||||
return preprocess_audio_clip(
|
return preprocess(
|
||||||
clip,
|
clip,
|
||||||
config=self.preprocessing_config,
|
config=self.preprocessing_config,
|
||||||
)
|
)
|
||||||
|
@ -2,10 +2,10 @@
|
|||||||
|
|
||||||
from typing import List, Optional, Tuple, Union, cast
|
from typing import List, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
|
import matplotlib.ticker as tick
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from matplotlib import axes, patches
|
from matplotlib import axes, patches
|
||||||
import matplotlib.ticker as tick
|
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
||||||
@ -102,7 +102,6 @@ def spectrogram(
|
|||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_with_detections(
|
def spectrogram_with_detections(
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
spec: Union[torch.Tensor, np.ndarray],
|
||||||
dets: List[Annotation],
|
dets: List[Annotation],
|
||||||
|
64
batdetect2/preprocess/__init__.py
Normal file
64
batdetect2/preprocess/__init__.py
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
"""Module containing functions for preprocessing audio clips."""
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.preprocess.audio import (
|
||||||
|
AudioConfig,
|
||||||
|
ResampleConfig,
|
||||||
|
load_clip_audio,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
FFTConfig,
|
||||||
|
FrequencyConfig,
|
||||||
|
PcenConfig,
|
||||||
|
SpecSizeConfig,
|
||||||
|
SpectrogramConfig,
|
||||||
|
compute_spectrogram,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AudioConfig",
|
||||||
|
"ResampleConfig",
|
||||||
|
"SpectrogramConfig",
|
||||||
|
"FFTConfig",
|
||||||
|
"FrequencyConfig",
|
||||||
|
"PcenConfig",
|
||||||
|
"SpecSizeConfig",
|
||||||
|
"PreprocessingConfig",
|
||||||
|
"preprocess_audio_clip",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingConfig(BaseModel):
|
||||||
|
"""Configuration for preprocessing data."""
|
||||||
|
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio_clip(
|
||||||
|
clip: data.Clip,
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Preprocesses audio clip to generate spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip
|
||||||
|
The audio clip to preprocess.
|
||||||
|
config
|
||||||
|
Configuration for preprocessing.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
Preprocessed spectrogram.
|
||||||
|
|
||||||
|
"""
|
||||||
|
config = config or PreprocessingConfig()
|
||||||
|
wav = load_clip_audio(clip, config=config.audio)
|
||||||
|
return compute_spectrogram(wav, config=config.spectrogram)
|
162
batdetect2/preprocess/audio.py
Normal file
162
batdetect2/preprocess/audio.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from numpy.typing import DTypeLike
|
||||||
|
from pydantic import Field
|
||||||
|
from scipy.signal import resample, resample_poly
|
||||||
|
from soundevent import arrays, audio, data
|
||||||
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
|
TARGET_SAMPLERATE_HZ = 256_000
|
||||||
|
SCALE_RAW_AUDIO = False
|
||||||
|
DEFAULT_DURATION = 1
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleConfig(BaseConfig):
|
||||||
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
mode: str = "poly"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioConfig(BaseConfig):
|
||||||
|
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||||
|
scale: bool = SCALE_RAW_AUDIO
|
||||||
|
center: bool = True
|
||||||
|
duration: Optional[float] = DEFAULT_DURATION
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip_audio(
|
||||||
|
clip: data.Clip,
|
||||||
|
config: Optional[AudioConfig] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
config = config or AudioConfig()
|
||||||
|
|
||||||
|
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
||||||
|
|
||||||
|
if config.duration is not None:
|
||||||
|
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||||
|
|
||||||
|
if config.resample:
|
||||||
|
wav = resample_audio(
|
||||||
|
wav,
|
||||||
|
samplerate=config.resample.samplerate,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.center:
|
||||||
|
wav = ops.center(wav)
|
||||||
|
|
||||||
|
if config.scale:
|
||||||
|
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||||
|
|
||||||
|
return wav.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_audio_duration(
|
||||||
|
wave: xr.DataArray,
|
||||||
|
duration: float,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
|
current_duration = end_time - start_time
|
||||||
|
|
||||||
|
if current_duration == duration:
|
||||||
|
return wave
|
||||||
|
|
||||||
|
if current_duration > duration:
|
||||||
|
return arrays.crop_dim(
|
||||||
|
wave,
|
||||||
|
dim="time",
|
||||||
|
start=start_time,
|
||||||
|
stop=start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
return arrays.extend_dim(
|
||||||
|
wave,
|
||||||
|
dim="time",
|
||||||
|
start=start_time,
|
||||||
|
stop=start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio(
|
||||||
|
wav: xr.DataArray,
|
||||||
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
mode: str = "poly",
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if "time" not in wav.dims:
|
||||||
|
raise ValueError("Audio must have a time dimension")
|
||||||
|
|
||||||
|
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||||
|
step = arrays.get_dim_step(wav, dim="time")
|
||||||
|
original_samplerate = int(1 / step)
|
||||||
|
|
||||||
|
if original_samplerate == samplerate:
|
||||||
|
return wav.astype(dtype)
|
||||||
|
|
||||||
|
if mode == "poly":
|
||||||
|
resampled = resample_audio_poly(
|
||||||
|
wav,
|
||||||
|
sr_orig=original_samplerate,
|
||||||
|
sr_new=samplerate,
|
||||||
|
axis=time_axis,
|
||||||
|
)
|
||||||
|
elif mode == "fourier":
|
||||||
|
resampled = resample_audio_fourier(
|
||||||
|
wav,
|
||||||
|
sr_orig=original_samplerate,
|
||||||
|
sr_new=samplerate,
|
||||||
|
axis=time_axis,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
|
||||||
|
|
||||||
|
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||||
|
times = np.linspace(
|
||||||
|
start,
|
||||||
|
stop + step,
|
||||||
|
len(resampled),
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=resampled.astype(dtype),
|
||||||
|
dims=wav.dims,
|
||||||
|
coords={
|
||||||
|
**wav.coords,
|
||||||
|
"time": arrays.create_time_dim_from_array(
|
||||||
|
times,
|
||||||
|
samplerate=samplerate,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
attrs=wav.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio_poly(
|
||||||
|
array: xr.DataArray,
|
||||||
|
sr_orig: int,
|
||||||
|
sr_new: int,
|
||||||
|
axis: int = -1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
gcd = np.gcd(sr_orig, sr_new)
|
||||||
|
return resample_poly(
|
||||||
|
array.values,
|
||||||
|
sr_new // gcd,
|
||||||
|
sr_orig // gcd,
|
||||||
|
axis=axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio_fourier(
|
||||||
|
array: xr.DataArray,
|
||||||
|
sr_orig: int,
|
||||||
|
sr_new: int,
|
||||||
|
axis: int = -1,
|
||||||
|
) -> np.ndarray:
|
||||||
|
ratio = sr_new / sr_orig
|
||||||
|
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
|
242
batdetect2/preprocess/spectrogram.py
Normal file
242
batdetect2/preprocess/spectrogram.py
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import librosa.core.spectrum
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from numpy.typing import DTypeLike
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import arrays, audio
|
||||||
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.preprocess.audio import DEFAULT_DURATION
|
||||||
|
|
||||||
|
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||||
|
FFT_OVERLAP = 0.75
|
||||||
|
MAX_FREQ_HZ = 120000
|
||||||
|
MIN_FREQ_HZ = 10000
|
||||||
|
SPEC_HEIGHT = 128
|
||||||
|
SPEC_WIDTH = 256
|
||||||
|
SPEC_SCALE = "pcen"
|
||||||
|
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||||
|
DENOISE_SPEC_AVG = True
|
||||||
|
MAX_SCALE_SPEC = False
|
||||||
|
|
||||||
|
|
||||||
|
class FFTConfig(BaseConfig):
|
||||||
|
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||||
|
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||||
|
window_fn: str = "hann"
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyConfig(BaseConfig):
|
||||||
|
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||||
|
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class PcenConfig(BaseConfig):
|
||||||
|
time_constant: float = 0.4
|
||||||
|
hop_length: int = 512
|
||||||
|
gain: float = 0.98
|
||||||
|
bias: float = 2
|
||||||
|
power: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class SpecSizeConfig(BaseConfig):
|
||||||
|
height: int = SPEC_HEIGHT
|
||||||
|
time_period: float = SPEC_TIME_PERIOD
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramConfig(BaseConfig):
|
||||||
|
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||||
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
|
scale: Union[Literal["log"], None, PcenConfig] = "log"
|
||||||
|
denoise: bool = True
|
||||||
|
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||||
|
max_scale: bool = MAX_SCALE_SPEC
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram(
|
||||||
|
wav: xr.DataArray,
|
||||||
|
config: Optional[SpectrogramConfig] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
|
spec = stft(
|
||||||
|
wav,
|
||||||
|
window_duration=config.fft.window_duration,
|
||||||
|
window_overlap=config.fft.window_overlap,
|
||||||
|
window_fn=config.fft.window_fn,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = crop_spectrogram_frequencies(
|
||||||
|
spec,
|
||||||
|
min_freq=config.frequencies.min_freq,
|
||||||
|
max_freq=config.frequencies.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = scale_spectrogram(spec, scale=config.scale)
|
||||||
|
|
||||||
|
if config.denoise:
|
||||||
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
|
if config.resize:
|
||||||
|
spec = resize_spectrogram(spec, config=config.resize)
|
||||||
|
|
||||||
|
if config.max_scale:
|
||||||
|
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||||
|
|
||||||
|
return spec.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def crop_spectrogram_frequencies(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
return arrays.crop_dim(
|
||||||
|
spec,
|
||||||
|
dim="frequency",
|
||||||
|
start=min_freq,
|
||||||
|
stop=max_freq,
|
||||||
|
).astype(spec.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def stft(
|
||||||
|
wave: xr.DataArray,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
window_fn: str = "hann",
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
|
step = arrays.get_dim_step(wave, dim="time")
|
||||||
|
sampling_rate = 1 / step
|
||||||
|
|
||||||
|
hop_len = window_duration * (1 - window_overlap)
|
||||||
|
nfft = int(window_duration * sampling_rate)
|
||||||
|
noverlap = int(window_overlap * nfft)
|
||||||
|
|
||||||
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
|
y=wave.data.astype(dtype),
|
||||||
|
power=1,
|
||||||
|
n_fft=nfft,
|
||||||
|
hop_length=nfft - noverlap,
|
||||||
|
center=False,
|
||||||
|
window=window_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=spec.astype(dtype),
|
||||||
|
dims=["frequency", "time"],
|
||||||
|
coords={
|
||||||
|
"frequency": arrays.create_frequency_dim_from_array(
|
||||||
|
np.linspace(
|
||||||
|
0,
|
||||||
|
sampling_rate / 2,
|
||||||
|
spec.shape[0],
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
step=sampling_rate / nfft,
|
||||||
|
),
|
||||||
|
"time": arrays.create_time_dim_from_array(
|
||||||
|
np.linspace(
|
||||||
|
start_time,
|
||||||
|
end_time - (window_duration - hop_len),
|
||||||
|
spec.shape[1],
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
step=hop_len,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
**wave.attrs,
|
||||||
|
"original_samplerate": sampling_rate,
|
||||||
|
"nfft": nfft,
|
||||||
|
"noverlap": noverlap,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||||
|
return xr.DataArray(
|
||||||
|
data=(spec - spec.mean("time")).clip(0),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs=spec.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
scale: Union[Literal["log"], None, PcenConfig],
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if scale == "log":
|
||||||
|
return scale_log(spec, dtype=dtype)
|
||||||
|
|
||||||
|
if isinstance(scale, PcenConfig):
|
||||||
|
return scale_pcen(
|
||||||
|
spec,
|
||||||
|
time_constant=scale.time_constant,
|
||||||
|
hop_length=scale.hop_length,
|
||||||
|
gain=scale.gain,
|
||||||
|
power=scale.power,
|
||||||
|
bias=scale.bias,
|
||||||
|
)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def scale_pcen(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
time_constant: float = 0.4,
|
||||||
|
hop_length: int = 512,
|
||||||
|
gain: float = 0.98,
|
||||||
|
bias: float = 2,
|
||||||
|
power: float = 0.5,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
|
# NOTE: Not sure why the 10 is there
|
||||||
|
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||||
|
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
|
return audio.pcen(
|
||||||
|
spec * (2**31),
|
||||||
|
smooth=smoothing_constant,
|
||||||
|
gain=gain,
|
||||||
|
bias=bias,
|
||||||
|
power=power,
|
||||||
|
).astype(spec.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_log(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
|
nfft = spec.attrs["nfft"]
|
||||||
|
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||||
|
return xr.DataArray(
|
||||||
|
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs=spec.attrs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
config: SpecSizeConfig,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
duration = arrays.get_dim_width(spec, dim="time")
|
||||||
|
return ops.resize(
|
||||||
|
spec,
|
||||||
|
time=int(np.ceil(duration / config.time_period)),
|
||||||
|
frequency=config.height,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
88
batdetect2/terms.py
Normal file
88
batdetect2/terms.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from inspect import getmembers
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from soundevent import data, terms
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"call_type",
|
||||||
|
"individual",
|
||||||
|
"get_term_from_info",
|
||||||
|
"get_tag_from_info",
|
||||||
|
"TermInfo",
|
||||||
|
"TagInfo",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TermInfo(BaseModel):
|
||||||
|
label: Optional[str]
|
||||||
|
name: Optional[str]
|
||||||
|
uri: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TagInfo(BaseModel):
|
||||||
|
value: str
|
||||||
|
label: Optional[str] = None
|
||||||
|
term: Optional[TermInfo] = None
|
||||||
|
key: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
call_type = data.Term(
|
||||||
|
name="soundevent:call_type",
|
||||||
|
label="Call Type",
|
||||||
|
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
||||||
|
)
|
||||||
|
|
||||||
|
individual = data.Term(
|
||||||
|
name="soundevent:individual",
|
||||||
|
label="Individual",
|
||||||
|
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ALL_TERMS = [
|
||||||
|
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||||
|
call_type,
|
||||||
|
individual,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
||||||
|
for term in ALL_TERMS:
|
||||||
|
if term_info.name and term_info.name == term.name:
|
||||||
|
return term
|
||||||
|
|
||||||
|
if term_info.label and term_info.label == term.label:
|
||||||
|
return term
|
||||||
|
|
||||||
|
if term_info.uri and term_info.uri == term.uri:
|
||||||
|
return term
|
||||||
|
|
||||||
|
if term_info.name is None:
|
||||||
|
if term_info.label is None:
|
||||||
|
raise ValueError("At least one of name or label must be provided.")
|
||||||
|
|
||||||
|
term_info.name = (
|
||||||
|
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if term_info.label is None:
|
||||||
|
term_info.label = term_info.name
|
||||||
|
|
||||||
|
return data.Term(
|
||||||
|
name=term_info.name,
|
||||||
|
label=term_info.label,
|
||||||
|
uri=term_info.uri,
|
||||||
|
definition="Unknown",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||||
|
if tag_info.term:
|
||||||
|
term = get_term_from_info(tag_info.term)
|
||||||
|
elif tag_info.key:
|
||||||
|
term = data.term_from_key(tag_info.key)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either term or key must be provided in tag info.")
|
||||||
|
|
||||||
|
return data.Tag(term=term, value=tag_info.value)
|
@ -1,11 +1,8 @@
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, List, Optional, Tuple
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
|
|
||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
@ -223,8 +220,8 @@ def mask_frequency(
|
|||||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||||
|
|
||||||
freq_coord = train_example.coords["frequency"]
|
freq_coord = train_example.coords["frequency"]
|
||||||
min_freq = freq_coord.min()
|
min_freq = float(freq_coord.min())
|
||||||
max_freq = freq_coord.max()
|
max_freq = float(freq_coord.max())
|
||||||
|
|
||||||
for _ in range(num_masks):
|
for _ in range(num_masks):
|
||||||
mask_size = np.random.uniform(0, max_freq_mask)
|
mask_size = np.random.uniform(0, max_freq_mask)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Tuple
|
from typing import Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
@ -17,7 +17,7 @@ TARGET_SIGMA = 3.0
|
|||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
clip_annotation: data.ClipAnnotation,
|
sound_events: Sequence[data.SoundEventAnnotation],
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
class_mapper: ClassMapper,
|
class_mapper: ClassMapper,
|
||||||
target_sigma: float = TARGET_SIGMA,
|
target_sigma: float = TARGET_SIGMA,
|
||||||
@ -52,9 +52,8 @@ def generate_heatmaps(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
for sound_event_annotation in clip_annotation.sound_events:
|
for sound_event_annotation in sound_events:
|
||||||
geom = sound_event_annotation.sound_event.geometry
|
geom = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
if geom is None:
|
if geom is None:
|
||||||
continue
|
continue
|
||||||
|
|
@ -1,21 +1,30 @@
|
|||||||
"""Module for preprocessing data for training."""
|
"""Module for preprocessing data for training."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Sequence, Union
|
from typing import Callable, Optional, Sequence, Union
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.data.preprocessing import (
|
from batdetect2.preprocess import (
|
||||||
PreprocessingConfig,
|
PreprocessingConfig,
|
||||||
preprocess_audio_clip,
|
preprocess_audio_clip,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.labels import (
|
||||||
|
TARGET_SIGMA,
|
||||||
|
generate_heatmaps,
|
||||||
|
)
|
||||||
|
from batdetect2.train.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_class_mapper,
|
||||||
|
build_sound_event_filter,
|
||||||
|
)
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||||
@ -25,25 +34,44 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MasksConfig(BaseConfig):
|
||||||
|
sigma: float = TARGET_SIGMA
|
||||||
|
|
||||||
|
|
||||||
|
class TrainPreprocessingConfig(BaseConfig):
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
masks: MasksConfig = Field(default_factory=MasksConfig)
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
class_mapper: ClassMapper,
|
config: Optional[TrainPreprocessingConfig] = None,
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
|
||||||
target_sigma: float = TARGET_SIGMA,
|
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Generate a training example."""
|
"""Generate a training example."""
|
||||||
preprocessing_config = preprocessing_config or PreprocessingConfig()
|
config = config or TrainPreprocessingConfig()
|
||||||
|
|
||||||
spectrogram = preprocess_audio_clip(
|
spectrogram = preprocess_audio_clip(
|
||||||
clip_annotation.clip,
|
clip_annotation.clip,
|
||||||
config=preprocessing_config,
|
config=config.preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
filter_fn = build_sound_event_filter(
|
||||||
|
include=config.target.include,
|
||||||
|
exclude=config.target.exclude,
|
||||||
|
)
|
||||||
|
selected_events = [
|
||||||
|
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||||
|
]
|
||||||
|
|
||||||
|
class_mapper = build_class_mapper(config.target.classes)
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
clip_annotation,
|
selected_events,
|
||||||
spectrogram,
|
spectrogram,
|
||||||
class_mapper,
|
class_mapper,
|
||||||
target_sigma=target_sigma,
|
target_sigma=config.masks.sigma,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
dataset = xr.Dataset(
|
||||||
@ -57,8 +85,7 @@ def generate_train_example(
|
|||||||
|
|
||||||
return dataset.assign_attrs(
|
return dataset.assign_attrs(
|
||||||
title=f"Training example for {clip_annotation.uuid}",
|
title=f"Training example for {clip_annotation.uuid}",
|
||||||
preprocessing_configuration=preprocessing_config.model_dump_json(),
|
config=config.model_dump_json(),
|
||||||
target_sigma=target_sigma,
|
|
||||||
clip_annotation=clip_annotation.model_dump_json(),
|
clip_annotation=clip_annotation.model_dump_json(),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -78,77 +105,22 @@ def save_to_file(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
|
||||||
"""Load configuration from file."""
|
|
||||||
|
|
||||||
path = Path(path)
|
|
||||||
|
|
||||||
if not path.is_file():
|
|
||||||
warnings.warn(
|
|
||||||
f"Config file not found: {path}. Using default config.",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
return PreprocessingConfig(**kwargs)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return PreprocessingConfig.model_validate_json(path.read_text())
|
|
||||||
except ValueError as e:
|
|
||||||
warnings.warn(
|
|
||||||
f"Failed to load config file: {e}. Using default config.",
|
|
||||||
stacklevel=1,
|
|
||||||
)
|
|
||||||
return PreprocessingConfig(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||||
return f"{clip_annotation.uuid}.nc"
|
return f"{clip_annotation.uuid}.nc"
|
||||||
|
|
||||||
|
|
||||||
def preprocess_single_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
output_dir: PathLike,
|
|
||||||
config: PreprocessingConfig,
|
|
||||||
class_mapper: ClassMapper,
|
|
||||||
filename_fn: FilenameFn = _get_filename,
|
|
||||||
replace: bool = False,
|
|
||||||
target_sigma: float = TARGET_SIGMA,
|
|
||||||
) -> None:
|
|
||||||
output_dir = Path(output_dir)
|
|
||||||
|
|
||||||
filename = filename_fn(clip_annotation)
|
|
||||||
path = output_dir / filename
|
|
||||||
|
|
||||||
if path.is_file() and not replace:
|
|
||||||
return
|
|
||||||
|
|
||||||
if path.is_file() and replace:
|
|
||||||
path.unlink()
|
|
||||||
|
|
||||||
sample = generate_train_example(
|
|
||||||
clip_annotation,
|
|
||||||
class_mapper,
|
|
||||||
preprocessing_config=config,
|
|
||||||
target_sigma=target_sigma,
|
|
||||||
)
|
|
||||||
|
|
||||||
save_to_file(sample, path)
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_annotations(
|
def preprocess_annotations(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
output_dir: PathLike,
|
output_dir: PathLike,
|
||||||
class_mapper: ClassMapper,
|
|
||||||
target_sigma: float = TARGET_SIGMA,
|
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
config: Optional[PreprocessingConfig] = None,
|
config: Optional[TrainPreprocessingConfig] = None,
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Preprocess annotations and save to disk."""
|
"""Preprocess annotations and save to disk."""
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
if config is None:
|
config = config or TrainPreprocessingConfig()
|
||||||
config = PreprocessingConfig()
|
|
||||||
|
|
||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
@ -161,13 +133,33 @@ def preprocess_annotations(
|
|||||||
preprocess_single_annotation,
|
preprocess_single_annotation,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
config=config,
|
config=config,
|
||||||
class_mapper=class_mapper,
|
|
||||||
filename_fn=filename_fn,
|
filename_fn=filename_fn,
|
||||||
replace=replace,
|
replace=replace,
|
||||||
target_sigma=target_sigma,
|
|
||||||
),
|
),
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
),
|
),
|
||||||
total=len(clip_annotations),
|
total=len(clip_annotations),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_single_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
output_dir: PathLike,
|
||||||
|
config: TrainPreprocessingConfig,
|
||||||
|
filename_fn: FilenameFn = _get_filename,
|
||||||
|
replace: bool = False,
|
||||||
|
) -> None:
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
filename = filename_fn(clip_annotation)
|
||||||
|
path = output_dir / filename
|
||||||
|
|
||||||
|
if path.is_file() and not replace:
|
||||||
|
return
|
||||||
|
|
||||||
|
if path.is_file() and replace:
|
||||||
|
path.unlink()
|
||||||
|
|
||||||
|
sample = generate_train_example(clip_annotation, config=config)
|
||||||
|
save_to_file(sample, path)
|
||||||
|
99
batdetect2/train/targets.py
Normal file
99
batdetect2/train/targets.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
from functools import partial
|
||||||
|
from typing import Callable, List, Optional, Set
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||||
|
|
||||||
|
|
||||||
|
class TargetConfig(BaseConfig):
|
||||||
|
"""Configuration for target generation."""
|
||||||
|
|
||||||
|
classes: List[TagInfo] = Field(default_factory=list)
|
||||||
|
|
||||||
|
include: Optional[List[TagInfo]] = None
|
||||||
|
|
||||||
|
exclude: Optional[List[TagInfo]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def build_sound_event_filter(
|
||||||
|
include: Optional[List[TagInfo]] = None,
|
||||||
|
exclude: Optional[List[TagInfo]] = None,
|
||||||
|
) -> Callable[[data.SoundEventAnnotation], bool]:
|
||||||
|
include_tags = (
|
||||||
|
{get_tag_from_info(tag) for tag in include} if include else None
|
||||||
|
)
|
||||||
|
exclude_tags = (
|
||||||
|
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
||||||
|
)
|
||||||
|
return partial(
|
||||||
|
filter_sound_event,
|
||||||
|
include=include_tags,
|
||||||
|
exclude=exclude_tags,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_class_mapper(classes: List[TagInfo]) -> ClassMapper:
|
||||||
|
target_tags = [get_tag_from_info(tag) for tag in classes]
|
||||||
|
labels = [tag.label if tag.label else tag.value for tag in classes]
|
||||||
|
return GenericMapper(
|
||||||
|
classes=target_tags,
|
||||||
|
labels=labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_sound_event(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
include: Optional[Set[data.Tag]] = None,
|
||||||
|
exclude: Optional[Set[data.Tag]] = None,
|
||||||
|
) -> bool:
|
||||||
|
tags = set(sound_event_annotation.tags)
|
||||||
|
|
||||||
|
if include is not None and not tags & include:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if exclude is not None and tags & exclude:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class GenericMapper(ClassMapper):
|
||||||
|
"""Generic class mapper configuration."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
classes: List[data.Tag],
|
||||||
|
labels: List[str],
|
||||||
|
):
|
||||||
|
if not len(classes) == len(labels):
|
||||||
|
raise ValueError("Number of targets and class labels must match.")
|
||||||
|
|
||||||
|
self.targets = set(classes)
|
||||||
|
self.class_labels = labels
|
||||||
|
|
||||||
|
self._mapping = {tag: label for tag, label in zip(classes, labels)}
|
||||||
|
self._inverse_mapping = {
|
||||||
|
label: tag for tag, label in zip(classes, labels)
|
||||||
|
}
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> Optional[str]:
|
||||||
|
tags = set(sound_event_annotation.tags)
|
||||||
|
|
||||||
|
intersection = tags & self.targets
|
||||||
|
if not intersection:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tag = intersection.pop()
|
||||||
|
return self._mapping[tag]
|
||||||
|
|
||||||
|
def decode(self, label: str) -> List[data.Tag]:
|
||||||
|
if label not in self._inverse_mapping:
|
||||||
|
return []
|
||||||
|
return [self._inverse_mapping[label]]
|
Loading…
Reference in New Issue
Block a user