Remove train preprocessing

This commit is contained in:
mbsantiago 2025-08-31 18:28:52 +01:00
parent 1cec332dd5
commit 40f6b64611
13 changed files with 712 additions and 664 deletions

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.7.0",
"soundevent[audio,geometry,plot]>=2.8.0",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",

View File

@ -6,19 +6,19 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
from batdetect2.train.dataset import list_preprocessed_files
__all__ = ["train_command"]
@cli.command(name="train")
@click.argument("train_dir", type=click.Path(exists=True))
@click.option("--val-dir", type=click.Path(exists=True))
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@ -31,8 +31,8 @@ __all__ = ["train_command"]
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command(
train_dir: Path,
val_dir: Optional[Path] = None,
train_dataset: Path,
val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
@ -58,29 +58,27 @@ def train_command(
else FullTrainingConfig()
)
logger.info("Scanning for training and validation data...")
train_examples = list_preprocessed_files(train_dir)
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(train_dataset)
logger.debug(
"Found {num_files} training examples in {path}",
num_files=len(train_examples),
path=train_dir,
"Loaded {num_annotations} training examples",
num_annotations=len(train_annotations),
)
val_examples = None
if val_dir is not None:
val_examples = list_preprocessed_files(val_dir)
val_annotations = None
if val_dataset is not None:
val_annotations = load_dataset_from_config(val_dataset)
logger.debug(
"Found {num_files} validation examples in {path}",
num_files=len(val_examples),
path=val_dir,
"Loaded {num_annotations} validation examples",
num_files=len(val_annotations),
)
else:
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train(
train_examples=train_examples,
val_examples=val_examples,
train_annotations=train_annotations,
val_annotations=val_annotations,
config=conf,
model_path=model_path,
train_workers=train_workers,

View File

@ -38,6 +38,8 @@ def plot_spectrogram(
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
spec = spec.squeeze()
ax = create_ax(ax=ax, figsize=figsize)
if start_time is None:

View File

@ -25,6 +25,8 @@ def plot_detection_heatmap(
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
heatmap = heatmap.squeeze()
if threshold is not None:
heatmap = np.ma.masked_where(
heatmap < threshold,

View File

@ -2,7 +2,7 @@ from batdetect2.train.augmentations import (
AugmentationsConfig,
EchoAugmentationConfig,
FrequencyMaskAugmentationConfig,
RandomExampleSource,
RandomAudioSource,
TimeMaskAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
@ -10,7 +10,7 @@ from batdetect2.train.augmentations import (
build_augmentations,
mask_frequency,
mask_time,
mix_examples,
mix_audio,
scale_volume,
warp_spectrogram,
)
@ -22,10 +22,7 @@ from batdetect2.train.config import (
load_full_training_config,
load_train_config,
)
from batdetect2.train.dataset import (
LabeledDataset,
list_preprocessed_files,
)
from batdetect2.train.dataset import TrainingDataset
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import (
@ -56,11 +53,11 @@ __all__ = [
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"LabeledDataset",
"TrainingDataset",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
"RandomExampleSource",
"RandomAudioSource",
"SizeLossConfig",
"TimeMaskAugmentationConfig",
"TrainingConfig",
@ -78,13 +75,12 @@ __all__ = [
"build_val_dataset",
"build_val_loader",
"generate_train_example",
"list_preprocessed_files",
"load_full_training_config",
"load_label_config",
"load_train_config",
"mask_frequency",
"mask_time",
"mix_examples",
"mix_audio",
"preprocess_annotations",
"scale_volume",
"select_subclip",

View File

@ -9,14 +9,12 @@ import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
from batdetect2.train.clips import get_subclip_annotation
from batdetect2.typing import Augmentation
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.utils.arrays import adjust_width
__all__ = [
@ -24,7 +22,7 @@ __all__ = [
"AugmentationsConfig",
"DEFAULT_AUGMENTATION_CONFIG",
"EchoAugmentationConfig",
"ExampleSource",
"AudioSource",
"FrequencyMaskAugmentationConfig",
"MixAugmentationConfig",
"TimeMaskAugmentationConfig",
@ -35,365 +33,12 @@ __all__ = [
"load_augmentation_config",
"mask_frequency",
"mask_time",
"mix_examples",
"mix_audio",
"scale_volume",
"warp_spectrogram",
]
ExampleSource = Callable[[], PreprocessedExample]
"""Type alias for a function that returns a training example"""
def mix_examples(
example: PreprocessedExample,
other: PreprocessedExample,
preprocessor: PreprocessorProtocol,
weight: float,
) -> PreprocessedExample:
"""Combine two training examples."""
audio1 = example.audio
audio2 = adjust_width(other.audio, audio1.shape[-1])
combined = weight * audio1 + (1 - weight) * audio2
spectrogram = preprocessor(combined)
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
previous_width = example.spectrogram.shape[-1]
spectrogram = adjust_width(spectrogram, previous_width)
detection_heatmap = torch.maximum(
example.detection_heatmap,
adjust_width(other.detection_heatmap, previous_width),
)
class_heatmap = torch.maximum(
example.class_heatmap,
adjust_width(other.class_heatmap, previous_width),
)
size_heatmap = torch.maximum(
example.size_heatmap,
adjust_width(other.size_heatmap, previous_width),
)
return PreprocessedExample(
audio=combined,
spectrogram=spectrogram,
detection_heatmap=detection_heatmap,
class_heatmap=class_heatmap,
size_heatmap=size_heatmap,
)
class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
augmentation_type: Literal["add_echo"] = "add_echo"
probability: float = 0.2
"""Probability of applying this augmentation."""
max_delay: float = 0.005
min_weight: float = 0.0
max_weight: float = 1.0
class AddEcho(torch.nn.Module):
def __init__(
self,
preprocessor: PreprocessorProtocol,
min_weight: float = 0.1,
max_weight: float = 1.0,
max_delay: float = 0.005,
):
super().__init__()
self.preprocessor = preprocessor
self.min_weight = min_weight
self.max_weight = max_weight
self.max_delay = max_delay
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
delay = np.random.uniform(0, self.max_delay)
weight = np.random.uniform(self.min_weight, self.max_weight)
return add_echo(
example,
preprocessor=self.preprocessor,
delay=delay,
weight=weight,
)
def add_echo(
example: PreprocessedExample,
preprocessor: PreprocessorProtocol,
delay: float,
weight: float,
) -> PreprocessedExample:
"""Add a synthetic echo to the audio waveform."""
audio = example.audio
delay_steps = int(preprocessor.input_samplerate * delay)
slices = [slice(None)] * audio.ndim
slices[-1] = slice(None, -delay_steps)
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
delay_steps, dims=-1
)
audio = audio + weight * audio_delay
spectrogram = preprocessor(audio)
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
spectrogram = adjust_width(
spectrogram,
example.spectrogram.shape[-1],
)
return PreprocessedExample(
audio=audio,
spectrogram=spectrogram,
detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
)
class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
augmentation_type: Literal["scale_volume"] = "scale_volume"
probability: float = 0.2
min_scaling: float = 0.0
max_scaling: float = 2.0
class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__()
self.min_scaling = min_scaling
self.max_scaling = max_scaling
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
factor = np.random.uniform(self.min_scaling, self.max_scaling)
return scale_volume(example, factor=factor)
def scale_volume(
example: PreprocessedExample,
factor: Optional[float] = None,
) -> PreprocessedExample:
"""Scale the amplitude of the spectrogram by a random factor."""
return PreprocessedExample(
audio=example.audio,
size_heatmap=example.size_heatmap,
class_heatmap=example.class_heatmap,
detection_heatmap=example.detection_heatmap,
spectrogram=example.spectrogram * factor,
)
class WarpAugmentationConfig(BaseConfig):
augmentation_type: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04
class WarpSpectrogram(torch.nn.Module):
def __init__(self, delta: float = 0.04) -> None:
super().__init__()
self.delta = delta
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
factor = np.random.uniform(1 - self.delta, 1 + self.delta)
return warp_spectrogram(example, factor=factor)
def warp_spectrogram(
example: PreprocessedExample, factor: float
) -> PreprocessedExample:
"""Apply time warping by resampling the time axis."""
width = example.spectrogram.shape[-1]
height = example.spectrogram.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor)
spectrogram = torch.nn.functional.interpolate(
adjust_width(example.spectrogram, new_width).unsqueeze(0),
size=target_shape,
mode="bilinear",
).squeeze(0)
detection = torch.nn.functional.interpolate(
adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
size=target_shape,
mode="nearest",
).squeeze(0)
classification = torch.nn.functional.interpolate(
adjust_width(example.class_heatmap, new_width).unsqueeze(1),
size=target_shape,
mode="nearest",
).squeeze(1)
size = torch.nn.functional.interpolate(
adjust_width(example.size_heatmap, new_width).unsqueeze(1),
size=target_shape,
mode="nearest",
).squeeze(1)
return PreprocessedExample(
audio=example.audio,
size_heatmap=size,
class_heatmap=classification,
detection_heatmap=detection,
spectrogram=spectrogram,
)
class TimeMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05
max_masks: int = 3
class MaskTime(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
width = example.spectrogram.shape[-1]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * width),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=width - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_time(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply time masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * example.spectrogram.ndim
slices[-1] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
size_heatmap=example.size_heatmap,
class_heatmap=example.class_heatmap,
detection_heatmap=example.detection_heatmap,
spectrogram=example.spectrogram,
)
class FrequencyMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
height = example.spectrogram.shape[-2]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * height),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=height - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_frequency(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply frequency masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * example.spectrogram.ndim
slices[-2] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
size_heatmap=example.size_heatmap,
class_heatmap=example.class_heatmap,
detection_heatmap=example.detection_heatmap,
spectrogram=example.spectrogram,
)
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
class MixAugmentationConfig(BaseConfig):
@ -416,8 +61,7 @@ class MixAudio(torch.nn.Module):
def __init__(
self,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
example_source: AudioSource,
min_weight: float = 0.3,
max_weight: float = 0.7,
):
@ -426,20 +70,364 @@ class MixAudio(torch.nn.Module):
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
def __call__(
self,
wav: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
"""Fetch another example and perform mixup."""
other = self.example_source()
other_wav, other_clip_annotation = self.example_source(
clip_annotation.clip.duration
)
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
mixed_audio = mix_audio(wav, other_wav, weight=weight)
mixed_annotations = combine_clip_annotations(
clip_annotation,
other_clip_annotation,
)
return mixed_audio, mixed_annotations
def mix_audio(
wav1: torch.Tensor,
wav2: torch.Tensor,
weight: float,
) -> torch.Tensor:
"""Combine two training examples."""
wav2 = adjust_width(wav2, wav1.shape[-1])
return weight * wav1 + (1 - weight) * wav2
def shift_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation,
time: float,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
sound_event = sound_event.model_copy(
update=dict(geometry=shift_geometry(geometry, time=time))
)
return sound_event_annotation.model_copy(
update=dict(sound_event=sound_event)
)
def combine_clip_annotations(
clip_annotation1: data.ClipAnnotation,
clip_annotation2: data.ClipAnnotation,
) -> data.ClipAnnotation:
time_shift = (
clip_annotation1.clip.start_time - clip_annotation2.clip.start_time
)
return clip_annotation1.model_copy(
update=dict(
sound_events=[
*clip_annotation1.sound_events,
*[
shift_sound_event_annotation(sound_event, time=time_shift)
for sound_event in clip_annotation2.sound_events
],
]
)
)
class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
augmentation_type: Literal["add_echo"] = "add_echo"
probability: float = 0.2
max_delay: float = 0.005
min_weight: float = 0.0
max_weight: float = 1.0
class AddEcho(torch.nn.Module):
def __init__(
self,
min_weight: float = 0.1,
max_weight: float = 1.0,
max_delay: int = 2560,
):
super().__init__()
self.min_weight = min_weight
self.max_weight = max_weight
self.max_delay = max_delay
def forward(
self,
wav: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
delay = np.random.randint(0, self.max_delay)
weight = np.random.uniform(self.min_weight, self.max_weight)
return add_echo(wav, delay=delay, weight=weight), clip_annotation
def add_echo(
wav: torch.Tensor,
delay: int,
weight: float,
) -> torch.Tensor:
"""Add a synthetic echo to the audio waveform."""
slices = [slice(None)] * wav.ndim
slices[-1] = slice(None, -delay)
audio_delay = adjust_width(wav[tuple(slices)], wav.shape[-1]).roll(
delay, dims=-1
)
return mix_audio(wav, audio_delay, weight)
class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
augmentation_type: Literal["scale_volume"] = "scale_volume"
probability: float = 0.2
min_scaling: float = 0.0
max_scaling: float = 2.0
class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__()
self.min_scaling = min_scaling
self.max_scaling = max_scaling
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
factor = np.random.uniform(self.min_scaling, self.max_scaling)
return scale_volume(spec, factor=factor), clip_annotation
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
"""Scale the amplitude of the spectrogram by a factor."""
return spec * factor
class WarpAugmentationConfig(BaseConfig):
augmentation_type: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04
class WarpSpectrogram(torch.nn.Module):
def __init__(self, delta: float = 0.04) -> None:
super().__init__()
self.delta = delta
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
factor = np.random.uniform(1 - self.delta, 1 + self.delta)
return (
warp_spectrogram(spec, factor=factor),
warp_clip_annotation(clip_annotation, factor=factor),
)
def warp_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation,
factor: float,
anchor: float,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
sound_event = sound_event.model_copy(
update=dict(
geometry=scale_geometry(
geometry,
time=1 / factor,
time_anchor=anchor,
)
),
)
return sound_event_annotation.model_copy(
update=dict(sound_event=sound_event)
)
def warp_clip_annotation(
clip_annotation: data.ClipAnnotation,
factor: float,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
warp_sound_event_annotation(
sound_event,
factor=factor,
anchor=clip_annotation.clip.start_time,
)
for sound_event in clip_annotation.sound_events
]
)
)
def warp_spectrogram(
spec: torch.Tensor,
factor: float,
) -> torch.Tensor:
"""Apply time warping by resampling the time axis."""
width = spec.shape[-1]
height = spec.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor)
return torch.nn.functional.interpolate(
adjust_width(spec, new_width).unsqueeze(0),
size=target_shape,
mode="bilinear",
).squeeze(0)
class TimeMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05
max_masks: int = 3
class MaskTime(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
num_masks = np.random.randint(1, self.max_masks + 1)
width = spec.shape[-1]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * width),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=width - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_time(spec, masks), clip_annotation
def mask_time(
spec: torch.Tensor,
masks: List[Tuple[int, int]],
value: float = 0,
) -> torch.Tensor:
"""Apply time masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * spec.ndim
slices[-1] = slice(start, end)
spec[tuple(slices)] = value
return spec
class FrequencyMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
num_masks = np.random.randint(1, self.max_masks + 1)
height = spec.shape[-2]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * height),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=height - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_frequency(spec, masks), clip_annotation
def mask_frequency(
spec: torch.Tensor,
masks: List[Tuple[int, int]],
) -> torch.Tensor:
"""Apply frequency masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * spec.ndim
slices[-2] = slice(start, end)
spec[tuple(slices)] = 0
return spec
AudioAugmentationConfig = Annotated[
Union[
MixAugmentationConfig,
EchoAugmentationConfig,
],
Field(discriminator="augmentation_type"),
]
SpectrogramAugmentationConfig = Annotated[
Union[
VolumeAugmentationConfig,
WarpAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="augmentation_type"),
]
AugmentationConfig = Annotated[
Union[
MixAugmentationConfig,
@ -459,7 +447,11 @@ class AugmentationsConfig(BaseConfig):
enabled: bool = True
steps: List[AugmentationConfig] = Field(default_factory=list)
audio: List[AudioAugmentationConfig] = Field(default_factory=list)
spectrogram: List[SpectrogramAugmentationConfig] = Field(
default_factory=list
)
class MaybeApply(torch.nn.Module):
@ -470,46 +462,31 @@ class MaybeApply(torch.nn.Module):
augmentation: Augmentation,
probability: float = 0.2,
):
"""Initialize the wrapper.
Parameters
----------
augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset])
The augmentation function to potentially apply.
probability : float, default=0.5
The probability (0.0 to 1.0) of applying the augmentation.
"""
"""Initialize the wrapper."""
super().__init__()
self.augmentation = augmentation
self.probability = probability
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Apply the wrapped augmentation with configured probability.
Parameters
----------
example : xr.Dataset
The input training example.
Returns
-------
xr.Dataset
The potentially augmented training example.
"""
def __call__(
self,
tensor: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
"""Apply the wrapped augmentation with configured probability."""
if np.random.random() > self.probability:
return example
return tensor, clip_annotation
return self.augmentation(example)
return self.augmentation(tensor, clip_annotation)
def build_augmentation_from_config(
config: AugmentationConfig,
preprocessor: PreprocessorProtocol,
example_source: Optional[ExampleSource] = None,
samplerate: int,
audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]:
"""Factory function to build a single augmentation from its config."""
if config.augmentation_type == "mix_audio":
if example_source is None:
if audio_source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.",
@ -518,16 +495,14 @@ def build_augmentation_from_config(
return None
return MixAudio(
example_source=example_source,
preprocessor=preprocessor,
example_source=audio_source,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.augmentation_type == "add_echo":
return AddEcho(
preprocessor=preprocessor,
max_delay=config.max_delay,
max_delay=int(config.max_delay * samplerate),
min_weight=config.min_weight,
max_weight=config.max_weight,
)
@ -562,37 +537,35 @@ def build_augmentation_from_config(
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
steps=[
enabled=True,
audio=[
MixAugmentationConfig(),
EchoAugmentationConfig(),
],
spectrogram=[
VolumeAugmentationConfig(),
WarpAugmentationConfig(),
TimeMaskAugmentationConfig(),
FrequencyMaskAugmentationConfig(),
]
],
)
def build_augmentations(
preprocessor: PreprocessorProtocol,
config: Optional[AugmentationsConfig] = None,
example_source: Optional[ExampleSource] = None,
) -> Augmentation:
"""Build a composite augmentation pipeline function from configuration."""
config = config or DEFAULT_AUGMENTATION_CONFIG
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
def build_augmentation_sequence(
samplerate: int,
steps: Optional[Sequence[AugmentationConfig]] = None,
audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]:
if not steps:
return None
augmentations = []
for step_config in config.steps:
for step_config in steps:
augmentation = build_augmentation_from_config(
step_config,
preprocessor=preprocessor,
example_source=example_source,
samplerate=samplerate,
audio_source=audio_source,
)
if augmentation is None:
@ -608,6 +581,33 @@ def build_augmentations(
return torch.nn.Sequential(*augmentations)
def build_augmentations(
samplerate: int,
config: Optional[AugmentationsConfig] = None,
audio_source: Optional[AudioSource] = None,
) -> Tuple[Optional[Augmentation], Optional[Augmentation]]:
"""Build a composite augmentation pipeline function from configuration."""
config = config or DEFAULT_AUGMENTATION_CONFIG
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
audio_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,
audio_source=audio_source,
)
spectrogram_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,
audio_source=audio_source,
)
return audio_augmentation, spectrogram_augmentation
def load_augmentation_config(
path: data.PathLike, field: Optional[str] = None
) -> AugmentationsConfig:
@ -615,23 +615,24 @@ def load_augmentation_config(
return load_config(path, schema=AugmentationsConfig, field=field)
class RandomExampleSource:
class RandomAudioSource:
def __init__(
self,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
):
self.filenames = filenames
self.clipper = clipper
self.audio_loader = audio_loader
self.clip_annotations = clip_annotations
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example
@classmethod
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol):
filenames = list_preprocessed_files(path)
return cls(filenames, clipper=clipper)
def __call__(
self,
duration: float,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
index = int(np.random.randint(len(self.clip_annotations)))
clip_annotation = get_subclip_annotation(
self.clip_annotations[index],
duration=duration,
max_empty=0,
)
wav = self.audio_loader.load_clip(clip_annotation.clip)
return torch.from_numpy(wav).unsqueeze(0), clip_annotation

View File

@ -17,7 +17,7 @@ from batdetect2.evaluate.match import (
from batdetect2.models import Model
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions
from batdetect2.train.dataset import LabeledDataset
from batdetect2.train.dataset import TrainingDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import (
BatDetect2Prediction,
@ -49,11 +49,11 @@ class ValidationMetrics(Callback):
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
] = []
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
def get_dataset(self, trainer: Trainer) -> TrainingDataset:
dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset)
assert isinstance(dataset, TrainingDataset)
return dataset
def plot_examples(
@ -136,12 +136,12 @@ class ValidationMetrics(Callback):
def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
dataset: TrainingDataset,
model: Model,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
dataset.clip_annotations[int(example_id)],
start_time=start_time.item(),
end_time=end_time.item(),
targets=model.targets,

View File

@ -1,14 +1,12 @@
from typing import Optional, Tuple
from typing import List, Optional
import numpy as np
import torch
from loguru import logger
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width, slice_tensor
DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -18,50 +16,127 @@ class ClipingConfig(BaseConfig):
duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
min_sound_event_overlap: float = 0
class Clipper(torch.nn.Module):
class Clipper:
def __init__(
self,
preprocessor: PreprocessorProtocol,
duration: float = 0.5,
max_empty: float = 0.2,
random: bool = True,
min_sound_event_overlap: float = 0,
):
super().__init__()
self.preprocessor = preprocessor
self.duration = duration
self.random = random
self.max_empty = max_empty
self.min_sound_event_overlap = min_sound_event_overlap
def forward(
def __call__(
self,
example: PreprocessedExample,
) -> Tuple[PreprocessedExample, float, float]:
start_time = 0
duration = example.audio.shape[-1] / self.preprocessor.input_samplerate
if self.random:
start_time = np.random.uniform(
-self.max_empty,
duration - self.duration + self.max_empty,
)
return (
select_subclip(
example,
start=start_time,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
return get_subclip_annotation(
clip_annotation,
random=self.random,
duration=self.duration,
input_samplerate=self.preprocessor.input_samplerate,
output_samplerate=self.preprocessor.output_samplerate,
),
start_time,
start_time + self.duration,
max_empty=self.max_empty,
min_sound_event_overlap=self.min_sound_event_overlap,
)
def get_subclip_annotation(
clip_annotation: data.ClipAnnotation,
random: bool = True,
duration: float = 0.5,
max_empty: float = 0.2,
min_sound_event_overlap: float = 0,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
subclip = select_subclip(
clip,
random=random,
duration=duration,
max_empty=max_empty,
)
sound_events = select_sound_event_annotations(
clip_annotation,
subclip,
min_overlap=min_sound_event_overlap,
)
return clip_annotation.model_copy(
update=dict(
clip=subclip,
sound_events=sound_events,
)
)
def select_subclip(
clip: data.Clip,
random: bool = True,
duration: float = 0.5,
max_empty: float = 0.2,
) -> data.Clip:
start_time = clip.start_time
end_time = clip.end_time
if duration > clip.duration + max_empty or not random:
return clip.model_copy(
update=dict(
start_time=start_time,
end_time=start_time + duration,
)
)
random_start_time = np.random.uniform(
low=start_time,
high=end_time + max_empty - duration,
)
return clip.model_copy(
update=dict(
start_time=random_start_time,
end_time=random_start_time + duration,
)
)
def select_sound_event_annotations(
clip_annotation: data.ClipAnnotation,
subclip: data.Clip,
min_overlap: float = 0,
) -> List[data.SoundEventAnnotation]:
selected = []
start_time = subclip.start_time
end_time = subclip.end_time
for sound_event_annotation in clip_annotation.sound_events:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
continue
geom_start_time, _, geom_end_time, _ = compute_bounds(geometry)
if not intervals_overlap(
(start_time, end_time),
(geom_start_time, geom_end_time),
min_absolute_overlap=min_overlap,
):
continue
selected.append(sound_event_annotation)
return selected
def build_clipper(
preprocessor: PreprocessorProtocol,
config: Optional[ClipingConfig] = None,
random: Optional[bool] = None,
) -> ClipperProtocol:
@ -71,73 +146,7 @@ def build_clipper(
lambda: config.to_yaml_string(),
)
return Clipper(
preprocessor=preprocessor,
duration=config.duration,
max_empty=config.max_empty,
random=config.random if random else False,
)
def select_subclip(
example: PreprocessedExample,
start: float,
duration: float,
input_samplerate: float,
output_samplerate: float,
fill_value: float = 0,
) -> PreprocessedExample:
audio_width = int(np.floor(duration * input_samplerate))
audio_start = int(np.floor(start * input_samplerate))
audio = adjust_width(
slice_tensor(
example.audio,
start=audio_start,
end=audio_start + audio_width,
dim=-1,
),
audio_width,
value=fill_value,
)
spec_start = int(np.floor(start * output_samplerate))
spec_width = int(np.floor(duration * output_samplerate))
return PreprocessedExample(
audio=audio,
spectrogram=adjust_width(
slice_tensor(
example.spectrogram,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
class_heatmap=adjust_width(
slice_tensor(
example.class_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
detection_heatmap=adjust_width(
slice_tensor(
example.detection_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
size_heatmap=adjust_width(
slice_tensor(
example.size_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
)

View File

@ -6,11 +6,13 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import ClipingConfig
from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig
@ -50,7 +52,7 @@ class DataLoaderConfig(BaseConfig):
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
class LoadersConfig(BaseConfig):
@ -73,6 +75,8 @@ class TrainingConfig(BaseConfig):
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def load_train_config(

View File

@ -1,78 +1,77 @@
from typing import Optional, Sequence, Tuple
from typing import Optional, Sequence
import numpy as np
import torch
from soundevent import data
from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.train import PreprocessedExample
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import Augmentation, ClipLabeller
__all__ = [
"LabeledDataset",
"TrainingDataset",
]
class LabeledDataset(Dataset):
class TrainingDataset(Dataset):
def __init__(
self,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
augmentation: Optional[Augmentation] = None,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_augmentation: Optional[Augmentation] = None,
spectrogram_augmentation: Optional[Augmentation] = None,
audio_dir: Optional[data.PathLike] = None,
):
self.filenames = filenames
self.clip_annotations = clip_annotations
self.clipper = clipper
self.augmentation = augmentation
self.labeller = labeller
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_augmentation = audio_augmentation
self.spectrogram_augmentation = spectrogram_augmentation
self.audio_dir = audio_dir
def __len__(self):
return len(self.filenames)
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
example = self.get_example(idx)
clip_annotation = self.clip_annotations[idx]
example, start_time, end_time = self.clipper(example)
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
if self.augmentation:
example = self.augmentation(example)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
# Add channel dim
wav_tensor = torch.tensor(wav).unsqueeze(0)
if self.audio_augmentation is not None:
wav_tensor, clip_annotation = self.audio_augmentation(
wav_tensor,
clip_annotation,
)
spectrogram = self.preprocessor(wav_tensor)
if self.spectrogram_augmentation is not None:
spectrogram, clip_annotation = self.spectrogram_augmentation(
spectrogram,
clip_annotation,
)
heatmaps = self.labeller(clip_annotation, spectrogram)
return TrainExample(
spec=example.spectrogram,
detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
spec=spectrogram,
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
idx=torch.tensor(idx),
start_time=torch.tensor(start_time),
end_time=torch.tensor(end_time),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
@classmethod
def from_directory(
cls,
directory: data.PathLike,
clipper: ClipperProtocol,
extension: str = ".npz",
augmentation: Optional[Augmentation] = None,
):
return cls(
filenames=list_preprocessed_files(directory, extension),
clipper=clipper,
augmentation=augmentation,
)
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
idx = np.random.randint(0, len(self))
dataset = self.get_example(idx)
dataset, start_time, end_time = self.clipper(dataset)
return dataset, start_time, end_time
def get_example(self, idx) -> PreprocessedExample:
return load_preprocessed_example(self.filenames[idx])
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
return item["clip_annotation"].tolist()

View File

@ -15,16 +15,18 @@ from batdetect2.evaluate.metrics import (
DetectionAveragePrecision,
)
from batdetect2.models import Model, build_model
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.train.augmentations import (
RandomExampleSource,
RandomAudioSource,
build_augmentations,
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import (
LabeledDataset,
TrainingDataset,
)
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
@ -33,6 +35,7 @@ from batdetect2.typing import (
TargetProtocol,
TrainExample,
)
from batdetect2.typing.train import ClipLabeller
from batdetect2.utils.arrays import adjust_width
__all__ = [
@ -46,8 +49,8 @@ __all__ = [
def train(
train_examples: Sequence[data.PathLike],
val_examples: Optional[Sequence[data.PathLike]] = None,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None,
@ -59,8 +62,19 @@ def train(
trainer = build_trainer(config, targets=model.targets)
audio_loader = build_audio_loader(config=config.preprocess.audio)
labeller = build_clip_labeler(
model.targets,
min_freq=model.preprocessor.min_freq,
max_freq=model.preprocessor.max_freq,
config=config.train.labels,
)
train_dataloader = build_train_loader(
train_examples,
train_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=model.preprocessor,
config=config.train,
num_workers=train_workers,
@ -68,12 +82,14 @@ def train(
val_dataloader = (
build_val_loader(
val_examples,
val_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=model.preprocessor,
config=config.train,
num_workers=val_workers,
)
if val_examples is not None
if val_annotations is not None
else None
)
@ -153,19 +169,23 @@ def build_trainer(
def build_train_loader(
train_examples: Sequence[data.PathLike],
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainingConfig()
logger.info("Building training data loader...")
train_dataset = build_train_dataset(
train_examples,
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
logger.info("Building training data loader...")
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
@ -182,16 +202,20 @@ def build_train_loader(
def build_val_loader(
val_examples: Sequence[data.PathLike],
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
config = config or TrainingConfig()
logger.info("Building validation data loader...")
val_dataset = build_val_dataset(
val_examples,
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
@ -203,7 +227,7 @@ def build_val_loader(
num_workers = num_workers or loader_conf.num_workers
return DataLoader(
val_dataset,
batch_size=loader_conf.batch_size,
batch_size=1,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
@ -232,52 +256,60 @@ def _collate_fn(batch: List[TrainExample]) -> TrainExample:
def build_train_dataset(
examples: Sequence[data.PathLike],
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
) -> LabeledDataset:
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainingConfig()
clipper = build_clipper(
preprocessor=preprocessor,
config=config.cliping,
random=True,
)
random_example_source = RandomExampleSource(
list(examples),
clipper=clipper,
random_example_source = RandomAudioSource(
clip_annotations,
audio_loader=audio_loader,
)
if config.augmentations.enabled and config.augmentations.steps:
augmentations = build_augmentations(
preprocessor,
if config.augmentations.enabled:
audio_augmentation, spectrogram_augmentation = build_augmentations(
samplerate=preprocessor.input_samplerate,
config=config.augmentations,
example_source=random_example_source,
audio_source=random_example_source,
)
else:
logger.debug("No augmentations configured for training dataset.")
augmentations = None
audio_augmentation = None
spectrogram_augmentation = None
return LabeledDataset(
examples,
return TrainingDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
clipper=clipper,
augmentation=augmentations,
preprocessor=preprocessor,
audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation,
)
def build_val_dataset(
examples: Sequence[data.PathLike],
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset:
) -> TrainingDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig()
clipper = build_clipper(
return TrainingDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.cliping,
random=train,
)
return LabeledDataset(examples, clipper=clipper)

View File

@ -49,7 +49,11 @@ spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training.
"""
Augmentation = Callable[[PreprocessedExample], PreprocessedExample]
Augmentation = Callable[
[torch.Tensor, data.ClipAnnotation],
Tuple[torch.Tensor, data.ClipAnnotation],
]
class TrainExample(NamedTuple):
@ -97,5 +101,6 @@ class LossProtocol(Protocol):
class ClipperProtocol(Protocol):
def __call__(
self, example: PreprocessedExample
) -> Tuple[PreprocessedExample, float, float]: ...
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ...

View File

@ -6,7 +6,7 @@ from soundevent import data
from batdetect2.train.augmentations import (
add_echo,
mix_examples,
mix_audio,
)
from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example
@ -41,7 +41,7 @@ def test_mix_examples(
labeller=sample_labeller,
)
mixed = mix_examples(
mixed = mix_audio(
example1,
example2,
weight=0.3,
@ -86,7 +86,7 @@ def test_mix_examples_of_different_durations(
labeller=sample_labeller,
)
mixed = mix_examples(
mixed = mix_audio(
example1,
example2,
weight=0.3,