Fix augmentations

This commit is contained in:
mbsantiago 2025-04-24 00:20:38 +01:00
parent 8a6ed3dec7
commit d8cf1db19f
3 changed files with 55 additions and 48 deletions

View File

@ -18,7 +18,6 @@ from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
SubclipConfig,
TrainExample,
list_preprocessed_files,
)
@ -37,7 +36,6 @@ __all__ = [
"LabeledDataset",
"LossFunction",
"RandomExampleSource",
"SubclipConfig",
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainerConfig",

View File

@ -126,6 +126,7 @@ def mix_examples(
audio1 = example["audio"]
audio2 = adjust_width(other["audio"].values, len(audio1))
with xr.set_options(keep_attrs=True):
combined = weight * audio1 + (1 - weight) * audio2
spectrogram = preprocessor.compute_spectrogram(
@ -227,7 +228,6 @@ def add_echo(
'spectrogram' variable recomputed. Other variables (targets, attrs)
are copied from the original example.
"""
if delay is None:
delay = np.random.uniform(0, max_delay)
@ -237,6 +237,8 @@ def add_echo(
audio = example["audio"]
step = arrays.get_dim_step(audio, "audio_time")
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
with xr.set_options(keep_attrs=True):
audio = audio + weight * audio_delay
spectrogram = preprocessor.compute_spectrogram(
@ -260,6 +262,7 @@ def add_echo(
data=spectrogram,
dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords,
attrs=example["spectrogram"].attrs,
),
)
@ -311,6 +314,7 @@ def scale_volume(
if factor is None:
factor = np.random.uniform(min_scaling, max_scaling)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=example["spectrogram"] * factor)
@ -366,6 +370,7 @@ def warp_spectrogram(
example.time.size,
)
with xr.set_options(keep_attrs=True):
spectrogram = (
example["spectrogram"]
.interp(
@ -423,7 +428,7 @@ def mask_axis(
"""Mask values along a specified dimension.
Sets values in the DataArray to `mask_value` where the coordinate along
`dim` falls within the range [`start`, `end`). Values outside this range
`dim` falls within the range (`start`, `end`). Values outside this range
are kept. Used as a helper for time/frequency masking.
Parameters
@ -462,6 +467,7 @@ def mask_axis(
if callable(mask_value):
mask_value = mask_value(array)
with xr.set_options(keep_attrs=True):
return array.where(condition, other=mask_value)
@ -511,6 +517,7 @@ def mask_time(
end = start + mask_size
spectrogram = mask_axis(spectrogram, "time", start, end)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram)
@ -566,6 +573,7 @@ def mask_frequency(
end = start + mask_size
spectrogram = mask_axis(spectrogram, "frequency", start, end)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram)

View File

@ -26,6 +26,7 @@ from pathlib import Path
from typing import Callable, Optional, Sequence
import xarray as xr
from loguru import logger
from soundevent import data
from tqdm.auto import tqdm