Fix augmentations

This commit is contained in:
mbsantiago 2025-04-24 00:20:38 +01:00
parent 8a6ed3dec7
commit d8cf1db19f
3 changed files with 55 additions and 48 deletions

View File

@ -18,7 +18,6 @@ from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
SubclipConfig,
TrainExample, TrainExample,
list_preprocessed_files, list_preprocessed_files,
) )
@ -37,7 +36,6 @@ __all__ = [
"LabeledDataset", "LabeledDataset",
"LossFunction", "LossFunction",
"RandomExampleSource", "RandomExampleSource",
"SubclipConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainExample", "TrainExample",
"TrainerConfig", "TrainerConfig",

View File

@ -126,6 +126,7 @@ def mix_examples(
audio1 = example["audio"] audio1 = example["audio"]
audio2 = adjust_width(other["audio"].values, len(audio1)) audio2 = adjust_width(other["audio"].values, len(audio1))
with xr.set_options(keep_attrs=True):
combined = weight * audio1 + (1 - weight) * audio2 combined = weight * audio1 + (1 - weight) * audio2
spectrogram = preprocessor.compute_spectrogram( spectrogram = preprocessor.compute_spectrogram(
@ -227,7 +228,6 @@ def add_echo(
'spectrogram' variable recomputed. Other variables (targets, attrs) 'spectrogram' variable recomputed. Other variables (targets, attrs)
are copied from the original example. are copied from the original example.
""" """
if delay is None: if delay is None:
delay = np.random.uniform(0, max_delay) delay = np.random.uniform(0, max_delay)
@ -237,6 +237,8 @@ def add_echo(
audio = example["audio"] audio = example["audio"]
step = arrays.get_dim_step(audio, "audio_time") step = arrays.get_dim_step(audio, "audio_time")
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0) audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
with xr.set_options(keep_attrs=True):
audio = audio + weight * audio_delay audio = audio + weight * audio_delay
spectrogram = preprocessor.compute_spectrogram( spectrogram = preprocessor.compute_spectrogram(
@ -260,6 +262,7 @@ def add_echo(
data=spectrogram, data=spectrogram,
dims=example["spectrogram"].dims, dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords, coords=example["spectrogram"].coords,
attrs=example["spectrogram"].attrs,
), ),
) )
@ -311,6 +314,7 @@ def scale_volume(
if factor is None: if factor is None:
factor = np.random.uniform(min_scaling, max_scaling) factor = np.random.uniform(min_scaling, max_scaling)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=example["spectrogram"] * factor) return example.assign(spectrogram=example["spectrogram"] * factor)
@ -366,6 +370,7 @@ def warp_spectrogram(
example.time.size, example.time.size,
) )
with xr.set_options(keep_attrs=True):
spectrogram = ( spectrogram = (
example["spectrogram"] example["spectrogram"]
.interp( .interp(
@ -423,7 +428,7 @@ def mask_axis(
"""Mask values along a specified dimension. """Mask values along a specified dimension.
Sets values in the DataArray to `mask_value` where the coordinate along Sets values in the DataArray to `mask_value` where the coordinate along
`dim` falls within the range [`start`, `end`). Values outside this range `dim` falls within the range (`start`, `end`). Values outside this range
are kept. Used as a helper for time/frequency masking. are kept. Used as a helper for time/frequency masking.
Parameters Parameters
@ -462,6 +467,7 @@ def mask_axis(
if callable(mask_value): if callable(mask_value):
mask_value = mask_value(array) mask_value = mask_value(array)
with xr.set_options(keep_attrs=True):
return array.where(condition, other=mask_value) return array.where(condition, other=mask_value)
@ -511,6 +517,7 @@ def mask_time(
end = start + mask_size end = start + mask_size
spectrogram = mask_axis(spectrogram, "time", start, end) spectrogram = mask_axis(spectrogram, "time", start, end)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram) return example.assign(spectrogram=spectrogram)
@ -566,6 +573,7 @@ def mask_frequency(
end = start + mask_size end = start + mask_size
spectrogram = mask_axis(spectrogram, "frequency", start, end) spectrogram = mask_axis(spectrogram, "frequency", start, end)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram) return example.assign(spectrogram=spectrogram)

View File

@ -26,6 +26,7 @@ from pathlib import Path
from typing import Callable, Optional, Sequence from typing import Callable, Optional, Sequence
import xarray as xr import xarray as xr
from loguru import logger
from soundevent import data from soundevent import data
from tqdm.auto import tqdm from tqdm.auto import tqdm