mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Minor restructuring
This commit is contained in:
parent
ee884da8b0
commit
1f0fb14d89
@ -1,12 +1,11 @@
|
||||
import click
|
||||
|
||||
from batdetect2 import api
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
|
5
batdetect2/configs.py
Normal file
5
batdetect2/configs.py
Normal file
@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
@ -9,9 +9,9 @@ import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2 import types
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
@ -1,436 +0,0 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import BaseModel, Field
|
||||
from scipy.signal import resample_poly
|
||||
from soundevent import arrays, audio, data
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
__all__ = [
|
||||
"PreprocessingConfig",
|
||||
"preprocess_audio_clip",
|
||||
]
|
||||
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
SCALE_RAW_AUDIO = False
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
DEFAULT_DURATION = 1
|
||||
SPEC_HEIGHT = 128
|
||||
SPEC_WIDTH = 256
|
||||
SPEC_SCALE = "pcen"
|
||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
class ResampleConfig(BaseModel):
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
mode: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseModel):
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
scale: bool = Field(default=SCALE_RAW_AUDIO)
|
||||
center: bool = True
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
|
||||
class FFTConfig(BaseModel):
|
||||
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseModel):
|
||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||
|
||||
|
||||
class PcenConfig(BaseModel):
|
||||
time_constant: float = 0.4
|
||||
hop_length: int = 512
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseModel):
|
||||
height: int = SPEC_HEIGHT
|
||||
time_period: float = SPEC_TIME_PERIOD
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseModel):
|
||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
scale: Union[Literal["log"], None, PcenConfig] = "log"
|
||||
denoise: bool = True
|
||||
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
max_scale: bool = MAX_SCALE_SPEC
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseModel):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
@classmethod
|
||||
def from_file(
|
||||
cls,
|
||||
path: Union[str, Path],
|
||||
) -> "PreprocessingConfig":
|
||||
"""Load configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path
|
||||
Path to the configuration file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
PreprocessingConfig
|
||||
The configuration loaded from the file.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the configuration file does not exist.
|
||||
pydantic.ValidationError
|
||||
If the configuration file is invalid.
|
||||
"""
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Config file not found: {path}")
|
||||
|
||||
return cls.model_validate_json(path.read_text())
|
||||
|
||||
def to_file(self, path: Union[str, Path]) -> None:
|
||||
"""Save configuration to a file."""
|
||||
path = Path(path)
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
path.write_text(self.model_dump_json())
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
config
|
||||
Configuration for preprocessing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
|
||||
"""
|
||||
config = config or PreprocessingConfig()
|
||||
wav = load_clip_audio(clip, config=config.audio)
|
||||
spec = compute_spectrogram(wav, config=config.spectrogram)
|
||||
return spec
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or AudioConfig()
|
||||
|
||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
|
||||
if config.resample:
|
||||
wav = resample_audio(
|
||||
wav,
|
||||
samplerate=config.resample.samplerate,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if config.center:
|
||||
wav = ops.center(wav)
|
||||
|
||||
if config.scale:
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
config: Optional[SpectrogramConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.fft.window_duration,
|
||||
window_overlap=config.fft.window_overlap,
|
||||
window_fn=config.fft.window_fn,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = crop_spectrogram_frequencies(
|
||||
spec,
|
||||
min_freq=config.frequencies.min_freq,
|
||||
max_freq=config.frequencies.max_freq,
|
||||
)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=config.scale)
|
||||
|
||||
if config.denoise:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if config.resize:
|
||||
spec = resize_spectrogram(spec, config=config.resize)
|
||||
|
||||
if config.max_scale:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
|
||||
return spec.astype(dtype)
|
||||
|
||||
|
||||
def crop_spectrogram_frequencies(
|
||||
spec: xr.DataArray,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
) -> xr.DataArray:
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def adjust_audio_duration(
|
||||
wave: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
current_duration = end_time - start_time
|
||||
|
||||
if current_duration == duration:
|
||||
return wave
|
||||
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
step = arrays.get_dim_step(wav, dim="time")
|
||||
original_samplerate = int(1 / step)
|
||||
|
||||
if original_samplerate == samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
gcd = np.gcd(original_samplerate, samplerate)
|
||||
resampled = resample_poly(
|
||||
wav.values,
|
||||
samplerate // gcd,
|
||||
original_samplerate // gcd,
|
||||
axis=time_axis,
|
||||
)
|
||||
|
||||
resampled_times = np.linspace(
|
||||
start,
|
||||
stop + step,
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
resampled_times,
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
)
|
||||
|
||||
|
||||
def stft(
|
||||
wave: xr.DataArray,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
hop_len = window_duration * (1 - window_overlap)
|
||||
nfft = int(window_duration * sampling_rate)
|
||||
noverlap = int(window_overlap * nfft)
|
||||
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data.astype(dtype),
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
window=window_fn,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_duration - hop_len),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_len,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: Union[Literal["log"], None, PcenConfig],
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if scale == "log":
|
||||
return scale_log(spec, dtype=dtype)
|
||||
|
||||
if isinstance(scale, PcenConfig):
|
||||
return scale_pcen(
|
||||
spec,
|
||||
time_constant=scale.time_constant,
|
||||
hop_length=scale.hop_length,
|
||||
gain=scale.gain,
|
||||
power=scale.power,
|
||||
bias=scale.bias,
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def scale_pcen(
|
||||
spec: xr.DataArray,
|
||||
time_constant: float = 0.4,
|
||||
hop_length: int = 512,
|
||||
gain: float = 0.98,
|
||||
bias: float = 2,
|
||||
power: float = 0.5,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
# NOTE: Not sure why the 10 is there
|
||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
return audio.pcen(
|
||||
spec * (2**31),
|
||||
smooth=smoothing_constant,
|
||||
gain=gain,
|
||||
bias=bias,
|
||||
power=power,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def resize_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
config: SpecSizeConfig,
|
||||
) -> xr.DataArray:
|
||||
duration = arrays.get_dim_width(spec, dim="time")
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(np.ceil(duration / config.time_period)),
|
||||
frequency=config.height,
|
||||
dtype=np.float32,
|
||||
)
|
@ -9,7 +9,7 @@ from torch import nn, optim
|
||||
from batdetect2.data.labels import ClassMapper
|
||||
from batdetect2.data.preprocessing import (
|
||||
PreprocessingConfig,
|
||||
preprocess_audio_clip,
|
||||
preprocess,
|
||||
)
|
||||
from batdetect2.models.feature_extractors import Net2DFast
|
||||
from batdetect2.models.post_process import (
|
||||
@ -79,7 +79,7 @@ class DetectorModel(L.LightningModule):
|
||||
)
|
||||
|
||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
||||
return preprocess_audio_clip(
|
||||
return preprocess(
|
||||
clip,
|
||||
config=self.preprocessing_config,
|
||||
)
|
||||
|
@ -2,10 +2,10 @@
|
||||
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import matplotlib.ticker as tick
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
import matplotlib.ticker as tick
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
@ -102,7 +102,6 @@ def spectrogram(
|
||||
return ax
|
||||
|
||||
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
dets: List[Annotation],
|
||||
|
64
batdetect2/preprocess/__init__.py
Normal file
64
batdetect2/preprocess/__init__.py
Normal file
@ -0,0 +1,64 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.preprocess.audio import (
|
||||
AudioConfig,
|
||||
ResampleConfig,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
FFTConfig,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
compute_spectrogram,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AudioConfig",
|
||||
"ResampleConfig",
|
||||
"SpectrogramConfig",
|
||||
"FFTConfig",
|
||||
"FrequencyConfig",
|
||||
"PcenConfig",
|
||||
"SpecSizeConfig",
|
||||
"PreprocessingConfig",
|
||||
"preprocess_audio_clip",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseModel):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
config
|
||||
Configuration for preprocessing.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
|
||||
"""
|
||||
config = config or PreprocessingConfig()
|
||||
wav = load_clip_audio(clip, config=config.audio)
|
||||
return compute_spectrogram(wav, config=config.spectrogram)
|
162
batdetect2/preprocess/audio.py
Normal file
162
batdetect2/preprocess/audio.py
Normal file
@ -0,0 +1,162 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from scipy.signal import resample, resample_poly
|
||||
from soundevent import arrays, audio, data
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256_000
|
||||
SCALE_RAW_AUDIO = False
|
||||
DEFAULT_DURATION = 1
|
||||
|
||||
|
||||
class ResampleConfig(BaseConfig):
|
||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
mode: str = "poly"
|
||||
|
||||
|
||||
class AudioConfig(BaseConfig):
|
||||
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||
scale: bool = SCALE_RAW_AUDIO
|
||||
center: bool = True
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or AudioConfig()
|
||||
|
||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
||||
|
||||
if config.duration is not None:
|
||||
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||
|
||||
if config.resample:
|
||||
wav = resample_audio(
|
||||
wav,
|
||||
samplerate=config.resample.samplerate,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if config.center:
|
||||
wav = ops.center(wav)
|
||||
|
||||
if config.scale:
|
||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||
|
||||
return wav.astype(dtype)
|
||||
|
||||
|
||||
def adjust_audio_duration(
|
||||
wave: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
current_duration = end_time - start_time
|
||||
|
||||
if current_duration == duration:
|
||||
return wave
|
||||
|
||||
if current_duration > duration:
|
||||
return arrays.crop_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
return arrays.extend_dim(
|
||||
wave,
|
||||
dim="time",
|
||||
start=start_time,
|
||||
stop=start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
mode: str = "poly",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
step = arrays.get_dim_step(wav, dim="time")
|
||||
original_samplerate = int(1 / step)
|
||||
|
||||
if original_samplerate == samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
if mode == "poly":
|
||||
resampled = resample_audio_poly(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
elif mode == "fourier":
|
||||
resampled = resample_audio_fourier(
|
||||
wav,
|
||||
sr_orig=original_samplerate,
|
||||
sr_new=samplerate,
|
||||
axis=time_axis,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Resampling mode '{mode}' not implemented")
|
||||
|
||||
start, stop = arrays.get_dim_range(wav, dim="time")
|
||||
times = np.linspace(
|
||||
start,
|
||||
stop + step,
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
times,
|
||||
samplerate=samplerate,
|
||||
),
|
||||
},
|
||||
attrs=wav.attrs,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_poly(
|
||||
array: xr.DataArray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
gcd = np.gcd(sr_orig, sr_new)
|
||||
return resample_poly(
|
||||
array.values,
|
||||
sr_new // gcd,
|
||||
sr_orig // gcd,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def resample_audio_fourier(
|
||||
array: xr.DataArray,
|
||||
sr_orig: int,
|
||||
sr_new: int,
|
||||
axis: int = -1,
|
||||
) -> np.ndarray:
|
||||
ratio = sr_new / sr_orig
|
||||
return resample(array, int(array.shape[axis] * ratio), axis=axis) # type: ignore
|
242
batdetect2/preprocess/spectrogram.py
Normal file
242
batdetect2/preprocess/spectrogram.py
Normal file
@ -0,0 +1,242 @@
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from pydantic import Field
|
||||
from soundevent import arrays, audio
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.audio import DEFAULT_DURATION
|
||||
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
SPEC_HEIGHT = 128
|
||||
SPEC_WIDTH = 256
|
||||
SPEC_SCALE = "pcen"
|
||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
class FFTConfig(BaseConfig):
|
||||
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
time_constant: float = 0.4
|
||||
hop_length: int = 512
|
||||
gain: float = 0.98
|
||||
bias: float = 2
|
||||
power: float = 0.5
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
height: int = SPEC_HEIGHT
|
||||
time_period: float = SPEC_TIME_PERIOD
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
scale: Union[Literal["log"], None, PcenConfig] = "log"
|
||||
denoise: bool = True
|
||||
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
max_scale: bool = MAX_SCALE_SPEC
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
config: Optional[SpectrogramConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.fft.window_duration,
|
||||
window_overlap=config.fft.window_overlap,
|
||||
window_fn=config.fft.window_fn,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = crop_spectrogram_frequencies(
|
||||
spec,
|
||||
min_freq=config.frequencies.min_freq,
|
||||
max_freq=config.frequencies.max_freq,
|
||||
)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=config.scale)
|
||||
|
||||
if config.denoise:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if config.resize:
|
||||
spec = resize_spectrogram(spec, config=config.resize)
|
||||
|
||||
if config.max_scale:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
|
||||
return spec.astype(dtype)
|
||||
|
||||
|
||||
def crop_spectrogram_frequencies(
|
||||
spec: xr.DataArray,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
) -> xr.DataArray:
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
stop=max_freq,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def stft(
|
||||
wave: xr.DataArray,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
window_fn: str = "hann",
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
hop_len = window_duration * (1 - window_overlap)
|
||||
nfft = int(window_duration * sampling_rate)
|
||||
noverlap = int(window_overlap * nfft)
|
||||
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data.astype(dtype),
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
window=window_fn,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": arrays.create_frequency_dim_from_array(
|
||||
np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=sampling_rate / nfft,
|
||||
),
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_duration - hop_len),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_len,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**wave.attrs,
|
||||
"original_samplerate": sampling_rate,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: Union[Literal["log"], None, PcenConfig],
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if scale == "log":
|
||||
return scale_log(spec, dtype=dtype)
|
||||
|
||||
if isinstance(scale, PcenConfig):
|
||||
return scale_pcen(
|
||||
spec,
|
||||
time_constant=scale.time_constant,
|
||||
hop_length=scale.hop_length,
|
||||
gain=scale.gain,
|
||||
power=scale.power,
|
||||
bias=scale.bias,
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def scale_pcen(
|
||||
spec: xr.DataArray,
|
||||
time_constant: float = 0.4,
|
||||
hop_length: int = 512,
|
||||
gain: float = 0.98,
|
||||
bias: float = 2,
|
||||
power: float = 0.5,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
# NOTE: Not sure why the 10 is there
|
||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
return audio.pcen(
|
||||
spec * (2**31),
|
||||
smooth=smoothing_constant,
|
||||
gain=gain,
|
||||
bias=bias,
|
||||
power=power,
|
||||
).astype(spec.dtype)
|
||||
|
||||
|
||||
def scale_log(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
nfft = spec.attrs["nfft"]
|
||||
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs=spec.attrs,
|
||||
)
|
||||
|
||||
|
||||
def resize_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
config: SpecSizeConfig,
|
||||
) -> xr.DataArray:
|
||||
duration = arrays.get_dim_width(spec, dim="time")
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(np.ceil(duration / config.time_period)),
|
||||
frequency=config.height,
|
||||
dtype=np.float32,
|
||||
)
|
88
batdetect2/terms.py
Normal file
88
batdetect2/terms.py
Normal file
@ -0,0 +1,88 @@
|
||||
from inspect import getmembers
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from soundevent import data, terms
|
||||
|
||||
__all__ = [
|
||||
"call_type",
|
||||
"individual",
|
||||
"get_term_from_info",
|
||||
"get_tag_from_info",
|
||||
"TermInfo",
|
||||
"TagInfo",
|
||||
]
|
||||
|
||||
|
||||
class TermInfo(BaseModel):
|
||||
label: Optional[str]
|
||||
name: Optional[str]
|
||||
uri: Optional[str]
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
value: str
|
||||
label: Optional[str] = None
|
||||
term: Optional[TermInfo] = None
|
||||
key: Optional[str] = None
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
name="soundevent:call_type",
|
||||
label="Call Type",
|
||||
definition="A broad categorization of animal vocalizations based on their intended function or purpose (e.g., social, distress, mating, territorial, echolocation).",
|
||||
)
|
||||
|
||||
individual = data.Term(
|
||||
name="soundevent:individual",
|
||||
label="Individual",
|
||||
definition="An id for an individual animal. In the context of bioacoustic annotation, this term is used to label vocalizations that are attributed to a specific individual.",
|
||||
)
|
||||
|
||||
|
||||
ALL_TERMS = [
|
||||
*getmembers(terms, lambda x: isinstance(x, data.Term)),
|
||||
call_type,
|
||||
individual,
|
||||
]
|
||||
|
||||
|
||||
def get_term_from_info(term_info: TermInfo) -> data.Term:
|
||||
for term in ALL_TERMS:
|
||||
if term_info.name and term_info.name == term.name:
|
||||
return term
|
||||
|
||||
if term_info.label and term_info.label == term.label:
|
||||
return term
|
||||
|
||||
if term_info.uri and term_info.uri == term.uri:
|
||||
return term
|
||||
|
||||
if term_info.name is None:
|
||||
if term_info.label is None:
|
||||
raise ValueError("At least one of name or label must be provided.")
|
||||
|
||||
term_info.name = (
|
||||
f"soundevent:{term_info.label.lower().replace(' ', '_')}"
|
||||
)
|
||||
|
||||
if term_info.label is None:
|
||||
term_info.label = term_info.name
|
||||
|
||||
return data.Term(
|
||||
name=term_info.name,
|
||||
label=term_info.label,
|
||||
uri=term_info.uri,
|
||||
definition="Unknown",
|
||||
)
|
||||
|
||||
|
||||
def get_tag_from_info(tag_info: TagInfo) -> data.Tag:
|
||||
if tag_info.term:
|
||||
term = get_term_from_info(tag_info.term)
|
||||
elif tag_info.key:
|
||||
term = data.term_from_key(tag_info.key)
|
||||
else:
|
||||
raise ValueError("Either term or key must be provided in tag info.")
|
||||
|
||||
return data.Tag(term=term, value=tag_info.value)
|
@ -1,11 +1,8 @@
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
@ -223,8 +220,8 @@ def mask_frequency(
|
||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||
|
||||
freq_coord = train_example.coords["frequency"]
|
||||
min_freq = freq_coord.min()
|
||||
max_freq = freq_coord.max()
|
||||
min_freq = float(freq_coord.min())
|
||||
max_freq = float(freq_coord.max())
|
||||
|
||||
for _ in range(num_masks):
|
||||
mask_size = np.random.uniform(0, max_freq_mask)
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Tuple
|
||||
from typing import Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
@ -17,7 +17,7 @@ TARGET_SIGMA = 3.0
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
class_mapper: ClassMapper,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
@ -52,9 +52,8 @@ def generate_heatmaps(
|
||||
},
|
||||
)
|
||||
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
for sound_event_annotation in sound_events:
|
||||
geom = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geom is None:
|
||||
continue
|
||||
|
@ -1,21 +1,30 @@
|
||||
"""Module for preprocessing data for training."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
||||
from batdetect2.data.preprocessing import (
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
preprocess_audio_clip,
|
||||
)
|
||||
from batdetect2.train.labels import (
|
||||
TARGET_SIGMA,
|
||||
generate_heatmaps,
|
||||
)
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_class_mapper,
|
||||
build_sound_event_filter,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
@ -25,25 +34,44 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class MasksConfig(BaseConfig):
|
||||
sigma: float = TARGET_SIGMA
|
||||
|
||||
|
||||
class TrainPreprocessingConfig(BaseConfig):
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||
masks: MasksConfig = Field(default_factory=MasksConfig)
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
class_mapper: ClassMapper,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
config: Optional[TrainPreprocessingConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a training example."""
|
||||
preprocessing_config = preprocessing_config or PreprocessingConfig()
|
||||
config = config or TrainPreprocessingConfig()
|
||||
|
||||
spectrogram = preprocess_audio_clip(
|
||||
clip_annotation.clip,
|
||||
config=preprocessing_config,
|
||||
config=config.preprocessing,
|
||||
)
|
||||
|
||||
filter_fn = build_sound_event_filter(
|
||||
include=config.target.include,
|
||||
exclude=config.target.exclude,
|
||||
)
|
||||
selected_events = [
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
|
||||
class_mapper = build_class_mapper(config.target.classes)
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation,
|
||||
selected_events,
|
||||
spectrogram,
|
||||
class_mapper,
|
||||
target_sigma=target_sigma,
|
||||
target_sigma=config.masks.sigma,
|
||||
)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
@ -57,8 +85,7 @@ def generate_train_example(
|
||||
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
preprocessing_configuration=preprocessing_config.model_dump_json(),
|
||||
target_sigma=target_sigma,
|
||||
config=config.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(),
|
||||
)
|
||||
|
||||
@ -78,77 +105,22 @@ def save_to_file(
|
||||
)
|
||||
|
||||
|
||||
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
||||
"""Load configuration from file."""
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_file():
|
||||
warnings.warn(
|
||||
f"Config file not found: {path}. Using default config.",
|
||||
stacklevel=1,
|
||||
)
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
try:
|
||||
return PreprocessingConfig.model_validate_json(path.read_text())
|
||||
except ValueError as e:
|
||||
warnings.warn(
|
||||
f"Failed to load config file: {e}. Using default config.",
|
||||
stacklevel=1,
|
||||
)
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
return f"{clip_annotation.uuid}.nc"
|
||||
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
config: PreprocessingConfig,
|
||||
class_mapper: ClassMapper,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
) -> None:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filename = filename_fn(clip_annotation)
|
||||
path = output_dir / filename
|
||||
|
||||
if path.is_file() and not replace:
|
||||
return
|
||||
|
||||
if path.is_file() and replace:
|
||||
path.unlink()
|
||||
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
class_mapper,
|
||||
preprocessing_config=config,
|
||||
target_sigma=target_sigma,
|
||||
)
|
||||
|
||||
save_to_file(sample, path)
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
class_mapper: ClassMapper,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
config: Optional[TrainPreprocessingConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if config is None:
|
||||
config = PreprocessingConfig()
|
||||
config = config or TrainPreprocessingConfig()
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
@ -161,13 +133,33 @@ def preprocess_annotations(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
config=config,
|
||||
class_mapper=class_mapper,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
target_sigma=target_sigma,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
config: TrainPreprocessingConfig,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
) -> None:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filename = filename_fn(clip_annotation)
|
||||
path = output_dir / filename
|
||||
|
||||
if path.is_file() and not replace:
|
||||
return
|
||||
|
||||
if path.is_file() and replace:
|
||||
path.unlink()
|
||||
|
||||
sample = generate_train_example(clip_annotation, config=config)
|
||||
save_to_file(sample, path)
|
||||
|
99
batdetect2/train/targets.py
Normal file
99
batdetect2/train/targets.py
Normal file
@ -0,0 +1,99 @@
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Set
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
|
||||
|
||||
class TargetConfig(BaseConfig):
|
||||
"""Configuration for target generation."""
|
||||
|
||||
classes: List[TagInfo] = Field(default_factory=list)
|
||||
|
||||
include: Optional[List[TagInfo]] = None
|
||||
|
||||
exclude: Optional[List[TagInfo]] = None
|
||||
|
||||
|
||||
def build_sound_event_filter(
|
||||
include: Optional[List[TagInfo]] = None,
|
||||
exclude: Optional[List[TagInfo]] = None,
|
||||
) -> Callable[[data.SoundEventAnnotation], bool]:
|
||||
include_tags = (
|
||||
{get_tag_from_info(tag) for tag in include} if include else None
|
||||
)
|
||||
exclude_tags = (
|
||||
{get_tag_from_info(tag) for tag in exclude} if exclude else None
|
||||
)
|
||||
return partial(
|
||||
filter_sound_event,
|
||||
include=include_tags,
|
||||
exclude=exclude_tags,
|
||||
)
|
||||
|
||||
|
||||
def build_class_mapper(classes: List[TagInfo]) -> ClassMapper:
|
||||
target_tags = [get_tag_from_info(tag) for tag in classes]
|
||||
labels = [tag.label if tag.label else tag.value for tag in classes]
|
||||
return GenericMapper(
|
||||
classes=target_tags,
|
||||
labels=labels,
|
||||
)
|
||||
|
||||
|
||||
def filter_sound_event(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
include: Optional[Set[data.Tag]] = None,
|
||||
exclude: Optional[Set[data.Tag]] = None,
|
||||
) -> bool:
|
||||
tags = set(sound_event_annotation.tags)
|
||||
|
||||
if include is not None and not tags & include:
|
||||
return False
|
||||
|
||||
if exclude is not None and tags & exclude:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class GenericMapper(ClassMapper):
|
||||
"""Generic class mapper configuration."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
classes: List[data.Tag],
|
||||
labels: List[str],
|
||||
):
|
||||
if not len(classes) == len(labels):
|
||||
raise ValueError("Number of targets and class labels must match.")
|
||||
|
||||
self.targets = set(classes)
|
||||
self.class_labels = labels
|
||||
|
||||
self._mapping = {tag: label for tag, label in zip(classes, labels)}
|
||||
self._inverse_mapping = {
|
||||
label: tag for tag, label in zip(classes, labels)
|
||||
}
|
||||
|
||||
def encode(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> Optional[str]:
|
||||
tags = set(sound_event_annotation.tags)
|
||||
|
||||
intersection = tags & self.targets
|
||||
if not intersection:
|
||||
return None
|
||||
|
||||
tag = intersection.pop()
|
||||
return self._mapping[tag]
|
||||
|
||||
def decode(self, label: str) -> List[data.Tag]:
|
||||
if label not in self._inverse_mapping:
|
||||
return []
|
||||
return [self._inverse_mapping[label]]
|
Loading…
Reference in New Issue
Block a user