mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fix augmentations
This commit is contained in:
parent
8a6ed3dec7
commit
d8cf1db19f
@ -18,7 +18,6 @@ from batdetect2.train.config import TrainingConfig, load_train_config
|
|||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
SubclipConfig,
|
|
||||||
TrainExample,
|
TrainExample,
|
||||||
list_preprocessed_files,
|
list_preprocessed_files,
|
||||||
)
|
)
|
||||||
@ -37,7 +36,6 @@ __all__ = [
|
|||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
"RandomExampleSource",
|
"RandomExampleSource",
|
||||||
"SubclipConfig",
|
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
"TrainerConfig",
|
||||||
|
@ -126,7 +126,8 @@ def mix_examples(
|
|||||||
audio1 = example["audio"]
|
audio1 = example["audio"]
|
||||||
audio2 = adjust_width(other["audio"].values, len(audio1))
|
audio2 = adjust_width(other["audio"].values, len(audio1))
|
||||||
|
|
||||||
combined = weight * audio1 + (1 - weight) * audio2
|
with xr.set_options(keep_attrs=True):
|
||||||
|
combined = weight * audio1 + (1 - weight) * audio2
|
||||||
|
|
||||||
spectrogram = preprocessor.compute_spectrogram(
|
spectrogram = preprocessor.compute_spectrogram(
|
||||||
combined.rename({"audio_time": "time"})
|
combined.rename({"audio_time": "time"})
|
||||||
@ -227,7 +228,6 @@ def add_echo(
|
|||||||
'spectrogram' variable recomputed. Other variables (targets, attrs)
|
'spectrogram' variable recomputed. Other variables (targets, attrs)
|
||||||
are copied from the original example.
|
are copied from the original example.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if delay is None:
|
if delay is None:
|
||||||
delay = np.random.uniform(0, max_delay)
|
delay = np.random.uniform(0, max_delay)
|
||||||
|
|
||||||
@ -237,7 +237,9 @@ def add_echo(
|
|||||||
audio = example["audio"]
|
audio = example["audio"]
|
||||||
step = arrays.get_dim_step(audio, "audio_time")
|
step = arrays.get_dim_step(audio, "audio_time")
|
||||||
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||||
audio = audio + weight * audio_delay
|
|
||||||
|
with xr.set_options(keep_attrs=True):
|
||||||
|
audio = audio + weight * audio_delay
|
||||||
|
|
||||||
spectrogram = preprocessor.compute_spectrogram(
|
spectrogram = preprocessor.compute_spectrogram(
|
||||||
audio.rename({"audio_time": "time"}),
|
audio.rename({"audio_time": "time"}),
|
||||||
@ -260,6 +262,7 @@ def add_echo(
|
|||||||
data=spectrogram,
|
data=spectrogram,
|
||||||
dims=example["spectrogram"].dims,
|
dims=example["spectrogram"].dims,
|
||||||
coords=example["spectrogram"].coords,
|
coords=example["spectrogram"].coords,
|
||||||
|
attrs=example["spectrogram"].attrs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -311,7 +314,8 @@ def scale_volume(
|
|||||||
if factor is None:
|
if factor is None:
|
||||||
factor = np.random.uniform(min_scaling, max_scaling)
|
factor = np.random.uniform(min_scaling, max_scaling)
|
||||||
|
|
||||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
with xr.set_options(keep_attrs=True):
|
||||||
|
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||||
|
|
||||||
|
|
||||||
class WarpAugmentationConfig(BaseConfig):
|
class WarpAugmentationConfig(BaseConfig):
|
||||||
@ -366,51 +370,52 @@ def warp_spectrogram(
|
|||||||
example.time.size,
|
example.time.size,
|
||||||
)
|
)
|
||||||
|
|
||||||
spectrogram = (
|
with xr.set_options(keep_attrs=True):
|
||||||
example["spectrogram"]
|
spectrogram = (
|
||||||
.interp(
|
example["spectrogram"]
|
||||||
coords={"time": new_time},
|
.interp(
|
||||||
method="linear",
|
coords={"time": new_time},
|
||||||
|
method="linear",
|
||||||
|
kwargs=dict(
|
||||||
|
fill_value=0,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
.clip(min=0)
|
||||||
|
)
|
||||||
|
|
||||||
|
detection = example["detection"].interp(
|
||||||
|
time=new_time,
|
||||||
|
method="nearest",
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
fill_value=0,
|
fill_value=0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.clip(min=0)
|
|
||||||
)
|
|
||||||
|
|
||||||
detection = example["detection"].interp(
|
classification = example["class"].interp(
|
||||||
time=new_time,
|
time=new_time,
|
||||||
method="nearest",
|
method="nearest",
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
fill_value=0,
|
fill_value=0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
classification = example["class"].interp(
|
size = example["size"].interp(
|
||||||
time=new_time,
|
time=new_time,
|
||||||
method="nearest",
|
method="nearest",
|
||||||
kwargs=dict(
|
kwargs=dict(
|
||||||
fill_value=0,
|
fill_value=0,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
size = example["size"].interp(
|
return example.assign(
|
||||||
time=new_time,
|
{
|
||||||
method="nearest",
|
"time": new_time,
|
||||||
kwargs=dict(
|
"spectrogram": spectrogram,
|
||||||
fill_value=0,
|
"detection": detection,
|
||||||
),
|
"class": classification,
|
||||||
)
|
"size": size,
|
||||||
|
}
|
||||||
return example.assign(
|
)
|
||||||
{
|
|
||||||
"time": new_time,
|
|
||||||
"spectrogram": spectrogram,
|
|
||||||
"detection": detection,
|
|
||||||
"class": classification,
|
|
||||||
"size": size,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def mask_axis(
|
def mask_axis(
|
||||||
@ -423,7 +428,7 @@ def mask_axis(
|
|||||||
"""Mask values along a specified dimension.
|
"""Mask values along a specified dimension.
|
||||||
|
|
||||||
Sets values in the DataArray to `mask_value` where the coordinate along
|
Sets values in the DataArray to `mask_value` where the coordinate along
|
||||||
`dim` falls within the range [`start`, `end`). Values outside this range
|
`dim` falls within the range (`start`, `end`). Values outside this range
|
||||||
are kept. Used as a helper for time/frequency masking.
|
are kept. Used as a helper for time/frequency masking.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -462,7 +467,8 @@ def mask_axis(
|
|||||||
if callable(mask_value):
|
if callable(mask_value):
|
||||||
mask_value = mask_value(array)
|
mask_value = mask_value(array)
|
||||||
|
|
||||||
return array.where(condition, other=mask_value)
|
with xr.set_options(keep_attrs=True):
|
||||||
|
return array.where(condition, other=mask_value)
|
||||||
|
|
||||||
|
|
||||||
class TimeMaskAugmentationConfig(BaseConfig):
|
class TimeMaskAugmentationConfig(BaseConfig):
|
||||||
@ -511,7 +517,8 @@ def mask_time(
|
|||||||
end = start + mask_size
|
end = start + mask_size
|
||||||
spectrogram = mask_axis(spectrogram, "time", start, end)
|
spectrogram = mask_axis(spectrogram, "time", start, end)
|
||||||
|
|
||||||
return example.assign(spectrogram=spectrogram)
|
with xr.set_options(keep_attrs=True):
|
||||||
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||||
@ -566,7 +573,8 @@ def mask_frequency(
|
|||||||
end = start + mask_size
|
end = start + mask_size
|
||||||
spectrogram = mask_axis(spectrogram, "frequency", start, end)
|
spectrogram = mask_axis(spectrogram, "frequency", start, end)
|
||||||
|
|
||||||
return example.assign(spectrogram=spectrogram)
|
with xr.set_options(keep_attrs=True):
|
||||||
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
|
@ -26,6 +26,7 @@ from pathlib import Path
|
|||||||
from typing import Callable, Optional, Sequence
|
from typing import Callable, Optional, Sequence
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user