diff --git a/batdetect2/data/compat.py b/batdetect2/data/compat.py index 0704c97..2aaf3d3 100644 --- a/batdetect2/data/compat.py +++ b/batdetect2/data/compat.py @@ -195,18 +195,29 @@ def annotation_to_sound_event( uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"), sound_event=sound_event, tags=[ - data.Tag(key=label_key, value=annotation.label), - data.Tag(key=event_key, value=annotation.event), - data.Tag(key=individual_key, value=str(annotation.individual)), + data.Tag( + term=data.term_from_key(label_key), + value=annotation.label, + ), + data.Tag( + term=data.term_from_key(event_key), + value=annotation.event, + ), + data.Tag( + term=data.term_from_key(individual_key), + value=str(annotation.individual), + ), ], ) def file_annotation_to_clip( file_annotation: FileAnnotation, - audio_dir: PathLike = Path.cwd(), + audio_dir: Optional[PathLike] = None, ) -> data.Clip: """Convert file annotation to recording.""" + audio_dir = audio_dir or Path.cwd() + full_path = Path(audio_dir) / file_annotation.id if not full_path.exists(): @@ -241,7 +252,11 @@ def file_annotation_to_clip_annotation( uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"), clip=clip, notes=notes, - tags=[data.Tag(key=label_key, value=file_annotation.label)], + tags=[ + data.Tag( + term=data.term_from_key(label_key), value=file_annotation.label + ) + ], sound_events=[ annotation_to_sound_event( annotation, @@ -286,9 +301,11 @@ def list_file_annotations(path: PathLike) -> List[Path]: def load_annotation_project( path: PathLike, name: Optional[str] = None, - audio_dir: PathLike = Path.cwd(), + audio_dir: Optional[PathLike] = None, ) -> data.AnnotationProject: """Convert annotations to annotation project.""" + audio_dir = audio_dir or Path.cwd() + paths = list_file_annotations(path) if name is None: diff --git a/batdetect2/data/labels.py b/batdetect2/data/labels.py index 67d8122..8e33e96 100644 --- a/batdetect2/data/labels.py +++ b/batdetect2/data/labels.py @@ -3,7 +3,7 @@ from typing import Tuple import numpy as np import xarray as xr from scipy.ndimage import gaussian_filter -from soundevent import data, geometry, arrays +from soundevent import arrays, data, geometry from soundevent.geometry.operations import Positions from soundevent.types import ClassMapper @@ -22,6 +22,8 @@ def generate_heatmaps( class_mapper: ClassMapper, target_sigma: float = TARGET_SIGMA, position: Positions = "bottom-left", + time_scale: float = 1.0, + frequency_scale: float = 1.0, dtype=np.float32, ) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]: shape = dict(zip(spec.dims, spec.shape)) @@ -31,13 +33,6 @@ def generate_heatmaps( "Spectrogram must have time and frequency dimensions." ) - time_duration = arrays.get_dim_width(spec, dim="time") - freq_bandwidth = arrays.get_dim_width(spec, dim="frequency") - - # Compute the size factors - time_scale = 1 / time_duration - frequency_scale = 1 / freq_bandwidth - # Initialize heatmaps detection_heatmap = xr.zeros_like(spec, dtype=dtype) class_heatmap = xr.DataArray( @@ -92,7 +87,7 @@ def generate_heatmaps( ) # Get the class name of the sound event - class_name = class_mapper.transform(sound_event_annotation) + class_name = class_mapper.encode(sound_event_annotation) if class_name is None: # If the label is None skip the sound event diff --git a/batdetect2/data/preprocessing.py b/batdetect2/data/preprocessing.py index f114df7..b46cfbb 100644 --- a/batdetect2/data/preprocessing.py +++ b/batdetect2/data/preprocessing.py @@ -1,7 +1,7 @@ """Module containing functions for preprocessing audio clips.""" -from typing import Optional, Union from pathlib import Path +from typing import Literal, Optional, Union import librosa import librosa.core.spectrum @@ -10,7 +10,7 @@ import xarray as xr from numpy.typing import DTypeLike from pydantic import BaseModel, Field from scipy.signal import resample_poly -from soundevent import audio, data, arrays +from soundevent import arrays, audio, data from soundevent.arrays import operations as ops __all__ = [ @@ -34,32 +34,56 @@ DENOISE_SPEC_AVG = True MAX_SCALE_SPEC = False +class ResampleConfig(BaseModel): + samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) + mode: str = "poly" + + +class AudioConfig(BaseModel): + resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig) + scale: bool = Field(default=SCALE_RAW_AUDIO) + center: bool = True + duration: Optional[float] = DEFAULT_DURATION + + +class FFTConfig(BaseModel): + window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0) + window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1) + window_fn: str = "hann" + + +class FrequencyConfig(BaseModel): + max_freq: int = Field(default=MAX_FREQ_HZ, gt=0) + min_freq: int = Field(default=MIN_FREQ_HZ, gt=0) + + +class PcenConfig(BaseModel): + time_constant: float = 0.4 + hop_length: int = 512 + gain: float = 0.98 + bias: float = 2 + power: float = 0.5 + + +class SpecSizeConfig(BaseModel): + height: int = SPEC_HEIGHT + time_period: float = SPEC_TIME_PERIOD + + +class SpectrogramConfig(BaseModel): + fft: FFTConfig = Field(default_factory=FFTConfig) + frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig) + scale: Union[Literal["log"], None, PcenConfig] = "log" + denoise: bool = True + resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig) + max_scale: bool = MAX_SCALE_SPEC + + class PreprocessingConfig(BaseModel): """Configuration for preprocessing data.""" - target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0) - - scale_audio: bool = Field(default=SCALE_RAW_AUDIO) - - fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0) - - fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1) - - max_freq: int = Field(default=MAX_FREQ_HZ, gt=0) - - min_freq: int = Field(default=MIN_FREQ_HZ, gt=0) - - spec_scale: str = Field(default=SPEC_SCALE) - - denoise_spec_avg: bool = DENOISE_SPEC_AVG - - max_scale_spec: bool = MAX_SCALE_SPEC - - duration: Optional[float] = DEFAULT_DURATION - - spec_height: int = SPEC_HEIGHT - - spec_time_period: float = SPEC_TIME_PERIOD + audio: AudioConfig = Field(default_factory=AudioConfig) + spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig) @classmethod def from_file( @@ -104,7 +128,7 @@ class PreprocessingConfig(BaseModel): def preprocess_audio_clip( clip: data.Clip, - config: PreprocessingConfig = PreprocessingConfig(), + config: Optional[PreprocessingConfig] = None, ) -> xr.DataArray: """Preprocesses audio clip to generate spectrogram. @@ -121,81 +145,117 @@ def preprocess_audio_clip( Preprocessed spectrogram. """ - wav = load_clip_audio( - clip, - target_sampling_rate=config.target_samplerate, - scale=config.scale_audio, - ) - - spec = compute_spectrogram( - wav, - fft_win_length=config.fft_win_length, - fft_overlap=config.fft_overlap, - max_freq=config.max_freq, - min_freq=config.min_freq, - spec_scale=config.spec_scale, - denoise_spec_avg=config.denoise_spec_avg, - max_scale_spec=config.max_scale_spec, - ) - - if config.duration is not None: - spec = adjust_spec_duration(clip, spec, config.duration) - - duration = arrays.get_dim_width(spec, dim="time") - return ops.resize( - spec, - time=int(np.ceil(duration / config.spec_time_period)), - frequency=config.spec_height, - dtype=np.float32, - ) - - -def adjust_spec_duration( - clip: data.Clip, - spec: xr.DataArray, - duration: float, -) -> xr.DataArray: - current_duration = clip.end_time - clip.start_time - - if current_duration == duration: - return spec - - if current_duration > duration: - return arrays.crop_dim( - spec, - dim="time", - start=clip.start_time, - stop=clip.start_time + duration, - ) - - return arrays.extend_dim( - spec, - dim="time", - start=clip.start_time, - stop=clip.start_time + duration, - ) + config = config or PreprocessingConfig() + wav = load_clip_audio(clip, config=config.audio) + spec = compute_spectrogram(wav, config=config.spectrogram) + return spec def load_clip_audio( clip: data.Clip, - target_sampling_rate: int = TARGET_SAMPLERATE_HZ, - scale: bool = SCALE_RAW_AUDIO, + config: Optional[AudioConfig] = None, dtype: DTypeLike = np.float32, ) -> xr.DataArray: + config = config or AudioConfig() + wav = audio.load_clip(clip).sel(channel=0).astype(dtype) - wav = resample_audio(wav, target_sampling_rate, dtype=dtype) + if config.duration is not None: + wav = adjust_audio_duration(wav, duration=config.duration) - if scale: + if config.resample: + wav = resample_audio( + wav, + samplerate=config.resample.samplerate, + dtype=dtype, + ) + + if config.center: wav = ops.center(wav) + + if config.scale: wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav)))) return wav.astype(dtype) +def compute_spectrogram( + wav: xr.DataArray, + config: Optional[SpectrogramConfig] = None, + dtype: DTypeLike = np.float32, +) -> xr.DataArray: + config = config or SpectrogramConfig() + + spec = stft( + wav, + window_duration=config.fft.window_duration, + window_overlap=config.fft.window_overlap, + window_fn=config.fft.window_fn, + dtype=dtype, + ) + + spec = crop_spectrogram_frequencies( + spec, + min_freq=config.frequencies.min_freq, + max_freq=config.frequencies.max_freq, + ) + + spec = scale_spectrogram(spec, scale=config.scale) + + if config.denoise: + spec = denoise_spectrogram(spec) + + if config.resize: + spec = resize_spectrogram(spec, config=config.resize) + + if config.max_scale: + spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) + + return spec.astype(dtype) + + +def crop_spectrogram_frequencies( + spec: xr.DataArray, + min_freq: int = MIN_FREQ_HZ, + max_freq: int = MAX_FREQ_HZ, +) -> xr.DataArray: + return arrays.crop_dim( + spec, + dim="frequency", + start=min_freq, + stop=max_freq, + ).astype(spec.dtype) + + +def adjust_audio_duration( + wave: xr.DataArray, + duration: float, +) -> xr.DataArray: + start_time, end_time = arrays.get_dim_range(wave, dim="time") + current_duration = end_time - start_time + + if current_duration == duration: + return wave + + if current_duration > duration: + return arrays.crop_dim( + wave, + dim="time", + start=start_time, + stop=start_time + duration, + ) + + return arrays.extend_dim( + wave, + dim="time", + start=start_time, + stop=start_time + duration, + ) + + def resample_audio( wav: xr.DataArray, - target_samplerate: int = TARGET_SAMPLERATE_HZ, + samplerate: int = TARGET_SAMPLERATE_HZ, dtype: DTypeLike = np.float32, ) -> xr.DataArray: if "time" not in wav.dims: @@ -207,13 +267,13 @@ def resample_audio( step = arrays.get_dim_step(wav, dim="time") original_samplerate = int(1 / step) - if original_samplerate == target_samplerate: + if original_samplerate == samplerate: return wav.astype(dtype) - gcd = np.gcd(original_samplerate, target_samplerate) + gcd = np.gcd(original_samplerate, samplerate) resampled = resample_poly( wav.values, - target_samplerate // gcd, + samplerate // gcd, original_samplerate // gcd, axis=time_axis, ) @@ -225,7 +285,6 @@ def resample_audio( endpoint=False, dtype=dtype, ) - return xr.DataArray( data=resampled.astype(dtype), dims=wav.dims, @@ -233,70 +292,35 @@ def resample_audio( **wav.coords, "time": arrays.create_time_dim_from_array( resampled_times, - samplerate=target_samplerate, + samplerate=samplerate, ), }, attrs=wav.attrs, ) -def compute_spectrogram( - wav: xr.DataArray, - fft_win_length: float = FFT_WIN_LENGTH_S, - fft_overlap: float = FFT_OVERLAP, - max_freq: int = MAX_FREQ_HZ, - min_freq: int = MIN_FREQ_HZ, - spec_scale: str = SPEC_SCALE, - denoise_spec_avg: bool = True, - max_scale_spec: bool = False, - dtype: DTypeLike = np.float32, -) -> xr.DataArray: - spec = gen_mag_spectrogram( - wav, - window_len=fft_win_length, - overlap_perc=fft_overlap, - dtype=dtype, - ) - - spec = arrays.crop_dim( - spec, - dim="frequency", - start=min_freq, - stop=max_freq, - ).astype(dtype) - - spec = scale_spectrogram(spec, scale=spec_scale) - - if denoise_spec_avg: - spec = denoise_spectrogram(spec) - - if max_scale_spec: - spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) - - return spec.astype(dtype) - - -def gen_mag_spectrogram( +def stft( wave: xr.DataArray, - window_len: float, - overlap_perc: float, + window_duration: float, + window_overlap: float, + window_fn: str = "hann", dtype: DTypeLike = np.float32, ) -> xr.DataArray: start_time, end_time = arrays.get_dim_range(wave, dim="time") step = arrays.get_dim_step(wave, dim="time") sampling_rate = 1 / step - hop_len = window_len * (1 - overlap_perc) - nfft = int(window_len * sampling_rate) - noverlap = int(overlap_perc * nfft) + hop_len = window_duration * (1 - window_overlap) + nfft = int(window_duration * sampling_rate) + noverlap = int(window_overlap * nfft) - # compute spec spec, _ = librosa.core.spectrum._spectrogram( - y=wave.data, + y=wave.data.astype(dtype), power=1, n_fft=nfft, hop_length=nfft - noverlap, center=False, + window=window_fn, ) return xr.DataArray( @@ -316,7 +340,7 @@ def gen_mag_spectrogram( "time": arrays.create_time_dim_from_array( np.linspace( start_time, - end_time - (window_len - hop_len), + end_time - (window_duration - hop_len), spec.shape[1], endpoint=False, dtype=dtype, @@ -333,9 +357,7 @@ def gen_mag_spectrogram( ) -def denoise_spectrogram( - spec: xr.DataArray, -) -> xr.DataArray: +def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray: return xr.DataArray( data=(spec - spec.mean("time")).clip(0), dims=spec.dims, @@ -346,35 +368,53 @@ def denoise_spectrogram( def scale_spectrogram( spec: xr.DataArray, - scale: str = SPEC_SCALE, + scale: Union[Literal["log"], None, PcenConfig], dtype: DTypeLike = np.float32, ) -> xr.DataArray: - samplerate = spec.attrs["original_samplerate"] - - if scale == "pcen": - smoothing_constant = get_pcen_smoothing_constant(samplerate / 10) - return audio.pcen( - spec * (2**31), - smooth=smoothing_constant, - ).astype(dtype) - if scale == "log": - return log_scale(spec, dtype=dtype) + return scale_log(spec, dtype=dtype) + + if isinstance(scale, PcenConfig): + return scale_pcen( + spec, + time_constant=scale.time_constant, + hop_length=scale.hop_length, + gain=scale.gain, + power=scale.power, + bias=scale.bias, + ) return spec -def log_scale( +def scale_pcen( + spec: xr.DataArray, + time_constant: float = 0.4, + hop_length: int = 512, + gain: float = 0.98, + bias: float = 2, + power: float = 0.5, +) -> xr.DataArray: + samplerate = spec.attrs["original_samplerate"] + # NOTE: Not sure why the 10 is there + t_frames = time_constant * samplerate / (float(hop_length) * 10) + smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) + return audio.pcen( + spec * (2**31), + smooth=smoothing_constant, + gain=gain, + bias=bias, + power=power, + ).astype(spec.dtype) + + +def scale_log( spec: xr.DataArray, dtype: DTypeLike = np.float32, ) -> xr.DataArray: samplerate = spec.attrs["original_samplerate"] nfft = spec.attrs["nfft"] - log_scaling = ( - 2.0 - * (1.0 / samplerate) - * (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum()) - ) + log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum()) return xr.DataArray( data=np.log1p(log_scaling * spec).astype(dtype), dims=spec.dims, @@ -383,10 +423,14 @@ def log_scale( ) -def get_pcen_smoothing_constant( - sr: int, - time_constant: float = 0.4, - hop_length: int = 512, -) -> float: - t_frames = time_constant * sr / float(hop_length) - return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2) +def resize_spectrogram( + spec: xr.DataArray, + config: SpecSizeConfig, +) -> xr.DataArray: + duration = arrays.get_dim_width(spec, dim="time") + return ops.resize( + spec, + time=int(np.ceil(duration / config.time_period)), + frequency=config.height, + dtype=np.float32, + ) diff --git a/batdetect2/models/readme.md b/batdetect2/models/checkpoints/readme.md similarity index 100% rename from batdetect2/models/readme.md rename to batdetect2/models/checkpoints/readme.md diff --git a/batdetect2/models/detectors.py b/batdetect2/models/detectors.py index 992e04d..716fc9a 100644 --- a/batdetect2/models/detectors.py +++ b/batdetect2/models/detectors.py @@ -1,4 +1,4 @@ -from typing import Type +from typing import Optional, Type import pytorch_lightning as L import torch @@ -6,11 +6,11 @@ import xarray as xr from soundevent import data from torch import nn, optim -from batdetect2.data.preprocessing import ( - preprocess_audio_clip, - PreprocessingConfig, -) from batdetect2.data.labels import ClassMapper +from batdetect2.data.preprocessing import ( + PreprocessingConfig, + preprocess_audio_clip, +) from batdetect2.models.feature_extractors import Net2DFast from batdetect2.models.post_process import ( PostprocessConfig, @@ -29,11 +29,14 @@ class DetectorModel(L.LightningModule): learning_rate: float = 1e-3, input_height: int = 128, num_features: int = 32, - preprocessing_config: PreprocessingConfig = PreprocessingConfig(), - postprocessing_config: PostprocessConfig = PostprocessConfig(), + preprocessing_config: Optional[PreprocessingConfig] = None, + postprocessing_config: Optional[PostprocessConfig] = None, ): super().__init__() + preprocessing_config = preprocessing_config or PreprocessingConfig() + postprocessing_config = postprocessing_config or PostprocessConfig() + self.save_hyperparameters() self.preprocessing_config = preprocessing_config diff --git a/batdetect2/models/post_process.py b/batdetect2/models/post_process.py index df3a47e..97c6052 100644 --- a/batdetect2/models/post_process.py +++ b/batdetect2/models/post_process.py @@ -1,10 +1,10 @@ """Module for postprocessing model outputs.""" from typing import Callable, List, Tuple, Union -from pydantic import BaseModel, Field import numpy as np import torch +from pydantic import BaseModel, Field from soundevent import data from torch import nn @@ -207,7 +207,7 @@ def compute_sound_events_from_outputs( ), features=[ data.Feature( - name=f"batdetect2_{i}", + term=data.term_from_key(f"batdetect2_{i}"), value=value.item(), ) for i, value in enumerate(feature) diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index b23c3c3..004c414 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -3,18 +3,18 @@ import os import warnings from functools import partial +from multiprocessing import Pool from pathlib import Path from typing import Callable, Optional, Sequence, Union -from tqdm.auto import tqdm -from multiprocessing import Pool import xarray as xr from soundevent import data +from tqdm.auto import tqdm from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps from batdetect2.data.preprocessing import ( - preprocess_audio_clip, PreprocessingConfig, + preprocess_audio_clip, ) PathLike = Union[Path, str, os.PathLike] @@ -25,14 +25,15 @@ __all__ = [ ] - def generate_train_example( clip_annotation: data.ClipAnnotation, class_mapper: ClassMapper, - preprocessing_config: PreprocessingConfig = PreprocessingConfig(), + preprocessing_config: Optional[PreprocessingConfig] = None, target_sigma: float = TARGET_SIGMA, ) -> xr.Dataset: """Generate a training example.""" + preprocessing_config = preprocessing_config or PreprocessingConfig() + spectrogram = preprocess_audio_clip( clip_annotation.clip, config=preprocessing_config, @@ -83,14 +84,18 @@ def load_config(path: PathLike, **kwargs) -> PreprocessingConfig: path = Path(path) if not path.is_file(): - warnings.warn(f"Config file not found: {path}. Using default config.") + warnings.warn( + f"Config file not found: {path}. Using default config.", + stacklevel=1, + ) return PreprocessingConfig(**kwargs) try: return PreprocessingConfig.model_validate_json(path.read_text()) except ValueError as e: warnings.warn( - f"Failed to load config file: {e}. Using default config." + f"Failed to load config file: {e}. Using default config.", + stacklevel=1, ) return PreprocessingConfig(**kwargs) diff --git a/batdetect2/utils/audio_utils.py b/batdetect2/utils/audio_utils.py index b4afa31..347e3ff 100644 --- a/batdetect2/utils/audio_utils.py +++ b/batdetect2/utils/audio_utils.py @@ -90,7 +90,7 @@ def generate_spectrogram( np.abs( np.hanning( int(params["fft_win_length"] * sampling_rate) - ) + ).astype(np.float32) ) ** 2 ).sum() diff --git a/batdetect2/utils/detector_utils.py b/batdetect2/utils/detector_utils.py index 9b1716b..9c297c4 100644 --- a/batdetect2/utils/detector_utils.py +++ b/batdetect2/utils/detector_utils.py @@ -409,7 +409,7 @@ def save_results_to_file(results, op_path: str) -> None: def compute_spectrogram( audio: np.ndarray, - sampling_rate: float, + sampling_rate: int, params: SpectrogramParameters, device: torch.device, ) -> Tuple[float, torch.Tensor]: @@ -627,7 +627,7 @@ def process_spectrogram( def _process_audio_array( audio: np.ndarray, - sampling_rate: float, + sampling_rate: int, model: DetectionModel, config: ProcessingConfiguration, device: torch.device, diff --git a/pyproject.toml b/pyproject.toml index 767d36d..9484940 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "torch>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0", "torchvision>=0.14.0", - "soundevent[audio,geometry,plot]>=2.0.1", + "soundevent[audio,geometry,plot]>=2.2", "click>=8.1.7", "netcdf4>=1.6.5", "tqdm>=4.66.2", diff --git a/tests/test_audio_utils.py b/tests/test_audio_utils.py index 1b489bc..86153aa 100644 --- a/tests/test_audio_utils.py +++ b/tests/test_audio_utils.py @@ -94,7 +94,7 @@ def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor( params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS length = int(duration * samplerate) audio = np.random.rand(length) - _, spectrogram, _ = detector_utils.compute_spectrogram( + _, spectrogram = detector_utils.compute_spectrogram( audio, samplerate, params, diff --git a/tests/test_data/test_labels.py b/tests/test_data/test_labels.py new file mode 100644 index 0000000..d2d7488 --- /dev/null +++ b/tests/test_data/test_labels.py @@ -0,0 +1,120 @@ +from pathlib import Path + +import numpy as np +import xarray as xr +from soundevent import data +from soundevent.types import ClassMapper + +from batdetect2.data.labels import generate_heatmaps + +recording = data.Recording( + samplerate=256_000, + duration=1, + channels=1, + time_expansion=1, + hash="asdf98sdf", + path=Path("/path/to/audio.wav"), +) + +clip = data.Clip( + recording=recording, + start_time=0, + end_time=1, +) + + +class Mapper(ClassMapper): + class_labels = ["bat", "cat"] + + def encode(self, sound_event_annotation: data.SoundEventAnnotation) -> str: + return "bat" + + def decode(self, label: str) -> list: + return [data.Tag(term=data.term_from_key("species"), value="bat")] + + +def test_generated_heatmaps_have_correct_dimensions(): + spec = xr.DataArray( + data=np.random.rand(100, 100), + dims=["time", "frequency"], + coords={ + "time": np.linspace(0, 100, 100, endpoint=False), + "frequency": np.linspace(0, 100, 100, endpoint=False), + }, + ) + + clip_annotation = data.ClipAnnotation( + clip=clip, + sound_events=[ + data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, + geometry=data.BoundingBox( + coordinates=[10, 10, 20, 20], + ), + ), + ) + ], + ) + + class_mapper = Mapper() + + detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( + clip_annotation, + spec, + class_mapper, + ) + + assert isinstance(detection_heatmap, xr.DataArray) + assert detection_heatmap.shape == (100, 100) + assert detection_heatmap.dims == ("time", "frequency") + + assert isinstance(class_heatmap, xr.DataArray) + assert class_heatmap.shape == (2, 100, 100) + assert class_heatmap.dims == ("category", "time", "frequency") + assert class_heatmap.coords["category"].values.tolist() == ["bat", "cat"] + + assert isinstance(size_heatmap, xr.DataArray) + assert size_heatmap.shape == (2, 100, 100) + assert size_heatmap.dims == ("dimension", "time", "frequency") + assert size_heatmap.coords["dimension"].values.tolist() == [ + "width", + "height", + ] + + +def test_generated_heatmap_are_non_zero_at_correct_positions(): + spec = xr.DataArray( + data=np.random.rand(100, 100), + dims=["time", "frequency"], + coords={ + "time": np.linspace(0, 100, 100, endpoint=False), + "frequency": np.linspace(0, 100, 100, endpoint=False), + }, + ) + + clip_annotation = data.ClipAnnotation( + clip=clip, + sound_events=[ + data.SoundEventAnnotation( + sound_event=data.SoundEvent( + recording=recording, + geometry=data.BoundingBox( + coordinates=[10, 10, 20, 20], + ), + ), + ) + ], + ) + + class_mapper = Mapper() + detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( + clip_annotation, + spec, + class_mapper, + ) + assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10 + assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10 + assert class_heatmap.sel(time=10, frequency=10, category="bat") == 1.0 + assert class_heatmap.sel(time=10, frequency=10, category="cat") == 0.0 + assert detection_heatmap.sel(time=10, frequency=10) == 1.0 diff --git a/tests/test_migration/test_preprocessing.py b/tests/test_migration/test_preprocessing.py index 6027898..3b3a855 100644 --- a/tests/test_migration/test_preprocessing.py +++ b/tests/test_migration/test_preprocessing.py @@ -46,8 +46,14 @@ def test_audio_loading_hasnt_changed( ) audio_new = preprocessing.load_clip_audio( clip, - target_sampling_rate=target_sampling_rate, - scale=scale, + config=preprocessing.AudioConfig( + resample=preprocessing.ResampleConfig( + samplerate=target_sampling_rate, + ), + center=scale, + scale=scale, + duration=None, + ), dtype=np.float32, ) @@ -73,18 +79,46 @@ def test_spectrogram_generation_hasnt_changed( min_freq = 10_000 max_freq = 120_000 fft_overlap = 0.75 + + scale = None + if spec_scale == "log": + scale = "log" + elif spec_scale == "pcen": + scale = preprocessing.PcenConfig() + + config = preprocessing.SpectrogramConfig( + fft=preprocessing.FFTConfig( + window_overlap=fft_overlap, + window_duration=fft_win_length, + ), + frequencies=preprocessing.FrequencyConfig( + min_freq=min_freq, + max_freq=max_freq, + ), + scale=scale, + denoise=denoise_spec_avg, + resize=None, + max_scale=max_scale_spec, + ) + recording = data.Recording.from_file( audio_file, time_expansion=time_expansion, ) + clip = data.Clip( recording=recording, start_time=0, end_time=recording.duration, ) + audio = preprocessing.load_clip_audio( clip, - target_sampling_rate=target_sampling_rate, + config=preprocessing.AudioConfig( + resample=preprocessing.ResampleConfig( + samplerate=target_sampling_rate, + ) + ), ) spec_original, _ = audio_utils.generate_spectrogram( @@ -103,18 +137,19 @@ def test_spectrogram_generation_hasnt_changed( new_spec = preprocessing.compute_spectrogram( audio, - fft_win_length=fft_win_length, - fft_overlap=fft_overlap, - max_freq=max_freq, - min_freq=min_freq, - spec_scale=spec_scale, - denoise_spec_avg=denoise_spec_avg, - max_scale_spec=max_scale_spec, + config=config, dtype=np.float32, ) assert spec_original.shape == new_spec.shape assert spec_original.dtype == new_spec.dtype + # Check that the spectrogram content is the same within a tolerance of 1e-5 + # for each element of the spectrogram at least 99.5% of the time. + # NOTE: The pcen function is not the same as the one in the original code + # thus the need for a tolerance, but the values are still very similar. # NOTE: The original spectrogram is flipped vertically - assert np.isclose(spec_original, np.flipud(new_spec.data)).all() + assert ( + np.isclose(spec_original, np.flipud(new_spec.data), atol=1e-5).mean() + > 0.995 + ) diff --git a/uv.lock b/uv.lock index 1b310b7..a3c9677 100644 --- a/uv.lock +++ b/uv.lock @@ -236,7 +236,7 @@ requires-dist = [ { name = "pytorch-lightning", specifier = ">=2.2.2" }, { name = "scikit-learn", specifier = ">=1.2.2" }, { name = "scipy", specifier = ">=1.10.1" }, - { name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.0.1" }, + { name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.2" }, { name = "tensorboard", specifier = ">=2.16.2" }, { name = "torch", specifier = ">=1.13.1,<2.5.0" }, { name = "torchaudio", specifier = ">=1.13.1,<2.5.0" },