From 0bb0caddea4df3f72f0d862da1ee546ffaf11015 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 17:06:17 +0100 Subject: [PATCH] Update augmentations --- src/batdetect2/plotting/__init__.py | 16 +- src/batdetect2/plotting/common.py | 14 +- src/batdetect2/plotting/heatmaps.py | 103 ++- src/batdetect2/train/augmentations.py | 850 ++++++++----------------- src/batdetect2/train/clips.py | 178 ++---- src/batdetect2/train/preprocess.py | 50 +- src/batdetect2/typing/preprocess.py | 4 +- src/batdetect2/typing/train.py | 47 +- src/batdetect2/utils/arrays.py | 77 +-- tests/conftest.py | 7 +- tests/test_train/test_augmentations.py | 62 +- tests/test_train/test_clips.py | 143 ----- tests/test_train/test_preprocessing.py | 93 +-- tests/test_utils/test_arrays.py | 48 +- 14 files changed, 620 insertions(+), 1072 deletions(-) diff --git a/src/batdetect2/plotting/__init__.py b/src/batdetect2/plotting/__init__.py index eab0a16..acf14fb 100644 --- a/src/batdetect2/plotting/__init__.py +++ b/src/batdetect2/plotting/__init__.py @@ -1,6 +1,11 @@ from batdetect2.plotting.clip_annotations import plot_clip_annotation from batdetect2.plotting.clip_predictions import plot_clip_prediction from batdetect2.plotting.clips import plot_clip +from batdetect2.plotting.common import plot_spectrogram +from batdetect2.plotting.heatmaps import ( + plot_classification_heatmap, + plot_detection_heatmap, +) from batdetect2.plotting.matches import ( plot_cross_trigger_match, plot_false_negative_match, @@ -13,9 +18,12 @@ __all__ = [ "plot_clip", "plot_clip_annotation", "plot_clip_prediction", - "plot_matches", - "plot_false_positive_match", - "plot_true_positive_match", - "plot_false_negative_match", "plot_cross_trigger_match", + "plot_false_negative_match", + "plot_false_positive_match", + "plot_matches", + "plot_spectrogram", + "plot_true_positive_match", + "plot_detection_heatmap", + "plot_classification_heatmap", ] diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index 0e5f003..a1b1b93 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import matplotlib.pyplot as plt +import torch from matplotlib import axes __all__ = [ @@ -12,7 +13,7 @@ __all__ = [ def create_ax( ax: Optional[axes.Axes] = None, - figsize: Tuple[int, int] = (10, 10), + figsize: Optional[Tuple[int, int]] = None, **kwargs, ) -> axes.Axes: """Create a new axis if none is provided""" @@ -20,3 +21,14 @@ def create_ax( _, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore return ax # type: ignore + + +def plot_spectrogram( + spec: torch.Tensor, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + cmap="gray", +) -> axes.Axes: + ax = create_ax(ax=ax, figsize=figsize) + ax.pcolormesh(spec.numpy(), cmap=cmap) + return ax diff --git a/src/batdetect2/plotting/heatmaps.py b/src/batdetect2/plotting/heatmaps.py index a3df74a..29f261b 100644 --- a/src/batdetect2/plotting/heatmaps.py +++ b/src/batdetect2/plotting/heatmaps.py @@ -1,26 +1,115 @@ """Plot heatmaps""" -from typing import Optional, Tuple +from typing import List, Optional, Tuple, Union -import xarray as xr -from matplotlib import axes +import numpy as np +import torch +from matplotlib import axes, patches +from matplotlib.cm import get_cmap +from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba from batdetect2.plotting.common import create_ax -def plot_heatmap( - heatmap: xr.DataArray, +def plot_detection_heatmap( + heatmap: Union[torch.Tensor, np.ndarray], ax: Optional[axes.Axes] = None, figsize: Tuple[int, int] = (10, 10), + threshold: Optional[float] = None, + alpha: float = 1, + cmap: Union[str, Colormap] = "jet", + color: Optional[str] = None, ) -> axes.Axes: ax = create_ax(ax, figsize=figsize) + if isinstance(heatmap, torch.Tensor): + heatmap = heatmap.numpy() + + if threshold is not None: + heatmap = np.ma.masked_where( + heatmap < threshold, + heatmap, + ) + + if color is not None: + cmap = create_colormap(color) + ax.pcolormesh( - heatmap.time, - heatmap.frequency, heatmap, vmax=1, vmin=0, + cmap=cmap, + alpha=alpha, ) return ax + + +def plot_classification_heatmap( + heatmap: Union[torch.Tensor, np.ndarray], + ax: Optional[axes.Axes] = None, + figsize: Tuple[int, int] = (10, 10), + class_names: Optional[List[str]] = None, + threshold: Optional[float] = 0.1, + alpha: float = 1, + cmap: Union[str, Colormap] = "tab20", +): + ax = create_ax(ax, figsize=figsize) + + if isinstance(heatmap, torch.Tensor): + heatmap = heatmap.numpy() + + if heatmap.ndim == 4: + heatmap = heatmap[0] + + if heatmap.ndim != 3: + raise ValueError("Expecting a 3-dimensional array") + + num_classes = heatmap.shape[0] + + if class_names is None: + class_names = [f"class_{i}" for i in range(num_classes)] + + if len(class_names) != num_classes: + raise ValueError("Inconsistent number of class names") + + if not isinstance(cmap, Colormap): + cmap = get_cmap(cmap) + + handles = [] + + for index, class_heatmap in enumerate(heatmap): + class_name = class_names[index] + + color = cmap(index / num_classes) + + max = class_heatmap.max() + + if max == 0: + continue + + if threshold is not None: + class_heatmap = np.ma.masked_where( + class_heatmap < threshold, + class_heatmap, + ) + + ax.pcolormesh( + class_heatmap, + vmax=1, + vmin=0, + cmap=create_colormap(color), # type: ignore + alpha=alpha, + ) + + handles.append(patches.Patch(color=color, label=class_name)) + + ax.legend(handles=handles) + return ax + + +def create_colormap(color: str) -> Colormap: + (r, g, b, a) = to_rgba(color) + return LinearSegmentedColormap.from_list( + "cmap", colors=[(0, 0, 0, 0), (r, g, b, a)] + ) diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index c779d7e..20c1de6 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -1,39 +1,17 @@ -"""Applies data augmentation techniques to BatDetect2 training examples. +"""Applies data augmentation techniques to BatDetect2 training examples.""" -This module provides various functions and configurable components for applying -data augmentation to the training examples (`xr.Dataset` containing audio, -spectrogram, and target heatmaps) generated by the -`batdetect2.train.preprocess` module. - -Data augmentation artificially increases the diversity of the training data by -applying random transformations, which generally helps improve the robustness -and generalization performance of trained models. - -Augmentations included: -- Time-based: `select_subclip`, `warp_spectrogram`, `mask_time`. -- Amplitude/Noise-based: `mix_examples`, `add_echo`, `scale_volume`, - `mask_frequency`. - -Some augmentations modify the audio waveform and require recomputing the -spectrogram using the `PreprocessorProtocol`, while others operate directly -on the spectrogram or target heatmaps. The entire augmentation pipeline can be -configured using the `AugmentationsConfig` class, specifying a sequence of -augmentation steps, each with its own parameters and application probability. -The `build_augmentations` function constructs the final augmentation callable -from this configuration. -""" - -from functools import partial -from typing import Annotated, Callable, List, Literal, Optional, Union +import warnings +from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union import numpy as np -import xarray as xr +import torch from loguru import logger from pydantic import Field -from soundevent import arrays, data +from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.typing import Augmentation, PreprocessorProtocol +from batdetect2.typing.train import PreprocessedExample from batdetect2.utils.arrays import adjust_width __all__ = [ @@ -57,11 +35,8 @@ __all__ = [ "warp_spectrogram", ] -ExampleSource = Callable[[], xr.Dataset] -"""Type alias for a function that returns a training example (`xr.Dataset`). - -Used by the `mix_examples` augmentation to fetch another example to mix with. -""" +ExampleSource = Callable[[], PreprocessedExample] +"""Type alias for a function that returns a training example""" class MixAugmentationConfig(BaseConfig): @@ -80,56 +55,18 @@ class MixAugmentationConfig(BaseConfig): def mix_examples( - example: xr.Dataset, - other: xr.Dataset, + example: PreprocessedExample, + other: PreprocessedExample, preprocessor: PreprocessorProtocol, - weight: Optional[float] = None, - min_weight: float = 0.3, - max_weight: float = 0.7, -) -> xr.Dataset: - """Combine two training examples using MixUp augmentation. + weight: float, +) -> PreprocessedExample: + """Combine two training examples.""" + audio1 = example.audio + audio2 = adjust_width(other.audio, audio1.shape[-1]) - Performs a weighted linear combination of the audio waveforms from two - examples (`example` and `other`). The spectrogram is then *recomputed* - from the mixed audio using the provided `preprocessor`. Target heatmaps - (detection, class) are combined by taking the element-wise maximum. Target - size heatmaps are combined by element-wise addition. + combined = weight * audio1 + (1 - weight) * audio2 - Parameters - ---------- - example : xr.Dataset - The primary training example. - other : xr.Dataset - The second training example to mix with `example`. - preprocessor : PreprocessorProtocol - The preprocessor used to recompute the spectrogram from mixed audio. - Ensures consistency with original preprocessing. - weight : float, optional - The mixing weight (lambda) applied to the primary `example`. The weight - for `other` will be `(1 - weight)`. If None, a random weight is chosen - uniformly between `min_weight` and `max_weight`. - min_weight : float, default=0.3 - Minimum value for the random weight lambda. - max_weight : float, default=0.7 - Maximum value for the random weight lambda. Must be >= `min_weight`. - - Returns - ------- - xr.Dataset - A new dataset representing the mixed example, with combined audio, - recomputed spectrogram, and combined target heatmaps. Attributes from - the primary `example` are preserved. - """ - if weight is None: - weight = np.random.uniform(min_weight, max_weight) - - audio1 = example["audio"] - audio2 = adjust_width(other["audio"].values, len(audio1)) - - with xr.set_options(keep_attrs=True): - combined = weight * audio1 + (1 - weight) * audio2 - - spectrogram = preprocessor.process_numpy(combined.data) + spectrogram = preprocessor(combined) # NOTE: The subclip's spectrogram might be slightly longer than the # spectrogram computed from the subclip's audio. This is due to a @@ -137,38 +74,30 @@ def mix_examples( # spectrogram parameters to precisely determine the corresponding audio # samples. To work around this, we pad the computed spectrogram with zeros # as needed. - previous_width = len(example["time"]) + previous_width = example.spectrogram.shape[-1] spectrogram = adjust_width(spectrogram, previous_width) - detection_heatmap = xr.apply_ufunc( - np.maximum, - example["detection"], - adjust_width(other["detection"].values, previous_width), + detection_heatmap = torch.maximum( + example.detection_heatmap, + adjust_width(other.detection_heatmap, previous_width), ) - class_heatmap = xr.apply_ufunc( - np.maximum, - example["class"], - adjust_width(other["class"].values, previous_width), + class_heatmap = torch.maximum( + example.class_heatmap, + adjust_width(other.class_heatmap, previous_width), ) - size_heatmap = example["size"] + adjust_width( - other["size"].values, previous_width + size_heatmap = torch.maximum( + example.size_heatmap, + adjust_width(other.size_heatmap, previous_width), ) - return xr.Dataset( - { - "audio": combined, - "spectrogram": xr.DataArray( - data=spectrogram, - dims=example["spectrogram"].dims, - coords=example["spectrogram"].coords, - ), - "detection": detection_heatmap, - "class": class_heatmap, - "size": size_heatmap, - }, - attrs=example.attrs, + return PreprocessedExample( + audio=combined, + spectrogram=spectrogram, + detection_heatmap=detection_heatmap, + class_heatmap=class_heatmap, + size_heatmap=size_heatmap, ) @@ -185,61 +114,45 @@ class EchoAugmentationConfig(BaseConfig): max_weight: float = 1.0 +class AddEcho(torch.nn.Module): + def __init__( + self, + preprocessor: PreprocessorProtocol, + min_weight: float = 0.1, + max_weight: float = 1.0, + max_delay: float = 0.005, + ): + super().__init__() + self.preprocessor = preprocessor + self.min_weight = min_weight + self.max_weight = max_weight + self.max_delay = max_delay + + def forward(self, example: PreprocessedExample) -> PreprocessedExample: + delay = np.random.uniform(0, self.max_delay) + weight = np.random.uniform(self.min_weight, self.max_weight) + return add_echo( + example, + preprocessor=self.preprocessor, + delay=delay, + weight=weight, + ) + + def add_echo( - example: xr.Dataset, + example: PreprocessedExample, preprocessor: PreprocessorProtocol, - delay: Optional[float] = None, - weight: Optional[float] = None, - min_weight: float = 0.1, - max_weight: float = 1.0, - max_delay: float = 0.005, -) -> xr.Dataset: - """Add a synthetic echo to the audio waveform. + delay: float, + weight: float, +) -> PreprocessedExample: + """Add a synthetic echo to the audio waveform.""" - Creates an echo by adding a delayed and attenuated version of the original - audio waveform back onto itself. The spectrogram is then recomputed from - the modified audio. Target heatmaps remain unchanged. + audio = example.audio + delay_steps = int(preprocessor.samplerate * delay) + audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1]) - Parameters - ---------- - example : xr.Dataset - The input training example containing 'audio', 'spectrogram', etc. - preprocessor : PreprocessorProtocol - Preprocessor used to recompute the spectrogram from the modified audio. - delay : float, optional - The delay time in seconds for the echo. If None, chosen randomly - between 0 and `max_delay`. - weight : float, optional - The relative amplitude (weight) of the echo compared to the original. - If None, chosen randomly between `min_weight` and `max_weight`. - min_weight : float, default=0.1 - Minimum value for the random echo weight. - max_weight : float, default=1.0 - Maximum value for the random echo weight. Must be >= `min_weight`. - max_delay : float, default=0.005 - Maximum value for the random echo delay in seconds. Must be >= 0. - - Returns - ------- - xr.Dataset - A new dataset with the echo added to the 'audio' variable and the - 'spectrogram' variable recomputed. Other variables (targets, attrs) - are copied from the original example. - """ - if delay is None: - delay = np.random.uniform(0, max_delay) - - if weight is None: - weight = np.random.uniform(min_weight, max_weight) - - audio = example["audio"] - step = arrays.get_dim_step(audio, "audio_time") - audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0) - - with xr.set_options(keep_attrs=True): - audio = audio + weight * audio_delay - - spectrogram = preprocessor.process_numpy(audio.data) + audio = audio + weight * audio_delay + spectrogram = preprocessor(audio) # NOTE: The subclip's spectrogram might be slightly longer than the # spectrogram computed from the subclip's audio. This is due to a @@ -249,17 +162,15 @@ def add_echo( # as needed. spectrogram = adjust_width( spectrogram, - example["spectrogram"].sizes["time"], + example.spectrogram.shape[-1], ) - return example.assign( + return PreprocessedExample( audio=audio, - spectrogram=xr.DataArray( - data=spectrogram, - dims=example["spectrogram"].dims, - coords=example["spectrogram"].coords, - attrs=example["spectrogram"].attrs, - ), + spectrogram=spectrogram, + detection_heatmap=example.detection_heatmap, + class_heatmap=example.class_heatmap, + size_heatmap=example.size_heatmap, ) @@ -272,46 +183,29 @@ class VolumeAugmentationConfig(BaseConfig): max_scaling: float = 2.0 +class ScaleVolume(torch.nn.Module): + def __init__(self, min_scaling: float, max_scaling: float): + super().__init__() + self.min_scaling = min_scaling + self.max_scaling = max_scaling + + def forward(self, example: PreprocessedExample) -> PreprocessedExample: + factor = np.random.uniform(self.min_scaling, self.max_scaling) + return scale_volume(example, factor=factor) + + def scale_volume( - example: xr.Dataset, + example: PreprocessedExample, factor: Optional[float] = None, - max_scaling: float = 2, - min_scaling: float = 0, -) -> xr.Dataset: - """Scale the amplitude of the spectrogram by a random factor. - - Multiplies the entire spectrogram DataArray by a scaling factor chosen - uniformly between `min_scaling` and `max_scaling`. This simulates changes - in recording volume or distance to the sound source directly in the - spectrogram domain. Audio and target heatmaps are unchanged. - - Parameters - ---------- - example : xr.Dataset - The input training example containing 'spectrogram'. - factor : float, optional - The scaling factor to apply. If None, chosen randomly between - `min_scaling` and `max_scaling`. - min_scaling : float, default=0.0 - Minimum value for the random scaling factor. Must be non-negative. - max_scaling : float, default=2.0 - Maximum value for the random scaling factor. Must be >= `min_scaling`. - - Returns - ------- - xr.Dataset - A new dataset with the 'spectrogram' variable scaled. - - Raises - ------ - ValueError - If `min_scaling` > `max_scaling` or if `min_scaling` is negative. - """ - if factor is None: - factor = np.random.uniform(min_scaling, max_scaling) - - with xr.set_options(keep_attrs=True): - return example.assign(spectrogram=example["spectrogram"] * factor) +) -> PreprocessedExample: + """Scale the amplitude of the spectrogram by a random factor.""" + return PreprocessedExample( + audio=example.audio, + size_heatmap=example.size_heatmap, + class_heatmap=example.class_heatmap, + detection_heatmap=example.detection_heatmap, + spectrogram=example.spectrogram * factor, + ) class WarpAugmentationConfig(BaseConfig): @@ -320,151 +214,66 @@ class WarpAugmentationConfig(BaseConfig): delta: float = 0.04 +class WarpSpectrogram(torch.nn.Module): + def __init__(self, delta: float = 0.04) -> None: + super().__init__() + self.delta = delta + + def forward(self, example: PreprocessedExample) -> PreprocessedExample: + factor = np.random.uniform(1 - self.delta, 1 + self.delta) + return warp_spectrogram(example, factor=factor) + + def warp_spectrogram( - example: xr.Dataset, - factor: Optional[float] = None, - delta: float = 0.04, -) -> xr.Dataset: - """Apply time warping by resampling the time axis. + example: PreprocessedExample, factor: float +) -> PreprocessedExample: + """Apply time warping by resampling the time axis.""" + target_shape = example.spectrogram.shape + new_width = int(target_shape[-1] * factor) - Stretches or compresses the time axis of the spectrogram and all target - heatmaps by a random `factor` (chosen between `1-delta` and `1+delta`). - Uses linear interpolation for the spectrogram and nearest neighbor for the - discrete-like target heatmaps. Updates the 'time' coordinate accordingly. - The audio waveform is not modified. - - Parameters - ---------- - example : xr.Dataset - Input training example with 'spectrogram', 'detection', 'class', - 'size', and 'time' coordinates. - factor : float, optional - The warping factor. If None, chosen randomly between `1-delta` and - `1+delta`. Values > 1 stretch time, < 1 compress time. - delta : float, default=0.04 - Controls the range `[1-delta, 1+delta]` for the random warp factor. - Must be >= 0 and < 1. - - Returns - ------- - xr.Dataset - A new dataset with time-warped spectrogram and target heatmaps, and - updated 'time' coordinates. - """ - if factor is None: - factor = np.random.uniform(1 - delta, 1 + delta) - - start_time, end_time = arrays.get_dim_range( - example, # type: ignore - "time", - ) - duration = end_time - start_time - - new_time = np.linspace( - start_time, - start_time + duration * factor, - example.time.size, + spectrogram = ( + torch.nn.functional.interpolate( + adjust_width(example.spectrogram, new_width) + .unsqueeze(0) + .unsqueeze(0), + size=target_shape, + mode="bilinear", + ) + .squeeze(0) + .squeeze(0) ) - with xr.set_options(keep_attrs=True): - spectrogram = ( - example["spectrogram"] - .interp( - coords={"time": new_time}, - method="linear", - kwargs=dict( - fill_value=0, - ), - ) - .clip(min=0) + detection = ( + torch.nn.functional.interpolate( + adjust_width(example.detection_heatmap, new_width) + .unsqueeze(0) + .unsqueeze(0), + size=target_shape, + mode="nearest", ) + .squeeze(0) + .squeeze(0) + ) - detection = example["detection"].interp( - time=new_time, - method="nearest", - kwargs=dict( - fill_value=0, - ), - ) + classification = torch.nn.functional.interpolate( + adjust_width(example.class_heatmap, new_width).unsqueeze(1), + size=target_shape, + mode="nearest", + ).squeeze(1) - classification = example["class"].interp( - time=new_time, - method="nearest", - kwargs=dict( - fill_value=0, - ), - ) + size = torch.nn.functional.interpolate( + adjust_width(example.size_heatmap, new_width).unsqueeze(1), + size=target_shape, + mode="nearest", + ).squeeze(1) - size = example["size"].interp( - time=new_time, - method="nearest", - kwargs=dict( - fill_value=0, - ), - ) - - return example.assign( - { - "time": new_time, - "spectrogram": spectrogram, - "detection": detection, - "class": classification, - "size": size, - } - ) - - -def mask_axis( - array: xr.DataArray, - dim: str, - start: float, - end: float, - mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean, -) -> xr.DataArray: - """Mask values along a specified dimension. - - Sets values in the DataArray to `mask_value` where the coordinate along - `dim` falls within the range (`start`, `end`). Values outside this range - are kept. Used as a helper for time/frequency masking. - - Parameters - ---------- - array : xr.DataArray - The input DataArray (e.g., spectrogram). - dim : str - The name of the dimension along which to mask - (e.g., "time", "frequency"). - start : float - The starting coordinate value for the mask range. - end : float - The ending coordinate value for the mask range (exclusive, typically). - Values >= start and < end will be masked. - mask_value : float or Callable[[xr.DataArray], float], default=np.mean - The value to use for masking. Can be a fixed float (e.g., 0.0) or a - callable (like `np.mean`, `np.median`) that computes the value from - the input `array`. - - Returns - ------- - xr.DataArray - The DataArray with the specified range along `dim` masked. - - Raises - ------ - ValueError - If `dim` is not found in the array's dimensions or coordinates. - """ - if dim not in array.dims: - raise ValueError(f"Axis {dim} not found in array") - - coord = array.coords[dim] - condition = (coord < start) | (coord > end) - - if callable(mask_value): - mask_value = mask_value(array) - - with xr.set_options(keep_attrs=True): - return array.where(condition, other=mask_value) + return PreprocessedExample( + audio=example.audio, + size_heatmap=size, + class_heatmap=classification, + detection_heatmap=detection, + spectrogram=spectrogram, + ) class TimeMaskAugmentationConfig(BaseConfig): @@ -474,48 +283,52 @@ class TimeMaskAugmentationConfig(BaseConfig): max_masks: int = 3 +class MaskTime(torch.nn.Module): + def __init__(self, max_perc: float = 0.05, max_masks: int = 3) -> None: + super().__init__() + self.max_perc = max_perc + self.max_masks = max_masks + + def forward(self, example: PreprocessedExample) -> PreprocessedExample: + num_masks = np.random.randint(1, self.max_masks + 1) + width = example.spectrogram.shape[-1] + + mask_size = np.random.randint( + low=0, + high=int(self.max_perc * width), + size=num_masks, + ) + mask_start = np.random.randint( + low=0, + high=width - mask_size, + size=num_masks, + ) + masks = [ + (start, start + size) for start, size in zip(mask_start, mask_size) + ] + return mask_time(example, masks) + + def mask_time( - example: xr.Dataset, - max_perc: float = 0.05, - max_mask: int = 3, -) -> xr.Dataset: - """Apply random time masking (SpecAugment) to the spectrogram. + example: PreprocessedExample, + masks: List[Tuple[int, int]], +) -> PreprocessedExample: + """Apply time masking to the spectrogram.""" - Randomly selects a number of time intervals (up to `max_masks`) and masks - (sets to the mean value) the spectrogram within those intervals. The width - of each mask is chosen randomly up to `max_perc` of the total duration. - Only the 'spectrogram' variable is modified. + for start, end in masks: + example.spectrogram[:, start:end] = example.spectrogram.mean() + example.class_heatmap[:, :, start:end] = 0 + example.size_heatmap[:, :, start:end] = 0 + example.detection_heatmap[:, start:end] = 0 - Parameters - ---------- - example : xr.Dataset - Input training example containing 'spectrogram' and 'time' coordinate. - max_perc : float, default=0.05 - Maximum width of a single mask as a fraction of total duration. - max_masks : int, default=3 - Maximum number of time masks to apply. - - Returns - ------- - xr.Dataset - Dataset with time masking applied to the 'spectrogram'. - """ - num_masks = np.random.randint(1, max_mask + 1) - start_time, end_time = arrays.get_dim_range( - example, # type: ignore - "time", + return PreprocessedExample( + audio=example.audio, + size_heatmap=example.size_heatmap, + class_heatmap=example.class_heatmap, + detection_heatmap=example.detection_heatmap, + spectrogram=example.spectrogram, ) - spectrogram = example["spectrogram"] - for _ in range(num_masks): - mask_size = np.random.uniform(0, max_perc) * (end_time - start_time) - start = np.random.uniform(start_time, end_time - mask_size) - end = start + mask_size - spectrogram = mask_axis(spectrogram, "time", start, end) - - with xr.set_options(keep_attrs=True): - return example.assign(spectrogram=spectrogram) - class FrequencyMaskAugmentationConfig(BaseConfig): augmentation_type: Literal["mask_freq"] = "mask_freq" @@ -524,54 +337,51 @@ class FrequencyMaskAugmentationConfig(BaseConfig): max_masks: int = 3 +class MaskFrequency(torch.nn.Module): + def __init__(self, max_perc: float = 0.10, max_masks: int = 3) -> None: + super().__init__() + self.max_perc = max_perc + self.max_masks = max_masks + + def forward(self, example: PreprocessedExample) -> PreprocessedExample: + num_masks = np.random.randint(1, self.max_masks + 1) + height = example.spectrogram.shape[-2] + + mask_size = np.random.randint( + low=0, + high=int(self.max_perc * height), + size=num_masks, + ) + mask_start = np.random.randint( + low=0, + high=height - mask_size, + size=num_masks, + ) + masks = [ + (start, start + size) for start, size in zip(mask_start, mask_size) + ] + return mask_frequency(example, masks) + + def mask_frequency( - example: xr.Dataset, - max_perc: float = 0.10, - max_masks: int = 3, -) -> xr.Dataset: - """Apply random frequency masking (SpecAugment) to the spectrogram. + example: PreprocessedExample, + masks: List[Tuple[int, int]], +) -> PreprocessedExample: + """Apply frequency masking to the spectrogram.""" + for start, end in masks: + example.spectrogram[start:end, :] = example.spectrogram.mean() + example.class_heatmap[:, start:end, :] = 0 + example.size_heatmap[:, start:end, :] = 0 + example.detection_heatmap[start:end, :] = 0 - Randomly selects a number of frequency intervals (up to `max_masks`) and - masks (sets to the mean value) the spectrogram within those intervals. The - height of each mask is chosen randomly up to `max_perc` of the total - frequency range. Only the 'spectrogram' variable is modified. - - Parameters - ---------- - example : xr.Dataset - Input training example containing 'spectrogram' and 'frequency' - coordinate. - max_perc : float, default=0.10 - Maximum height of a single mask as a fraction of total frequency range. - max_masks : int, default=3 - Maximum number of frequency masks to apply. - - Returns - ------- - xr.Dataset - Dataset with frequency masking applied to the 'spectrogram'. - - Raises - ------ - ValueError - If frequency coordinate info is missing or invalid. - """ - num_masks = np.random.randint(1, max_masks + 1) - min_freq, max_freq = arrays.get_dim_range( - example, # type: ignore - "frequency", + return PreprocessedExample( + audio=example.audio, + size_heatmap=example.size_heatmap, + class_heatmap=example.class_heatmap, + detection_heatmap=example.detection_heatmap, + spectrogram=example.spectrogram, ) - spectrogram = example["spectrogram"] - for _ in range(num_masks): - mask_size = np.random.uniform(0, max_perc) * (max_freq - min_freq) - start = np.random.uniform(min_freq, max_freq - mask_size) - end = start + mask_size - spectrogram = mask_axis(spectrogram, "frequency", start, end) - - with xr.set_options(keep_attrs=True): - return example.assign(spectrogram=spectrogram) - AugmentationConfig = Annotated[ Union[ @@ -588,23 +398,12 @@ AugmentationConfig = Annotated[ class AugmentationsConfig(BaseConfig): - """Configuration for a sequence of data augmentations. - - Attributes - ---------- - steps : List[AugmentationConfig] - An ordered list of configuration objects, each defining a single - augmentation step (e.g., MixAugmentationConfig, - TimeMaskAugmentationConfig). Each step's configuration must include an - `augmentation_type` field and a `probability` field, along with - type-specific parameters. The augmentations will be applied - (probabilistically) in the sequence defined by this list. - """ + """Configuration for a sequence of data augmentations.""" steps: List[AugmentationConfig] = Field(default_factory=list) -class MaybeApply: +class MaybeApply(torch.nn.Module): """Applies an augmentation function probabilistically.""" def __init__( @@ -621,10 +420,11 @@ class MaybeApply: probability : float, default=0.5 The probability (0.0 to 1.0) of applying the augmentation. """ + super().__init__() self.augmentation = augmentation self.probability = probability - def __call__(self, example: xr.Dataset) -> xr.Dataset: + def __call__(self, example: PreprocessedExample) -> PreprocessedExample: """Apply the wrapped augmentation with configured probability. Parameters @@ -643,25 +443,8 @@ class MaybeApply: return self.augmentation(example) -class AudioMixer: - """Callable class for MixUp augmentation, handling example fetching. - - Wraps the `mix_examples` logic, providing the necessary `example_source` - to fetch a second example dynamically when called. - - Parameters - ---------- - min_weight : float - Minimum mixing weight (lambda) for the primary example. - max_weight : float - Maximum mixing weight (lambda) for the primary example. - example_source : ExampleSource (Callable[[], xr.Dataset]) - A function that, when called, returns another random training example - dataset (`xr.Dataset`) to be mixed with the input example. - preprocessor : PreprocessorProtocol - The preprocessor needed to recompute the spectrogram after mixing - audio. - """ +class AudioMixer(torch.nn.Module): + """Callable class for MixUp augmentation, handling example fetching.""" def __init__( self, @@ -671,32 +454,21 @@ class AudioMixer: preprocessor: PreprocessorProtocol, ): """Initialize the AudioMixer.""" + super().__init__() self.min_weight = min_weight self.example_source = example_source self.max_weight = max_weight self.preprocessor = preprocessor - def __call__(self, example: xr.Dataset) -> xr.Dataset: - """Fetch another example and perform mixup. - - Parameters - ---------- - example : xr.Dataset - The primary input training example. - - Returns - ------- - xr.Dataset - The mixed training example. Returns the original example if - fetching the second example fails. - """ + def __call__(self, example: PreprocessedExample) -> PreprocessedExample: + """Fetch another example and perform mixup.""" other = self.example_source() + weight = np.random.uniform(self.min_weight, self.max_weight) return mix_examples( example, other, self.preprocessor, - min_weight=self.min_weight, - max_weight=self.max_weight, + weight=weight, ) @@ -704,48 +476,16 @@ def build_augmentation_from_config( config: AugmentationConfig, preprocessor: PreprocessorProtocol, example_source: Optional[ExampleSource] = None, -) -> Augmentation: - """Factory function to build a single augmentation from its config. - - Takes a configuration object for one augmentation step (which includes the - `augmentation_type` discriminator) and returns the corresponding functional - augmentation callable (e.g., a `partial` function or a callable class like - `AudioMixer`). - - Parameters - ---------- - config : AugmentationConfig - Configuration object for a single augmentation (e.g., instance of - `MixAugmentationConfig`, `EchoAugmentationConfig`, etc.). - preprocessor : PreprocessorProtocol - The preprocessor object, required by augmentations that modify audio - and need to recompute the spectrogram (e.g., mixup, echo). - example_source : ExampleSource, optional - A callable that provides other training examples. Required only if - the configuration includes `MixAugmentationConfig` (`augmentation_type` - is "mix_audio"). - - Returns - ------- - Augmentation (Callable[[xr.Dataset], xr.Dataset]) - A callable function that takes a training example (`xr.Dataset`) and - returns a potentially augmented version. - - Raises - ------ - ValueError - If `config.augmentation_type` is "mix_audio" but `example_source` is - None. - NotImplementedError - If `config.augmentation_type` does not match any known augmentation - type. - """ +) -> Optional[Augmentation]: + """Factory function to build a single augmentation from its config.""" if config.augmentation_type == "mix_audio": if example_source is None: - raise ValueError( + warnings.warn( "Mix audio augmentation ('mix_audio') requires an " - "'example_source' callable to be provided." + "'example_source' callable to be provided.", + stacklevel=2, ) + return None return AudioMixer( example_source=example_source, @@ -755,8 +495,7 @@ def build_augmentation_from_config( ) if config.augmentation_type == "add_echo": - return partial( - add_echo, + return AddEcho( preprocessor=preprocessor, max_delay=config.max_delay, min_weight=config.min_weight, @@ -764,28 +503,24 @@ def build_augmentation_from_config( ) if config.augmentation_type == "scale_volume": - return partial( - scale_volume, + return ScaleVolume( max_scaling=config.max_scaling, min_scaling=config.min_scaling, ) if config.augmentation_type == "warp": - return partial( - warp_spectrogram, + return WarpSpectrogram( delta=config.delta, ) if config.augmentation_type == "mask_time": - return partial( - mask_time, + return MaskTime( max_perc=config.max_perc, - max_mask=config.max_masks, + max_masks=config.max_masks, ) if config.augmentation_type == "mask_freq": - return partial( - mask_frequency, + return MaskFrequency( max_perc=config.max_perc, max_masks=config.max_masks, ) @@ -813,38 +548,7 @@ def build_augmentations( config: Optional[AugmentationsConfig] = None, example_source: Optional[ExampleSource] = None, ) -> Augmentation: - """Build a composite augmentation pipeline function from configuration. - - Takes the overall `AugmentationsConfig` (containing a list of individual - augmentation steps), builds the callable function for each step using - `build_augmentation_from_config`, wraps each function with `MaybeApply` - to handle its application probability, and returns a single - `Augmentation` function that applies the entire sequence. - - Parameters - ---------- - config : AugmentationsConfig - The configuration object detailing the sequence of augmentation steps. - preprocessor : PreprocessorProtocol - The preprocessor object, needed for audio-modifying augmentations. - example_source : ExampleSource, optional - A callable providing other examples, required if 'mix_audio' is used. - - Returns - ------- - Augmentation (Callable[[xr.Dataset], xr.Dataset]) - A single callable function that takes a training example (`xr.Dataset`) - and applies the configured sequence of augmentations probabilistically, - returning the augmented example. Returns the original example if - `config.steps` is empty. - - Raises - ------ - ValueError - If 'mix_audio' is configured but `example_source` is not provided. - NotImplementedError - If an unknown `augmentation_type` is encountered in `config.steps`. - """ + """Build a composite augmentation pipeline function from configuration.""" config = config or DEFAULT_AUGMENTATION_CONFIG logger.opt(lazy=True).debug( @@ -860,6 +564,10 @@ def build_augmentations( preprocessor=preprocessor, example_source=example_source, ) + + if augmentation is None: + continue + augmentations.append( MaybeApply( augmentation=augmentation, @@ -867,51 +575,11 @@ def build_augmentations( ) ) - return partial(_apply_augmentations, augmentations=augmentations) + return torch.nn.Sequential(*augmentations) def load_augmentation_config( path: data.PathLike, field: Optional[str] = None ) -> AugmentationsConfig: - """Load the augmentations configuration from a file. - - Reads a configuration file (YAML) and validates it against the - `AugmentationsConfig` schema, potentially extracting data from a nested - field. - - Parameters - ---------- - path : PathLike - Path to the configuration file. - field : str, optional - Dot-separated path to a nested section within the file containing the - augmentations configuration (e.g., "training.augmentations"). If None, - the entire file content is used. - - Returns - ------- - AugmentationsConfig - The loaded and validated augmentations configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - yaml.YAMLError - If the file content is not valid YAML. - pydantic.ValidationError - If the loaded config data does not conform to `AugmentationsConfig`. - KeyError, TypeError - If `field` specifies an invalid path. - """ + """Load the augmentations configuration from a file.""" return load_config(path, schema=AugmentationsConfig, field=field) - - -def _apply_augmentations( - example: xr.Dataset, - augmentations: List[Augmentation], -): - """Apply a sequence of augmentation functions to an example.""" - for augmentation in augmentations: - example = augmentation(example) - return example diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 090a6a5..acf6a95 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -1,12 +1,12 @@ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple import numpy as np -import xarray as xr from loguru import logger -from soundevent import arrays from batdetect2.configs import BaseConfig from batdetect2.typing import ClipperProtocol +from batdetect2.typing.train import PreprocessedExample +from batdetect2.utils.arrays import adjust_width DEFAULT_TRAIN_CLIP_DURATION = 0.513 DEFAULT_MAX_EMPTY_CLIP = 0.1 @@ -32,40 +32,23 @@ class Clipper(ClipperProtocol): self.max_empty = max_empty def extract_clip( - self, example: xr.Dataset - ) -> Tuple[xr.Dataset, float, float]: - step = arrays.get_dim_step( - example.spectrogram, - dim=arrays.Dimensions.time.value, - ) - duration = ( - arrays.get_dim_width( - example.spectrogram, - dim=arrays.Dimensions.time.value, - ) - + step - ) - + self, example: PreprocessedExample + ) -> Tuple[PreprocessedExample, float, float]: start_time = 0 + duration = example.audio.shape[-1] / self.samplerate + if self.random: start_time = np.random.uniform( -self.max_empty, duration - self.duration + self.max_empty, ) - subclip = select_subclip( - example, - start=start_time, - span=self.duration, - dim="time", - ) - return ( select_subclip( - subclip, + example, start=start_time, - span=self.duration, - dim="audio_time", + duration=self.duration, + samplerate=self.samplerate, ), start_time, start_time + self.duration, @@ -73,6 +56,7 @@ class Clipper(ClipperProtocol): def build_clipper( + samplerate: int, config: Optional[ClipingConfig] = None, random: Optional[bool] = None, ) -> ClipperProtocol: @@ -82,6 +66,7 @@ def build_clipper( lambda: config.to_yaml_string(), ) return Clipper( + samplerate=samplerate, duration=config.duration, max_empty=config.max_empty, random=config.random if random else False, @@ -89,106 +74,43 @@ def build_clipper( def select_subclip( - dataset: xr.Dataset, - span: float, + example: PreprocessedExample, start: float, - fill_value: float = 0, - dim: str = "time", -) -> xr.Dataset: - width = _compute_expected_width( - dataset, # type: ignore - span, - dim=dim, - ) - - coord = dataset.coords[dim] - - if len(coord) == width: - return dataset - - new_coords, start_pad, end_pad, dim_slice = _extract_coordinate( - coord, start, span - ) - - data_vars = {} - for name, data_array in dataset.data_vars.items(): - if dim not in data_array.dims: - data_vars[name] = data_array - continue - - if width == data_array.sizes[dim]: - data_vars[name] = data_array - continue - - sliced = data_array.isel({dim: dim_slice}).data - - if start_pad > 0 or end_pad > 0: - padding = [ - [0, 0] if other_dim != dim else [start_pad, end_pad] - for other_dim in data_array.dims - ] - sliced = np.pad(sliced, padding, constant_values=fill_value) - - data_vars[name] = xr.DataArray( - data=sliced, - dims=data_array.dims, - coords={**data_array.coords, dim: new_coords}, - attrs=data_array.attrs, - ) - - return xr.Dataset(data_vars=data_vars, attrs=dataset.attrs) - - -def _extract_coordinate( - coord: xr.DataArray, - start: float, - span: float, -) -> Tuple[xr.Variable, int, int, slice]: - step = arrays.get_dim_step(coord, str(coord.name)) - - current_width = len(coord) - expected_width = int(np.floor(span / step)) - - coord_start = float(coord[0]) - offset = start - coord_start - - start_index = int(np.floor(offset / step)) - end_index = start_index + expected_width - - if start_index > current_width: - raise ValueError("Requested span does not overlap with current range") - - if end_index < 0: - raise ValueError("Requested span does not overlap with current range") - - corrected_start = float(start_index * step) - corrected_end = float(end_index * step) - - start_index_offset = max(0, -start_index) - end_index_offset = max(0, end_index - current_width) - - sl = slice( - start_index if start_index >= 0 else None, - end_index if end_index < current_width else None, - ) - - return ( - arrays.create_range_dim( - str(coord.name), - start=corrected_start, - stop=corrected_end, - step=step, - ), - start_index_offset, - end_index_offset, - sl, - ) - - -def _compute_expected_width( - array: Union[xr.DataArray, xr.Dataset], duration: float, - dim: str, -) -> int: - step = arrays.get_dim_step(array, dim) # type: ignore - return int(np.floor(duration / step)) + samplerate: float, + fill_value: float = 0, +) -> PreprocessedExample: + audio_width = int(np.floor(duration * samplerate)) + audio_start = int(np.floor(start * samplerate)) + + audio = adjust_width( + example.audio[audio_start : audio_start + audio_width], + audio_width, + value=fill_value, + ) + + audio_duration = example.audio.shape[-1] / samplerate + spec_sr = example.spectrogram.shape[-1] / audio_duration + + spec_start = int(np.floor(start * spec_sr)) + spec_width = int(np.floor(duration * spec_sr)) + + return PreprocessedExample( + audio=audio, + spectrogram=adjust_width( + example.spectrogram[:, spec_start : spec_start + spec_width], + spec_width, + ), + class_heatmap=adjust_width( + example.class_heatmap[:, :, spec_start : spec_start + spec_width], + spec_width, + ), + detection_heatmap=adjust_width( + example.detection_heatmap[:, spec_start : spec_start + spec_width], + spec_width, + ), + size_heatmap=adjust_width( + example.size_heatmap[:, :, spec_start : spec_start + spec_width], + spec_width, + ), + ) diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index 3a26c59..ebdd489 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -22,7 +22,7 @@ includes utilities for parallel processing using `multiprocessing`. import os from pathlib import Path -from typing import Callable, Dict, Optional, Sequence +from typing import Callable, Optional, Sequence, TypedDict import numpy as np import torch @@ -98,17 +98,25 @@ def preprocess_dataset( ) +class Example(TypedDict): + audio: torch.Tensor + spectrogram: torch.Tensor + detection_heatmap: torch.Tensor + class_heatmap: torch.Tensor + size_heatmap: torch.Tensor + + def generate_train_example( clip_annotation: data.ClipAnnotation, audio_loader: AudioLoader, preprocessor: PreprocessorProtocol, labeller: ClipLabeller, -) -> Dict[str, torch.Tensor]: +) -> PreprocessedExample: """Generate a complete training example for one annotation.""" wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip)) spectrogram = preprocessor(wave) heatmaps = labeller(clip_annotation, spectrogram) - return dict( + return PreprocessedExample( audio=wave, spectrogram=spectrogram, detection_heatmap=heatmaps.detection, @@ -138,8 +146,14 @@ class PreprocessingDataset(torch.utils.data.Dataset): preprocessor=self.preprocessor, labeller=self.labeller, ) - example["idx"] = idx - return example + return { + "idx": idx, + "spectrogram": example.spectrogram, + "audio": example.audio, + "class_heatmap": example.class_heatmap, + "size_heatmap": example.size_heatmap, + "detection_heatmap": example.detection_heatmap, + } def __len__(self) -> int: return len(self.clips) @@ -147,16 +161,17 @@ class PreprocessingDataset(torch.utils.data.Dataset): def _save_example_to_file( example: PreprocessedExample, + clip_annotation: data.ClipAnnotation, path: data.PathLike, ) -> None: np.savez_compressed( path, - audio=example.audio, - spectrogram=example.spectrogram, - detection_heatmap=example.detection_heatmap, - class_heatmap=example.class_heatmap, - size_heatmap=example.size_heatmap, - clip_annotation=example.clip_annotation, + audio=example.audio.numpy(), + spectrogram=example.spectrogram.numpy(), + detection_heatmap=example.detection_heatmap.numpy(), + class_heatmap=example.class_heatmap.numpy(), + size_heatmap=example.size_heatmap.numpy(), + clip_annotation=clip_annotation, ) @@ -211,11 +226,10 @@ def preprocess_annotations( filename = filename_fn(clip_annotation) path = output_dir / filename example = PreprocessedExample( - clip_annotation=clip_annotation, - spectrogram=batch["spectrogram"].numpy(), - audio=batch["audio"].numpy(), - class_heatmap=batch["class_heatmap"].numpy(), - size_heatmap=batch["size_heatmap"].numpy(), - detection_heatmap=batch["detection_heatmap"].numpy(), + spectrogram=batch["spectrogram"], + audio=batch["audio"], + class_heatmap=batch["class_heatmap"], + size_heatmap=batch["size_heatmap"], + detection_heatmap=batch["detection_heatmap"], ) - _save_example_to_file(example, path) + _save_example_to_file(example, clip_annotation, path) diff --git a/src/batdetect2/typing/preprocess.py b/src/batdetect2/typing/preprocess.py index 9f02ab8..584f739 100644 --- a/src/batdetect2/typing/preprocess.py +++ b/src/batdetect2/typing/preprocess.py @@ -148,6 +148,8 @@ class PreprocessorProtocol(Protocol): min_freq: float + samplerate: int + audio_pipeline: AudioPipeline spectrogram_pipeline: SpectrogramPipeline @@ -155,4 +157,4 @@ class PreprocessorProtocol(Protocol): def __call__(self, wav: torch.Tensor) -> torch.Tensor: ... def process_numpy(self, wav: np.ndarray) -> np.ndarray: - return self(torch.tensor(wav)).numpy()[0, 0] + return self(torch.tensor(wav)).numpy() diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py index 79868d9..7720a27 100644 --- a/src/batdetect2/typing/train.py +++ b/src/batdetect2/typing/train.py @@ -1,8 +1,6 @@ from typing import Callable, NamedTuple, Protocol, Tuple -import numpy as np import torch -import xarray as xr from soundevent import data from batdetect2.typing.models import ModelOutput @@ -19,24 +17,7 @@ __all__ = [ class Heatmaps(NamedTuple): - """Structure holding the generated heatmap targets. - - Attributes - ---------- - detection : xr.DataArray - Heatmap indicating the probability of sound event presence. Typically - smoothed with a Gaussian kernel centered on event reference points. - Shape matches the input spectrogram. Values normalized [0, 1]. - classes : xr.DataArray - Heatmap indicating the probability of specific class presence. Has an - additional 'category' dimension corresponding to the target class - names. Each category slice is typically smoothed with a Gaussian - kernel. Values normalized [0, 1] per category. - size : xr.DataArray - Heatmap encoding the size (width, height) of detected events. Has an - additional 'dimension' coordinate ('width', 'height'). Values represent - scaled dimensions placed at the event reference points. - """ + """Structure holding the generated heatmap targets.""" detection: torch.Tensor classes: torch.Tensor @@ -44,12 +25,20 @@ class Heatmaps(NamedTuple): class PreprocessedExample(NamedTuple): - audio: np.ndarray - spectrogram: np.ndarray - detection_heatmap: np.ndarray - class_heatmap: np.ndarray - size_heatmap: np.ndarray - clip_annotation: data.ClipAnnotation + audio: torch.Tensor + spectrogram: torch.Tensor + detection_heatmap: torch.Tensor + class_heatmap: torch.Tensor + size_heatmap: torch.Tensor + + def copy(self): + return PreprocessedExample( + audio=self.audio.clone(), + spectrogram=self.spectrogram.clone(), + detection_heatmap=self.detection_heatmap.clone(), + size_heatmap=self.size_heatmap.clone(), + class_heatmap=self.class_heatmap.clone(), + ) ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps] @@ -60,7 +49,7 @@ spectrogram, applies all configured filtering, transformation, and encoding steps, and returns the final `Heatmaps` used for model training. """ -Augmentation = Callable[[xr.Dataset], xr.Dataset] +Augmentation = Callable[[PreprocessedExample], PreprocessedExample] class TrainExample(NamedTuple): @@ -108,5 +97,5 @@ class LossProtocol(Protocol): class ClipperProtocol(Protocol): def extract_clip( - self, example: xr.Dataset - ) -> Tuple[xr.Dataset, float, float]: ... + self, example: PreprocessedExample + ) -> Tuple[PreprocessedExample, float, float]: ... diff --git a/src/batdetect2/utils/arrays.py b/src/batdetect2/utils/arrays.py index 60a8bd3..7a46dd7 100644 --- a/src/batdetect2/utils/arrays.py +++ b/src/batdetect2/utils/arrays.py @@ -1,4 +1,5 @@ import numpy as np +import torch import xarray as xr @@ -35,77 +36,40 @@ def spec_to_xarray( ) -def audio_to_xarray( - wav: np.ndarray, - start_time: float, - end_time: float, - time_axis: str = "time", -) -> xr.DataArray: - if wav.ndim != 1: - raise ValueError("Input numpy audio array should be 1-dimensional") - - return xr.DataArray( - data=wav, - dims=[time_axis], - coords={ - time_axis: np.linspace( - start_time, - end_time, - len(wav), - endpoint=False, - ), - }, - ) - - def extend_width( - array: np.ndarray, + tensor: torch.Tensor, extra: int, axis: int = -1, value: float = 0, -) -> np.ndarray: - dims = len(array.shape) - axis = axis % dims - pad = [[0, 0] if index != axis else [0, extra] for index in range(dims)] - return np.pad( - array, +) -> torch.Tensor: + dims = len(tensor.shape) + axis = dims - axis % dims - 1 + pad = [0 for _ in range(2 * dims)] + pad[2 * axis + 1] = extra + return torch.nn.functional.pad( + tensor, pad, mode="constant", - constant_values=value, + value=value, ) -def make_width_divisible( - array: np.ndarray, - factor: int, - axis: int = -1, - value: float = 0, -) -> np.ndarray: - width = array.shape[axis] - - if width % factor == 0: - return array - - extra = (-width) % factor - return extend_width(array, extra, axis=axis, value=value) - - def adjust_width( - array: np.ndarray, + tensor: torch.Tensor, width: int, axis: int = -1, value: float = 0, -) -> np.ndarray: - dims = len(array.shape) +) -> torch.Tensor: + dims = len(tensor.shape) axis = axis % dims - current_width = array.shape[axis] + current_width = tensor.shape[axis] if current_width == width: - return array + return tensor if current_width < width: return extend_width( - array, + tensor, extra=width - current_width, axis=axis, value=value, @@ -115,11 +79,4 @@ def adjust_width( slice(None, None) if index != axis else slice(None, width) for index in range(dims) ] - return array[tuple(slices)] - - -def iterate_over_array(array: xr.DataArray): - dim_name = array.dims[0] - coords = array.coords[dim_name] - for value, coord in zip(array.values, coords.values): - yield coord, float(value) + return tensor[tuple(slices)] diff --git a/tests/conftest.py b/tests/conftest.py index 0f1f806..49bea5b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -431,8 +431,13 @@ def sample_targets( @pytest.fixture def sample_labeller( sample_targets: TargetProtocol, + sample_preprocessor: PreprocessorProtocol, ) -> ClipLabeller: - return build_clip_labeler(sample_targets) + return build_clip_labeler( + sample_targets, + min_freq=sample_preprocessor.min_freq, + max_freq=sample_preprocessor.max_freq, + ) @pytest.fixture diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index beffc32..99d4640 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -2,6 +2,7 @@ from collections.abc import Callable import numpy as np import pytest +import torch import xarray as xr from soundevent import arrays, data @@ -42,12 +43,17 @@ def test_mix_examples( labeller=sample_labeller, ) - mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor) + mixed = mix_examples( + example1, + example2, + weight=0.3, + preprocessor=sample_preprocessor, + ) - assert mixed["spectrogram"].shape == example1["spectrogram"].shape - assert mixed["detection"].shape == example1["detection"].shape - assert mixed["size"].shape == example1["size"].shape - assert mixed["class"].shape == example1["class"].shape + assert mixed.spectrogram.shape == example1.spectrogram.shape + assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape + assert mixed.size_heatmap.shape == example1.size_heatmap.shape + assert mixed.class_heatmap.shape == example1.class_heatmap.shape @pytest.mark.parametrize("duration1", [0.1, 0.4, 0.7]) @@ -82,13 +88,17 @@ def test_mix_examples_of_different_durations( labeller=sample_labeller, ) - mixed = mix_examples(example1, example2, preprocessor=sample_preprocessor) + mixed = mix_examples( + example1, + example2, + weight=0.3, + preprocessor=sample_preprocessor, + ) - # Check the spectrogram has the expected duration - step = arrays.get_dim_step(mixed["spectrogram"], "time") - start, stop = arrays.get_dim_range(mixed["spectrogram"], "time") - assert start == 0 - assert np.isclose(stop + step, duration1, atol=2 * step) + assert mixed.spectrogram.shape == example1.spectrogram.shape + assert mixed.detection_heatmap.shape == example1.detection_heatmap.shape + assert mixed.size_heatmap.shape == example1.size_heatmap.shape + assert mixed.class_heatmap.shape == example1.class_heatmap.shape def test_add_echo( @@ -107,12 +117,32 @@ def test_add_echo( preprocessor=sample_preprocessor, labeller=sample_labeller, ) - with_echo = add_echo(original, preprocessor=sample_preprocessor) + with_echo = add_echo( + original, + preprocessor=sample_preprocessor, + delay=0.1, + weight=0.3, + ) - assert with_echo["spectrogram"].shape == original["spectrogram"].shape - xr.testing.assert_identical(with_echo["size"], original["size"]) - xr.testing.assert_identical(with_echo["class"], original["class"]) - xr.testing.assert_identical(with_echo["detection"], original["detection"]) + assert with_echo.spectrogram.shape == original.spectrogram.shape + torch.testing.assert_close( + with_echo.size_heatmap, + original.size_heatmap, + atol=0, + rtol=0, + ) + torch.testing.assert_close( + with_echo.class_heatmap, + original.class_heatmap, + atol=0, + rtol=0, + ) + torch.testing.assert_close( + with_echo.detection_heatmap, + original.detection_heatmap, + atol=0, + rtol=0, + ) def test_selected_random_subclip_has_the_correct_width( diff --git a/tests/test_train/test_clips.py b/tests/test_train/test_clips.py index 95ac37e..3ab7661 100644 --- a/tests/test_train/test_clips.py +++ b/tests/test_train/test_clips.py @@ -3,7 +3,6 @@ import pytest import xarray as xr from batdetect2.train.clips import ( - Clipper, _compute_expected_width, select_subclip, ) @@ -322,145 +321,3 @@ def test_select_subclip_no_overlap_raises_error(long_dataset): start=-1.0 * CLIP_DURATION - 1.0, dim="time", ) - - -def test_clipper_non_random(long_dataset, exact_dataset, short_dataset): - clipper = Clipper(duration=CLIP_DURATION, random=False) - - for ds in [long_dataset, exact_dataset, short_dataset]: - clip, _, _ = clipper.extract_clip(ds) - expected_spec_width = _compute_expected_width( - ds, CLIP_DURATION, "time" - ) - expected_audio_width = _compute_expected_width( - ds, CLIP_DURATION, "audio_time" - ) - - assert clip.dims["time"] == expected_spec_width - assert clip.dims["audio_time"] == expected_audio_width - assert clip.spectrogram.shape[1] == expected_spec_width - assert clip.audio.shape[0] == expected_audio_width - - assert clip.time.min() >= -1 / SPEC_SAMPLERATE - assert clip.audio_time.min() >= -1 / AUDIO_SAMPLERATE - - time_span = clip.time.max() - clip.time.min() - audio_span = clip.audio_time.max() - clip.audio_time.min() - assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE) - assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE) - - -def test_clipper_random(long_dataset): - seed = 42 - np.random.seed(seed) - clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY) - clip1, _, _ = clipper.extract_clip(long_dataset) - - np.random.seed(seed + 1) - clip2, _, _ = clipper.extract_clip(long_dataset) - - expected_spec_width = _compute_expected_width( - long_dataset, CLIP_DURATION, "time" - ) - expected_audio_width = _compute_expected_width( - long_dataset, CLIP_DURATION, "audio_time" - ) - - for clip in [clip1, clip2]: - assert clip.dims["time"] == expected_spec_width - assert clip.dims["audio_time"] == expected_audio_width - assert clip.spectrogram.shape[1] == expected_spec_width - assert clip.audio.shape[0] == expected_audio_width - - assert not np.isclose(clip1.time.min(), clip2.time.min()) - assert not np.isclose(clip1.audio_time.min(), clip2.audio_time.min()) - - for clip in [clip1, clip2]: - time_span = clip.time.max() - clip.time.min() - audio_span = clip.audio_time.max() - clip.audio_time.min() - assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE) - assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE) - - max_start_time = ( - (long_dataset.time.max() - long_dataset.time.min()) - - CLIP_DURATION - + MAX_EMPTY - ) - assert clip1.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE - assert clip2.time.min() <= max_start_time + 1 / SPEC_SAMPLERATE - - -def test_clipper_random_max_empty_effect(long_dataset): - """Check that max_empty influences the possible start times.""" - seed = 123 - data_duration = long_dataset.time.max() - long_dataset.time.min() - - np.random.seed(seed) - clipper0 = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.0) - max_start_time0 = data_duration - CLIP_DURATION - start_times0 = [] - - for _ in range(20): - clip, _, _ = clipper0.extract_clip(long_dataset) - start_times0.append(clip.time.min().item()) - - assert all( - st <= max_start_time0 + 1 / SPEC_SAMPLERATE for st in start_times0 - ) - assert any(st > 0.1 for st in start_times0) - - np.random.seed(seed) - clipper_pos = Clipper(duration=CLIP_DURATION, random=True, max_empty=0.2) - max_start_time_pos = data_duration - CLIP_DURATION + 0.2 - start_times_pos = [] - for _ in range(20): - clip, _, _ = clipper_pos.extract_clip(long_dataset) - start_times_pos.append(clip.time.min().item()) - assert all( - st <= max_start_time_pos + 1 / SPEC_SAMPLERATE - for st in start_times_pos - ) - - assert any(st > max_start_time0 + 1e-6 for st in start_times_pos) - - -def test_clipper_short_dataset_random(short_dataset): - clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY) - clip, _, _ = clipper.extract_clip(short_dataset) - - expected_spec_width = _compute_expected_width( - short_dataset, CLIP_DURATION, "time" - ) - expected_audio_width = _compute_expected_width( - short_dataset, CLIP_DURATION, "audio_time" - ) - - assert clip.sizes["time"] == expected_spec_width - assert clip.sizes["audio_time"] == expected_audio_width - assert clip["spectrogram"].shape[1] == expected_spec_width - assert clip["audio"].shape[0] == expected_audio_width - - assert np.any(clip.spectrogram == 0) - assert np.any(clip.audio == 0) - - -def test_clipper_exact_dataset_random(exact_dataset): - clipper = Clipper(duration=CLIP_DURATION, random=True, max_empty=MAX_EMPTY) - clip, _, _ = clipper.extract_clip(exact_dataset) - - expected_spec_width = _compute_expected_width( - exact_dataset, CLIP_DURATION, "time" - ) - expected_audio_width = _compute_expected_width( - exact_dataset, CLIP_DURATION, "audio_time" - ) - - assert clip.dims["time"] == expected_spec_width - assert clip.dims["audio_time"] == expected_audio_width - assert clip.spectrogram.shape[1] == expected_spec_width - assert clip.audio.shape[0] == expected_audio_width - - time_span = clip.time.max() - clip.time.min() - audio_span = clip.audio_time.max() - clip.audio_time.min() - assert np.isclose(time_span, CLIP_DURATION, atol=1 / SPEC_SAMPLERATE) - assert np.isclose(audio_span, CLIP_DURATION, atol=1 / AUDIO_SAMPLERATE) diff --git a/tests/test_train/test_preprocessing.py b/tests/test_train/test_preprocessing.py index 83b05c5..bf52eed 100644 --- a/tests/test_train/test_preprocessing.py +++ b/tests/test_train/test_preprocessing.py @@ -1,6 +1,4 @@ import pytest -import torch -import xarray as xr from soundevent import data from soundevent.terms import get_term @@ -10,6 +8,7 @@ from batdetect2.targets import build_targets, load_target_config from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.preprocess import generate_train_example from batdetect2.typing import ModelOutput +from batdetect2.typing.preprocess import AudioLoader @pytest.fixture @@ -35,6 +34,8 @@ def build_from_config( labeller = build_clip_labeler( targets=targets, config=labels_config, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, ) postprocessor = build_postprocessor( targets, @@ -48,62 +49,8 @@ def build_from_config( return build -# TODO: better name -def test_generated_train_example_has_expected_outputs( - build_from_config, - recording, -): - yaml_content = """ - labels: - targets: - roi: - name: anchor_bbox - anchor: bottom-left - classes: - classes: - - name: pippip - tags: - - key: species - value: Pipistrellus pipistrellus - generic_class: - - key: order - value: Chiroptera - preprocessing: - postprocessing: - """ - _, preprocessor, labeller, _ = build_from_config(yaml_content) - - geometry = data.BoundingBox(coordinates=[0.1, 12_000, 0.2, 18_000]) - se1 = data.SoundEventAnnotation( - sound_event=data.SoundEvent(recording=recording, geometry=geometry), - tags=[ - data.Tag(key="species", value="Pipistrellus pipistrellus"), # type: ignore - ], - ) - clip_annotation = data.ClipAnnotation( - clip=data.Clip(start_time=0, end_time=0.5, recording=recording), - sound_events=[se1], - ) - - encoded = generate_train_example(clip_annotation, preprocessor, labeller) - - assert isinstance(encoded, xr.Dataset) - assert "audio" in encoded - assert "spectrogram" in encoded - assert "detection" in encoded - assert "class" in encoded - assert "size" in encoded - - spec_shape = encoded["spectrogram"].shape - assert len(spec_shape) == 2 - - height, width = spec_shape - assert encoded["detection"].shape == (height, width) - assert encoded["class"].shape == (1, height, width) - assert encoded["size"].shape == (2, height, width) - - def test_encoding_decoding_roundtrip_recovers_object( + sample_audio_loader: AudioLoader, build_from_config, recording, ): @@ -136,13 +83,17 @@ def test_encoding_decoding_roundtrip_recovers_object( clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) - encoded = generate_train_example(clip_annotation, preprocessor, labeller) + encoded = generate_train_example( + clip_annotation, sample_audio_loader, preprocessor, labeller + ) predictions = postprocessor.get_predictions( ModelOutput( - detection_probs=torch.tensor([[encoded["detection"].data]]), - size_preds=torch.tensor([encoded["size"].data]), - class_probs=torch.tensor([encoded["class"].data]), - features=torch.tensor([[encoded["spectrogram"].data]]), + detection_probs=encoded["detection_heatmap"] + .unsqueeze(0) + .unsqueeze(0), + size_preds=encoded["size_heatmap"].unsqueeze(0), + class_probs=encoded["class_heatmap"].unsqueeze(0), + features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0), ), [clip], )[0] @@ -185,6 +136,7 @@ def test_encoding_decoding_roundtrip_recovers_object( def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( + sample_audio_loader: AudioLoader, build_from_config, recording, ): @@ -222,13 +174,20 @@ def test_encoding_decoding_roundtrip_recovers_object_with_roi_override( clip = data.Clip(start_time=0, end_time=0.5, recording=recording) clip_annotation = data.ClipAnnotation(clip=clip, sound_events=[se1]) - encoded = generate_train_example(clip_annotation, preprocessor, labeller) + encoded = generate_train_example( + clip_annotation, + sample_audio_loader, + preprocessor, + labeller, + ) predictions = postprocessor.get_predictions( ModelOutput( - detection_probs=torch.tensor([[encoded["detection"].data]]), - size_preds=torch.tensor([encoded["size"].data]), - class_probs=torch.tensor([encoded["class"].data]), - features=torch.tensor([[encoded["spectrogram"].data]]), + detection_probs=encoded["detection_heatmap"] + .unsqueeze(0) + .unsqueeze(0), + size_preds=encoded["size_heatmap"].unsqueeze(0), + class_probs=encoded["class_heatmap"].unsqueeze(0), + features=encoded["spectrogram"].unsqueeze(0).unsqueeze(0), ), [clip], )[0] diff --git a/tests/test_utils/test_arrays.py b/tests/test_utils/test_arrays.py index a316390..be9d04c 100644 --- a/tests/test_utils/test_arrays.py +++ b/tests/test_utils/test_arrays.py @@ -1,23 +1,59 @@ -import numpy as np +import torch from batdetect2.utils.arrays import adjust_width, extend_width def test_extend_width(): - array = np.random.random([1, 1, 128, 100]) - + array = torch.rand([1, 1, 128, 100]) extended = extend_width(array, 100) - assert extended.shape == (1, 1, 128, 200) + extended = extend_width(array, 100, axis=0) + assert extended.shape == (101, 1, 128, 100) + + extended = extend_width(array, 100, axis=1) + assert extended.shape == (1, 101, 128, 100) + + extended = extend_width(array, 100, axis=2) + assert extended.shape == (1, 1, 228, 100) + + extended = extend_width(array, 100, axis=3) + assert extended.shape == (1, 1, 128, 200) + + extended = extend_width(array, 100, axis=-2) + assert extended.shape == (1, 1, 228, 100) + + +def test_extends_with_value(): + array = torch.rand([1, 1, 128, 100]) + extended = extend_width(array, 100, value=-1) + torch.testing.assert_close( + extended[:, :, :, 100:], + torch.ones_like(array) * -1, + rtol=0, + atol=0, + ) + def test_can_adjust_short_width(): - array = np.random.random([1, 1, 128, 100]) + array = torch.rand([1, 1, 128, 100]) extended = adjust_width(array, 512) assert extended.shape == (1, 1, 128, 512) + extended = adjust_width(array, 512, axis=0) + assert extended.shape == (512, 1, 128, 100) + + extended = adjust_width(array, 512, axis=1) + assert extended.shape == (1, 512, 128, 100) + + extended = adjust_width(array, 512, axis=2) + assert extended.shape == (1, 1, 512, 100) + + extended = adjust_width(array, 512, axis=3) + assert extended.shape == (1, 1, 128, 512) + def test_can_adjust_long_width(): - array = np.random.random([1, 1, 128, 512]) + array = torch.rand([1, 1, 128, 512]) extended = adjust_width(array, 256) assert extended.shape == (1, 1, 128, 256)