mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fix augmentations
This commit is contained in:
parent
8a6ed3dec7
commit
d8cf1db19f
@ -18,7 +18,6 @@ from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
SubclipConfig,
|
||||
TrainExample,
|
||||
list_preprocessed_files,
|
||||
)
|
||||
@ -37,7 +36,6 @@ __all__ = [
|
||||
"LabeledDataset",
|
||||
"LossFunction",
|
||||
"RandomExampleSource",
|
||||
"SubclipConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
|
@ -126,6 +126,7 @@ def mix_examples(
|
||||
audio1 = example["audio"]
|
||||
audio2 = adjust_width(other["audio"].values, len(audio1))
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
combined = weight * audio1 + (1 - weight) * audio2
|
||||
|
||||
spectrogram = preprocessor.compute_spectrogram(
|
||||
@ -227,7 +228,6 @@ def add_echo(
|
||||
'spectrogram' variable recomputed. Other variables (targets, attrs)
|
||||
are copied from the original example.
|
||||
"""
|
||||
|
||||
if delay is None:
|
||||
delay = np.random.uniform(0, max_delay)
|
||||
|
||||
@ -237,6 +237,8 @@ def add_echo(
|
||||
audio = example["audio"]
|
||||
step = arrays.get_dim_step(audio, "audio_time")
|
||||
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
audio = audio + weight * audio_delay
|
||||
|
||||
spectrogram = preprocessor.compute_spectrogram(
|
||||
@ -260,6 +262,7 @@ def add_echo(
|
||||
data=spectrogram,
|
||||
dims=example["spectrogram"].dims,
|
||||
coords=example["spectrogram"].coords,
|
||||
attrs=example["spectrogram"].attrs,
|
||||
),
|
||||
)
|
||||
|
||||
@ -311,6 +314,7 @@ def scale_volume(
|
||||
if factor is None:
|
||||
factor = np.random.uniform(min_scaling, max_scaling)
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||
|
||||
|
||||
@ -366,6 +370,7 @@ def warp_spectrogram(
|
||||
example.time.size,
|
||||
)
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
spectrogram = (
|
||||
example["spectrogram"]
|
||||
.interp(
|
||||
@ -423,7 +428,7 @@ def mask_axis(
|
||||
"""Mask values along a specified dimension.
|
||||
|
||||
Sets values in the DataArray to `mask_value` where the coordinate along
|
||||
`dim` falls within the range [`start`, `end`). Values outside this range
|
||||
`dim` falls within the range (`start`, `end`). Values outside this range
|
||||
are kept. Used as a helper for time/frequency masking.
|
||||
|
||||
Parameters
|
||||
@ -462,6 +467,7 @@ def mask_axis(
|
||||
if callable(mask_value):
|
||||
mask_value = mask_value(array)
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
return array.where(condition, other=mask_value)
|
||||
|
||||
|
||||
@ -511,6 +517,7 @@ def mask_time(
|
||||
end = start + mask_size
|
||||
spectrogram = mask_axis(spectrogram, "time", start, end)
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
@ -566,6 +573,7 @@ def mask_frequency(
|
||||
end = start + mask_size
|
||||
spectrogram = mask_axis(spectrogram, "frequency", start, end)
|
||||
|
||||
with xr.set_options(keep_attrs=True):
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
|
@ -26,6 +26,7 @@ from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
import xarray as xr
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user