mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Ensure train inputs are almost equal
This commit is contained in:
parent
1f0fb14d89
commit
36c90a600f
148
batdetect2/compat/params.py
Normal file
148
batdetect2/compat/params.py
Normal file
@ -0,0 +1,148 @@
|
||||
from batdetect2.preprocess import (
|
||||
AudioConfig,
|
||||
FFTConfig,
|
||||
FrequencyConfig,
|
||||
PcenConfig,
|
||||
PreprocessingConfig,
|
||||
ResampleConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
)
|
||||
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||
from batdetect2.terms import TagInfo
|
||||
from batdetect2.train.preprocess import (
|
||||
HeatmapsConfig,
|
||||
TargetConfig,
|
||||
TrainPreprocessingConfig,
|
||||
)
|
||||
|
||||
|
||||
def get_spectrogram_scale(scale: str):
|
||||
if scale == "pcen":
|
||||
return PcenConfig()
|
||||
if scale == "log":
|
||||
return "log"
|
||||
return None
|
||||
|
||||
|
||||
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||
return PreprocessingConfig(
|
||||
audio=AudioConfig(
|
||||
resample=ResampleConfig(
|
||||
samplerate=params["target_samp_rate"],
|
||||
mode="poly",
|
||||
),
|
||||
scale=params["scale_raw_audio"],
|
||||
center=params["scale_raw_audio"],
|
||||
duration=None,
|
||||
),
|
||||
spectrogram=SpectrogramConfig(
|
||||
fft=FFTConfig(
|
||||
window_duration=params["fft_win_length"],
|
||||
window_overlap=params["fft_overlap"],
|
||||
window_fn="hann",
|
||||
),
|
||||
frequencies=FrequencyConfig(
|
||||
min_freq=params["min_freq"],
|
||||
max_freq=params["max_freq"],
|
||||
),
|
||||
scale=get_spectrogram_scale(params["spec_scale"]),
|
||||
denoise=params["denoise_spec_avg"],
|
||||
size=SpecSizeConfig(
|
||||
height=params["spec_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
),
|
||||
max_scale=params["max_scale_spec"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def get_training_preprocessing_config(
|
||||
params: dict,
|
||||
) -> TrainPreprocessingConfig:
|
||||
generic = params["generic_class"][0]
|
||||
preprocessing = get_preprocessing_config(params)
|
||||
|
||||
freq_bin_width, time_bin_width = get_spectrogram_resolution(
|
||||
preprocessing.spectrogram
|
||||
)
|
||||
|
||||
return TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing,
|
||||
target=TargetConfig(
|
||||
classes=[
|
||||
TagInfo(key="class", value=class_name, label=class_name)
|
||||
for class_name in params["class_names"]
|
||||
],
|
||||
generic_class=TagInfo(
|
||||
key="class",
|
||||
value=generic,
|
||||
label=generic,
|
||||
),
|
||||
include=[
|
||||
TagInfo(key="event", value=event)
|
||||
for event in params["events_of_interest"]
|
||||
],
|
||||
exclude=[
|
||||
TagInfo(key="class", value=value)
|
||||
for value in params["classes_to_ignore"]
|
||||
],
|
||||
),
|
||||
heatmaps=HeatmapsConfig(
|
||||
position="bottom-left",
|
||||
time_scale=1 / time_bin_width,
|
||||
frequency_scale=1 / freq_bin_width,
|
||||
sigma=params["target_sigma"],
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# 'standardize_classs_names_ip',
|
||||
# 'convert_to_genus',
|
||||
# 'genus_mapping',
|
||||
# 'standardize_classs_names',
|
||||
# 'genus_names',
|
||||
|
||||
# ['data_dir',
|
||||
# 'ann_dir',
|
||||
# 'train_split',
|
||||
# 'model_name',
|
||||
# 'num_filters',
|
||||
# 'experiment',
|
||||
# 'model_file_name',
|
||||
# 'op_im_dir',
|
||||
# 'op_im_dir_test',
|
||||
# 'notes',
|
||||
# 'spec_divide_factor',
|
||||
# 'detection_overlap',
|
||||
# 'ignore_start_end',
|
||||
# 'detection_threshold',
|
||||
# 'nms_kernel_size',
|
||||
# 'nms_top_k_per_sec',
|
||||
# 'aug_prob',
|
||||
# 'augment_at_train',
|
||||
# 'augment_at_train_combine',
|
||||
# 'echo_max_delay',
|
||||
# 'stretch_squeeze_delta',
|
||||
# 'mask_max_time_perc',
|
||||
# 'mask_max_freq_perc',
|
||||
# 'spec_amp_scaling',
|
||||
# 'aug_sampling_rates',
|
||||
# 'train_loss',
|
||||
# 'det_loss_weight',
|
||||
# 'size_loss_weight',
|
||||
# 'class_loss_weight',
|
||||
# 'individual_loss_weight',
|
||||
# 'emb_dim',
|
||||
# 'lr',
|
||||
# 'batch_size',
|
||||
# 'num_workers',
|
||||
# 'num_epochs',
|
||||
# 'num_eval_epochs',
|
||||
# 'device',
|
||||
# 'save_test_image_during_train',
|
||||
# 'save_test_image_after_train',
|
||||
# 'train_sets',
|
||||
# 'test_sets',
|
||||
# 'class_inv_freq',
|
||||
# 'ip_height']
|
@ -27,6 +27,28 @@ class AudioConfig(BaseConfig):
|
||||
duration: Optional[float] = DEFAULT_DURATION
|
||||
|
||||
|
||||
def load_file_audio(
|
||||
path: data.PathLike,
|
||||
config: Optional[AudioConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
recording = data.Recording.from_file(path)
|
||||
return load_recording_audio(recording, config=config, dtype=dtype)
|
||||
|
||||
|
||||
def load_recording_audio(
|
||||
recording: data.Recording,
|
||||
config: Optional[AudioConfig] = None,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
return load_clip_audio(clip, config=config, dtype=dtype)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
config: Optional[AudioConfig] = None,
|
||||
|
@ -10,29 +10,17 @@ from soundevent import arrays, audio
|
||||
from soundevent.arrays import operations as ops
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.audio import DEFAULT_DURATION
|
||||
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
SPEC_HEIGHT = 128
|
||||
SPEC_WIDTH = 256
|
||||
SPEC_SCALE = "pcen"
|
||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
class FFTConfig(BaseConfig):
|
||||
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||
window_duration: float = Field(default=0.002, gt=0)
|
||||
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||
window_fn: str = "hann"
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||
max_freq: int = Field(default=120_000, gt=0)
|
||||
min_freq: int = Field(default=10_000, gt=0)
|
||||
|
||||
|
||||
class PcenConfig(BaseConfig):
|
||||
@ -44,17 +32,20 @@ class PcenConfig(BaseConfig):
|
||||
|
||||
|
||||
class SpecSizeConfig(BaseConfig):
|
||||
height: int = SPEC_HEIGHT
|
||||
time_period: float = SPEC_TIME_PERIOD
|
||||
height: int = 256
|
||||
resize_factor: Optional[float] = 0.5
|
||||
divide_factor: Optional[int] = 32
|
||||
|
||||
|
||||
class SpectrogramConfig(BaseConfig):
|
||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||
scale: Union[Literal["log"], None, PcenConfig] = "log"
|
||||
scale: Union[Literal["log"], None, PcenConfig] = Field(
|
||||
default_factory=PcenConfig
|
||||
)
|
||||
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
denoise: bool = True
|
||||
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||
max_scale: bool = MAX_SCALE_SPEC
|
||||
max_scale: bool = False
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
@ -64,6 +55,16 @@ def compute_spectrogram(
|
||||
) -> xr.DataArray:
|
||||
config = config or SpectrogramConfig()
|
||||
|
||||
if config.size and config.size.divide_factor:
|
||||
# Need to pad the audio to make sure the spectrogram has a
|
||||
# width compatible with the divide factor
|
||||
wav = pad_audio(
|
||||
wav,
|
||||
window_duration=config.fft.window_duration,
|
||||
window_overlap=config.fft.window_overlap,
|
||||
divide_factor=config.size.divide_factor,
|
||||
)
|
||||
|
||||
spec = stft(
|
||||
wav,
|
||||
window_duration=config.fft.window_duration,
|
||||
@ -83,8 +84,12 @@ def compute_spectrogram(
|
||||
if config.denoise:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if config.resize:
|
||||
spec = resize_spectrogram(spec, config=config.resize)
|
||||
if config.size:
|
||||
spec = resize_spectrogram(
|
||||
spec,
|
||||
height=config.size.height,
|
||||
resize_factor=config.size.resize_factor,
|
||||
)
|
||||
|
||||
if config.max_scale:
|
||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||
@ -94,8 +99,8 @@ def compute_spectrogram(
|
||||
|
||||
def crop_spectrogram_frequencies(
|
||||
spec: xr.DataArray,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
min_freq: int = 10_000,
|
||||
max_freq: int = 120_000,
|
||||
) -> xr.DataArray:
|
||||
return arrays.crop_dim(
|
||||
spec,
|
||||
@ -116,9 +121,10 @@ def stft(
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
sampling_rate = 1 / step
|
||||
|
||||
hop_len = window_duration * (1 - window_overlap)
|
||||
nfft = int(window_duration * sampling_rate)
|
||||
noverlap = int(window_overlap * nfft)
|
||||
hop_len = nfft - noverlap
|
||||
hop_duration = hop_len / sampling_rate
|
||||
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=wave.data.astype(dtype),
|
||||
@ -146,12 +152,12 @@ def stft(
|
||||
"time": arrays.create_time_dim_from_array(
|
||||
np.linspace(
|
||||
start_time,
|
||||
end_time - (window_duration - hop_len),
|
||||
end_time - (window_duration - hop_duration),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
step=hop_len,
|
||||
step=hop_duration,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
@ -202,7 +208,6 @@ def scale_pcen(
|
||||
power: float = 0.5,
|
||||
) -> xr.DataArray:
|
||||
samplerate = spec.attrs["original_samplerate"]
|
||||
# NOTE: Not sure why the 10 is there
|
||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||
return audio.pcen(
|
||||
@ -231,12 +236,114 @@ def scale_log(
|
||||
|
||||
def resize_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
config: SpecSizeConfig,
|
||||
height: int = 128,
|
||||
resize_factor: Optional[float] = 0.5,
|
||||
) -> xr.DataArray:
|
||||
duration = arrays.get_dim_width(spec, dim="time")
|
||||
resize_factor = resize_factor or 1
|
||||
current_width = spec.sizes["time"]
|
||||
return ops.resize(
|
||||
spec,
|
||||
time=int(np.ceil(duration / config.time_period)),
|
||||
frequency=config.height,
|
||||
time=int(resize_factor * current_width),
|
||||
frequency=int(resize_factor * height),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
|
||||
def adjust_spectrogram_width(
|
||||
spec: xr.DataArray,
|
||||
divide_factor: int = 32,
|
||||
time_period: float = 0.001,
|
||||
) -> xr.DataArray:
|
||||
time_width = spec.sizes["time"]
|
||||
|
||||
if time_width % divide_factor == 0:
|
||||
return spec
|
||||
|
||||
target_size = int(
|
||||
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
|
||||
)
|
||||
extra_duration = (target_size - time_width) * time_period
|
||||
_, stop = arrays.get_dim_range(spec, dim="time")
|
||||
resized = ops.extend_dim(
|
||||
spec,
|
||||
dim="time",
|
||||
stop=stop + extra_duration,
|
||||
)
|
||||
return resized
|
||||
|
||||
|
||||
def pad_audio(
|
||||
wave: xr.DataArray,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
divide_factor: int = 32,
|
||||
) -> xr.DataArray:
|
||||
current_duration = arrays.get_dim_width(wave, dim="time")
|
||||
step = arrays.get_dim_step(wave, dim="time")
|
||||
samplerate = int(1 / step)
|
||||
|
||||
estimated_spec_width = duration_to_spec_width(
|
||||
current_duration,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
|
||||
if estimated_spec_width % divide_factor == 0:
|
||||
return wave
|
||||
|
||||
target_spec_width = int(
|
||||
np.ceil(estimated_spec_width / divide_factor) * divide_factor
|
||||
)
|
||||
target_samples = spec_width_to_samples(
|
||||
target_spec_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
return ops.adjust_dim_width(
|
||||
wave,
|
||||
dim="time",
|
||||
width=target_samples,
|
||||
position="start",
|
||||
)
|
||||
|
||||
|
||||
def duration_to_spec_width(
|
||||
duration: float,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
samples = int(duration * samplerate)
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
width = (samples - fft_len + hop_len) / hop_len
|
||||
return int(np.floor(width))
|
||||
|
||||
|
||||
def spec_width_to_samples(
|
||||
width: int,
|
||||
samplerate: int,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
) -> int:
|
||||
fft_len = int(window_duration * samplerate)
|
||||
fft_overlap = int(window_overlap * fft_len)
|
||||
hop_len = fft_len - fft_overlap
|
||||
return width * hop_len + fft_len - hop_len
|
||||
|
||||
|
||||
def get_spectrogram_resolution(
|
||||
config: SpectrogramConfig,
|
||||
) -> tuple[float, float]:
|
||||
max_freq = config.frequencies.max_freq
|
||||
min_freq = config.frequencies.min_freq
|
||||
assert config.size is not None
|
||||
|
||||
spec_height = config.size.height
|
||||
resize_factor = config.size.resize_factor or 1
|
||||
freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor)
|
||||
hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap)
|
||||
return freq_bin_width, hop_duration / resize_factor
|
||||
|
@ -22,9 +22,9 @@ class TermInfo(BaseModel):
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
value: str
|
||||
label: Optional[str] = None
|
||||
term: Optional[TermInfo] = None
|
||||
key: Optional[str] = None
|
||||
label: Optional[str] = None
|
||||
|
||||
|
||||
call_type = data.Term(
|
||||
|
@ -7,23 +7,29 @@ from soundevent import arrays, data, geometry
|
||||
from soundevent.geometry.operations import Positions
|
||||
from soundevent.types import ClassMapper
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"ClassMapper",
|
||||
"generate_heatmaps",
|
||||
]
|
||||
|
||||
|
||||
TARGET_SIGMA = 3.0
|
||||
class HeatmapsConfig(BaseConfig):
|
||||
position: Positions = "bottom-left"
|
||||
sigma: float = 3.0
|
||||
time_scale: float = 1000.0
|
||||
frequency_scale: float = 1 / 859.375
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
class_mapper: ClassMapper,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
target_sigma: float = 3.0,
|
||||
position: Positions = "bottom-left",
|
||||
time_scale: float = 1.0,
|
||||
frequency_scale: float = 1.0,
|
||||
time_scale: float = 1000.0,
|
||||
frequency_scale: float = 1 / 859.375,
|
||||
dtype=np.float32,
|
||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
@ -39,7 +45,7 @@ def generate_heatmaps(
|
||||
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
||||
dims=["category", *spec.dims],
|
||||
coords={
|
||||
"category": class_mapper.class_labels,
|
||||
"category": [*class_mapper.class_labels],
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
|
@ -14,12 +14,10 @@ from tqdm.auto import tqdm
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
preprocess_audio_clip,
|
||||
)
|
||||
from batdetect2.train.labels import (
|
||||
TARGET_SIGMA,
|
||||
generate_heatmaps,
|
||||
compute_spectrogram,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_class_mapper,
|
||||
@ -34,16 +32,12 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class MasksConfig(BaseConfig):
|
||||
sigma: float = TARGET_SIGMA
|
||||
|
||||
|
||||
class TrainPreprocessingConfig(BaseConfig):
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||
masks: MasksConfig = Field(default_factory=MasksConfig)
|
||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
@ -53,9 +47,14 @@ def generate_train_example(
|
||||
"""Generate a training example."""
|
||||
config = config or TrainPreprocessingConfig()
|
||||
|
||||
spectrogram = preprocess_audio_clip(
|
||||
wave = load_clip_audio(
|
||||
clip_annotation.clip,
|
||||
config=config.preprocessing,
|
||||
config=config.preprocessing.audio,
|
||||
)
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
wave,
|
||||
config=config.preprocessing.spectrogram,
|
||||
)
|
||||
|
||||
filter_fn = build_sound_event_filter(
|
||||
@ -65,17 +64,24 @@ def generate_train_example(
|
||||
selected_events = [
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
|
||||
class_mapper = build_class_mapper(config.target.classes)
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
selected_events,
|
||||
spectrogram,
|
||||
class_mapper,
|
||||
target_sigma=config.masks.sigma,
|
||||
target_sigma=config.heatmaps.sigma,
|
||||
position=config.heatmaps.position,
|
||||
time_scale=config.heatmaps.time_scale,
|
||||
frequency_scale=config.heatmaps.frequency_scale,
|
||||
)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
{
|
||||
# NOTE: Need to rename the time dimension to avoid conflicts with
|
||||
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||
# the spectrogram and the heatmaps to the same temporal resolution
|
||||
# as the waveform.
|
||||
"audio": wave.rename({"time": "audio_time"}),
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
|
@ -13,9 +13,9 @@ class TargetConfig(BaseConfig):
|
||||
"""Configuration for target generation."""
|
||||
|
||||
classes: List[TagInfo] = Field(default_factory=list)
|
||||
generic_class: Optional[TagInfo] = None
|
||||
|
||||
include: Optional[List[TagInfo]] = None
|
||||
|
||||
exclude: Optional[List[TagInfo]] = None
|
||||
|
||||
|
||||
@ -73,7 +73,7 @@ class GenericMapper(ClassMapper):
|
||||
raise ValueError("Number of targets and class labels must match.")
|
||||
|
||||
self.targets = set(classes)
|
||||
self.class_labels = labels
|
||||
self.class_labels = list(dict.fromkeys(labels))
|
||||
|
||||
self._mapping = {tag: label for tag, label in zip(classes, labels)}
|
||||
self._inverse_mapping = {
|
||||
|
@ -17,7 +17,7 @@ dependencies = [
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.2",
|
||||
"soundevent[audio,geometry,plot]>=2.3",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
|
@ -1,7 +1,11 @@
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import soundfile as sf
|
||||
from soundevent import data
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -19,6 +23,13 @@ def example_audio_dir(example_data_dir: Path) -> Path:
|
||||
return example_audio_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_anns_dir(example_data_dir: Path) -> Path:
|
||||
example_anns_dir = example_data_dir / "anns"
|
||||
assert example_anns_dir.exists()
|
||||
return example_anns_dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||
@ -38,3 +49,61 @@ def contrib_dir(data_dir) -> Path:
|
||||
dir = data_dir / "contrib"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def wav_factory(tmp_path: Path):
|
||||
def _wav_factory(
|
||||
path: Optional[Path] = None,
|
||||
duration: float = 0.3,
|
||||
channels: int = 1,
|
||||
samplerate: int = 441_000,
|
||||
bit_depth: int = 16,
|
||||
) -> Path:
|
||||
path = path or tmp_path / f"{uuid.uuid4()}.wav"
|
||||
frames = int(samplerate * duration)
|
||||
shape = (frames, channels)
|
||||
subtype = f"PCM_{bit_depth}"
|
||||
|
||||
if bit_depth == 16:
|
||||
dtype = np.int16
|
||||
elif bit_depth == 32:
|
||||
dtype = np.int32
|
||||
else:
|
||||
raise ValueError(f"Unsupported bit depth: {bit_depth}")
|
||||
|
||||
wav = np.random.uniform(
|
||||
low=np.iinfo(dtype).min,
|
||||
high=np.iinfo(dtype).max,
|
||||
size=shape,
|
||||
).astype(dtype)
|
||||
sf.write(str(path), wav, samplerate, subtype=subtype)
|
||||
return path
|
||||
|
||||
return _wav_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def recording_factory(wav_factory: Callable[..., Path]):
|
||||
def _recording_factory(
|
||||
tags: Optional[list[data.Tag]] = None,
|
||||
path: Optional[Path] = None,
|
||||
recording_id: Optional[uuid.UUID] = None,
|
||||
duration: float = 1,
|
||||
channels: int = 1,
|
||||
samplerate: int = 44100,
|
||||
time_expansion: float = 1,
|
||||
) -> data.Recording:
|
||||
path = path or wav_factory(
|
||||
duration=duration,
|
||||
channels=channels,
|
||||
samplerate=samplerate,
|
||||
)
|
||||
return data.Recording.from_file(
|
||||
path=path,
|
||||
uuid=recording_id or uuid.uuid4(),
|
||||
time_expansion=time_expansion,
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
return _recording_factory
|
||||
|
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
@ -4,7 +4,7 @@ import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data import preprocessing
|
||||
from batdetect2 import preprocess
|
||||
from batdetect2.utils import audio_utils
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
@ -44,10 +44,10 @@ def test_audio_loading_hasnt_changed(
|
||||
target_samp_rate=target_sampling_rate,
|
||||
scale=scale,
|
||||
)
|
||||
audio_new = preprocessing.load_clip_audio(
|
||||
audio_new = preprocess.load_clip_audio(
|
||||
clip,
|
||||
config=preprocessing.AudioConfig(
|
||||
resample=preprocessing.ResampleConfig(
|
||||
config=preprocess.AudioConfig(
|
||||
resample=preprocess.ResampleConfig(
|
||||
samplerate=target_sampling_rate,
|
||||
),
|
||||
center=scale,
|
||||
@ -84,20 +84,20 @@ def test_spectrogram_generation_hasnt_changed(
|
||||
if spec_scale == "log":
|
||||
scale = "log"
|
||||
elif spec_scale == "pcen":
|
||||
scale = preprocessing.PcenConfig()
|
||||
scale = preprocess.PcenConfig()
|
||||
|
||||
config = preprocessing.SpectrogramConfig(
|
||||
fft=preprocessing.FFTConfig(
|
||||
config = preprocess.SpectrogramConfig(
|
||||
fft=preprocess.FFTConfig(
|
||||
window_overlap=fft_overlap,
|
||||
window_duration=fft_win_length,
|
||||
),
|
||||
frequencies=preprocessing.FrequencyConfig(
|
||||
frequencies=preprocess.FrequencyConfig(
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
),
|
||||
scale=scale,
|
||||
denoise=denoise_spec_avg,
|
||||
resize=None,
|
||||
size=None,
|
||||
max_scale=max_scale_spec,
|
||||
)
|
||||
|
||||
@ -112,10 +112,10 @@ def test_spectrogram_generation_hasnt_changed(
|
||||
end_time=recording.duration,
|
||||
)
|
||||
|
||||
audio = preprocessing.load_clip_audio(
|
||||
audio = preprocess.load_clip_audio(
|
||||
clip,
|
||||
config=preprocessing.AudioConfig(
|
||||
resample=preprocessing.ResampleConfig(
|
||||
config=preprocess.AudioConfig(
|
||||
resample=preprocess.ResampleConfig(
|
||||
samplerate=target_sampling_rate,
|
||||
)
|
||||
),
|
||||
@ -135,7 +135,7 @@ def test_spectrogram_generation_hasnt_changed(
|
||||
),
|
||||
)
|
||||
|
||||
new_spec = preprocessing.compute_spectrogram(
|
||||
new_spec = preprocess.compute_spectrogram(
|
||||
audio,
|
||||
config=config,
|
||||
dtype=np.float32,
|
||||
|
75
tests/test_migration/test_training.py
Normal file
75
tests/test_migration/test_training.py
Normal file
@ -0,0 +1,75 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from batdetect2.compat.data import load_annotation_project
|
||||
from batdetect2.compat.params import get_training_preprocessing_config
|
||||
from batdetect2.train.preprocess import generate_train_example
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regression_dir(data_dir: Path) -> Path:
|
||||
dir = data_dir / "regression"
|
||||
assert dir.exists()
|
||||
return dir
|
||||
|
||||
|
||||
def test_can_generate_similar_training_inputs(
|
||||
example_audio_dir: Path,
|
||||
example_audio_files: List[Path],
|
||||
example_anns_dir: Path,
|
||||
regression_dir: Path,
|
||||
):
|
||||
old_parameters = json.loads((regression_dir / "params.json").read_text())
|
||||
config = get_training_preprocessing_config(old_parameters)
|
||||
|
||||
for audio_file in example_audio_files:
|
||||
example_file = regression_dir / f"{audio_file.name}.npz"
|
||||
|
||||
dataset = np.load(example_file)
|
||||
|
||||
spec = dataset["spec"][0]
|
||||
detection_mask = dataset["detection_mask"][0]
|
||||
size_mask = dataset["size_mask"]
|
||||
class_mask = dataset["class_mask"]
|
||||
|
||||
project = load_annotation_project(
|
||||
example_anns_dir,
|
||||
audio_dir=example_audio_dir,
|
||||
)
|
||||
|
||||
clip_annotation = next(
|
||||
ann
|
||||
for ann in project.clip_annotations
|
||||
if ann.clip.recording.path == audio_file
|
||||
)
|
||||
|
||||
new_dataset = generate_train_example(clip_annotation, config)
|
||||
new_spec = new_dataset["spectrogram"].values
|
||||
new_detection_mask = new_dataset["detection"].values
|
||||
new_size_mask = new_dataset["size"].values
|
||||
new_class_mask = new_dataset["class"].values
|
||||
|
||||
assert spec.shape == new_spec.shape
|
||||
assert detection_mask.shape == new_detection_mask.shape
|
||||
assert size_mask.shape == new_size_mask.shape
|
||||
assert class_mask.shape[1:] == new_class_mask.shape[1:]
|
||||
assert class_mask.shape[0] == new_class_mask.shape[0] + 1
|
||||
|
||||
x_new, y_new = np.nonzero(new_size_mask.max(axis=0))
|
||||
x_orig, y_orig = np.nonzero(np.flipud(size_mask.max(axis=0)))
|
||||
|
||||
assert (x_new == x_orig).all()
|
||||
|
||||
# NOTE: a difference of 1 pixel is due to discrepancies on how
|
||||
# frequency bins are interpreted. Shouldn't be an issue
|
||||
assert (y_new == y_orig + 1).all()
|
||||
|
||||
width_new, height_new = new_size_mask[:, x_new, y_new]
|
||||
width_orig, height_orig = np.flip(size_mask, axis=1)[:, x_orig, y_orig]
|
||||
|
||||
assert (np.floor(width_new) == width_orig).all()
|
||||
assert (np.ceil(height_new) == height_orig).all()
|
0
tests/test_preprocessing/__init__.py
Normal file
0
tests/test_preprocessing/__init__.py
Normal file
0
tests/test_preprocessing/test_audio.py
Normal file
0
tests/test_preprocessing/test_audio.py
Normal file
141
tests/test_preprocessing/test_spectrogram.py
Normal file
141
tests/test_preprocessing/test_spectrogram.py
Normal file
@ -0,0 +1,141 @@
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
from hypothesis import HealthCheck, given, settings
|
||||
from hypothesis import strategies as st
|
||||
from soundevent import arrays
|
||||
|
||||
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
||||
from batdetect2.preprocess.spectrogram import (
|
||||
FFTConfig,
|
||||
FrequencyConfig,
|
||||
SpecSizeConfig,
|
||||
SpectrogramConfig,
|
||||
compute_spectrogram,
|
||||
duration_to_spec_width,
|
||||
get_spectrogram_resolution,
|
||||
pad_audio,
|
||||
spec_width_to_samples,
|
||||
stft,
|
||||
)
|
||||
|
||||
|
||||
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||
@given(
|
||||
duration=st.floats(min_value=0.1, max_value=1.0),
|
||||
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
||||
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
||||
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
||||
)
|
||||
def test_can_estimate_correctly_spectrogram_width_from_duration(
|
||||
duration: float,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
samplerate: int,
|
||||
wav_factory: Callable[..., Path],
|
||||
):
|
||||
path = wav_factory(duration=duration, samplerate=samplerate)
|
||||
audio = load_file_audio(
|
||||
path,
|
||||
# NOTE: Dont resample nor adjust duration to test if the width
|
||||
# estimation works on all scenarios
|
||||
config=AudioConfig(resample=None, duration=None),
|
||||
)
|
||||
spectrogram = stft(audio, window_duration, window_overlap)
|
||||
|
||||
spec_width = duration_to_spec_width(
|
||||
duration,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
assert spectrogram.sizes["time"] == spec_width
|
||||
|
||||
rebuilt_duration = (
|
||||
spec_width_to_samples(
|
||||
spec_width,
|
||||
samplerate=samplerate,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
)
|
||||
/ samplerate
|
||||
)
|
||||
|
||||
assert (
|
||||
abs(duration - rebuilt_duration)
|
||||
< (1 - window_overlap) * window_duration
|
||||
)
|
||||
|
||||
|
||||
@settings(
|
||||
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||
deadline=400,
|
||||
)
|
||||
@given(
|
||||
duration=st.floats(min_value=0.1, max_value=1.0),
|
||||
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
||||
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
||||
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
||||
divide_factor=st.integers(min_value=16, max_value=64),
|
||||
)
|
||||
def test_can_pad_audio_to_adjust_spectrogram_width(
|
||||
duration: float,
|
||||
window_duration: float,
|
||||
window_overlap: float,
|
||||
samplerate: int,
|
||||
divide_factor: int,
|
||||
wav_factory: Callable[..., Path],
|
||||
):
|
||||
path = wav_factory(duration=duration, samplerate=samplerate)
|
||||
|
||||
audio = load_file_audio(
|
||||
path,
|
||||
# NOTE: Dont resample nor adjust duration to test if the width
|
||||
# estimation works on all scenarios
|
||||
config=AudioConfig(resample=None, duration=None),
|
||||
)
|
||||
|
||||
audio = pad_audio(
|
||||
audio,
|
||||
window_duration=window_duration,
|
||||
window_overlap=window_overlap,
|
||||
divide_factor=divide_factor,
|
||||
)
|
||||
|
||||
spectrogram = stft(audio, window_duration, window_overlap)
|
||||
assert spectrogram.sizes["time"] % divide_factor == 0
|
||||
|
||||
|
||||
def test_can_estimate_spectrogram_resolution(
|
||||
wav_factory: Callable[..., Path],
|
||||
):
|
||||
path = wav_factory(duration=0.2, samplerate=256_000)
|
||||
|
||||
audio = load_file_audio(
|
||||
path,
|
||||
# NOTE: Dont resample nor adjust duration to test if the width
|
||||
# estimation works on all scenarios
|
||||
config=AudioConfig(resample=None, duration=None),
|
||||
)
|
||||
|
||||
config = SpectrogramConfig(
|
||||
fft=FFTConfig(),
|
||||
size=SpecSizeConfig(height=256, resize_factor=0.5),
|
||||
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||
)
|
||||
|
||||
spec = compute_spectrogram(audio, config=config)
|
||||
|
||||
freq_res, time_res = get_spectrogram_resolution(config)
|
||||
|
||||
assert math.isclose(
|
||||
arrays.get_dim_step(spec, dim="frequency"),
|
||||
freq_res,
|
||||
rel_tol=0.1,
|
||||
)
|
||||
assert math.isclose(
|
||||
arrays.get_dim_step(spec, dim="time"),
|
||||
time_res,
|
||||
rel_tol=0.1,
|
||||
)
|
8
uv.lock
generated
8
uv.lock
generated
@ -236,7 +236,7 @@ requires-dist = [
|
||||
{ name = "pytorch-lightning", specifier = ">=2.2.2" },
|
||||
{ name = "scikit-learn", specifier = ">=1.2.2" },
|
||||
{ name = "scipy", specifier = ">=1.10.1" },
|
||||
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.2" },
|
||||
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" },
|
||||
{ name = "tensorboard", specifier = ">=2.16.2" },
|
||||
{ name = "torch", specifier = ">=1.13.1,<2.5.0" },
|
||||
{ name = "torchaudio", specifier = ">=1.13.1,<2.5.0" },
|
||||
@ -2679,15 +2679,15 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "soundevent"
|
||||
version = "2.2.0"
|
||||
version = "2.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "email-validator" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5a/f9/77d723df4d8d3a390d32c07325a08bfc669d5fb55e88b98181b5793e7333/soundevent-2.2.0.tar.gz", hash = "sha256:a87a97c8e4bfdadec4b6edc128470919ef9240a344cf924d2ac21f8c6b50acf1", size = 8715229 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ff/51/83093cabe9ada781a0f7a78f82cc04162d005755b2f0ca3fdcb4ecd47a01/soundevent-2.3.0.tar.gz", hash = "sha256:b75d7674578a52bf196619f8b4b3d9170f2ca321d165ceb45916579a549c3e76", size = 8716539 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/09/d8/ce6e2830d47fc3d24db264b94384fa7cdcb8069fd7547d55b7a83857e730/soundevent-2.2.0-py3-none-any.whl", hash = "sha256:c40913c15fc697a82a02df5f62d18ad3b77bfae80a9a5d54c47bc1377d3b4d7c", size = 144188 },
|
||||
{ url = "https://files.pythonhosted.org/packages/ee/67/4c2d881f9b4a0b453dbee91e119e1e48df0fc92de2cc3062fcd8ad0a7e6b/soundevent-2.3.0-py3-none-any.whl", hash = "sha256:f7c74b1d73a347ebe843187c93130dc8af3214add95e6bc485f64944bea0d690", size = 145513 },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
Loading…
Reference in New Issue
Block a user