Fix augmentations

This commit is contained in:
mbsantiago 2025-04-24 00:20:38 +01:00
parent 8a6ed3dec7
commit d8cf1db19f
3 changed files with 55 additions and 48 deletions

View File

@ -18,7 +18,6 @@ from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
SubclipConfig,
TrainExample,
list_preprocessed_files,
)
@ -37,7 +36,6 @@ __all__ = [
"LabeledDataset",
"LossFunction",
"RandomExampleSource",
"SubclipConfig",
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainerConfig",

View File

@ -126,7 +126,8 @@ def mix_examples(
audio1 = example["audio"]
audio2 = adjust_width(other["audio"].values, len(audio1))
combined = weight * audio1 + (1 - weight) * audio2
with xr.set_options(keep_attrs=True):
combined = weight * audio1 + (1 - weight) * audio2
spectrogram = preprocessor.compute_spectrogram(
combined.rename({"audio_time": "time"})
@ -227,7 +228,6 @@ def add_echo(
'spectrogram' variable recomputed. Other variables (targets, attrs)
are copied from the original example.
"""
if delay is None:
delay = np.random.uniform(0, max_delay)
@ -237,7 +237,9 @@ def add_echo(
audio = example["audio"]
step = arrays.get_dim_step(audio, "audio_time")
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
audio = audio + weight * audio_delay
with xr.set_options(keep_attrs=True):
audio = audio + weight * audio_delay
spectrogram = preprocessor.compute_spectrogram(
audio.rename({"audio_time": "time"}),
@ -260,6 +262,7 @@ def add_echo(
data=spectrogram,
dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords,
attrs=example["spectrogram"].attrs,
),
)
@ -311,7 +314,8 @@ def scale_volume(
if factor is None:
factor = np.random.uniform(min_scaling, max_scaling)
return example.assign(spectrogram=example["spectrogram"] * factor)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=example["spectrogram"] * factor)
class WarpAugmentationConfig(BaseConfig):
@ -366,51 +370,52 @@ def warp_spectrogram(
example.time.size,
)
spectrogram = (
example["spectrogram"]
.interp(
coords={"time": new_time},
method="linear",
with xr.set_options(keep_attrs=True):
spectrogram = (
example["spectrogram"]
.interp(
coords={"time": new_time},
method="linear",
kwargs=dict(
fill_value=0,
),
)
.clip(min=0)
)
detection = example["detection"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
.clip(min=0)
)
detection = example["detection"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
classification = example["class"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
classification = example["class"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
size = example["size"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
size = example["size"].interp(
time=new_time,
method="nearest",
kwargs=dict(
fill_value=0,
),
)
return example.assign(
{
"time": new_time,
"spectrogram": spectrogram,
"detection": detection,
"class": classification,
"size": size,
}
)
return example.assign(
{
"time": new_time,
"spectrogram": spectrogram,
"detection": detection,
"class": classification,
"size": size,
}
)
def mask_axis(
@ -423,7 +428,7 @@ def mask_axis(
"""Mask values along a specified dimension.
Sets values in the DataArray to `mask_value` where the coordinate along
`dim` falls within the range [`start`, `end`). Values outside this range
`dim` falls within the range (`start`, `end`). Values outside this range
are kept. Used as a helper for time/frequency masking.
Parameters
@ -462,7 +467,8 @@ def mask_axis(
if callable(mask_value):
mask_value = mask_value(array)
return array.where(condition, other=mask_value)
with xr.set_options(keep_attrs=True):
return array.where(condition, other=mask_value)
class TimeMaskAugmentationConfig(BaseConfig):
@ -511,7 +517,8 @@ def mask_time(
end = start + mask_size
spectrogram = mask_axis(spectrogram, "time", start, end)
return example.assign(spectrogram=spectrogram)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram)
class FrequencyMaskAugmentationConfig(BaseConfig):
@ -566,7 +573,8 @@ def mask_frequency(
end = start + mask_size
spectrogram = mask_axis(spectrogram, "frequency", start, end)
return example.assign(spectrogram=spectrogram)
with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram)
AugmentationConfig = Annotated[

View File

@ -26,6 +26,7 @@ from pathlib import Path
from typing import Callable, Optional, Sequence
import xarray as xr
from loguru import logger
from soundevent import data
from tqdm.auto import tqdm