diff --git a/batdetect2/data/preprocessing.py b/batdetect2/compat/__init__.py similarity index 100% rename from batdetect2/data/preprocessing.py rename to batdetect2/compat/__init__.py diff --git a/batdetect2/data/compat.py b/batdetect2/compat/data.py similarity index 100% rename from batdetect2/data/compat.py rename to batdetect2/compat/data.py diff --git a/batdetect2/compat/params.py b/batdetect2/compat/params.py new file mode 100644 index 0000000..acb811f --- /dev/null +++ b/batdetect2/compat/params.py @@ -0,0 +1,148 @@ +from batdetect2.preprocess import ( + AudioConfig, + FFTConfig, + FrequencyConfig, + PcenConfig, + PreprocessingConfig, + ResampleConfig, + SpecSizeConfig, + SpectrogramConfig, +) +from batdetect2.preprocess.spectrogram import get_spectrogram_resolution +from batdetect2.terms import TagInfo +from batdetect2.train.preprocess import ( + HeatmapsConfig, + TargetConfig, + TrainPreprocessingConfig, +) + + +def get_spectrogram_scale(scale: str): + if scale == "pcen": + return PcenConfig() + if scale == "log": + return "log" + return None + + +def get_preprocessing_config(params: dict) -> PreprocessingConfig: + return PreprocessingConfig( + audio=AudioConfig( + resample=ResampleConfig( + samplerate=params["target_samp_rate"], + mode="poly", + ), + scale=params["scale_raw_audio"], + center=params["scale_raw_audio"], + duration=None, + ), + spectrogram=SpectrogramConfig( + fft=FFTConfig( + window_duration=params["fft_win_length"], + window_overlap=params["fft_overlap"], + window_fn="hann", + ), + frequencies=FrequencyConfig( + min_freq=params["min_freq"], + max_freq=params["max_freq"], + ), + scale=get_spectrogram_scale(params["spec_scale"]), + denoise=params["denoise_spec_avg"], + size=SpecSizeConfig( + height=params["spec_height"], + resize_factor=params["resize_factor"], + ), + max_scale=params["max_scale_spec"], + ), + ) + + +def get_training_preprocessing_config( + params: dict, +) -> TrainPreprocessingConfig: + generic = params["generic_class"][0] + preprocessing = get_preprocessing_config(params) + + freq_bin_width, time_bin_width = get_spectrogram_resolution( + preprocessing.spectrogram + ) + + return TrainPreprocessingConfig( + preprocessing=preprocessing, + target=TargetConfig( + classes=[ + TagInfo(key="class", value=class_name, label=class_name) + for class_name in params["class_names"] + ], + generic_class=TagInfo( + key="class", + value=generic, + label=generic, + ), + include=[ + TagInfo(key="event", value=event) + for event in params["events_of_interest"] + ], + exclude=[ + TagInfo(key="class", value=value) + for value in params["classes_to_ignore"] + ], + ), + heatmaps=HeatmapsConfig( + position="bottom-left", + time_scale=1 / time_bin_width, + frequency_scale=1 / freq_bin_width, + sigma=params["target_sigma"], + ), + ) + + +# 'standardize_classs_names_ip', +# 'convert_to_genus', +# 'genus_mapping', +# 'standardize_classs_names', +# 'genus_names', + +# ['data_dir', +# 'ann_dir', +# 'train_split', +# 'model_name', +# 'num_filters', +# 'experiment', +# 'model_file_name', +# 'op_im_dir', +# 'op_im_dir_test', +# 'notes', +# 'spec_divide_factor', +# 'detection_overlap', +# 'ignore_start_end', +# 'detection_threshold', +# 'nms_kernel_size', +# 'nms_top_k_per_sec', +# 'aug_prob', +# 'augment_at_train', +# 'augment_at_train_combine', +# 'echo_max_delay', +# 'stretch_squeeze_delta', +# 'mask_max_time_perc', +# 'mask_max_freq_perc', +# 'spec_amp_scaling', +# 'aug_sampling_rates', +# 'train_loss', +# 'det_loss_weight', +# 'size_loss_weight', +# 'class_loss_weight', +# 'individual_loss_weight', +# 'emb_dim', +# 'lr', +# 'batch_size', +# 'num_workers', +# 'num_epochs', +# 'num_eval_epochs', +# 'device', +# 'save_test_image_during_train', +# 'save_test_image_after_train', +# 'train_sets', +# 'test_sets', +# 'class_inv_freq', +# 'ip_height'] diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py index 9c538d2..2a7bff1 100644 --- a/batdetect2/preprocess/audio.py +++ b/batdetect2/preprocess/audio.py @@ -27,6 +27,28 @@ class AudioConfig(BaseConfig): duration: Optional[float] = DEFAULT_DURATION +def load_file_audio( + path: data.PathLike, + config: Optional[AudioConfig] = None, + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + recording = data.Recording.from_file(path) + return load_recording_audio(recording, config=config, dtype=dtype) + + +def load_recording_audio( + recording: data.Recording, + config: Optional[AudioConfig] = None, + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + clip = data.Clip( + recording=recording, + start_time=0, + end_time=recording.duration, + ) + return load_clip_audio(clip, config=config, dtype=dtype) + + def load_clip_audio( clip: data.Clip, config: Optional[AudioConfig] = None, diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index c0a8e45..2026785 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -10,29 +10,17 @@ from soundevent import arrays, audio from soundevent.arrays import operations as ops from batdetect2.configs import BaseConfig -from batdetect2.preprocess.audio import DEFAULT_DURATION - -FFT_WIN_LENGTH_S = 512 / 256000.0 -FFT_OVERLAP = 0.75 -MAX_FREQ_HZ = 120000 -MIN_FREQ_HZ = 10000 -SPEC_HEIGHT = 128 -SPEC_WIDTH = 256 -SPEC_SCALE = "pcen" -SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH -DENOISE_SPEC_AVG = True -MAX_SCALE_SPEC = False class FFTConfig(BaseConfig): - window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0) - window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1) + window_duration: float = Field(default=0.002, gt=0) + window_overlap: float = Field(default=0.75, ge=0, lt=1) window_fn: str = "hann" class FrequencyConfig(BaseConfig): - max_freq: int = Field(default=MAX_FREQ_HZ, gt=0) - min_freq: int = Field(default=MIN_FREQ_HZ, gt=0) + max_freq: int = Field(default=120_000, gt=0) + min_freq: int = Field(default=10_000, gt=0) class PcenConfig(BaseConfig): @@ -44,17 +32,20 @@ class PcenConfig(BaseConfig): class SpecSizeConfig(BaseConfig): - height: int = SPEC_HEIGHT - time_period: float = SPEC_TIME_PERIOD + height: int = 256 + resize_factor: Optional[float] = 0.5 + divide_factor: Optional[int] = 32 class SpectrogramConfig(BaseConfig): fft: FFTConfig = Field(default_factory=FFTConfig) frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) - scale: Union[Literal["log"], None, PcenConfig] = "log" + scale: Union[Literal["log"], None, PcenConfig] = Field( + default_factory=PcenConfig + ) + size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) denoise: bool = True - resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) - max_scale: bool = MAX_SCALE_SPEC + max_scale: bool = False def compute_spectrogram( @@ -64,6 +55,16 @@ def compute_spectrogram( ) -> xr.DataArray: config = config or SpectrogramConfig() + if config.size and config.size.divide_factor: + # Need to pad the audio to make sure the spectrogram has a + # width compatible with the divide factor + wav = pad_audio( + wav, + window_duration=config.fft.window_duration, + window_overlap=config.fft.window_overlap, + divide_factor=config.size.divide_factor, + ) + spec = stft( wav, window_duration=config.fft.window_duration, @@ -83,8 +84,12 @@ def compute_spectrogram( if config.denoise: spec = denoise_spectrogram(spec) - if config.resize: - spec = resize_spectrogram(spec, config=config.resize) + if config.size: + spec = resize_spectrogram( + spec, + height=config.size.height, + resize_factor=config.size.resize_factor, + ) if config.max_scale: spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) @@ -94,8 +99,8 @@ def compute_spectrogram( def crop_spectrogram_frequencies( spec: xr.DataArray, - min_freq: int = MIN_FREQ_HZ, - max_freq: int = MAX_FREQ_HZ, + min_freq: int = 10_000, + max_freq: int = 120_000, ) -> xr.DataArray: return arrays.crop_dim( spec, @@ -116,9 +121,10 @@ def stft( step = arrays.get_dim_step(wave, dim="time") sampling_rate = 1 / step - hop_len = window_duration * (1 - window_overlap) nfft = int(window_duration * sampling_rate) noverlap = int(window_overlap * nfft) + hop_len = nfft - noverlap + hop_duration = hop_len / sampling_rate spec, _ = librosa.core.spectrum._spectrogram( y=wave.data.astype(dtype), @@ -146,12 +152,12 @@ def stft( "time": arrays.create_time_dim_from_array( np.linspace( start_time, - end_time - (window_duration - hop_len), + end_time - (window_duration - hop_duration), spec.shape[1], endpoint=False, dtype=dtype, ), - step=hop_len, + step=hop_duration, ), }, attrs={ @@ -202,7 +208,6 @@ def scale_pcen( power: float = 0.5, ) -> xr.DataArray: samplerate = spec.attrs["original_samplerate"] - # NOTE: Not sure why the 10 is there t_frames = time_constant * samplerate / (float(hop_length) * 10) smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) return audio.pcen( @@ -231,12 +236,114 @@ def scale_log( def resize_spectrogram( spec: xr.DataArray, - config: SpecSizeConfig, + height: int = 128, + resize_factor: Optional[float] = 0.5, ) -> xr.DataArray: - duration = arrays.get_dim_width(spec, dim="time") + resize_factor = resize_factor or 1 + current_width = spec.sizes["time"] return ops.resize( spec, - time=int(np.ceil(duration / config.time_period)), - frequency=config.height, + time=int(resize_factor * current_width), + frequency=int(resize_factor * height), dtype=np.float32, ) + + +def adjust_spectrogram_width( + spec: xr.DataArray, + divide_factor: int = 32, + time_period: float = 0.001, +) -> xr.DataArray: + time_width = spec.sizes["time"] + + if time_width % divide_factor == 0: + return spec + + target_size = int( + np.ceil(spec.sizes["time"] / divide_factor) * divide_factor + ) + extra_duration = (target_size - time_width) * time_period + _, stop = arrays.get_dim_range(spec, dim="time") + resized = ops.extend_dim( + spec, + dim="time", + stop=stop + extra_duration, + ) + return resized + + +def pad_audio( + wave: xr.DataArray, + window_duration: float, + window_overlap: float, + divide_factor: int = 32, +) -> xr.DataArray: + current_duration = arrays.get_dim_width(wave, dim="time") + step = arrays.get_dim_step(wave, dim="time") + samplerate = int(1 / step) + + estimated_spec_width = duration_to_spec_width( + current_duration, + samplerate=samplerate, + window_duration=window_duration, + window_overlap=window_overlap, + ) + + if estimated_spec_width % divide_factor == 0: + return wave + + target_spec_width = int( + np.ceil(estimated_spec_width / divide_factor) * divide_factor + ) + target_samples = spec_width_to_samples( + target_spec_width, + samplerate=samplerate, + window_duration=window_duration, + window_overlap=window_overlap, + ) + return ops.adjust_dim_width( + wave, + dim="time", + width=target_samples, + position="start", + ) + + +def duration_to_spec_width( + duration: float, + samplerate: int, + window_duration: float, + window_overlap: float, +) -> int: + samples = int(duration * samplerate) + fft_len = int(window_duration * samplerate) + fft_overlap = int(window_overlap * fft_len) + hop_len = fft_len - fft_overlap + width = (samples - fft_len + hop_len) / hop_len + return int(np.floor(width)) + + +def spec_width_to_samples( + width: int, + samplerate: int, + window_duration: float, + window_overlap: float, +) -> int: + fft_len = int(window_duration * samplerate) + fft_overlap = int(window_overlap * fft_len) + hop_len = fft_len - fft_overlap + return width * hop_len + fft_len - hop_len + + +def get_spectrogram_resolution( + config: SpectrogramConfig, +) -> tuple[float, float]: + max_freq = config.frequencies.max_freq + min_freq = config.frequencies.min_freq + assert config.size is not None + + spec_height = config.size.height + resize_factor = config.size.resize_factor or 1 + freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor) + hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap) + return freq_bin_width, hop_duration / resize_factor diff --git a/batdetect2/terms.py b/batdetect2/terms.py index 8d162fd..e60e3f2 100644 --- a/batdetect2/terms.py +++ b/batdetect2/terms.py @@ -22,9 +22,9 @@ class TermInfo(BaseModel): class TagInfo(BaseModel): value: str - label: Optional[str] = None term: Optional[TermInfo] = None key: Optional[str] = None + label: Optional[str] = None call_type = data.Term( diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index db45d3a..a1dc340 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -7,23 +7,29 @@ from soundevent import arrays, data, geometry from soundevent.geometry.operations import Positions from soundevent.types import ClassMapper +from batdetect2.configs import BaseConfig + __all__ = [ "ClassMapper", "generate_heatmaps", ] -TARGET_SIGMA = 3.0 +class HeatmapsConfig(BaseConfig): + position: Positions = "bottom-left" + sigma: float = 3.0 + time_scale: float = 1000.0 + frequency_scale: float = 1 / 859.375 def generate_heatmaps( sound_events: Sequence[data.SoundEventAnnotation], spec: xr.DataArray, class_mapper: ClassMapper, - target_sigma: float = TARGET_SIGMA, + target_sigma: float = 3.0, position: Positions = "bottom-left", - time_scale: float = 1.0, - frequency_scale: float = 1.0, + time_scale: float = 1000.0, + frequency_scale: float = 1 / 859.375, dtype=np.float32, ) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]: shape = dict(zip(spec.dims, spec.shape)) @@ -39,7 +45,7 @@ def generate_heatmaps( data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype), dims=["category", *spec.dims], coords={ - "category": class_mapper.class_labels, + "category": [*class_mapper.class_labels], **spec.coords, }, ) diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 9dfe755..309d49d 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -14,12 +14,10 @@ from tqdm.auto import tqdm from batdetect2.configs import BaseConfig from batdetect2.preprocess import ( PreprocessingConfig, - preprocess_audio_clip, -) -from batdetect2.train.labels import ( - TARGET_SIGMA, - generate_heatmaps, + compute_spectrogram, + load_clip_audio, ) +from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps from batdetect2.train.targets import ( TargetConfig, build_class_mapper, @@ -34,16 +32,12 @@ __all__ = [ ] -class MasksConfig(BaseConfig): - sigma: float = TARGET_SIGMA - - class TrainPreprocessingConfig(BaseConfig): preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) target: TargetConfig = Field(default_factory=TargetConfig) - masks: MasksConfig = Field(default_factory=MasksConfig) + heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig) def generate_train_example( @@ -53,9 +47,14 @@ def generate_train_example( """Generate a training example.""" config = config or TrainPreprocessingConfig() - spectrogram = preprocess_audio_clip( + wave = load_clip_audio( clip_annotation.clip, - config=config.preprocessing, + config=config.preprocessing.audio, + ) + + spectrogram = compute_spectrogram( + wave, + config=config.preprocessing.spectrogram, ) filter_fn = build_sound_event_filter( @@ -65,17 +64,24 @@ def generate_train_example( selected_events = [ event for event in clip_annotation.sound_events if filter_fn(event) ] - class_mapper = build_class_mapper(config.target.classes) detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( selected_events, spectrogram, class_mapper, - target_sigma=config.masks.sigma, + target_sigma=config.heatmaps.sigma, + position=config.heatmaps.position, + time_scale=config.heatmaps.time_scale, + frequency_scale=config.heatmaps.frequency_scale, ) dataset = xr.Dataset( { + # NOTE: Need to rename the time dimension to avoid conflicts with + # the spectrogram time dimension, otherwise xarray will interpolate + # the spectrogram and the heatmaps to the same temporal resolution + # as the waveform. + "audio": wave.rename({"time": "audio_time"}), "spectrogram": spectrogram, "detection": detection_heatmap, "class": class_heatmap, diff --git a/batdetect2/train/targets.py b/batdetect2/train/targets.py index e88aa04..dc23259 100644 --- a/batdetect2/train/targets.py +++ b/batdetect2/train/targets.py @@ -13,9 +13,9 @@ class TargetConfig(BaseConfig): """Configuration for target generation.""" classes: List[TagInfo] = Field(default_factory=list) + generic_class: Optional[TagInfo] = None include: Optional[List[TagInfo]] = None - exclude: Optional[List[TagInfo]] = None @@ -73,7 +73,7 @@ class GenericMapper(ClassMapper): raise ValueError("Number of targets and class labels must match.") self.targets = set(classes) - self.class_labels = labels + self.class_labels = list(dict.fromkeys(labels)) self._mapping = {tag: label for tag, label in zip(classes, labels)} self._inverse_mapping = { diff --git a/pyproject.toml b/pyproject.toml index 9484940..fa3a339 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "torch>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0", "torchvision>=0.14.0", - "soundevent[audio,geometry,plot]>=2.2", + "soundevent[audio,geometry,plot]>=2.3", "click>=8.1.7", "netcdf4>=1.6.5", "tqdm>=4.66.2", diff --git a/tests/conftest.py b/tests/conftest.py index fbebc98..2c7e5a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,11 @@ +import uuid from pathlib import Path -from typing import List +from typing import Callable, List, Optional +import numpy as np import pytest +import soundfile as sf +from soundevent import data @pytest.fixture @@ -19,6 +23,13 @@ def example_audio_dir(example_data_dir: Path) -> Path: return example_audio_dir +@pytest.fixture +def example_anns_dir(example_data_dir: Path) -> Path: + example_anns_dir = example_data_dir / "anns" + assert example_anns_dir.exists() + return example_anns_dir + + @pytest.fixture def example_audio_files(example_audio_dir: Path) -> List[Path]: audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]")) @@ -38,3 +49,61 @@ def contrib_dir(data_dir) -> Path: dir = data_dir / "contrib" assert dir.exists() return dir + + +@pytest.fixture +def wav_factory(tmp_path: Path): + def _wav_factory( + path: Optional[Path] = None, + duration: float = 0.3, + channels: int = 1, + samplerate: int = 441_000, + bit_depth: int = 16, + ) -> Path: + path = path or tmp_path / f"{uuid.uuid4()}.wav" + frames = int(samplerate * duration) + shape = (frames, channels) + subtype = f"PCM_{bit_depth}" + + if bit_depth == 16: + dtype = np.int16 + elif bit_depth == 32: + dtype = np.int32 + else: + raise ValueError(f"Unsupported bit depth: {bit_depth}") + + wav = np.random.uniform( + low=np.iinfo(dtype).min, + high=np.iinfo(dtype).max, + size=shape, + ).astype(dtype) + sf.write(str(path), wav, samplerate, subtype=subtype) + return path + + return _wav_factory + + +@pytest.fixture +def recording_factory(wav_factory: Callable[..., Path]): + def _recording_factory( + tags: Optional[list[data.Tag]] = None, + path: Optional[Path] = None, + recording_id: Optional[uuid.UUID] = None, + duration: float = 1, + channels: int = 1, + samplerate: int = 44100, + time_expansion: float = 1, + ) -> data.Recording: + path = path or wav_factory( + duration=duration, + channels=channels, + samplerate=samplerate, + ) + return data.Recording.from_file( + path=path, + uuid=recording_id or uuid.uuid4(), + time_expansion=time_expansion, + tags=tags or [], + ) + + return _recording_factory diff --git a/tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz b/tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz new file mode 100644 index 0000000..a66a088 Binary files /dev/null and b/tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz differ diff --git a/tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz b/tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz new file mode 100644 index 0000000..56884ba Binary files /dev/null and b/tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz differ diff --git a/tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz b/tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz new file mode 100644 index 0000000..7a3f27c Binary files /dev/null and b/tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz differ diff --git a/tests/test_migration/test_preprocessing.py b/tests/test_migration/test_preprocessing.py index 3b3a855..d9ee89a 100644 --- a/tests/test_migration/test_preprocessing.py +++ b/tests/test_migration/test_preprocessing.py @@ -4,7 +4,7 @@ import numpy as np import pytest from soundevent import data -from batdetect2.data import preprocessing +from batdetect2 import preprocess from batdetect2.utils import audio_utils ROOT_DIR = Path(__file__).parent.parent.parent @@ -44,10 +44,10 @@ def test_audio_loading_hasnt_changed( target_samp_rate=target_sampling_rate, scale=scale, ) - audio_new = preprocessing.load_clip_audio( + audio_new = preprocess.load_clip_audio( clip, - config=preprocessing.AudioConfig( - resample=preprocessing.ResampleConfig( + config=preprocess.AudioConfig( + resample=preprocess.ResampleConfig( samplerate=target_sampling_rate, ), center=scale, @@ -84,20 +84,20 @@ def test_spectrogram_generation_hasnt_changed( if spec_scale == "log": scale = "log" elif spec_scale == "pcen": - scale = preprocessing.PcenConfig() + scale = preprocess.PcenConfig() - config = preprocessing.SpectrogramConfig( - fft=preprocessing.FFTConfig( + config = preprocess.SpectrogramConfig( + fft=preprocess.FFTConfig( window_overlap=fft_overlap, window_duration=fft_win_length, ), - frequencies=preprocessing.FrequencyConfig( + frequencies=preprocess.FrequencyConfig( min_freq=min_freq, max_freq=max_freq, ), scale=scale, denoise=denoise_spec_avg, - resize=None, + size=None, max_scale=max_scale_spec, ) @@ -112,10 +112,10 @@ def test_spectrogram_generation_hasnt_changed( end_time=recording.duration, ) - audio = preprocessing.load_clip_audio( + audio = preprocess.load_clip_audio( clip, - config=preprocessing.AudioConfig( - resample=preprocessing.ResampleConfig( + config=preprocess.AudioConfig( + resample=preprocess.ResampleConfig( samplerate=target_sampling_rate, ) ), @@ -135,7 +135,7 @@ def test_spectrogram_generation_hasnt_changed( ), ) - new_spec = preprocessing.compute_spectrogram( + new_spec = preprocess.compute_spectrogram( audio, config=config, dtype=np.float32, diff --git a/tests/test_migration/test_training.py b/tests/test_migration/test_training.py new file mode 100644 index 0000000..646831e --- /dev/null +++ b/tests/test_migration/test_training.py @@ -0,0 +1,75 @@ +import json +from pathlib import Path +from typing import List + +import numpy as np +import pytest + +from batdetect2.compat.data import load_annotation_project +from batdetect2.compat.params import get_training_preprocessing_config +from batdetect2.train.preprocess import generate_train_example + + +@pytest.fixture +def regression_dir(data_dir: Path) -> Path: + dir = data_dir / "regression" + assert dir.exists() + return dir + + +def test_can_generate_similar_training_inputs( + example_audio_dir: Path, + example_audio_files: List[Path], + example_anns_dir: Path, + regression_dir: Path, +): + old_parameters = json.loads((regression_dir / "params.json").read_text()) + config = get_training_preprocessing_config(old_parameters) + + for audio_file in example_audio_files: + example_file = regression_dir / f"{audio_file.name}.npz" + + dataset = np.load(example_file) + + spec = dataset["spec"][0] + detection_mask = dataset["detection_mask"][0] + size_mask = dataset["size_mask"] + class_mask = dataset["class_mask"] + + project = load_annotation_project( + example_anns_dir, + audio_dir=example_audio_dir, + ) + + clip_annotation = next( + ann + for ann in project.clip_annotations + if ann.clip.recording.path == audio_file + ) + + new_dataset = generate_train_example(clip_annotation, config) + new_spec = new_dataset["spectrogram"].values + new_detection_mask = new_dataset["detection"].values + new_size_mask = new_dataset["size"].values + new_class_mask = new_dataset["class"].values + + assert spec.shape == new_spec.shape + assert detection_mask.shape == new_detection_mask.shape + assert size_mask.shape == new_size_mask.shape + assert class_mask.shape[1:] == new_class_mask.shape[1:] + assert class_mask.shape[0] == new_class_mask.shape[0] + 1 + + x_new, y_new = np.nonzero(new_size_mask.max(axis=0)) + x_orig, y_orig = np.nonzero(np.flipud(size_mask.max(axis=0))) + + assert (x_new == x_orig).all() + + # NOTE: a difference of 1 pixel is due to discrepancies on how + # frequency bins are interpreted. Shouldn't be an issue + assert (y_new == y_orig + 1).all() + + width_new, height_new = new_size_mask[:, x_new, y_new] + width_orig, height_orig = np.flip(size_mask, axis=1)[:, x_orig, y_orig] + + assert (np.floor(width_new) == width_orig).all() + assert (np.ceil(height_new) == height_orig).all() diff --git a/tests/test_preprocessing/__init__.py b/tests/test_preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_preprocessing/test_audio.py b/tests/test_preprocessing/test_audio.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py new file mode 100644 index 0000000..5eb235f --- /dev/null +++ b/tests/test_preprocessing/test_spectrogram.py @@ -0,0 +1,141 @@ +import math +from pathlib import Path +from typing import Callable + +from hypothesis import HealthCheck, given, settings +from hypothesis import strategies as st +from soundevent import arrays + +from batdetect2.preprocess.audio import AudioConfig, load_file_audio +from batdetect2.preprocess.spectrogram import ( + FFTConfig, + FrequencyConfig, + SpecSizeConfig, + SpectrogramConfig, + compute_spectrogram, + duration_to_spec_width, + get_spectrogram_resolution, + pad_audio, + spec_width_to_samples, + stft, +) + + +@settings(suppress_health_check=[HealthCheck.function_scoped_fixture]) +@given( + duration=st.floats(min_value=0.1, max_value=1.0), + window_duration=st.floats(min_value=0.001, max_value=0.01), + window_overlap=st.floats(min_value=0.2, max_value=0.9), + samplerate=st.integers(min_value=256_000, max_value=512_000), +) +def test_can_estimate_correctly_spectrogram_width_from_duration( + duration: float, + window_duration: float, + window_overlap: float, + samplerate: int, + wav_factory: Callable[..., Path], +): + path = wav_factory(duration=duration, samplerate=samplerate) + audio = load_file_audio( + path, + # NOTE: Dont resample nor adjust duration to test if the width + # estimation works on all scenarios + config=AudioConfig(resample=None, duration=None), + ) + spectrogram = stft(audio, window_duration, window_overlap) + + spec_width = duration_to_spec_width( + duration, + samplerate=samplerate, + window_duration=window_duration, + window_overlap=window_overlap, + ) + assert spectrogram.sizes["time"] == spec_width + + rebuilt_duration = ( + spec_width_to_samples( + spec_width, + samplerate=samplerate, + window_duration=window_duration, + window_overlap=window_overlap, + ) + / samplerate + ) + + assert ( + abs(duration - rebuilt_duration) + < (1 - window_overlap) * window_duration + ) + + +@settings( + suppress_health_check=[HealthCheck.function_scoped_fixture], + deadline=400, +) +@given( + duration=st.floats(min_value=0.1, max_value=1.0), + window_duration=st.floats(min_value=0.001, max_value=0.01), + window_overlap=st.floats(min_value=0.2, max_value=0.9), + samplerate=st.integers(min_value=256_000, max_value=512_000), + divide_factor=st.integers(min_value=16, max_value=64), +) +def test_can_pad_audio_to_adjust_spectrogram_width( + duration: float, + window_duration: float, + window_overlap: float, + samplerate: int, + divide_factor: int, + wav_factory: Callable[..., Path], +): + path = wav_factory(duration=duration, samplerate=samplerate) + + audio = load_file_audio( + path, + # NOTE: Dont resample nor adjust duration to test if the width + # estimation works on all scenarios + config=AudioConfig(resample=None, duration=None), + ) + + audio = pad_audio( + audio, + window_duration=window_duration, + window_overlap=window_overlap, + divide_factor=divide_factor, + ) + + spectrogram = stft(audio, window_duration, window_overlap) + assert spectrogram.sizes["time"] % divide_factor == 0 + + +def test_can_estimate_spectrogram_resolution( + wav_factory: Callable[..., Path], +): + path = wav_factory(duration=0.2, samplerate=256_000) + + audio = load_file_audio( + path, + # NOTE: Dont resample nor adjust duration to test if the width + # estimation works on all scenarios + config=AudioConfig(resample=None, duration=None), + ) + + config = SpectrogramConfig( + fft=FFTConfig(), + size=SpecSizeConfig(height=256, resize_factor=0.5), + frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000), + ) + + spec = compute_spectrogram(audio, config=config) + + freq_res, time_res = get_spectrogram_resolution(config) + + assert math.isclose( + arrays.get_dim_step(spec, dim="frequency"), + freq_res, + rel_tol=0.1, + ) + assert math.isclose( + arrays.get_dim_step(spec, dim="time"), + time_res, + rel_tol=0.1, + ) diff --git a/uv.lock b/uv.lock index a3c9677..40d2b66 100644 --- a/uv.lock +++ b/uv.lock @@ -236,7 +236,7 @@ requires-dist = [ { name = "pytorch-lightning", specifier = ">=2.2.2" }, { name = "scikit-learn", specifier = ">=1.2.2" }, { name = "scipy", specifier = ">=1.10.1" }, - { name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.2" }, + { name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" }, { name = "tensorboard", specifier = ">=2.16.2" }, { name = "torch", specifier = ">=1.13.1,<2.5.0" }, { name = "torchaudio", specifier = ">=1.13.1,<2.5.0" }, @@ -2679,15 +2679,15 @@ wheels = [ [[package]] name = "soundevent" -version = "2.2.0" +version = "2.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "email-validator" }, { name = "pydantic" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/5a/f9/77d723df4d8d3a390d32c07325a08bfc669d5fb55e88b98181b5793e7333/soundevent-2.2.0.tar.gz", hash = "sha256:a87a97c8e4bfdadec4b6edc128470919ef9240a344cf924d2ac21f8c6b50acf1", size = 8715229 } +sdist = { url = "https://files.pythonhosted.org/packages/ff/51/83093cabe9ada781a0f7a78f82cc04162d005755b2f0ca3fdcb4ecd47a01/soundevent-2.3.0.tar.gz", hash = "sha256:b75d7674578a52bf196619f8b4b3d9170f2ca321d165ceb45916579a549c3e76", size = 8716539 } wheels = [ - { url = "https://files.pythonhosted.org/packages/09/d8/ce6e2830d47fc3d24db264b94384fa7cdcb8069fd7547d55b7a83857e730/soundevent-2.2.0-py3-none-any.whl", hash = "sha256:c40913c15fc697a82a02df5f62d18ad3b77bfae80a9a5d54c47bc1377d3b4d7c", size = 144188 }, + { url = "https://files.pythonhosted.org/packages/ee/67/4c2d881f9b4a0b453dbee91e119e1e48df0fc92de2cc3062fcd8ad0a7e6b/soundevent-2.3.0-py3-none-any.whl", hash = "sha256:f7c74b1d73a347ebe843187c93130dc8af3214add95e6bc485f64944bea0d690", size = 145513 }, ] [package.optional-dependencies]