From d8cf1db19f475ba9168c43a7d139ebef5ae261b9 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 24 Apr 2025 00:20:38 +0100 Subject: [PATCH] Fix augmentations --- batdetect2/train/__init__.py | 2 - batdetect2/train/augmentations.py | 100 ++++++++++++++++-------------- batdetect2/train/preprocess.py | 1 + 3 files changed, 55 insertions(+), 48 deletions(-) diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index be4c1b0..692de6a 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -18,7 +18,6 @@ from batdetect2.train.config import TrainingConfig, load_train_config from batdetect2.train.dataset import ( LabeledDataset, RandomExampleSource, - SubclipConfig, TrainExample, list_preprocessed_files, ) @@ -37,7 +36,6 @@ __all__ = [ "LabeledDataset", "LossFunction", "RandomExampleSource", - "SubclipConfig", "TimeMaskAugmentationConfig", "TrainExample", "TrainerConfig", diff --git a/batdetect2/train/augmentations.py b/batdetect2/train/augmentations.py index f8d3fd0..6930c97 100644 --- a/batdetect2/train/augmentations.py +++ b/batdetect2/train/augmentations.py @@ -126,7 +126,8 @@ def mix_examples( audio1 = example["audio"] audio2 = adjust_width(other["audio"].values, len(audio1)) - combined = weight * audio1 + (1 - weight) * audio2 + with xr.set_options(keep_attrs=True): + combined = weight * audio1 + (1 - weight) * audio2 spectrogram = preprocessor.compute_spectrogram( combined.rename({"audio_time": "time"}) @@ -227,7 +228,6 @@ def add_echo( 'spectrogram' variable recomputed. Other variables (targets, attrs) are copied from the original example. """ - if delay is None: delay = np.random.uniform(0, max_delay) @@ -237,7 +237,9 @@ def add_echo( audio = example["audio"] step = arrays.get_dim_step(audio, "audio_time") audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0) - audio = audio + weight * audio_delay + + with xr.set_options(keep_attrs=True): + audio = audio + weight * audio_delay spectrogram = preprocessor.compute_spectrogram( audio.rename({"audio_time": "time"}), @@ -260,6 +262,7 @@ def add_echo( data=spectrogram, dims=example["spectrogram"].dims, coords=example["spectrogram"].coords, + attrs=example["spectrogram"].attrs, ), ) @@ -311,7 +314,8 @@ def scale_volume( if factor is None: factor = np.random.uniform(min_scaling, max_scaling) - return example.assign(spectrogram=example["spectrogram"] * factor) + with xr.set_options(keep_attrs=True): + return example.assign(spectrogram=example["spectrogram"] * factor) class WarpAugmentationConfig(BaseConfig): @@ -366,51 +370,52 @@ def warp_spectrogram( example.time.size, ) - spectrogram = ( - example["spectrogram"] - .interp( - coords={"time": new_time}, - method="linear", + with xr.set_options(keep_attrs=True): + spectrogram = ( + example["spectrogram"] + .interp( + coords={"time": new_time}, + method="linear", + kwargs=dict( + fill_value=0, + ), + ) + .clip(min=0) + ) + + detection = example["detection"].interp( + time=new_time, + method="nearest", kwargs=dict( fill_value=0, ), ) - .clip(min=0) - ) - detection = example["detection"].interp( - time=new_time, - method="nearest", - kwargs=dict( - fill_value=0, - ), - ) + classification = example["class"].interp( + time=new_time, + method="nearest", + kwargs=dict( + fill_value=0, + ), + ) - classification = example["class"].interp( - time=new_time, - method="nearest", - kwargs=dict( - fill_value=0, - ), - ) + size = example["size"].interp( + time=new_time, + method="nearest", + kwargs=dict( + fill_value=0, + ), + ) - size = example["size"].interp( - time=new_time, - method="nearest", - kwargs=dict( - fill_value=0, - ), - ) - - return example.assign( - { - "time": new_time, - "spectrogram": spectrogram, - "detection": detection, - "class": classification, - "size": size, - } - ) + return example.assign( + { + "time": new_time, + "spectrogram": spectrogram, + "detection": detection, + "class": classification, + "size": size, + } + ) def mask_axis( @@ -423,7 +428,7 @@ def mask_axis( """Mask values along a specified dimension. Sets values in the DataArray to `mask_value` where the coordinate along - `dim` falls within the range [`start`, `end`). Values outside this range + `dim` falls within the range (`start`, `end`). Values outside this range are kept. Used as a helper for time/frequency masking. Parameters @@ -462,7 +467,8 @@ def mask_axis( if callable(mask_value): mask_value = mask_value(array) - return array.where(condition, other=mask_value) + with xr.set_options(keep_attrs=True): + return array.where(condition, other=mask_value) class TimeMaskAugmentationConfig(BaseConfig): @@ -511,7 +517,8 @@ def mask_time( end = start + mask_size spectrogram = mask_axis(spectrogram, "time", start, end) - return example.assign(spectrogram=spectrogram) + with xr.set_options(keep_attrs=True): + return example.assign(spectrogram=spectrogram) class FrequencyMaskAugmentationConfig(BaseConfig): @@ -566,7 +573,8 @@ def mask_frequency( end = start + mask_size spectrogram = mask_axis(spectrogram, "frequency", start, end) - return example.assign(spectrogram=spectrogram) + with xr.set_options(keep_attrs=True): + return example.assign(spectrogram=spectrogram) AugmentationConfig = Annotated[ diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 14372ae..0368d10 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -26,6 +26,7 @@ from pathlib import Path from typing import Callable, Optional, Sequence import xarray as xr +from loguru import logger from soundevent import data from tqdm.auto import tqdm