Ensure train inputs are almost equal

This commit is contained in:
mbsantiago 2024-11-18 18:10:58 +00:00
parent 1f0fb14d89
commit 36c90a600f
20 changed files with 648 additions and 74 deletions

148
batdetect2/compat/params.py Normal file
View File

@ -0,0 +1,148 @@
from batdetect2.preprocess import (
AudioConfig,
FFTConfig,
FrequencyConfig,
PcenConfig,
PreprocessingConfig,
ResampleConfig,
SpecSizeConfig,
SpectrogramConfig,
)
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
from batdetect2.terms import TagInfo
from batdetect2.train.preprocess import (
HeatmapsConfig,
TargetConfig,
TrainPreprocessingConfig,
)
def get_spectrogram_scale(scale: str):
if scale == "pcen":
return PcenConfig()
if scale == "log":
return "log"
return None
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
return PreprocessingConfig(
audio=AudioConfig(
resample=ResampleConfig(
samplerate=params["target_samp_rate"],
mode="poly",
),
scale=params["scale_raw_audio"],
center=params["scale_raw_audio"],
duration=None,
),
spectrogram=SpectrogramConfig(
fft=FFTConfig(
window_duration=params["fft_win_length"],
window_overlap=params["fft_overlap"],
window_fn="hann",
),
frequencies=FrequencyConfig(
min_freq=params["min_freq"],
max_freq=params["max_freq"],
),
scale=get_spectrogram_scale(params["spec_scale"]),
denoise=params["denoise_spec_avg"],
size=SpecSizeConfig(
height=params["spec_height"],
resize_factor=params["resize_factor"],
),
max_scale=params["max_scale_spec"],
),
)
def get_training_preprocessing_config(
params: dict,
) -> TrainPreprocessingConfig:
generic = params["generic_class"][0]
preprocessing = get_preprocessing_config(params)
freq_bin_width, time_bin_width = get_spectrogram_resolution(
preprocessing.spectrogram
)
return TrainPreprocessingConfig(
preprocessing=preprocessing,
target=TargetConfig(
classes=[
TagInfo(key="class", value=class_name, label=class_name)
for class_name in params["class_names"]
],
generic_class=TagInfo(
key="class",
value=generic,
label=generic,
),
include=[
TagInfo(key="event", value=event)
for event in params["events_of_interest"]
],
exclude=[
TagInfo(key="class", value=value)
for value in params["classes_to_ignore"]
],
),
heatmaps=HeatmapsConfig(
position="bottom-left",
time_scale=1 / time_bin_width,
frequency_scale=1 / freq_bin_width,
sigma=params["target_sigma"],
),
)
# 'standardize_classs_names_ip',
# 'convert_to_genus',
# 'genus_mapping',
# 'standardize_classs_names',
# 'genus_names',
# ['data_dir',
# 'ann_dir',
# 'train_split',
# 'model_name',
# 'num_filters',
# 'experiment',
# 'model_file_name',
# 'op_im_dir',
# 'op_im_dir_test',
# 'notes',
# 'spec_divide_factor',
# 'detection_overlap',
# 'ignore_start_end',
# 'detection_threshold',
# 'nms_kernel_size',
# 'nms_top_k_per_sec',
# 'aug_prob',
# 'augment_at_train',
# 'augment_at_train_combine',
# 'echo_max_delay',
# 'stretch_squeeze_delta',
# 'mask_max_time_perc',
# 'mask_max_freq_perc',
# 'spec_amp_scaling',
# 'aug_sampling_rates',
# 'train_loss',
# 'det_loss_weight',
# 'size_loss_weight',
# 'class_loss_weight',
# 'individual_loss_weight',
# 'emb_dim',
# 'lr',
# 'batch_size',
# 'num_workers',
# 'num_epochs',
# 'num_eval_epochs',
# 'device',
# 'save_test_image_during_train',
# 'save_test_image_after_train',
# 'train_sets',
# 'test_sets',
# 'class_inv_freq',
# 'ip_height']

View File

