Fix augmentations

This commit is contained in:
mbsantiago 2025-04-24 00:20:38 +01:00
parent 8a6ed3dec7
commit d8cf1db19f
3 changed files with 55 additions and 48 deletions

View File

@ -18,7 +18,6 @@ from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
SubclipConfig,
TrainExample, TrainExample,
list_preprocessed_files, list_preprocessed_files,
) )
@ -37,7 +36,6 @@ __all__ = [
"LabeledDataset", "LabeledDataset",
"LossFunction", "LossFunction",
"RandomExampleSource", "RandomExampleSource",
"SubclipConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainExample", "TrainExample",
"TrainerConfig", "TrainerConfig",

View File

@ -126,7 +126,8 @@ def mix_examples(
audio1 = example["audio"] audio1 = example["audio"]
audio2 = adjust_width(other["audio"].values, len(audio1)) audio2 = adjust_width(other["audio"].values, len(audio1))
combined = weight * audio1 + (1 - weight) * audio2 with xr.set_options(keep_attrs=True):
combined = weight * audio1 + (1 - weight) * audio2
spectrogram = preprocessor.compute_spectrogram( spectrogram = preprocessor.compute_spectrogram(
combined.rename({"audio_time": "time"}) combined.rename({"audio_time": "time"})
@ -227,7 +228,6 @@ def add_echo(
'spectrogram' variable recomputed. Other variables (targets, attrs) 'spectrogram' variable recomputed. Other variables (targets, attrs)
are copied from the original example. are copied from the original example.
""" """
if delay is None: if delay is None:
delay = np.random.uniform(0, max_delay) delay = np.random.uniform(0, max_delay)
@ -237,7 +237,9 @@ def add_echo(
audio = example["audio"] audio = example["audio"]
step = arrays.get_dim_step(audio, "audio_time") step = arrays.get_dim_step(audio, "audio_time")
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0) audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
audio = audio + weight * audio_delay
with xr.set_options(keep_attrs=True):
audio = audio + weight * audio_delay
spectrogram = preprocessor.compute_spectrogram( spectrogram = preprocessor.compute_spectrogram(
audio.rename({"audio_time": "time"}), audio.rename({"audio_time": "time"}),
@ -260,6 +262,7 @@ def add_echo(
data=spectrogram, data=spectrogram,
dims=example["spectrogram"].dims, dims=example["spectrogram"].dims,
coords=example["spectrogram"].coords, coords=example["spectrogram"].coords,
attrs=example["spectrogram"].attrs,
), ),
) )
@ -311,7 +314,8 @@ def scale_volume(
if factor is None: if factor is None:
factor = np.random.uniform(min_scaling, max_scaling) factor = np.random.uniform(min_scaling, max_scaling)
return example.assign(spectrogram=example["spectrogram"] * factor) with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=example["spectrogram"] * factor)
class WarpAugmentationConfig(BaseConfig): class WarpAugmentationConfig(BaseConfig):
@ -366,51 +370,52 @@ def warp_spectrogram(
example.time.size, example.time.size,
) )
spectrogram = ( with xr.set_options(keep_attrs=True):
example["spectrogram"] spectrogram = (
.interp( example["spectrogram"]
coords={"time": new_time}, .interp(
method="linear", coords={"time": new_time},
method="linear",
kwargs=dict(
fill_value=0,
),
)
.clip(min=0)
)
detection = example["detection"].interp(
time=new_time,
method="nearest",
kwargs=dict( kwargs=dict(
fill_value=0, fill_value=0,
), ),
) )
.clip(min=0)
)
detection = example["detection"].interp( classification = example["class"].interp(
time=new_time, time=new_time,
method="nearest", method="nearest",
kwargs=dict( kwargs=dict(
fill_value=0, fill_value=0,
), ),
) )
classification = example["class"].interp( size = example["size"].interp(
time=new_time, time=new_time,
method="nearest", method="nearest",
kwargs=dict( kwargs=dict(
fill_value=0, fill_value=0,
), ),
) )
size = example["size"].interp( return example.assign(
time=new_time, {
method="nearest", "time": new_time,
kwargs=dict( "spectrogram": spectrogram,
fill_value=0, "detection": detection,
), "class": classification,
) "size": size,
}
return example.assign( )
{
"time": new_time,
"spectrogram": spectrogram,
"detection": detection,
"class": classification,
"size": size,
}
)
def mask_axis( def mask_axis(
@ -423,7 +428,7 @@ def mask_axis(
"""Mask values along a specified dimension. """Mask values along a specified dimension.
Sets values in the DataArray to `mask_value` where the coordinate along Sets values in the DataArray to `mask_value` where the coordinate along
`dim` falls within the range [`start`, `end`). Values outside this range `dim` falls within the range (`start`, `end`). Values outside this range
are kept. Used as a helper for time/frequency masking. are kept. Used as a helper for time/frequency masking.
Parameters Parameters
@ -462,7 +467,8 @@ def mask_axis(
if callable(mask_value): if callable(mask_value):
mask_value = mask_value(array) mask_value = mask_value(array)
return array.where(condition, other=mask_value) with xr.set_options(keep_attrs=True):
return array.where(condition, other=mask_value)
class TimeMaskAugmentationConfig(BaseConfig): class TimeMaskAugmentationConfig(BaseConfig):
@ -511,7 +517,8 @@ def mask_time(
end = start + mask_size end = start + mask_size
spectrogram = mask_axis(spectrogram, "time", start, end) spectrogram = mask_axis(spectrogram, "time", start, end)
return example.assign(spectrogram=spectrogram) with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram)
class FrequencyMaskAugmentationConfig(BaseConfig): class FrequencyMaskAugmentationConfig(BaseConfig):
@ -566,7 +573,8 @@ def mask_frequency(
end = start + mask_size end = start + mask_size
spectrogram = mask_axis(spectrogram, "frequency", start, end) spectrogram = mask_axis(spectrogram, "frequency", start, end)
return example.assign(spectrogram=spectrogram) with xr.set_options(keep_attrs=True):
return example.assign(spectrogram=spectrogram)
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[

View File

@ -26,6 +26,7 @@ from pathlib import Path
from typing import Callable, Optional, Sequence from typing import Callable, Optional, Sequence
import xarray as xr import xarray as xr
from loguru import logger
from soundevent import data from soundevent import data
from tqdm.auto import tqdm from tqdm.auto import tqdm