Remove train preprocessing

This commit is contained in:
mbsantiago 2025-08-31 18:28:52 +01:00
parent 1cec332dd5
commit 40f6b64611
13 changed files with 712 additions and 664 deletions

View File

@ -17,7 +17,7 @@ dependencies = [
"torch>=1.13.1,<2.5.0", "torch>=1.13.1,<2.5.0",
"torchaudio>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0",
"torchvision>=0.14.0", "torchvision>=0.14.0",
"soundevent[audio,geometry,plot]>=2.7.0", "soundevent[audio,geometry,plot]>=2.8.0",
"click>=8.1.7", "click>=8.1.7",
"netcdf4>=1.6.5", "netcdf4>=1.6.5",
"tqdm>=4.66.2", "tqdm>=4.66.2",

View File

@ -6,19 +6,19 @@ import click
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.train import ( from batdetect2.train import (
FullTrainingConfig, FullTrainingConfig,
load_full_training_config, load_full_training_config,
train, train,
) )
from batdetect2.train.dataset import list_preprocessed_files
__all__ = ["train_command"] __all__ = ["train_command"]
@cli.command(name="train") @cli.command(name="train")
@click.argument("train_dir", type=click.Path(exists=True)) @click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dir", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@ -31,8 +31,8 @@ __all__ = ["train_command"]
help="Increase verbosity. -v for INFO, -vv for DEBUG.", help="Increase verbosity. -v for INFO, -vv for DEBUG.",
) )
def train_command( def train_command(
train_dir: Path, train_dataset: Path,
val_dir: Optional[Path] = None, val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None, model_path: Optional[Path] = None,
config: Optional[Path] = None, config: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
@ -58,29 +58,27 @@ def train_command(
else FullTrainingConfig() else FullTrainingConfig()
) )
logger.info("Scanning for training and validation data...") logger.info("Loading training dataset...")
train_examples = list_preprocessed_files(train_dir) train_annotations = load_dataset_from_config(train_dataset)
logger.debug( logger.debug(
"Found {num_files} training examples in {path}", "Loaded {num_annotations} training examples",
num_files=len(train_examples), num_annotations=len(train_annotations),
path=train_dir,
) )
val_examples = None val_annotations = None
if val_dir is not None: if val_dataset is not None:
val_examples = list_preprocessed_files(val_dir) val_annotations = load_dataset_from_config(val_dataset)
logger.debug( logger.debug(
"Found {num_files} validation examples in {path}", "Loaded {num_annotations} validation examples",
num_files=len(val_examples), num_files=len(val_annotations),
path=val_dir,
) )
else: else:
logger.debug("No validation directory provided.") logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...") logger.info("Configuration and data loaded. Starting training...")
train( train(
train_examples=train_examples, train_annotations=train_annotations,
val_examples=val_examples, val_annotations=val_annotations,
config=conf, config=conf,
model_path=model_path, model_path=model_path,
train_workers=train_workers, train_workers=train_workers,

View File

@ -38,6 +38,8 @@ def plot_spectrogram(
if isinstance(spec, torch.Tensor): if isinstance(spec, torch.Tensor):
spec = spec.numpy() spec = spec.numpy()
spec = spec.squeeze()
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
if start_time is None: if start_time is None:

View File

@ -25,6 +25,8 @@ def plot_detection_heatmap(
if isinstance(heatmap, torch.Tensor): if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy() heatmap = heatmap.numpy()
heatmap = heatmap.squeeze()
if threshold is not None: if threshold is not None:
heatmap = np.ma.masked_where( heatmap = np.ma.masked_where(
heatmap < threshold, heatmap < threshold,

View File

@ -2,7 +2,7 @@ from batdetect2.train.augmentations import (
AugmentationsConfig, AugmentationsConfig,
EchoAugmentationConfig, EchoAugmentationConfig,
FrequencyMaskAugmentationConfig, FrequencyMaskAugmentationConfig,
RandomExampleSource, RandomAudioSource,
TimeMaskAugmentationConfig, TimeMaskAugmentationConfig,
VolumeAugmentationConfig, VolumeAugmentationConfig,
WarpAugmentationConfig, WarpAugmentationConfig,
@ -10,7 +10,7 @@ from batdetect2.train.augmentations import (
build_augmentations, build_augmentations,
mask_frequency, mask_frequency,
mask_time, mask_time,
mix_examples, mix_audio,
scale_volume, scale_volume,
warp_spectrogram, warp_spectrogram,
) )
@ -22,10 +22,7 @@ from batdetect2.train.config import (
load_full_training_config, load_full_training_config,
load_train_config, load_train_config,
) )
from batdetect2.train.dataset import ( from batdetect2.train.dataset import TrainingDataset
LabeledDataset,
list_preprocessed_files,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import ( from batdetect2.train.losses import (
@ -56,11 +53,11 @@ __all__ = [
"EchoAugmentationConfig", "EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig", "FrequencyMaskAugmentationConfig",
"FullTrainingConfig", "FullTrainingConfig",
"LabeledDataset", "TrainingDataset",
"LossConfig", "LossConfig",
"LossFunction", "LossFunction",
"PLTrainerConfig", "PLTrainerConfig",
"RandomExampleSource", "RandomAudioSource",
"SizeLossConfig", "SizeLossConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainingConfig", "TrainingConfig",
@ -78,13 +75,12 @@ __all__ = [
"build_val_dataset", "build_val_dataset",
"build_val_loader", "build_val_loader",
"generate_train_example", "generate_train_example",
"list_preprocessed_files",
"load_full_training_config", "load_full_training_config",
"load_label_config", "load_label_config",
"load_train_config", "load_train_config",
"mask_frequency", "mask_frequency",
"mask_time", "mask_time",
"mix_examples", "mix_audio",
"preprocess_annotations", "preprocess_annotations",
"scale_volume", "scale_volume",
"select_subclip", "select_subclip",

View File

@ -9,14 +9,12 @@ import torch
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.preprocess import ( from batdetect2.train.clips import get_subclip_annotation
list_preprocessed_files, from batdetect2.typing import Augmentation
load_preprocessed_example, from batdetect2.typing.preprocess import AudioLoader
)
from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width
__all__ = [ __all__ = [
@ -24,7 +22,7 @@ __all__ = [
"AugmentationsConfig", "AugmentationsConfig",
"DEFAULT_AUGMENTATION_CONFIG", "DEFAULT_AUGMENTATION_CONFIG",
"EchoAugmentationConfig", "EchoAugmentationConfig",
"ExampleSource", "AudioSource",
"FrequencyMaskAugmentationConfig", "FrequencyMaskAugmentationConfig",
"MixAugmentationConfig", "MixAugmentationConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
@ -35,365 +33,12 @@ __all__ = [
"load_augmentation_config", "load_augmentation_config",
"mask_frequency", "mask_frequency",
"mask_time", "mask_time",
"mix_examples", "mix_audio",
"scale_volume", "scale_volume",
"warp_spectrogram", "warp_spectrogram",
] ]
ExampleSource = Callable[[], PreprocessedExample] AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
"""Type alias for a function that returns a training example"""
def mix_examples(
example: PreprocessedExample,
other: PreprocessedExample,
preprocessor: PreprocessorProtocol,
weight: float,
) -> PreprocessedExample:
"""Combine two training examples."""
audio1 = example.audio
audio2 = adjust_width(other.audio, audio1.shape[-1])
combined = weight * audio1 + (1 - weight) * audio2
spectrogram = preprocessor(combined)
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
previous_width = example.spectrogram.shape[-1]
spectrogram = adjust_width(spectrogram, previous_width)
detection_heatmap = torch.maximum(
example.detection_heatmap,
adjust_width(other.detection_heatmap, previous_width),
)
class_heatmap = torch.maximum(
example.class_heatmap,
adjust_width(other.class_heatmap, previous_width),
)
size_heatmap = torch.maximum(
example.size_heatmap,
adjust_width(other.size_heatmap, previous_width),
)
return PreprocessedExample(
audio=combined,
spectrogram=spectrogram,
detection_heatmap=detection_heatmap,
class_heatmap=class_heatmap,
size_heatmap=size_heatmap,
)
class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
augmentation_type: Literal["add_echo"] = "add_echo"
probability: float = 0.2
"""Probability of applying this augmentation."""
max_delay: float = 0.005
min_weight: float = 0.0
max_weight: float = 1.0
class AddEcho(torch.nn.Module):
def __init__(
self,
preprocessor: PreprocessorProtocol,
min_weight: float = 0.1,
max_weight: float = 1.0,
max_delay: float = 0.005,
):
super().__init__()
self.preprocessor = preprocessor
self.min_weight = min_weight
self.max_weight = max_weight
self.max_delay = max_delay
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
delay = np.random.uniform(0, self.max_delay)
weight = np.random.uniform(self.min_weight, self.max_weight)
return add_echo(
example,
preprocessor=self.preprocessor,
delay=delay,
weight=weight,
)
def add_echo(
example: PreprocessedExample,
preprocessor: PreprocessorProtocol,
delay: float,
weight: float,
) -> PreprocessedExample:
"""Add a synthetic echo to the audio waveform."""
audio = example.audio
delay_steps = int(preprocessor.input_samplerate * delay)
slices = [slice(None)] * audio.ndim
slices[-1] = slice(None, -delay_steps)
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
delay_steps, dims=-1
)
audio = audio + weight * audio_delay
spectrogram = preprocessor(audio)
# NOTE: The subclip's spectrogram might be slightly longer than the
# spectrogram computed from the subclip's audio. This is due to a
# simplification in the subclip process: It doesn't account for the
# spectrogram parameters to precisely determine the corresponding audio
# samples. To work around this, we pad the computed spectrogram with zeros
# as needed.
spectrogram = adjust_width(
spectrogram,
example.spectrogram.shape[-1],
)
return PreprocessedExample(
audio=audio,
spectrogram=spectrogram,
detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
)
class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
augmentation_type: Literal["scale_volume"] = "scale_volume"
probability: float = 0.2
min_scaling: float = 0.0
max_scaling: float = 2.0
class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__()
self.min_scaling = min_scaling
self.max_scaling = max_scaling
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
factor = np.random.uniform(self.min_scaling, self.max_scaling)
return scale_volume(example, factor=factor)
def scale_volume(
example: PreprocessedExample,
factor: Optional[float] = None,
) -> PreprocessedExample:
"""Scale the amplitude of the spectrogram by a random factor."""
return PreprocessedExample(
audio=example.audio,
size_heatmap=example.size_heatmap,
class_heatmap=example.class_heatmap,
detection_heatmap=example.detection_heatmap,
spectrogram=example.spectrogram * factor,
)
class WarpAugmentationConfig(BaseConfig):
augmentation_type: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04
class WarpSpectrogram(torch.nn.Module):
def __init__(self, delta: float = 0.04) -> None:
super().__init__()
self.delta = delta
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
factor = np.random.uniform(1 - self.delta, 1 + self.delta)
return warp_spectrogram(example, factor=factor)
def warp_spectrogram(
example: PreprocessedExample, factor: float
) -> PreprocessedExample:
"""Apply time warping by resampling the time axis."""
width = example.spectrogram.shape[-1]
height = example.spectrogram.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor)
spectrogram = torch.nn.functional.interpolate(
adjust_width(example.spectrogram, new_width).unsqueeze(0),
size=target_shape,
mode="bilinear",
).squeeze(0)
detection = torch.nn.functional.interpolate(
adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
size=target_shape,
mode="nearest",
).squeeze(0)
classification = torch.nn.functional.interpolate(
adjust_width(example.class_heatmap, new_width).unsqueeze(1),
size=target_shape,
mode="nearest",
).squeeze(1)
size = torch.nn.functional.interpolate(
adjust_width(example.size_heatmap, new_width).unsqueeze(1),
size=target_shape,
mode="nearest",
).squeeze(1)
return PreprocessedExample(
audio=example.audio,
size_heatmap=size,
class_heatmap=classification,
detection_heatmap=detection,
spectrogram=spectrogram,
)
class TimeMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05
max_masks: int = 3
class MaskTime(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
width = example.spectrogram.shape[-1]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * width),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=width - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_time(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply time masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * example.spectrogram.ndim
slices[-1] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
size_heatmap=example.size_heatmap,
class_heatmap=example.class_heatmap,
detection_heatmap=example.detection_heatmap,
spectrogram=example.spectrogram,
)
class FrequencyMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
height = example.spectrogram.shape[-2]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * height),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=height - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_frequency(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply frequency masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * example.spectrogram.ndim
slices[-2] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
size_heatmap=example.size_heatmap,
class_heatmap=example.class_heatmap,
detection_heatmap=example.detection_heatmap,
spectrogram=example.spectrogram,
)
class MixAugmentationConfig(BaseConfig): class MixAugmentationConfig(BaseConfig):
@ -416,8 +61,7 @@ class MixAudio(torch.nn.Module):
def __init__( def __init__(
self, self,
example_source: ExampleSource, example_source: AudioSource,
preprocessor: PreprocessorProtocol,
min_weight: float = 0.3, min_weight: float = 0.3,
max_weight: float = 0.7, max_weight: float = 0.7,
): ):
@ -426,20 +70,364 @@ class MixAudio(torch.nn.Module):
self.min_weight = min_weight self.min_weight = min_weight
self.example_source = example_source self.example_source = example_source
self.max_weight = max_weight self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample: def __call__(
self,
wav: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
"""Fetch another example and perform mixup.""" """Fetch another example and perform mixup."""
other = self.example_source() other_wav, other_clip_annotation = self.example_source(
clip_annotation.clip.duration
)
weight = np.random.uniform(self.min_weight, self.max_weight) weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples( mixed_audio = mix_audio(wav, other_wav, weight=weight)
example, mixed_annotations = combine_clip_annotations(
other, clip_annotation,
self.preprocessor, other_clip_annotation,
weight=weight, )
return mixed_audio, mixed_annotations
def mix_audio(
wav1: torch.Tensor,
wav2: torch.Tensor,
weight: float,
) -> torch.Tensor:
"""Combine two training examples."""
wav2 = adjust_width(wav2, wav1.shape[-1])
return weight * wav1 + (1 - weight) * wav2
def shift_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation,
time: float,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
sound_event = sound_event.model_copy(
update=dict(geometry=shift_geometry(geometry, time=time))
)
return sound_event_annotation.model_copy(
update=dict(sound_event=sound_event)
)
def combine_clip_annotations(
clip_annotation1: data.ClipAnnotation,
clip_annotation2: data.ClipAnnotation,
) -> data.ClipAnnotation:
time_shift = (
clip_annotation1.clip.start_time - clip_annotation2.clip.start_time
)
return clip_annotation1.model_copy(
update=dict(
sound_events=[
*clip_annotation1.sound_events,
*[
shift_sound_event_annotation(sound_event, time=time_shift)
for sound_event in clip_annotation2.sound_events
],
]
)
)
class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
augmentation_type: Literal["add_echo"] = "add_echo"
probability: float = 0.2
max_delay: float = 0.005
min_weight: float = 0.0
max_weight: float = 1.0
class AddEcho(torch.nn.Module):
def __init__(
self,
min_weight: float = 0.1,
max_weight: float = 1.0,
max_delay: int = 2560,
):
super().__init__()
self.min_weight = min_weight
self.max_weight = max_weight
self.max_delay = max_delay
def forward(
self,
wav: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
delay = np.random.randint(0, self.max_delay)
weight = np.random.uniform(self.min_weight, self.max_weight)
return add_echo(wav, delay=delay, weight=weight), clip_annotation
def add_echo(
wav: torch.Tensor,
delay: int,
weight: float,
) -> torch.Tensor:
"""Add a synthetic echo to the audio waveform."""
slices = [slice(None)] * wav.ndim
slices[-1] = slice(None, -delay)
audio_delay = adjust_width(wav[tuple(slices)], wav.shape[-1]).roll(
delay, dims=-1
)
return mix_audio(wav, audio_delay, weight)
class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
augmentation_type: Literal["scale_volume"] = "scale_volume"
probability: float = 0.2
min_scaling: float = 0.0
max_scaling: float = 2.0
class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__()
self.min_scaling = min_scaling
self.max_scaling = max_scaling
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
factor = np.random.uniform(self.min_scaling, self.max_scaling)
return scale_volume(spec, factor=factor), clip_annotation
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
"""Scale the amplitude of the spectrogram by a factor."""
return spec * factor
class WarpAugmentationConfig(BaseConfig):
augmentation_type: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04
class WarpSpectrogram(torch.nn.Module):
def __init__(self, delta: float = 0.04) -> None:
super().__init__()
self.delta = delta
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
factor = np.random.uniform(1 - self.delta, 1 + self.delta)
return (
warp_spectrogram(spec, factor=factor),
warp_clip_annotation(clip_annotation, factor=factor),
) )
def warp_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation,
factor: float,
anchor: float,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
sound_event = sound_event.model_copy(
update=dict(
geometry=scale_geometry(
geometry,
time=1 / factor,
time_anchor=anchor,
)
),
)
return sound_event_annotation.model_copy(
update=dict(sound_event=sound_event)
)
def warp_clip_annotation(
clip_annotation: data.ClipAnnotation,
factor: float,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
warp_sound_event_annotation(
sound_event,
factor=factor,
anchor=clip_annotation.clip.start_time,
)
for sound_event in clip_annotation.sound_events
]
)
)
def warp_spectrogram(
spec: torch.Tensor,
factor: float,
) -> torch.Tensor:
"""Apply time warping by resampling the time axis."""
width = spec.shape[-1]
height = spec.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor)
return torch.nn.functional.interpolate(
adjust_width(spec, new_width).unsqueeze(0),
size=target_shape,
mode="bilinear",
).squeeze(0)
class TimeMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05
max_masks: int = 3
class MaskTime(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
num_masks = np.random.randint(1, self.max_masks + 1)
width = spec.shape[-1]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * width),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=width - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_time(spec, masks), clip_annotation
def mask_time(
spec: torch.Tensor,
masks: List[Tuple[int, int]],
value: float = 0,
) -> torch.Tensor:
"""Apply time masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * spec.ndim
slices[-1] = slice(start, end)
spec[tuple(slices)] = value
return spec
class FrequencyMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module):
def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(
self,
spec: torch.Tensor,
clip_annotation: data.ClipAnnotation,
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
num_masks = np.random.randint(1, self.max_masks + 1)
height = spec.shape[-2]
mask_size = np.random.randint(
low=0,
high=int(self.max_perc * height),
size=num_masks,
)
mask_start = np.random.randint(
low=0,
high=height - mask_size,
size=num_masks,
)
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_frequency(spec, masks), clip_annotation
def mask_frequency(
spec: torch.Tensor,
masks: List[Tuple[int, int]],
) -> torch.Tensor:
"""Apply frequency masking to the spectrogram."""
for start, end in masks:
slices = [slice(None)] * spec.ndim
slices[-2] = slice(start, end)
spec[tuple(slices)] = 0
return spec
AudioAugmentationConfig = Annotated[
Union[
MixAugmentationConfig,
EchoAugmentationConfig,
],
Field(discriminator="augmentation_type"),
]
SpectrogramAugmentationConfig = Annotated[
Union[
VolumeAugmentationConfig,
WarpAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="augmentation_type"),
]
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[
Union[ Union[
MixAugmentationConfig, MixAugmentationConfig,
@ -459,7 +447,11 @@ class AugmentationsConfig(BaseConfig):
enabled: bool = True enabled: bool = True
steps: List[AugmentationConfig] = Field(default_factory=list) audio: List[AudioAugmentationConfig] = Field(default_factory=list)
spectrogram: List[SpectrogramAugmentationConfig] = Field(
default_factory=list
)
class MaybeApply(torch.nn.Module): class MaybeApply(torch.nn.Module):
@ -470,46 +462,31 @@ class MaybeApply(torch.nn.Module):
augmentation: Augmentation, augmentation: Augmentation,
probability: float = 0.2, probability: float = 0.2,
): ):
"""Initialize the wrapper. """Initialize the wrapper."""
Parameters
----------
augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset])
The augmentation function to potentially apply.
probability : float, default=0.5
The probability (0.0 to 1.0) of applying the augmentation.
"""
super().__init__() super().__init__()
self.augmentation = augmentation self.augmentation = augmentation
self.probability = probability self.probability = probability
def __call__(self, example: PreprocessedExample) -> PreprocessedExample: def __call__(
"""Apply the wrapped augmentation with configured probability. self,
tensor: torch.Tensor,
Parameters clip_annotation: data.ClipAnnotation,
---------- ) -> Tuple[torch.Tensor, data.ClipAnnotation]:
example : xr.Dataset """Apply the wrapped augmentation with configured probability."""
The input training example.
Returns
-------
xr.Dataset
The potentially augmented training example.
"""
if np.random.random() > self.probability: if np.random.random() > self.probability:
return example return tensor, clip_annotation
return self.augmentation(example) return self.augmentation(tensor, clip_annotation)
def build_augmentation_from_config( def build_augmentation_from_config(
config: AugmentationConfig, config: AugmentationConfig,
preprocessor: PreprocessorProtocol, samplerate: int,
example_source: Optional[ExampleSource] = None, audio_source: Optional[AudioSource] = None,
) -> Optional[Augmentation]: ) -> Optional[Augmentation]:
"""Factory function to build a single augmentation from its config.""" """Factory function to build a single augmentation from its config."""
if config.augmentation_type == "mix_audio": if config.augmentation_type == "mix_audio":
if example_source is None: if audio_source is None:
warnings.warn( warnings.warn(
"Mix audio augmentation ('mix_audio') requires an " "Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.", "'example_source' callable to be provided.",
@ -518,16 +495,14 @@ def build_augmentation_from_config(
return None return None
return MixAudio( return MixAudio(
example_source=example_source, example_source=audio_source,
preprocessor=preprocessor,
min_weight=config.min_weight, min_weight=config.min_weight,
max_weight=config.max_weight, max_weight=config.max_weight,
) )
if config.augmentation_type == "add_echo": if config.augmentation_type == "add_echo":
return AddEcho( return AddEcho(
preprocessor=preprocessor, max_delay=int(config.max_delay * samplerate),
max_delay=config.max_delay,
min_weight=config.min_weight, min_weight=config.min_weight,
max_weight=config.max_weight, max_weight=config.max_weight,
) )
@ -562,37 +537,35 @@ def build_augmentation_from_config(
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
steps=[ enabled=True,
audio=[
MixAugmentationConfig(), MixAugmentationConfig(),
EchoAugmentationConfig(), EchoAugmentationConfig(),
],
spectrogram=[
VolumeAugmentationConfig(), VolumeAugmentationConfig(),
WarpAugmentationConfig(), WarpAugmentationConfig(),
TimeMaskAugmentationConfig(), TimeMaskAugmentationConfig(),
FrequencyMaskAugmentationConfig(), FrequencyMaskAugmentationConfig(),
] ],
) )
def build_augmentations( def build_augmentation_sequence(
preprocessor: PreprocessorProtocol, samplerate: int,
config: Optional[AugmentationsConfig] = None, steps: Optional[Sequence[AugmentationConfig]] = None,
example_source: Optional[ExampleSource] = None, audio_source: Optional[AudioSource] = None,
) -> Augmentation: ) -> Optional[Augmentation]:
"""Build a composite augmentation pipeline function from configuration.""" if not steps:
config = config or DEFAULT_AUGMENTATION_CONFIG return None
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
augmentations = [] augmentations = []
for step_config in config.steps: for step_config in steps:
augmentation = build_augmentation_from_config( augmentation = build_augmentation_from_config(
step_config, step_config,
preprocessor=preprocessor, samplerate=samplerate,
example_source=example_source, audio_source=audio_source,
) )
if augmentation is None: if augmentation is None:
@ -608,6 +581,33 @@ def build_augmentations(
return torch.nn.Sequential(*augmentations) return torch.nn.Sequential(*augmentations)
def build_augmentations(
samplerate: int,
config: Optional[AugmentationsConfig] = None,
audio_source: Optional[AudioSource] = None,
) -> Tuple[Optional[Augmentation], Optional[Augmentation]]:
"""Build a composite augmentation pipeline function from configuration."""
config = config or DEFAULT_AUGMENTATION_CONFIG
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
audio_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,
audio_source=audio_source,
)
spectrogram_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,
audio_source=audio_source,
)
return audio_augmentation, spectrogram_augmentation
def load_augmentation_config( def load_augmentation_config(
path: data.PathLike, field: Optional[str] = None path: data.PathLike, field: Optional[str] = None
) -> AugmentationsConfig: ) -> AugmentationsConfig:
@ -615,23 +615,24 @@ def load_augmentation_config(
return load_config(path, schema=AugmentationsConfig, field=field) return load_config(path, schema=AugmentationsConfig, field=field)
class RandomExampleSource: class RandomAudioSource:
def __init__( def __init__(
self, self,
filenames: Sequence[data.PathLike], clip_annotations: Sequence[data.ClipAnnotation],
clipper: ClipperProtocol, audio_loader: AudioLoader,
): ):
self.filenames = filenames self.audio_loader = audio_loader
self.clipper = clipper self.clip_annotations = clip_annotations
def __call__(self) -> PreprocessedExample: def __call__(
index = int(np.random.randint(len(self.filenames))) self,
filename = self.filenames[index] duration: float,
example = load_preprocessed_example(filename) ) -> Tuple[torch.Tensor, data.ClipAnnotation]:
example, _, _ = self.clipper(example) index = int(np.random.randint(len(self.clip_annotations)))
return example clip_annotation = get_subclip_annotation(
self.clip_annotations[index],
@classmethod duration=duration,
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol): max_empty=0,
filenames = list_preprocessed_files(path) )
return cls(filenames, clipper=clipper) wav = self.audio_loader.load_clip(clip_annotation.clip)
return torch.from_numpy(wav).unsqueeze(0), clip_annotation

View File

@ -17,7 +17,7 @@ from batdetect2.evaluate.match import (
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions from batdetect2.postprocess import get_sound_event_predictions
from batdetect2.train.dataset import LabeledDataset from batdetect2.train.dataset import TrainingDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import ( from batdetect2.typing import (
BatDetect2Prediction, BatDetect2Prediction,
@ -49,11 +49,11 @@ class ValidationMetrics(Callback):
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]] Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
] = [] ] = []
def get_dataset(self, trainer: Trainer) -> LabeledDataset: def get_dataset(self, trainer: Trainer) -> TrainingDataset:
dataloaders = trainer.val_dataloaders dataloaders = trainer.val_dataloaders
assert isinstance(dataloaders, DataLoader) assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset dataset = dataloaders.dataset
assert isinstance(dataset, LabeledDataset) assert isinstance(dataset, TrainingDataset)
return dataset return dataset
def plot_examples( def plot_examples(
@ -136,12 +136,12 @@ class ValidationMetrics(Callback):
def _get_batch_clips_and_predictions( def _get_batch_clips_and_predictions(
batch: TrainExample, batch: TrainExample,
outputs: ModelOutput, outputs: ModelOutput,
dataset: LabeledDataset, dataset: TrainingDataset,
model: Model, model: Model,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]: ) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [ clip_annotations = [
_get_subclip( _get_subclip(
dataset.get_clip_annotation(example_id), dataset.clip_annotations[int(example_id)],
start_time=start_time.item(), start_time=start_time.item(),
end_time=end_time.item(), end_time=end_time.item(),
targets=model.targets, targets=model.targets,

View File

@ -1,14 +1,12 @@
from typing import Optional, Tuple from typing import List, Optional
import numpy as np import numpy as np
import torch
from loguru import logger from loguru import logger
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width, slice_tensor
DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -18,50 +16,127 @@ class ClipingConfig(BaseConfig):
duration: float = DEFAULT_TRAIN_CLIP_DURATION duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP max_empty: float = DEFAULT_MAX_EMPTY_CLIP
min_sound_event_overlap: float = 0
class Clipper(torch.nn.Module): class Clipper:
def __init__( def __init__(
self, self,
preprocessor: PreprocessorProtocol,
duration: float = 0.5, duration: float = 0.5,
max_empty: float = 0.2, max_empty: float = 0.2,
random: bool = True, random: bool = True,
min_sound_event_overlap: float = 0,
): ):
super().__init__() super().__init__()
self.preprocessor = preprocessor
self.duration = duration self.duration = duration
self.random = random self.random = random
self.max_empty = max_empty self.max_empty = max_empty
self.min_sound_event_overlap = min_sound_event_overlap
def forward( def __call__(
self, self,
example: PreprocessedExample, clip_annotation: data.ClipAnnotation,
) -> Tuple[PreprocessedExample, float, float]: ) -> data.ClipAnnotation:
start_time = 0 return get_subclip_annotation(
duration = example.audio.shape[-1] / self.preprocessor.input_samplerate clip_annotation,
random=self.random,
if self.random: duration=self.duration,
start_time = np.random.uniform( max_empty=self.max_empty,
-self.max_empty, min_sound_event_overlap=self.min_sound_event_overlap,
duration - self.duration + self.max_empty,
)
return (
select_subclip(
example,
start=start_time,
duration=self.duration,
input_samplerate=self.preprocessor.input_samplerate,
output_samplerate=self.preprocessor.output_samplerate,
),
start_time,
start_time + self.duration,
) )
def get_subclip_annotation(
clip_annotation: data.ClipAnnotation,
random: bool = True,
duration: float = 0.5,
max_empty: float = 0.2,
min_sound_event_overlap: float = 0,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
subclip = select_subclip(
clip,
random=random,
duration=duration,
max_empty=max_empty,
)
sound_events = select_sound_event_annotations(
clip_annotation,
subclip,
min_overlap=min_sound_event_overlap,
)
return clip_annotation.model_copy(
update=dict(
clip=subclip,
sound_events=sound_events,
)
)
def select_subclip(
clip: data.Clip,
random: bool = True,
duration: float = 0.5,
max_empty: float = 0.2,
) -> data.Clip:
start_time = clip.start_time
end_time = clip.end_time
if duration > clip.duration + max_empty or not random:
return clip.model_copy(
update=dict(
start_time=start_time,
end_time=start_time + duration,
)
)
random_start_time = np.random.uniform(
low=start_time,
high=end_time + max_empty - duration,
)
return clip.model_copy(
update=dict(
start_time=random_start_time,
end_time=random_start_time + duration,
)
)
def select_sound_event_annotations(
clip_annotation: data.ClipAnnotation,
subclip: data.Clip,
min_overlap: float = 0,
) -> List[data.SoundEventAnnotation]:
selected = []
start_time = subclip.start_time
end_time = subclip.end_time
for sound_event_annotation in clip_annotation.sound_events:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
continue
geom_start_time, _, geom_end_time, _ = compute_bounds(geometry)
if not intervals_overlap(
(start_time, end_time),
(geom_start_time, geom_end_time),
min_absolute_overlap=min_overlap,
):
continue
selected.append(sound_event_annotation)
return selected
def build_clipper( def build_clipper(
preprocessor: PreprocessorProtocol,
config: Optional[ClipingConfig] = None, config: Optional[ClipingConfig] = None,
random: Optional[bool] = None, random: Optional[bool] = None,
) -> ClipperProtocol: ) -> ClipperProtocol:
@ -71,73 +146,7 @@ def build_clipper(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return Clipper( return Clipper(
preprocessor=preprocessor,
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,
random=config.random if random else False, random=config.random if random else False,
) )
def select_subclip(
example: PreprocessedExample,
start: float,
duration: float,
input_samplerate: float,
output_samplerate: float,
fill_value: float = 0,
) -> PreprocessedExample:
audio_width = int(np.floor(duration * input_samplerate))
audio_start = int(np.floor(start * input_samplerate))
audio = adjust_width(
slice_tensor(
example.audio,
start=audio_start,
end=audio_start + audio_width,
dim=-1,
),
audio_width,
value=fill_value,
)
spec_start = int(np.floor(start * output_samplerate))
spec_width = int(np.floor(duration * output_samplerate))
return PreprocessedExample(
audio=audio,
spectrogram=adjust_width(
slice_tensor(
example.spectrogram,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
class_heatmap=adjust_width(
slice_tensor(
example.class_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
detection_heatmap=adjust_width(
slice_tensor(
example.detection_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
size_heatmap=adjust_width(
slice_tensor(
example.size_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
)

View File

@ -6,11 +6,13 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG, DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig, AugmentationsConfig,
) )
from batdetect2.train.clips import ClipingConfig from batdetect2.train.clips import ClipingConfig
from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
@ -50,7 +52,7 @@ class DataLoaderConfig(BaseConfig):
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True) DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False) DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
class LoadersConfig(BaseConfig): class LoadersConfig(BaseConfig):
@ -73,6 +75,8 @@ class TrainingConfig(BaseConfig):
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def load_train_config( def load_train_config(

View File

@ -1,78 +1,77 @@
from typing import Optional, Sequence, Tuple from typing import Optional, Sequence
import numpy as np
import torch import torch
from soundevent import data from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import Augmentation, ClipLabeller
__all__ = [ __all__ = [
"LabeledDataset", "TrainingDataset",
] ]
class LabeledDataset(Dataset): class TrainingDataset(Dataset):
def __init__( def __init__(
self, self,
filenames: Sequence[data.PathLike], clip_annotations: Sequence[data.ClipAnnotation],
clipper: ClipperProtocol, audio_loader: AudioLoader,
augmentation: Optional[Augmentation] = None, preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_augmentation: Optional[Augmentation] = None,
spectrogram_augmentation: Optional[Augmentation] = None,
audio_dir: Optional[data.PathLike] = None,
): ):
self.filenames = filenames self.clip_annotations = clip_annotations
self.clipper = clipper self.clipper = clipper
self.augmentation = augmentation self.labeller = labeller
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_augmentation = audio_augmentation
self.spectrogram_augmentation = spectrogram_augmentation
self.audio_dir = audio_dir
def __len__(self): def __len__(self):
return len(self.filenames) return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
example = self.get_example(idx) clip_annotation = self.clip_annotations[idx]
example, start_time, end_time = self.clipper(example) if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
if self.augmentation: clip = clip_annotation.clip
example = self.augmentation(example)
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
# Add channel dim
wav_tensor = torch.tensor(wav).unsqueeze(0)
if self.audio_augmentation is not None:
wav_tensor, clip_annotation = self.audio_augmentation(
wav_tensor,
clip_annotation,
)
spectrogram = self.preprocessor(wav_tensor)
if self.spectrogram_augmentation is not None:
spectrogram, clip_annotation = self.spectrogram_augmentation(
spectrogram,
clip_annotation,
)
heatmaps = self.labeller(clip_annotation, spectrogram)
return TrainExample( return TrainExample(
spec=example.spectrogram, spec=spectrogram,
detection_heatmap=example.detection_heatmap, detection_heatmap=heatmaps.detection,
class_heatmap=example.class_heatmap, class_heatmap=heatmaps.classes,
size_heatmap=example.size_heatmap, size_heatmap=heatmaps.size,
idx=torch.tensor(idx), idx=torch.tensor(idx),
start_time=torch.tensor(start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(end_time), end_time=torch.tensor(clip.end_time),
) )
@classmethod
def from_directory(
cls,
directory: data.PathLike,
clipper: ClipperProtocol,
extension: str = ".npz",
augmentation: Optional[Augmentation] = None,
):
return cls(
filenames=list_preprocessed_files(directory, extension),
clipper=clipper,
augmentation=augmentation,
)
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
idx = np.random.randint(0, len(self))
dataset = self.get_example(idx)
dataset, start_time, end_time = self.clipper(dataset)
return dataset, start_time, end_time
def get_example(self, idx) -> PreprocessedExample:
return load_preprocessed_example(self.filenames[idx])
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
return item["clip_annotation"].tolist()

View File

@ -15,16 +15,18 @@ from batdetect2.evaluate.metrics import (
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
RandomExampleSource, RandomAudioSource,
build_augmentations, build_augmentations,
) )
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, TrainingDataset,
) )
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss from batdetect2.train.losses import build_loss
@ -33,6 +35,7 @@ from batdetect2.typing import (
TargetProtocol, TargetProtocol,
TrainExample, TrainExample,
) )
from batdetect2.typing.train import ClipLabeller
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width
__all__ = [ __all__ = [
@ -46,8 +49,8 @@ __all__ = [
def train( def train(
train_examples: Sequence[data.PathLike], train_annotations: Sequence[data.ClipAnnotation],
val_examples: Optional[Sequence[data.PathLike]] = None, val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
config: Optional[FullTrainingConfig] = None, config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None, train_workers: Optional[int] = None,
@ -59,8 +62,19 @@ def train(
trainer = build_trainer(config, targets=model.targets) trainer = build_trainer(config, targets=model.targets)
audio_loader = build_audio_loader(config=config.preprocess.audio)
labeller = build_clip_labeler(
model.targets,
min_freq=model.preprocessor.min_freq,
max_freq=model.preprocessor.max_freq,
config=config.train.labels,
)
train_dataloader = build_train_loader( train_dataloader = build_train_loader(
train_examples, train_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=model.preprocessor, preprocessor=model.preprocessor,
config=config.train, config=config.train,
num_workers=train_workers, num_workers=train_workers,
@ -68,12 +82,14 @@ def train(
val_dataloader = ( val_dataloader = (
build_val_loader( build_val_loader(
val_examples, val_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=model.preprocessor, preprocessor=model.preprocessor,
config=config.train, config=config.train,
num_workers=val_workers, num_workers=val_workers,
) )
if val_examples is not None if val_annotations is not None
else None else None
) )
@ -153,19 +169,23 @@ def build_trainer(
def build_train_loader( def build_train_loader(
train_examples: Sequence[data.PathLike], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
config = config or TrainingConfig() config = config or TrainingConfig()
logger.info("Building training data loader...")
train_dataset = build_train_dataset( train_dataset = build_train_dataset(
train_examples, clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, config=config,
) )
logger.info("Building training data loader...")
loader_conf = config.dataloaders.train loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Training data loader config: \n{config}", "Training data loader config: \n{config}",
@ -182,16 +202,20 @@ def build_train_loader(
def build_val_loader( def build_val_loader(
val_examples: Sequence[data.PathLike], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
): ):
logger.info("Building validation data loader...")
config = config or TrainingConfig() config = config or TrainingConfig()
logger.info("Building validation data loader...")
val_dataset = build_val_dataset( val_dataset = build_val_dataset(
val_examples, clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, config=config,
) )
@ -203,7 +227,7 @@ def build_val_loader(
num_workers = num_workers or loader_conf.num_workers num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
val_dataset, val_dataset,
batch_size=loader_conf.batch_size, batch_size=1,
shuffle=loader_conf.shuffle, shuffle=loader_conf.shuffle,
num_workers=num_workers, num_workers=num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
@ -232,52 +256,60 @@ def _collate_fn(batch: List[TrainExample]) -> TrainExample:
def build_train_dataset( def build_train_dataset(
examples: Sequence[data.PathLike], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
) -> LabeledDataset: ) -> TrainingDataset:
logger.info("Building training dataset...") logger.info("Building training dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper( clipper = build_clipper(
preprocessor=preprocessor,
config=config.cliping, config=config.cliping,
random=True, random=True,
) )
random_example_source = RandomExampleSource( random_example_source = RandomAudioSource(
list(examples), clip_annotations,
clipper=clipper, audio_loader=audio_loader,
) )
if config.augmentations.enabled and config.augmentations.steps: if config.augmentations.enabled:
augmentations = build_augmentations( audio_augmentation, spectrogram_augmentation = build_augmentations(
preprocessor, samplerate=preprocessor.input_samplerate,
config=config.augmentations, config=config.augmentations,
example_source=random_example_source, audio_source=random_example_source,
) )
else: else:
logger.debug("No augmentations configured for training dataset.") logger.debug("No augmentations configured for training dataset.")
augmentations = None audio_augmentation = None
spectrogram_augmentation = None
return LabeledDataset( return TrainingDataset(
examples, clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
clipper=clipper, clipper=clipper,
augmentation=augmentations, preprocessor=preprocessor,
audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation,
) )
def build_val_dataset( def build_val_dataset(
examples: Sequence[data.PathLike], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
train: bool = True, ) -> TrainingDataset:
) -> LabeledDataset:
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper(
return TrainingDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.cliping,
random=train,
) )
return LabeledDataset(examples, clipper=clipper)

View File

@ -49,7 +49,11 @@ spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training. steps, and returns the final `Heatmaps` used for model training.
""" """
Augmentation = Callable[[PreprocessedExample], PreprocessedExample]
Augmentation = Callable[
[torch.Tensor, data.ClipAnnotation],
Tuple[torch.Tensor, data.ClipAnnotation],
]
class TrainExample(NamedTuple): class TrainExample(NamedTuple):
@ -97,5 +101,6 @@ class LossProtocol(Protocol):
class ClipperProtocol(Protocol): class ClipperProtocol(Protocol):
def __call__( def __call__(
self, example: PreprocessedExample self,
) -> Tuple[PreprocessedExample, float, float]: ... clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ...

View File

@ -6,7 +6,7 @@ from soundevent import data
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
add_echo, add_echo,
mix_examples, mix_audio,
) )
from batdetect2.train.clips import select_subclip from batdetect2.train.clips import select_subclip
from batdetect2.train.preprocess import generate_train_example from batdetect2.train.preprocess import generate_train_example
@ -41,7 +41,7 @@ def test_mix_examples(
labeller=sample_labeller, labeller=sample_labeller,
) )
mixed = mix_examples( mixed = mix_audio(
example1, example1,
example2, example2,
weight=0.3, weight=0.3,
@ -86,7 +86,7 @@ def test_mix_examples_of_different_durations(
labeller=sample_labeller, labeller=sample_labeller,
) )
mixed = mix_examples( mixed = mix_audio(
example1, example1,
example2, example2,
weight=0.3, weight=0.3,