@ -27,6 +27,28 @@ class AudioConfig(BaseConfig):
duration: Optional[float] = DEFAULT_DURATION duration: Optional[float] = DEFAULT_DURATION
def load_file_audio(
path: data.PathLike,
config: Optional[AudioConfig] = None,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
recording = data.Recording.from_file(path)
return load_recording_audio(recording, config=config, dtype=dtype)
def load_recording_audio(
recording: data.Recording,
config: Optional[AudioConfig] = None,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
return load_clip_audio(clip, config=config, dtype=dtype)
def load_clip_audio( def load_clip_audio(
clip: data.Clip, clip: data.Clip,
config: Optional[AudioConfig] = None, config: Optional[AudioConfig] = None,

View File

@ -10,29 +10,17 @@ from soundevent import arrays, audio
from soundevent.arrays import operations as ops from soundevent.arrays import operations as ops
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.preprocess.audio import DEFAULT_DURATION
FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75
MAX_FREQ_HZ = 120000
MIN_FREQ_HZ = 10000
SPEC_HEIGHT = 128
SPEC_WIDTH = 256
SPEC_SCALE = "pcen"
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False
class FFTConfig(BaseConfig): class FFTConfig(BaseConfig):
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0) window_duration: float = Field(default=0.002, gt=0)
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1) window_overlap: float = Field(default=0.75, ge=0, lt=1)
window_fn: str = "hann" window_fn: str = "hann"
class FrequencyConfig(BaseConfig): class FrequencyConfig(BaseConfig):
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0) max_freq: int = Field(default=120_000, gt=0)
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0) min_freq: int = Field(default=10_000, gt=0)
class PcenConfig(BaseConfig): class PcenConfig(BaseConfig):
@ -44,17 +32,20 @@ class PcenConfig(BaseConfig):
class SpecSizeConfig(BaseConfig): class SpecSizeConfig(BaseConfig):
height: int = SPEC_HEIGHT height: int = 256
time_period: float = SPEC_TIME_PERIOD resize_factor: Optional[float] = 0.5
divide_factor: Optional[int] = 32
class SpectrogramConfig(BaseConfig): class SpectrogramConfig(BaseConfig):
fft: FFTConfig = Field(default_factory=FFTConfig) fft: FFTConfig = Field(default_factory=FFTConfig)
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
scale: Union[Literal["log"], None, PcenConfig] = "log" scale: Union[Literal["log"], None, PcenConfig] = Field(
default_factory=PcenConfig
)
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
denoise: bool = True denoise: bool = True
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) max_scale: bool = False
max_scale: bool = MAX_SCALE_SPEC
def compute_spectrogram( def compute_spectrogram(
@ -64,6 +55,16 @@ def compute_spectrogram(
) -> xr.DataArray: ) -> xr.DataArray:
config = config or SpectrogramConfig() config = config or SpectrogramConfig()
if config.size and config.size.divide_factor:
# Need to pad the audio to make sure the spectrogram has a
# width compatible with the divide factor
wav = pad_audio(
wav,
window_duration=config.fft.window_duration,
window_overlap=config.fft.window_overlap,
divide_factor=config.size.divide_factor,
)
spec = stft( spec = stft(
wav, wav,
window_duration=config.fft.window_duration, window_duration=config.fft.window_duration,
@ -83,8 +84,12 @@ def compute_spectrogram(
if config.denoise: if config.denoise:
spec = denoise_spectrogram(spec) spec = denoise_spectrogram(spec)
if config.resize: if config.size:
spec = resize_spectrogram(spec, config=config.resize) spec = resize_spectrogram(
spec,
height=config.size.height,
resize_factor=config.size.resize_factor,
)
if config.max_scale: if config.max_scale:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
@ -94,8 +99,8 @@ def compute_spectrogram(
def crop_spectrogram_frequencies( def crop_spectrogram_frequencies(
spec: xr.DataArray, spec: xr.DataArray,
min_freq: int = MIN_FREQ_HZ, min_freq: int = 10_000,
max_freq: int = MAX_FREQ_HZ, max_freq: int = 120_000,
) -> xr.DataArray: ) -> xr.DataArray:
return arrays.crop_dim( return arrays.crop_dim(
spec, spec,
@ -116,9 +121,10 @@ def stft(
step = arrays.get_dim_step(wave, dim="time") step = arrays.get_dim_step(wave, dim="time")
sampling_rate = 1 / step sampling_rate = 1 / step
hop_len = window_duration * (1 - window_overlap)
nfft = int(window_duration * sampling_rate) nfft = int(window_duration * sampling_rate)
noverlap = int(window_overlap * nfft) noverlap = int(window_overlap * nfft)
hop_len = nfft - noverlap
hop_duration = hop_len / sampling_rate
spec, _ = librosa.core.spectrum._spectrogram( spec, _ = librosa.core.spectrum._spectrogram(
y=wave.data.astype(dtype), y=wave.data.astype(dtype),
@ -146,12 +152,12 @@ def stft(
"time": arrays.create_time_dim_from_array( "time": arrays.create_time_dim_from_array(
np.linspace( np.linspace(
start_time, start_time,
end_time - (window_duration - hop_len), end_time - (window_duration - hop_duration),
spec.shape[1], spec.shape[1],
endpoint=False, endpoint=False,
dtype=dtype, dtype=dtype,
), ),
step=hop_len, step=hop_duration,
), ),
}, },
attrs={ attrs={
@ -202,7 +208,6 @@ def scale_pcen(
power: float = 0.5, power: float = 0.5,
) -> xr.DataArray: ) -> xr.DataArray:
samplerate = spec.attrs["original_samplerate"] samplerate = spec.attrs["original_samplerate"]
# NOTE: Not sure why the 10 is there
t_frames = time_constant * samplerate / (float(hop_length) * 10) t_frames = time_constant * samplerate / (float(hop_length) * 10)
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
return audio.pcen( return audio.pcen(
@ -231,12 +236,114 @@ def scale_log(
def resize_spectrogram( def resize_spectrogram(
spec: xr.DataArray, spec: xr.DataArray,
config: SpecSizeConfig, height: int = 128,
resize_factor: Optional[float] = 0.5,
) -> xr.DataArray: ) -> xr.DataArray:
duration = arrays.get_dim_width(spec, dim="time") resize_factor = resize_factor or 1
current_width = spec.sizes["time"]
return ops.resize( return ops.resize(
spec, spec,
time=int(np.ceil(duration / config.time_period)), time=int(resize_factor * current_width),
frequency=config.height, frequency=int(resize_factor * height),
dtype=np.float32, dtype=np.float32,
) )
def adjust_spectrogram_width(
spec: xr.DataArray,
divide_factor: int = 32,
time_period: float = 0.001,
) -> xr.DataArray:
time_width = spec.sizes["time"]
if time_width % divide_factor == 0:
return spec
target_size = int(
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
)
extra_duration = (target_size - time_width) * time_period
_, stop = arrays.get_dim_range(spec, dim="time")
resized = ops.extend_dim(
spec,
dim="time",
stop=stop + extra_duration,
)
return resized
def pad_audio(
wave: xr.DataArray,
window_duration: float,
window_overlap: float,
divide_factor: int = 32,
) -> xr.DataArray:
current_duration = arrays.get_dim_width(wave, dim="time")
step = arrays.get_dim_step(wave, dim="time")
samplerate = int(1 / step)
estimated_spec_width = duration_to_spec_width(
current_duration,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
if estimated_spec_width % divide_factor == 0:
return wave
target_spec_width = int(
np.ceil(estimated_spec_width / divide_factor) * divide_factor
)
target_samples = spec_width_to_samples(
target_spec_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
return ops.adjust_dim_width(
wave,
dim="time",
width=target_samples,
position="start",
)
def duration_to_spec_width(
duration: float,
samplerate: int,
window_duration: float,
window_overlap: float,
) -> int:
samples = int(duration * samplerate)
fft_len = int(window_duration * samplerate)
fft_overlap = int(window_overlap * fft_len)
hop_len = fft_len - fft_overlap
width = (samples - fft_len + hop_len) / hop_len
return int(np.floor(width))
def spec_width_to_samples(
width: int,
samplerate: int,
window_duration: float,
window_overlap: float,
) -> int:
fft_len = int(window_duration * samplerate)
fft_overlap = int(window_overlap * fft_len)
hop_len = fft_len - fft_overlap
return width * hop_len + fft_len - hop_len
def get_spectrogram_resolution(
config: SpectrogramConfig,
) -> tuple[float, float]:
max_freq = config.frequencies.max_freq
min_freq = config.frequencies.min_freq
assert config.size is not None
spec_height = config.size.height
resize_factor = config.size.resize_factor or 1
freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor)
hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap)
return freq_bin_width, hop_duration / resize_factor

View File

@ -22,9 +22,9 @@ class TermInfo(BaseModel):
class TagInfo(BaseModel): class TagInfo(BaseModel):
value: str value: str
label: Optional[str] = None
term: Optional[TermInfo] = None term: Optional[TermInfo] = None
key: Optional[str] = None key: Optional[str] = None
label: Optional[str] = None
call_type = data.Term( call_type = data.Term(

View File

@ -7,23 +7,29 @@ from soundevent import arrays, data, geometry
from soundevent.geometry.operations import Positions from soundevent.geometry.operations import Positions
from soundevent.types import ClassMapper from soundevent.types import ClassMapper
from batdetect2.configs import BaseConfig
__all__ = [ __all__ = [
"ClassMapper", "ClassMapper",
"generate_heatmaps", "generate_heatmaps",
] ]
TARGET_SIGMA = 3.0 class HeatmapsConfig(BaseConfig):
position: Positions = "bottom-left"
sigma: float = 3.0
time_scale: float = 1000.0
frequency_scale: float = 1 / 859.375
def generate_heatmaps( def generate_heatmaps(
sound_events: Sequence[data.SoundEventAnnotation], sound_events: Sequence[data.SoundEventAnnotation],
spec: xr.DataArray, spec: xr.DataArray,
class_mapper: ClassMapper, class_mapper: ClassMapper,
target_sigma: float = TARGET_SIGMA, target_sigma: float = 3.0,
position: Positions = "bottom-left", position: Positions = "bottom-left",
time_scale: float = 1.0, time_scale: float = 1000.0,
frequency_scale: float = 1.0, frequency_scale: float = 1 / 859.375,
dtype=np.float32, dtype=np.float32,
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]: ) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
shape = dict(zip(spec.dims, spec.shape)) shape = dict(zip(spec.dims, spec.shape))
@ -39,7 +45,7 @@ def generate_heatmaps(
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype), data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
dims=["category", *spec.dims], dims=["category", *spec.dims],
coords={ coords={
"category": class_mapper.class_labels, "category": [*class_mapper.class_labels],
**spec.coords, **spec.coords,
}, },
) )

View File

@ -14,12 +14,10 @@ from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.preprocess import ( from batdetect2.preprocess import (
PreprocessingConfig, PreprocessingConfig,
preprocess_audio_clip, compute_spectrogram,
) load_clip_audio,
from batdetect2.train.labels import (
TARGET_SIGMA,
generate_heatmaps,
) )
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
from batdetect2.train.targets import ( from batdetect2.train.targets import (
TargetConfig, TargetConfig,
build_class_mapper, build_class_mapper,
@ -34,16 +32,12 @@ __all__ = [
] ]
class MasksConfig(BaseConfig):
sigma: float = TARGET_SIGMA
class TrainPreprocessingConfig(BaseConfig): class TrainPreprocessingConfig(BaseConfig):
preprocessing: PreprocessingConfig = Field( preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
target: TargetConfig = Field(default_factory=TargetConfig) target: TargetConfig = Field(default_factory=TargetConfig)
masks: MasksConfig = Field(default_factory=MasksConfig) heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
def generate_train_example( def generate_train_example(
@ -53,9 +47,14 @@ def generate_train_example(
"""Generate a training example.""" """Generate a training example."""
config = config or TrainPreprocessingConfig() config = config or TrainPreprocessingConfig()
spectrogram = preprocess_audio_clip( wave = load_clip_audio(
clip_annotation.clip, clip_annotation.clip,
config=config.preprocessing, config=config.preprocessing.audio,
)
spectrogram = compute_spectrogram(
wave,
config=config.preprocessing.spectrogram,
) )
filter_fn = build_sound_event_filter( filter_fn = build_sound_event_filter(
@ -65,17 +64,24 @@ def generate_train_example(
selected_events = [ selected_events = [
event for event in clip_annotation.sound_events if filter_fn(event) event for event in clip_annotation.sound_events if filter_fn(event)
] ]
class_mapper = build_class_mapper(config.target.classes) class_mapper = build_class_mapper(config.target.classes)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
selected_events, selected_events,
spectrogram, spectrogram,
class_mapper, class_mapper,
target_sigma=config.masks.sigma, target_sigma=config.heatmaps.sigma,
position=config.heatmaps.position,
time_scale=config.heatmaps.time_scale,
frequency_scale=config.heatmaps.frequency_scale,
) )
dataset = xr.Dataset( dataset = xr.Dataset(
{ {
# NOTE: Need to rename the time dimension to avoid conflicts with
# the spectrogram time dimension, otherwise xarray will interpolate
# the spectrogram and the heatmaps to the same temporal resolution
# as the waveform.
"audio": wave.rename({"time": "audio_time"}),
"spectrogram": spectrogram, "spectrogram": spectrogram,
"detection": detection_heatmap, "detection": detection_heatmap,
"class": class_heatmap, "class": class_heatmap,

View File

@ -13,9 +13,9 @@ class TargetConfig(BaseConfig):
"""Configuration for target generation.""" """Configuration for target generation."""
classes: List[TagInfo] = Field(default_factory=list) classes: List[TagInfo] = Field(default_factory=list)
generic_class: Optional[TagInfo] = None
include: Optional[List[TagInfo]] = None include: Optional[List[TagInfo]] = None
exclude: Optional[List[TagInfo]] = None exclude: Optional[List[TagInfo]] = None
@ -73,7 +73,7 @@ class GenericMapper(ClassMapper):
raise ValueError("Number of targets and class labels must match.") raise ValueError("Number of targets and class labels must match.")
self.targets = set(classes) self.targets = set(classes)
self.class_labels = labels self.class_labels = list(dict.fromkeys(labels))
self._mapping = {tag: label for tag, label in zip(classes, labels)} self._mapping = {tag: label for tag, label in zip(classes, labels)}
self._inverse_mapping = { self._inverse_mapping = {

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0", "torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0", "torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.2", "soundevent[audio,geometry,plot]>=2.3",
"click>=8.1.7", "click>=8.1.7",
"netcdf4>=1.6.5", "netcdf4>=1.6.5",
"tqdm>=4.66.2", "tqdm>=4.66.2",

View File

@ -1,7 +1,11 @@
import uuid
from pathlib import Path from pathlib import Path
from typing import List from typing import Callable, List, Optional
import numpy as np
import pytest import pytest
import soundfile as sf
from soundevent import data
@pytest.fixture @pytest.fixture
@ -19,6 +23,13 @@ def example_audio_dir(example_data_dir: Path) -> Path:
return example_audio_dir return example_audio_dir
@pytest.fixture
def example_anns_dir(example_data_dir: Path) -> Path:
example_anns_dir = example_data_dir / "anns"
assert example_anns_dir.exists()
return example_anns_dir
@pytest.fixture @pytest.fixture
def example_audio_files(example_audio_dir: Path) -> List[Path]: def example_audio_files(example_audio_dir: Path) -> List[Path]:
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]")) audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
@ -38,3 +49,61 @@ def contrib_dir(data_dir) -> Path:
dir = data_dir / "contrib" dir = data_dir / "contrib"
assert dir.exists() assert dir.exists()
return dir return dir
@pytest.fixture
def wav_factory(tmp_path: Path):
def _wav_factory(
path: Optional[Path] = None,
duration: float = 0.3,
channels: int = 1,
samplerate: int = 441_000,
bit_depth: int = 16,
) -> Path:
path = path or tmp_path / f"{uuid.uuid4()}.wav"
frames = int(samplerate * duration)
shape = (frames, channels)
subtype = f"PCM_{bit_depth}"
if bit_depth == 16:
dtype = np.int16
elif bit_depth == 32:
dtype = np.int32
else:
raise ValueError(f"Unsupported bit depth: {bit_depth}")
wav = np.random.uniform(
low=np.iinfo(dtype).min,
high=np.iinfo(dtype).max,
size=shape,
).astype(dtype)
sf.write(str(path), wav, samplerate, subtype=subtype)
return path
return _wav_factory
@pytest.fixture
def recording_factory(wav_factory: Callable[..., Path]):
def _recording_factory(
tags: Optional[list[data.Tag]] = None,
path: Optional[Path] = None,
recording_id: Optional[uuid.UUID] = None,
duration: float = 1,
channels: int = 1,
samplerate: int = 44100,
time_expansion: float = 1,
) -> data.Recording:
path = path or wav_factory(
duration=duration,
channels=channels,
samplerate=samplerate,
)
return data.Recording.from_file(
path=path,
uuid=recording_id or uuid.uuid4(),
time_expansion=time_expansion,
tags=tags or [],
)
return _recording_factory

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest import pytest
from soundevent import data from soundevent import data
from batdetect2.data import preprocessing from batdetect2 import preprocess
from batdetect2.utils import audio_utils from batdetect2.utils import audio_utils
ROOT_DIR = Path(__file__).parent.parent.parent ROOT_DIR = Path(__file__).parent.parent.parent
@ -44,10 +44,10 @@ def test_audio_loading_hasnt_changed(
target_samp_rate=target_sampling_rate, target_samp_rate=target_sampling_rate,
scale=scale, scale=scale,
) )
audio_new = preprocessing.load_clip_audio( audio_new = preprocess.load_clip_audio(
clip, clip,
config=preprocessing.AudioConfig( config=preprocess.AudioConfig(
resample=preprocessing.ResampleConfig( resample=preprocess.ResampleConfig(
samplerate=target_sampling_rate, samplerate=target_sampling_rate,
), ),
center=scale, center=scale,
@ -84,20 +84,20 @@ def test_spectrogram_generation_hasnt_changed(
if spec_scale == "log": if spec_scale == "log":
scale = "log" scale = "log"
elif spec_scale == "pcen": elif spec_scale == "pcen":
scale = preprocessing.PcenConfig() scale = preprocess.PcenConfig()
config = preprocessing.SpectrogramConfig( config = preprocess.SpectrogramConfig(
fft=preprocessing.FFTConfig( fft=preprocess.FFTConfig(
window_overlap=fft_overlap, window_overlap=fft_overlap,
window_duration=fft_win_length, window_duration=fft_win_length,
), ),
frequencies=preprocessing.FrequencyConfig( frequencies=preprocess.FrequencyConfig(
min_freq=min_freq, min_freq=min_freq,
max_freq=max_freq, max_freq=max_freq,
), ),
scale=scale, scale=scale,
denoise=denoise_spec_avg, denoise=denoise_spec_avg,
resize=None, size=None,
max_scale=max_scale_spec, max_scale=max_scale_spec,
) )
@ -112,10 +112,10 @@ def test_spectrogram_generation_hasnt_changed(
end_time=recording.duration, end_time=recording.duration,
) )
audio = preprocessing.load_clip_audio( audio = preprocess.load_clip_audio(
clip, clip,
config=preprocessing.AudioConfig( config=preprocess.AudioConfig(
resample=preprocessing.ResampleConfig( resample=preprocess.ResampleConfig(
samplerate=target_sampling_rate, samplerate=target_sampling_rate,
) )
), ),
@ -135,7 +135,7 @@ def test_spectrogram_generation_hasnt_changed(
), ),
) )
new_spec = preprocessing.compute_spectrogram( new_spec = preprocess.compute_spectrogram(
audio, audio,
config=config, config=config,
dtype=np.float32, dtype=np.float32,

View File

@ -0,0 +1,75 @@
import json
from pathlib import Path
from typing import List
import numpy as np
import pytest
from batdetect2.compat.data import load_annotation_project
from batdetect2.compat.params import get_training_preprocessing_config
from batdetect2.train.preprocess import generate_train_example
@pytest.fixture
def regression_dir(data_dir: Path) -> Path:
dir = data_dir / "regression"
assert dir.exists()
return dir
def test_can_generate_similar_training_inputs(
example_audio_dir: Path,
example_audio_files: List[Path],
example_anns_dir: Path,
regression_dir: Path,
):
old_parameters = json.loads((regression_dir / "params.json").read_text())
config = get_training_preprocessing_config(old_parameters)
for audio_file in example_audio_files:
example_file = regression_dir / f"{audio_file.name}.npz"
dataset = np.load(example_file)
spec = dataset["spec"][0]
detection_mask = dataset["detection_mask"][0]
size_mask = dataset["size_mask"]
class_mask = dataset["class_mask"]
project = load_annotation_project(
example_anns_dir,
audio_dir=example_audio_dir,
)
clip_annotation = next(
ann
for ann in project.clip_annotations
if ann.clip.recording.path == audio_file
)
new_dataset = generate_train_example(clip_annotation, config)
new_spec = new_dataset["spectrogram"].values
new_detection_mask = new_dataset["detection"].values
new_size_mask = new_dataset["size"].values
new_class_mask = new_dataset["class"].values
assert spec.shape == new_spec.shape
assert detection_mask.shape == new_detection_mask.shape
assert size_mask.shape == new_size_mask.shape
assert class_mask.shape[1:] == new_class_mask.shape[1:]
assert class_mask.shape[0] == new_class_mask.shape[0] + 1
x_new, y_new = np.nonzero(new_size_mask.max(axis=0))
x_orig, y_orig = np.nonzero(np.flipud(size_mask.max(axis=0)))
assert (x_new == x_orig).all()
# NOTE: a difference of 1 pixel is due to discrepancies on how
# frequency bins are interpreted. Shouldn't be an issue
assert (y_new == y_orig + 1).all()
width_new, height_new = new_size_mask[:, x_new, y_new]
width_orig, height_orig = np.flip(size_mask, axis=1)[:, x_orig, y_orig]
assert (np.floor(width_new) == width_orig).all()
assert (np.ceil(height_new) == height_orig).all()

View File

View File

View File

@ -0,0 +1,141 @@
import math
from pathlib import Path
from typing import Callable
from hypothesis import HealthCheck, given, settings
from hypothesis import strategies as st
from soundevent import arrays
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
from batdetect2.preprocess.spectrogram import (
FFTConfig,
FrequencyConfig,
SpecSizeConfig,
SpectrogramConfig,
compute_spectrogram,
duration_to_spec_width,
get_spectrogram_resolution,
pad_audio,
spec_width_to_samples,
stft,
)
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
@given(
duration=st.floats(min_value=0.1, max_value=1.0),
window_duration=st.floats(min_value=0.001, max_value=0.01),
window_overlap=st.floats(min_value=0.2, max_value=0.9),
samplerate=st.integers(min_value=256_000, max_value=512_000),
)
def test_can_estimate_correctly_spectrogram_width_from_duration(
duration: float,
window_duration: float,
window_overlap: float,
samplerate: int,
wav_factory: Callable[..., Path],
):
path = wav_factory(duration=duration, samplerate=samplerate)
audio = load_file_audio(
path,
# NOTE: Dont resample nor adjust duration to test if the width
# estimation works on all scenarios
config=AudioConfig(resample=None, duration=None),
)
spectrogram = stft(audio, window_duration, window_overlap)
spec_width = duration_to_spec_width(
duration,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
assert spectrogram.sizes["time"] == spec_width
rebuilt_duration = (
spec_width_to_samples(
spec_width,
samplerate=samplerate,
window_duration=window_duration,
window_overlap=window_overlap,
)
/ samplerate
)
assert (
abs(duration - rebuilt_duration)
< (1 - window_overlap) * window_duration
)
@settings(
suppress_health_check=[HealthCheck.function_scoped_fixture],
deadline=400,
)
@given(
duration=st.floats(min_value=0.1, max_value=1.0),
window_duration=st.floats(min_value=0.001, max_value=0.01),
window_overlap=st.floats(min_value=0.2, max_value=0.9),
samplerate=st.integers(min_value=256_000, max_value=512_000),
divide_factor=st.integers(min_value=16, max_value=64),
)
def test_can_pad_audio_to_adjust_spectrogram_width(
duration: float,
window_duration: float,
window_overlap: float,
samplerate: int,
divide_factor: int,
wav_factory: Callable[..., Path],
):
path = wav_factory(duration=duration, samplerate=samplerate)
audio = load_file_audio(
path,
# NOTE: Dont resample nor adjust duration to test if the width
# estimation works on all scenarios
config=AudioConfig(resample=None, duration=None),
)
audio = pad_audio(
audio,
window_duration=window_duration,
window_overlap=window_overlap,
divide_factor=divide_factor,
)
spectrogram = stft(audio, window_duration, window_overlap)
assert spectrogram.sizes["time"] % divide_factor == 0
def test_can_estimate_spectrogram_resolution(
wav_factory: Callable[..., Path],
):
path = wav_factory(duration=0.2, samplerate=256_000)
audio = load_file_audio(
path,
# NOTE: Dont resample nor adjust duration to test if the width
# estimation works on all scenarios
config=AudioConfig(resample=None, duration=None),
)
config = SpectrogramConfig(
fft=FFTConfig(),
size=SpecSizeConfig(height=256, resize_factor=0.5),
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
)
spec = compute_spectrogram(audio, config=config)
freq_res, time_res = get_spectrogram_resolution(config)
assert math.isclose(
arrays.get_dim_step(spec, dim="frequency"),
freq_res,
rel_tol=0.1,
)
assert math.isclose(
arrays.get_dim_step(spec, dim="time"),
time_res,
rel_tol=0.1,
)

8
uv.lock generated
View File

@ -236,7 +236,7 @@ requires-dist = [
{ name = "pytorch-lightning", specifier = ">=2.2.2" }, { name = "pytorch-lightning", specifier = ">=2.2.2" },
{ name = "scikit-learn", specifier = ">=1.2.2" }, { name = "scikit-learn", specifier = ">=1.2.2" },
{ name = "scipy", specifier = ">=1.10.1" }, { name = "scipy", specifier = ">=1.10.1" },
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.2" }, { name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" },
{ name = "tensorboard", specifier = ">=2.16.2" }, { name = "tensorboard", specifier = ">=2.16.2" },
{ name = "torch", specifier = ">=1.13.1,<2.5.0" }, { name = "torch", specifier = ">=1.13.1,<2.5.0" },
{ name = "torchaudio", specifier = ">=1.13.1,<2.5.0" }, { name = "torchaudio", specifier = ">=1.13.1,<2.5.0" },
@ -2679,15 +2679,15 @@ wheels = [
[[package]] [[package]]
name = "soundevent" name = "soundevent"
version = "2.2.0" version = "2.3.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "email-validator" }, { name = "email-validator" },
{ name = "pydantic" }, { name = "pydantic" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/5a/f9/77d723df4d8d3a390d32c07325a08bfc669d5fb55e88b98181b5793e7333/soundevent-2.2.0.tar.gz", hash = "sha256:a87a97c8e4bfdadec4b6edc128470919ef9240a344cf924d2ac21f8c6b50acf1", size = 8715229 } sdist = { url = "https://files.pythonhosted.org/packages/ff/51/83093cabe9ada781a0f7a78f82cc04162d005755b2f0ca3fdcb4ecd47a01/soundevent-2.3.0.tar.gz", hash = "sha256:b75d7674578a52bf196619f8b4b3d9170f2ca321d165ceb45916579a549c3e76", size = 8716539 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/09/d8/ce6e2830d47fc3d24db264b94384fa7cdcb8069fd7547d55b7a83857e730/soundevent-2.2.0-py3-none-any.whl", hash = "sha256:c40913c15fc697a82a02df5f62d18ad3b77bfae80a9a5d54c47bc1377d3b4d7c", size = 144188 }, { url = "https://files.pythonhosted.org/packages/ee/67/4c2d881f9b4a0b453dbee91e119e1e48df0fc92de2cc3062fcd8ad0a7e6b/soundevent-2.3.0-py3-none-any.whl", hash = "sha256:f7c74b1d73a347ebe843187c93130dc8af3214add95e6bc485f64944bea0d690", size = 145513 },
] ]
[package.optional-dependencies] [package.optional-dependencies]