From 40f6b646110eb5b2fcb7457cdd09410c99c850c5 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 31 Aug 2025 18:28:52 +0100 Subject: [PATCH] Remove train preprocessing --- pyproject.toml | 2 +- src/batdetect2/cli/train.py | 34 +- src/batdetect2/plotting/common.py | 2 + src/batdetect2/plotting/heatmaps.py | 2 + src/batdetect2/train/__init__.py | 16 +- src/batdetect2/train/augmentations.py | 875 +++++++++++++------------ src/batdetect2/train/callbacks.py | 10 +- src/batdetect2/train/clips.py | 203 +++--- src/batdetect2/train/config.py | 6 +- src/batdetect2/train/dataset.py | 105 ++- src/batdetect2/train/train.py | 104 ++- src/batdetect2/typing/train.py | 11 +- tests/test_train/test_augmentations.py | 6 +- 13 files changed, 712 insertions(+), 664 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d6993fb..e10c5c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ "torch>=1.13.1,<2.5.0", "torchaudio>=1.13.1,<2.5.0", "torchvision>=0.14.0", - "soundevent[audio,geometry,plot]>=2.7.0", + "soundevent[audio,geometry,plot]>=2.8.0", "click>=8.1.7", "netcdf4>=1.6.5", "tqdm>=4.66.2", diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 3dcc96b..b16a587 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -6,19 +6,19 @@ import click from loguru import logger from batdetect2.cli.base import cli +from batdetect2.data import load_dataset_from_config from batdetect2.train import ( FullTrainingConfig, load_full_training_config, train, ) -from batdetect2.train.dataset import list_preprocessed_files __all__ = ["train_command"] @cli.command(name="train") -@click.argument("train_dir", type=click.Path(exists=True)) -@click.option("--val-dir", type=click.Path(exists=True)) +@click.argument("train_dataset", type=click.Path(exists=True)) +@click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True)) @click.option("--config-field", type=str) @@ -31,8 +31,8 @@ __all__ = ["train_command"] help="Increase verbosity. -v for INFO, -vv for DEBUG.", ) def train_command( - train_dir: Path, - val_dir: Optional[Path] = None, + train_dataset: Path, + val_dataset: Optional[Path] = None, model_path: Optional[Path] = None, config: Optional[Path] = None, config_field: Optional[str] = None, @@ -58,29 +58,27 @@ def train_command( else FullTrainingConfig() ) - logger.info("Scanning for training and validation data...") - train_examples = list_preprocessed_files(train_dir) + logger.info("Loading training dataset...") + train_annotations = load_dataset_from_config(train_dataset) logger.debug( - "Found {num_files} training examples in {path}", - num_files=len(train_examples), - path=train_dir, + "Loaded {num_annotations} training examples", + num_annotations=len(train_annotations), ) - val_examples = None - if val_dir is not None: - val_examples = list_preprocessed_files(val_dir) + val_annotations = None + if val_dataset is not None: + val_annotations = load_dataset_from_config(val_dataset) logger.debug( - "Found {num_files} validation examples in {path}", - num_files=len(val_examples), - path=val_dir, + "Loaded {num_annotations} validation examples", + num_files=len(val_annotations), ) else: logger.debug("No validation directory provided.") logger.info("Configuration and data loaded. Starting training...") train( - train_examples=train_examples, - val_examples=val_examples, + train_annotations=train_annotations, + val_annotations=val_annotations, config=conf, model_path=model_path, train_workers=train_workers, diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index b0adf80..de54b76 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -38,6 +38,8 @@ def plot_spectrogram( if isinstance(spec, torch.Tensor): spec = spec.numpy() + spec = spec.squeeze() + ax = create_ax(ax=ax, figsize=figsize) if start_time is None: diff --git a/src/batdetect2/plotting/heatmaps.py b/src/batdetect2/plotting/heatmaps.py index 29f261b..8354b38 100644 --- a/src/batdetect2/plotting/heatmaps.py +++ b/src/batdetect2/plotting/heatmaps.py @@ -25,6 +25,8 @@ def plot_detection_heatmap( if isinstance(heatmap, torch.Tensor): heatmap = heatmap.numpy() + heatmap = heatmap.squeeze() + if threshold is not None: heatmap = np.ma.masked_where( heatmap < threshold, diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index f2658e7..bf180e3 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -2,7 +2,7 @@ from batdetect2.train.augmentations import ( AugmentationsConfig, EchoAugmentationConfig, FrequencyMaskAugmentationConfig, - RandomExampleSource, + RandomAudioSource, TimeMaskAugmentationConfig, VolumeAugmentationConfig, WarpAugmentationConfig, @@ -10,7 +10,7 @@ from batdetect2.train.augmentations import ( build_augmentations, mask_frequency, mask_time, - mix_examples, + mix_audio, scale_volume, warp_spectrogram, ) @@ -22,10 +22,7 @@ from batdetect2.train.config import ( load_full_training_config, load_train_config, ) -from batdetect2.train.dataset import ( - LabeledDataset, - list_preprocessed_files, -) +from batdetect2.train.dataset import TrainingDataset from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.lightning import TrainingModule from batdetect2.train.losses import ( @@ -56,11 +53,11 @@ __all__ = [ "EchoAugmentationConfig", "FrequencyMaskAugmentationConfig", "FullTrainingConfig", - "LabeledDataset", + "TrainingDataset", "LossConfig", "LossFunction", "PLTrainerConfig", - "RandomExampleSource", + "RandomAudioSource", "SizeLossConfig", "TimeMaskAugmentationConfig", "TrainingConfig", @@ -78,13 +75,12 @@ __all__ = [ "build_val_dataset", "build_val_loader", "generate_train_example", - "list_preprocessed_files", "load_full_training_config", "load_label_config", "load_train_config", "mask_frequency", "mask_time", - "mix_examples", + "mix_audio", "preprocess_annotations", "scale_volume", "select_subclip", diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index b350e07..37d96b9 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -9,14 +9,12 @@ import torch from loguru import logger from pydantic import Field from soundevent import data +from soundevent.geometry import scale_geometry, shift_geometry from batdetect2.configs import BaseConfig, load_config -from batdetect2.train.preprocess import ( - list_preprocessed_files, - load_preprocessed_example, -) -from batdetect2.typing import Augmentation, PreprocessorProtocol -from batdetect2.typing.train import ClipperProtocol, PreprocessedExample +from batdetect2.train.clips import get_subclip_annotation +from batdetect2.typing import Augmentation +from batdetect2.typing.preprocess import AudioLoader from batdetect2.utils.arrays import adjust_width __all__ = [ @@ -24,7 +22,7 @@ __all__ = [ "AugmentationsConfig", "DEFAULT_AUGMENTATION_CONFIG", "EchoAugmentationConfig", - "ExampleSource", + "AudioSource", "FrequencyMaskAugmentationConfig", "MixAugmentationConfig", "TimeMaskAugmentationConfig", @@ -35,365 +33,12 @@ __all__ = [ "load_augmentation_config", "mask_frequency", "mask_time", - "mix_examples", + "mix_audio", "scale_volume", "warp_spectrogram", ] -ExampleSource = Callable[[], PreprocessedExample] -"""Type alias for a function that returns a training example""" - - -def mix_examples( - example: PreprocessedExample, - other: PreprocessedExample, - preprocessor: PreprocessorProtocol, - weight: float, -) -> PreprocessedExample: - """Combine two training examples.""" - audio1 = example.audio - audio2 = adjust_width(other.audio, audio1.shape[-1]) - - combined = weight * audio1 + (1 - weight) * audio2 - - spectrogram = preprocessor(combined) - - # NOTE: The subclip's spectrogram might be slightly longer than the - # spectrogram computed from the subclip's audio. This is due to a - # simplification in the subclip process: It doesn't account for the - # spectrogram parameters to precisely determine the corresponding audio - # samples. To work around this, we pad the computed spectrogram with zeros - # as needed. - previous_width = example.spectrogram.shape[-1] - spectrogram = adjust_width(spectrogram, previous_width) - - detection_heatmap = torch.maximum( - example.detection_heatmap, - adjust_width(other.detection_heatmap, previous_width), - ) - - class_heatmap = torch.maximum( - example.class_heatmap, - adjust_width(other.class_heatmap, previous_width), - ) - - size_heatmap = torch.maximum( - example.size_heatmap, - adjust_width(other.size_heatmap, previous_width), - ) - - return PreprocessedExample( - audio=combined, - spectrogram=spectrogram, - detection_heatmap=detection_heatmap, - class_heatmap=class_heatmap, - size_heatmap=size_heatmap, - ) - - -class EchoAugmentationConfig(BaseConfig): - """Configuration for adding synthetic echo/reverb.""" - - augmentation_type: Literal["add_echo"] = "add_echo" - - probability: float = 0.2 - """Probability of applying this augmentation.""" - - max_delay: float = 0.005 - min_weight: float = 0.0 - max_weight: float = 1.0 - - -class AddEcho(torch.nn.Module): - def __init__( - self, - preprocessor: PreprocessorProtocol, - min_weight: float = 0.1, - max_weight: float = 1.0, - max_delay: float = 0.005, - ): - super().__init__() - self.preprocessor = preprocessor - self.min_weight = min_weight - self.max_weight = max_weight - self.max_delay = max_delay - - def forward(self, example: PreprocessedExample) -> PreprocessedExample: - delay = np.random.uniform(0, self.max_delay) - weight = np.random.uniform(self.min_weight, self.max_weight) - return add_echo( - example, - preprocessor=self.preprocessor, - delay=delay, - weight=weight, - ) - - -def add_echo( - example: PreprocessedExample, - preprocessor: PreprocessorProtocol, - delay: float, - weight: float, -) -> PreprocessedExample: - """Add a synthetic echo to the audio waveform.""" - - audio = example.audio - delay_steps = int(preprocessor.input_samplerate * delay) - - slices = [slice(None)] * audio.ndim - slices[-1] = slice(None, -delay_steps) - audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll( - delay_steps, dims=-1 - ) - - audio = audio + weight * audio_delay - spectrogram = preprocessor(audio) - - # NOTE: The subclip's spectrogram might be slightly longer than the - # spectrogram computed from the subclip's audio. This is due to a - # simplification in the subclip process: It doesn't account for the - # spectrogram parameters to precisely determine the corresponding audio - # samples. To work around this, we pad the computed spectrogram with zeros - # as needed. - spectrogram = adjust_width( - spectrogram, - example.spectrogram.shape[-1], - ) - - return PreprocessedExample( - audio=audio, - spectrogram=spectrogram, - detection_heatmap=example.detection_heatmap, - class_heatmap=example.class_heatmap, - size_heatmap=example.size_heatmap, - ) - - -class VolumeAugmentationConfig(BaseConfig): - """Configuration for random volume scaling of the spectrogram.""" - - augmentation_type: Literal["scale_volume"] = "scale_volume" - probability: float = 0.2 - min_scaling: float = 0.0 - max_scaling: float = 2.0 - - -class ScaleVolume(torch.nn.Module): - def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0): - super().__init__() - self.min_scaling = min_scaling - self.max_scaling = max_scaling - - def forward(self, example: PreprocessedExample) -> PreprocessedExample: - factor = np.random.uniform(self.min_scaling, self.max_scaling) - return scale_volume(example, factor=factor) - - -def scale_volume( - example: PreprocessedExample, - factor: Optional[float] = None, -) -> PreprocessedExample: - """Scale the amplitude of the spectrogram by a random factor.""" - return PreprocessedExample( - audio=example.audio, - size_heatmap=example.size_heatmap, - class_heatmap=example.class_heatmap, - detection_heatmap=example.detection_heatmap, - spectrogram=example.spectrogram * factor, - ) - - -class WarpAugmentationConfig(BaseConfig): - augmentation_type: Literal["warp"] = "warp" - probability: float = 0.2 - delta: float = 0.04 - - -class WarpSpectrogram(torch.nn.Module): - def __init__(self, delta: float = 0.04) -> None: - super().__init__() - self.delta = delta - - def forward(self, example: PreprocessedExample) -> PreprocessedExample: - factor = np.random.uniform(1 - self.delta, 1 + self.delta) - return warp_spectrogram(example, factor=factor) - - -def warp_spectrogram( - example: PreprocessedExample, factor: float -) -> PreprocessedExample: - """Apply time warping by resampling the time axis.""" - width = example.spectrogram.shape[-1] - height = example.spectrogram.shape[-2] - target_shape = [height, width] - new_width = int(target_shape[-1] * factor) - - spectrogram = torch.nn.functional.interpolate( - adjust_width(example.spectrogram, new_width).unsqueeze(0), - size=target_shape, - mode="bilinear", - ).squeeze(0) - - detection = torch.nn.functional.interpolate( - adjust_width(example.detection_heatmap, new_width).unsqueeze(0), - size=target_shape, - mode="nearest", - ).squeeze(0) - - classification = torch.nn.functional.interpolate( - adjust_width(example.class_heatmap, new_width).unsqueeze(1), - size=target_shape, - mode="nearest", - ).squeeze(1) - - size = torch.nn.functional.interpolate( - adjust_width(example.size_heatmap, new_width).unsqueeze(1), - size=target_shape, - mode="nearest", - ).squeeze(1) - - return PreprocessedExample( - audio=example.audio, - size_heatmap=size, - class_heatmap=classification, - detection_heatmap=detection, - spectrogram=spectrogram, - ) - - -class TimeMaskAugmentationConfig(BaseConfig): - augmentation_type: Literal["mask_time"] = "mask_time" - probability: float = 0.2 - max_perc: float = 0.05 - max_masks: int = 3 - - -class MaskTime(torch.nn.Module): - def __init__( - self, - max_perc: float = 0.05, - max_masks: int = 3, - mask_heatmaps: bool = False, - ) -> None: - super().__init__() - self.max_perc = max_perc - self.max_masks = max_masks - self.mask_heatmaps = mask_heatmaps - - def forward(self, example: PreprocessedExample) -> PreprocessedExample: - num_masks = np.random.randint(1, self.max_masks + 1) - width = example.spectrogram.shape[-1] - - mask_size = np.random.randint( - low=0, - high=int(self.max_perc * width), - size=num_masks, - ) - mask_start = np.random.randint( - low=0, - high=width - mask_size, - size=num_masks, - ) - masks = [ - (start, start + size) for start, size in zip(mask_start, mask_size) - ] - return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps) - - -def mask_time( - example: PreprocessedExample, - masks: List[Tuple[int, int]], - mask_heatmaps: bool = False, -) -> PreprocessedExample: - """Apply time masking to the spectrogram.""" - - for start, end in masks: - slices = [slice(None)] * example.spectrogram.ndim - slices[-1] = slice(start, end) - - example.spectrogram[tuple(slices)] = 0 - - if not mask_heatmaps: - continue - - example.class_heatmap[tuple(slices)] = 0 - example.size_heatmap[tuple(slices)] = 0 - example.detection_heatmap[tuple(slices)] = 0 - - return PreprocessedExample( - audio=example.audio, - size_heatmap=example.size_heatmap, - class_heatmap=example.class_heatmap, - detection_heatmap=example.detection_heatmap, - spectrogram=example.spectrogram, - ) - - -class FrequencyMaskAugmentationConfig(BaseConfig): - augmentation_type: Literal["mask_freq"] = "mask_freq" - probability: float = 0.2 - max_perc: float = 0.10 - max_masks: int = 3 - mask_heatmaps: bool = False - - -class MaskFrequency(torch.nn.Module): - def __init__( - self, - max_perc: float = 0.10, - max_masks: int = 3, - mask_heatmaps: bool = False, - ) -> None: - super().__init__() - self.max_perc = max_perc - self.max_masks = max_masks - self.mask_heatmaps = mask_heatmaps - - def forward(self, example: PreprocessedExample) -> PreprocessedExample: - num_masks = np.random.randint(1, self.max_masks + 1) - height = example.spectrogram.shape[-2] - - mask_size = np.random.randint( - low=0, - high=int(self.max_perc * height), - size=num_masks, - ) - mask_start = np.random.randint( - low=0, - high=height - mask_size, - size=num_masks, - ) - masks = [ - (start, start + size) for start, size in zip(mask_start, mask_size) - ] - return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps) - - -def mask_frequency( - example: PreprocessedExample, - masks: List[Tuple[int, int]], - mask_heatmaps: bool = False, -) -> PreprocessedExample: - """Apply frequency masking to the spectrogram.""" - for start, end in masks: - slices = [slice(None)] * example.spectrogram.ndim - slices[-2] = slice(start, end) - example.spectrogram[tuple(slices)] = 0 - - if not mask_heatmaps: - continue - - example.class_heatmap[tuple(slices)] = 0 - example.size_heatmap[tuple(slices)] = 0 - example.detection_heatmap[tuple(slices)] = 0 - - return PreprocessedExample( - audio=example.audio, - size_heatmap=example.size_heatmap, - class_heatmap=example.class_heatmap, - detection_heatmap=example.detection_heatmap, - spectrogram=example.spectrogram, - ) +AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]] class MixAugmentationConfig(BaseConfig): @@ -416,8 +61,7 @@ class MixAudio(torch.nn.Module): def __init__( self, - example_source: ExampleSource, - preprocessor: PreprocessorProtocol, + example_source: AudioSource, min_weight: float = 0.3, max_weight: float = 0.7, ): @@ -426,20 +70,364 @@ class MixAudio(torch.nn.Module): self.min_weight = min_weight self.example_source = example_source self.max_weight = max_weight - self.preprocessor = preprocessor - def __call__(self, example: PreprocessedExample) -> PreprocessedExample: + def __call__( + self, + wav: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: """Fetch another example and perform mixup.""" - other = self.example_source() + other_wav, other_clip_annotation = self.example_source( + clip_annotation.clip.duration + ) weight = np.random.uniform(self.min_weight, self.max_weight) - return mix_examples( - example, - other, - self.preprocessor, - weight=weight, + mixed_audio = mix_audio(wav, other_wav, weight=weight) + mixed_annotations = combine_clip_annotations( + clip_annotation, + other_clip_annotation, + ) + return mixed_audio, mixed_annotations + + +def mix_audio( + wav1: torch.Tensor, + wav2: torch.Tensor, + weight: float, +) -> torch.Tensor: + """Combine two training examples.""" + wav2 = adjust_width(wav2, wav1.shape[-1]) + return weight * wav1 + (1 - weight) * wav2 + + +def shift_sound_event_annotation( + sound_event_annotation: data.SoundEventAnnotation, + time: float, +) -> data.SoundEventAnnotation: + sound_event = sound_event_annotation.sound_event + geometry = sound_event.geometry + + if geometry is None: + return sound_event_annotation + + sound_event = sound_event.model_copy( + update=dict(geometry=shift_geometry(geometry, time=time)) + ) + return sound_event_annotation.model_copy( + update=dict(sound_event=sound_event) + ) + + +def combine_clip_annotations( + clip_annotation1: data.ClipAnnotation, + clip_annotation2: data.ClipAnnotation, +) -> data.ClipAnnotation: + time_shift = ( + clip_annotation1.clip.start_time - clip_annotation2.clip.start_time + ) + return clip_annotation1.model_copy( + update=dict( + sound_events=[ + *clip_annotation1.sound_events, + *[ + shift_sound_event_annotation(sound_event, time=time_shift) + for sound_event in clip_annotation2.sound_events + ], + ] + ) + ) + + +class EchoAugmentationConfig(BaseConfig): + """Configuration for adding synthetic echo/reverb.""" + + augmentation_type: Literal["add_echo"] = "add_echo" + probability: float = 0.2 + max_delay: float = 0.005 + min_weight: float = 0.0 + max_weight: float = 1.0 + + +class AddEcho(torch.nn.Module): + def __init__( + self, + min_weight: float = 0.1, + max_weight: float = 1.0, + max_delay: int = 2560, + ): + super().__init__() + self.min_weight = min_weight + self.max_weight = max_weight + self.max_delay = max_delay + + def forward( + self, + wav: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + delay = np.random.randint(0, self.max_delay) + weight = np.random.uniform(self.min_weight, self.max_weight) + return add_echo(wav, delay=delay, weight=weight), clip_annotation + + +def add_echo( + wav: torch.Tensor, + delay: int, + weight: float, +) -> torch.Tensor: + """Add a synthetic echo to the audio waveform.""" + + slices = [slice(None)] * wav.ndim + slices[-1] = slice(None, -delay) + audio_delay = adjust_width(wav[tuple(slices)], wav.shape[-1]).roll( + delay, dims=-1 + ) + return mix_audio(wav, audio_delay, weight) + + +class VolumeAugmentationConfig(BaseConfig): + """Configuration for random volume scaling of the spectrogram.""" + + augmentation_type: Literal["scale_volume"] = "scale_volume" + probability: float = 0.2 + min_scaling: float = 0.0 + max_scaling: float = 2.0 + + +class ScaleVolume(torch.nn.Module): + def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0): + super().__init__() + self.min_scaling = min_scaling + self.max_scaling = max_scaling + + def forward( + self, + spec: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + factor = np.random.uniform(self.min_scaling, self.max_scaling) + return scale_volume(spec, factor=factor), clip_annotation + + +def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor: + """Scale the amplitude of the spectrogram by a factor.""" + return spec * factor + + +class WarpAugmentationConfig(BaseConfig): + augmentation_type: Literal["warp"] = "warp" + probability: float = 0.2 + delta: float = 0.04 + + +class WarpSpectrogram(torch.nn.Module): + def __init__(self, delta: float = 0.04) -> None: + super().__init__() + self.delta = delta + + def forward( + self, + spec: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + factor = np.random.uniform(1 - self.delta, 1 + self.delta) + return ( + warp_spectrogram(spec, factor=factor), + warp_clip_annotation(clip_annotation, factor=factor), ) +def warp_sound_event_annotation( + sound_event_annotation: data.SoundEventAnnotation, + factor: float, + anchor: float, +) -> data.SoundEventAnnotation: + sound_event = sound_event_annotation.sound_event + geometry = sound_event.geometry + + if geometry is None: + return sound_event_annotation + + sound_event = sound_event.model_copy( + update=dict( + geometry=scale_geometry( + geometry, + time=1 / factor, + time_anchor=anchor, + ) + ), + ) + return sound_event_annotation.model_copy( + update=dict(sound_event=sound_event) + ) + + +def warp_clip_annotation( + clip_annotation: data.ClipAnnotation, + factor: float, +) -> data.ClipAnnotation: + return clip_annotation.model_copy( + update=dict( + sound_events=[ + warp_sound_event_annotation( + sound_event, + factor=factor, + anchor=clip_annotation.clip.start_time, + ) + for sound_event in clip_annotation.sound_events + ] + ) + ) + + +def warp_spectrogram( + spec: torch.Tensor, + factor: float, +) -> torch.Tensor: + """Apply time warping by resampling the time axis.""" + width = spec.shape[-1] + height = spec.shape[-2] + target_shape = [height, width] + new_width = int(target_shape[-1] * factor) + return torch.nn.functional.interpolate( + adjust_width(spec, new_width).unsqueeze(0), + size=target_shape, + mode="bilinear", + ).squeeze(0) + + +class TimeMaskAugmentationConfig(BaseConfig): + augmentation_type: Literal["mask_time"] = "mask_time" + probability: float = 0.2 + max_perc: float = 0.05 + max_masks: int = 3 + + +class MaskTime(torch.nn.Module): + def __init__( + self, + max_perc: float = 0.05, + max_masks: int = 3, + mask_heatmaps: bool = False, + ) -> None: + super().__init__() + self.max_perc = max_perc + self.max_masks = max_masks + self.mask_heatmaps = mask_heatmaps + + def forward( + self, + spec: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + num_masks = np.random.randint(1, self.max_masks + 1) + width = spec.shape[-1] + + mask_size = np.random.randint( + low=0, + high=int(self.max_perc * width), + size=num_masks, + ) + mask_start = np.random.randint( + low=0, + high=width - mask_size, + size=num_masks, + ) + masks = [ + (start, start + size) for start, size in zip(mask_start, mask_size) + ] + return mask_time(spec, masks), clip_annotation + + +def mask_time( + spec: torch.Tensor, + masks: List[Tuple[int, int]], + value: float = 0, +) -> torch.Tensor: + """Apply time masking to the spectrogram.""" + for start, end in masks: + slices = [slice(None)] * spec.ndim + slices[-1] = slice(start, end) + spec[tuple(slices)] = value + + return spec + + +class FrequencyMaskAugmentationConfig(BaseConfig): + augmentation_type: Literal["mask_freq"] = "mask_freq" + probability: float = 0.2 + max_perc: float = 0.10 + max_masks: int = 3 + mask_heatmaps: bool = False + + +class MaskFrequency(torch.nn.Module): + def __init__( + self, + max_perc: float = 0.10, + max_masks: int = 3, + mask_heatmaps: bool = False, + ) -> None: + super().__init__() + self.max_perc = max_perc + self.max_masks = max_masks + self.mask_heatmaps = mask_heatmaps + + def forward( + self, + spec: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + num_masks = np.random.randint(1, self.max_masks + 1) + height = spec.shape[-2] + + mask_size = np.random.randint( + low=0, + high=int(self.max_perc * height), + size=num_masks, + ) + mask_start = np.random.randint( + low=0, + high=height - mask_size, + size=num_masks, + ) + masks = [ + (start, start + size) for start, size in zip(mask_start, mask_size) + ] + return mask_frequency(spec, masks), clip_annotation + + +def mask_frequency( + spec: torch.Tensor, + masks: List[Tuple[int, int]], +) -> torch.Tensor: + """Apply frequency masking to the spectrogram.""" + for start, end in masks: + slices = [slice(None)] * spec.ndim + slices[-2] = slice(start, end) + spec[tuple(slices)] = 0 + + return spec + + +AudioAugmentationConfig = Annotated[ + Union[ + MixAugmentationConfig, + EchoAugmentationConfig, + ], + Field(discriminator="augmentation_type"), +] + + +SpectrogramAugmentationConfig = Annotated[ + Union[ + VolumeAugmentationConfig, + WarpAugmentationConfig, + FrequencyMaskAugmentationConfig, + TimeMaskAugmentationConfig, + ], + Field(discriminator="augmentation_type"), +] + AugmentationConfig = Annotated[ Union[ MixAugmentationConfig, @@ -459,7 +447,11 @@ class AugmentationsConfig(BaseConfig): enabled: bool = True - steps: List[AugmentationConfig] = Field(default_factory=list) + audio: List[AudioAugmentationConfig] = Field(default_factory=list) + + spectrogram: List[SpectrogramAugmentationConfig] = Field( + default_factory=list + ) class MaybeApply(torch.nn.Module): @@ -470,46 +462,31 @@ class MaybeApply(torch.nn.Module): augmentation: Augmentation, probability: float = 0.2, ): - """Initialize the wrapper. - - Parameters - ---------- - augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset]) - The augmentation function to potentially apply. - probability : float, default=0.5 - The probability (0.0 to 1.0) of applying the augmentation. - """ + """Initialize the wrapper.""" super().__init__() self.augmentation = augmentation self.probability = probability - def __call__(self, example: PreprocessedExample) -> PreprocessedExample: - """Apply the wrapped augmentation with configured probability. - - Parameters - ---------- - example : xr.Dataset - The input training example. - - Returns - ------- - xr.Dataset - The potentially augmented training example. - """ + def __call__( + self, + tensor: torch.Tensor, + clip_annotation: data.ClipAnnotation, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + """Apply the wrapped augmentation with configured probability.""" if np.random.random() > self.probability: - return example + return tensor, clip_annotation - return self.augmentation(example) + return self.augmentation(tensor, clip_annotation) def build_augmentation_from_config( config: AugmentationConfig, - preprocessor: PreprocessorProtocol, - example_source: Optional[ExampleSource] = None, + samplerate: int, + audio_source: Optional[AudioSource] = None, ) -> Optional[Augmentation]: """Factory function to build a single augmentation from its config.""" if config.augmentation_type == "mix_audio": - if example_source is None: + if audio_source is None: warnings.warn( "Mix audio augmentation ('mix_audio') requires an " "'example_source' callable to be provided.", @@ -518,16 +495,14 @@ def build_augmentation_from_config( return None return MixAudio( - example_source=example_source, - preprocessor=preprocessor, + example_source=audio_source, min_weight=config.min_weight, max_weight=config.max_weight, ) if config.augmentation_type == "add_echo": return AddEcho( - preprocessor=preprocessor, - max_delay=config.max_delay, + max_delay=int(config.max_delay * samplerate), min_weight=config.min_weight, max_weight=config.max_weight, ) @@ -562,37 +537,35 @@ def build_augmentation_from_config( DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig( - steps=[ + enabled=True, + audio=[ MixAugmentationConfig(), EchoAugmentationConfig(), + ], + spectrogram=[ VolumeAugmentationConfig(), WarpAugmentationConfig(), TimeMaskAugmentationConfig(), FrequencyMaskAugmentationConfig(), - ] + ], ) -def build_augmentations( - preprocessor: PreprocessorProtocol, - config: Optional[AugmentationsConfig] = None, - example_source: Optional[ExampleSource] = None, -) -> Augmentation: - """Build a composite augmentation pipeline function from configuration.""" - config = config or DEFAULT_AUGMENTATION_CONFIG - - logger.opt(lazy=True).debug( - "Building augmentations with config: \n{}", - lambda: config.to_yaml_string(), - ) +def build_augmentation_sequence( + samplerate: int, + steps: Optional[Sequence[AugmentationConfig]] = None, + audio_source: Optional[AudioSource] = None, +) -> Optional[Augmentation]: + if not steps: + return None augmentations = [] - for step_config in config.steps: + for step_config in steps: augmentation = build_augmentation_from_config( step_config, - preprocessor=preprocessor, - example_source=example_source, + samplerate=samplerate, + audio_source=audio_source, ) if augmentation is None: @@ -608,6 +581,33 @@ def build_augmentations( return torch.nn.Sequential(*augmentations) +def build_augmentations( + samplerate: int, + config: Optional[AugmentationsConfig] = None, + audio_source: Optional[AudioSource] = None, +) -> Tuple[Optional[Augmentation], Optional[Augmentation]]: + """Build a composite augmentation pipeline function from configuration.""" + config = config or DEFAULT_AUGMENTATION_CONFIG + + logger.opt(lazy=True).debug( + "Building augmentations with config: \n{}", + lambda: config.to_yaml_string(), + ) + + audio_augmentation = build_augmentation_sequence( + samplerate, + steps=config.audio, + audio_source=audio_source, + ) + spectrogram_augmentation = build_augmentation_sequence( + samplerate, + steps=config.audio, + audio_source=audio_source, + ) + + return audio_augmentation, spectrogram_augmentation + + def load_augmentation_config( path: data.PathLike, field: Optional[str] = None ) -> AugmentationsConfig: @@ -615,23 +615,24 @@ def load_augmentation_config( return load_config(path, schema=AugmentationsConfig, field=field) -class RandomExampleSource: +class RandomAudioSource: def __init__( self, - filenames: Sequence[data.PathLike], - clipper: ClipperProtocol, + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, ): - self.filenames = filenames - self.clipper = clipper + self.audio_loader = audio_loader + self.clip_annotations = clip_annotations - def __call__(self) -> PreprocessedExample: - index = int(np.random.randint(len(self.filenames))) - filename = self.filenames[index] - example = load_preprocessed_example(filename) - example, _, _ = self.clipper(example) - return example - - @classmethod - def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol): - filenames = list_preprocessed_files(path) - return cls(filenames, clipper=clipper) + def __call__( + self, + duration: float, + ) -> Tuple[torch.Tensor, data.ClipAnnotation]: + index = int(np.random.randint(len(self.clip_annotations))) + clip_annotation = get_subclip_annotation( + self.clip_annotations[index], + duration=duration, + max_empty=0, + ) + wav = self.audio_loader.load_clip(clip_annotation.clip) + return torch.from_numpy(wav).unsqueeze(0), clip_annotation diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 4184440..2195615 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -17,7 +17,7 @@ from batdetect2.evaluate.match import ( from batdetect2.models import Model from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.postprocess import get_sound_event_predictions -from batdetect2.train.dataset import LabeledDataset +from batdetect2.train.dataset import TrainingDataset from batdetect2.train.lightning import TrainingModule from batdetect2.typing import ( BatDetect2Prediction, @@ -49,11 +49,11 @@ class ValidationMetrics(Callback): Tuple[data.ClipAnnotation, List[BatDetect2Prediction]] ] = [] - def get_dataset(self, trainer: Trainer) -> LabeledDataset: + def get_dataset(self, trainer: Trainer) -> TrainingDataset: dataloaders = trainer.val_dataloaders assert isinstance(dataloaders, DataLoader) dataset = dataloaders.dataset - assert isinstance(dataset, LabeledDataset) + assert isinstance(dataset, TrainingDataset) return dataset def plot_examples( @@ -136,12 +136,12 @@ class ValidationMetrics(Callback): def _get_batch_clips_and_predictions( batch: TrainExample, outputs: ModelOutput, - dataset: LabeledDataset, + dataset: TrainingDataset, model: Model, ) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]: clip_annotations = [ _get_subclip( - dataset.get_clip_annotation(example_id), + dataset.clip_annotations[int(example_id)], start_time=start_time.item(), end_time=end_time.item(), targets=model.targets, diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 04da59f..7578aa7 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -1,14 +1,12 @@ -from typing import Optional, Tuple +from typing import List, Optional import numpy as np -import torch from loguru import logger +from soundevent import data +from soundevent.geometry import compute_bounds, intervals_overlap from batdetect2.configs import BaseConfig from batdetect2.typing import ClipperProtocol -from batdetect2.typing.preprocess import PreprocessorProtocol -from batdetect2.typing.train import PreprocessedExample -from batdetect2.utils.arrays import adjust_width, slice_tensor DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_MAX_EMPTY_CLIP = 0.1 @@ -18,50 +16,127 @@ class ClipingConfig(BaseConfig): duration: float = DEFAULT_TRAIN_CLIP_DURATION random: bool = True max_empty: float = DEFAULT_MAX_EMPTY_CLIP + min_sound_event_overlap: float = 0 -class Clipper(torch.nn.Module): +class Clipper: def __init__( self, - preprocessor: PreprocessorProtocol, duration: float = 0.5, max_empty: float = 0.2, random: bool = True, + min_sound_event_overlap: float = 0, ): super().__init__() - self.preprocessor = preprocessor self.duration = duration self.random = random self.max_empty = max_empty + self.min_sound_event_overlap = min_sound_event_overlap - def forward( + def __call__( self, - example: PreprocessedExample, - ) -> Tuple[PreprocessedExample, float, float]: - start_time = 0 - duration = example.audio.shape[-1] / self.preprocessor.input_samplerate - - if self.random: - start_time = np.random.uniform( - -self.max_empty, - duration - self.duration + self.max_empty, - ) - - return ( - select_subclip( - example, - start=start_time, - duration=self.duration, - input_samplerate=self.preprocessor.input_samplerate, - output_samplerate=self.preprocessor.output_samplerate, - ), - start_time, - start_time + self.duration, + clip_annotation: data.ClipAnnotation, + ) -> data.ClipAnnotation: + return get_subclip_annotation( + clip_annotation, + random=self.random, + duration=self.duration, + max_empty=self.max_empty, + min_sound_event_overlap=self.min_sound_event_overlap, ) +def get_subclip_annotation( + clip_annotation: data.ClipAnnotation, + random: bool = True, + duration: float = 0.5, + max_empty: float = 0.2, + min_sound_event_overlap: float = 0, +) -> data.ClipAnnotation: + clip = clip_annotation.clip + + subclip = select_subclip( + clip, + random=random, + duration=duration, + max_empty=max_empty, + ) + + sound_events = select_sound_event_annotations( + clip_annotation, + subclip, + min_overlap=min_sound_event_overlap, + ) + + return clip_annotation.model_copy( + update=dict( + clip=subclip, + sound_events=sound_events, + ) + ) + + +def select_subclip( + clip: data.Clip, + random: bool = True, + duration: float = 0.5, + max_empty: float = 0.2, +) -> data.Clip: + start_time = clip.start_time + end_time = clip.end_time + + if duration > clip.duration + max_empty or not random: + return clip.model_copy( + update=dict( + start_time=start_time, + end_time=start_time + duration, + ) + ) + + random_start_time = np.random.uniform( + low=start_time, + high=end_time + max_empty - duration, + ) + + return clip.model_copy( + update=dict( + start_time=random_start_time, + end_time=random_start_time + duration, + ) + ) + + +def select_sound_event_annotations( + clip_annotation: data.ClipAnnotation, + subclip: data.Clip, + min_overlap: float = 0, +) -> List[data.SoundEventAnnotation]: + selected = [] + + start_time = subclip.start_time + end_time = subclip.end_time + + for sound_event_annotation in clip_annotation.sound_events: + geometry = sound_event_annotation.sound_event.geometry + + if geometry is None: + continue + + geom_start_time, _, geom_end_time, _ = compute_bounds(geometry) + + if not intervals_overlap( + (start_time, end_time), + (geom_start_time, geom_end_time), + min_absolute_overlap=min_overlap, + ): + continue + + selected.append(sound_event_annotation) + + return selected + + def build_clipper( - preprocessor: PreprocessorProtocol, config: Optional[ClipingConfig] = None, random: Optional[bool] = None, ) -> ClipperProtocol: @@ -71,73 +146,7 @@ def build_clipper( lambda: config.to_yaml_string(), ) return Clipper( - preprocessor=preprocessor, duration=config.duration, max_empty=config.max_empty, random=config.random if random else False, ) - - -def select_subclip( - example: PreprocessedExample, - start: float, - duration: float, - input_samplerate: float, - output_samplerate: float, - fill_value: float = 0, -) -> PreprocessedExample: - audio_width = int(np.floor(duration * input_samplerate)) - audio_start = int(np.floor(start * input_samplerate)) - - audio = adjust_width( - slice_tensor( - example.audio, - start=audio_start, - end=audio_start + audio_width, - dim=-1, - ), - audio_width, - value=fill_value, - ) - - spec_start = int(np.floor(start * output_samplerate)) - spec_width = int(np.floor(duration * output_samplerate)) - return PreprocessedExample( - audio=audio, - spectrogram=adjust_width( - slice_tensor( - example.spectrogram, - start=spec_start, - end=spec_start + spec_width, - dim=-1, - ), - spec_width, - ), - class_heatmap=adjust_width( - slice_tensor( - example.class_heatmap, - start=spec_start, - end=spec_start + spec_width, - dim=-1, - ), - spec_width, - ), - detection_heatmap=adjust_width( - slice_tensor( - example.detection_heatmap, - start=spec_start, - end=spec_start + spec_width, - dim=-1, - ), - spec_width, - ), - size_heatmap=adjust_width( - slice_tensor( - example.size_heatmap, - start=spec_start, - end=spec_start + spec_width, - dim=-1, - ), - spec_width, - ), - ) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 1e82cc2..acffb82 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -6,11 +6,13 @@ from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.evaluate import EvaluationConfig from batdetect2.models import ModelConfig +from batdetect2.targets import TargetConfig from batdetect2.train.augmentations import ( DEFAULT_AUGMENTATION_CONFIG, AugmentationsConfig, ) from batdetect2.train.clips import ClipingConfig +from batdetect2.train.labels import LabelConfig from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.losses import LossConfig @@ -50,7 +52,7 @@ class DataLoaderConfig(BaseConfig): DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True) -DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False) +DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False) class LoadersConfig(BaseConfig): @@ -73,6 +75,8 @@ class TrainingConfig(BaseConfig): cliping: ClipingConfig = Field(default_factory=ClipingConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) + targets: TargetConfig = Field(default_factory=TargetConfig) + labels: LabelConfig = Field(default_factory=LabelConfig) def load_train_config( diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index da2aa55..f7f62f5 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -1,78 +1,77 @@ -from typing import Optional, Sequence, Tuple +from typing import Optional, Sequence -import numpy as np import torch from soundevent import data from torch.utils.data import Dataset -from batdetect2.train.augmentations import Augmentation -from batdetect2.train.preprocess import ( - list_preprocessed_files, - load_preprocessed_example, -) from batdetect2.typing import ClipperProtocol, TrainExample -from batdetect2.typing.train import PreprocessedExample +from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol +from batdetect2.typing.train import Augmentation, ClipLabeller __all__ = [ - "LabeledDataset", + "TrainingDataset", ] -class LabeledDataset(Dataset): +class TrainingDataset(Dataset): def __init__( self, - filenames: Sequence[data.PathLike], - clipper: ClipperProtocol, - augmentation: Optional[Augmentation] = None, + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + labeller: ClipLabeller, + clipper: Optional[ClipperProtocol] = None, + audio_augmentation: Optional[Augmentation] = None, + spectrogram_augmentation: Optional[Augmentation] = None, + audio_dir: Optional[data.PathLike] = None, ): - self.filenames = filenames + self.clip_annotations = clip_annotations self.clipper = clipper - self.augmentation = augmentation + self.labeller = labeller + self.preprocessor = preprocessor + self.audio_loader = audio_loader + self.audio_augmentation = audio_augmentation + self.spectrogram_augmentation = spectrogram_augmentation + self.audio_dir = audio_dir def __len__(self): - return len(self.filenames) + return len(self.clip_annotations) def __getitem__(self, idx) -> TrainExample: - example = self.get_example(idx) + clip_annotation = self.clip_annotations[idx] - example, start_time, end_time = self.clipper(example) + if self.clipper is not None: + clip_annotation = self.clipper(clip_annotation) - if self.augmentation: - example = self.augmentation(example) + clip = clip_annotation.clip + + wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) + + # Add channel dim + wav_tensor = torch.tensor(wav).unsqueeze(0) + + if self.audio_augmentation is not None: + wav_tensor, clip_annotation = self.audio_augmentation( + wav_tensor, + clip_annotation, + ) + + spectrogram = self.preprocessor(wav_tensor) + + if self.spectrogram_augmentation is not None: + spectrogram, clip_annotation = self.spectrogram_augmentation( + spectrogram, + clip_annotation, + ) + + heatmaps = self.labeller(clip_annotation, spectrogram) return TrainExample( - spec=example.spectrogram, - detection_heatmap=example.detection_heatmap, - class_heatmap=example.class_heatmap, - size_heatmap=example.size_heatmap, + spec=spectrogram, + detection_heatmap=heatmaps.detection, + class_heatmap=heatmaps.classes, + size_heatmap=heatmaps.size, idx=torch.tensor(idx), - start_time=torch.tensor(start_time), - end_time=torch.tensor(end_time), + start_time=torch.tensor(clip.start_time), + end_time=torch.tensor(clip.end_time), ) - - @classmethod - def from_directory( - cls, - directory: data.PathLike, - clipper: ClipperProtocol, - extension: str = ".npz", - augmentation: Optional[Augmentation] = None, - ): - return cls( - filenames=list_preprocessed_files(directory, extension), - clipper=clipper, - augmentation=augmentation, - ) - - def get_random_example(self) -> Tuple[PreprocessedExample, float, float]: - idx = np.random.randint(0, len(self)) - dataset = self.get_example(idx) - dataset, start_time, end_time = self.clipper(dataset) - return dataset, start_time, end_time - - def get_example(self, idx) -> PreprocessedExample: - return load_preprocessed_example(self.filenames[idx]) - - def get_clip_annotation(self, idx) -> data.ClipAnnotation: - item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+") - return item["clip_annotation"].tolist() diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index a24b187..32d43c4 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -15,16 +15,18 @@ from batdetect2.evaluate.metrics import ( DetectionAveragePrecision, ) from batdetect2.models import Model, build_model +from batdetect2.plotting.clips import AudioLoader, build_audio_loader from batdetect2.train.augmentations import ( - RandomExampleSource, + RandomAudioSource, build_augmentations, ) from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.dataset import ( - LabeledDataset, + TrainingDataset, ) +from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger from batdetect2.train.losses import build_loss @@ -33,6 +35,7 @@ from batdetect2.typing import ( TargetProtocol, TrainExample, ) +from batdetect2.typing.train import ClipLabeller from batdetect2.utils.arrays import adjust_width __all__ = [ @@ -46,8 +49,8 @@ __all__ = [ def train( - train_examples: Sequence[data.PathLike], - val_examples: Optional[Sequence[data.PathLike]] = None, + train_annotations: Sequence[data.ClipAnnotation], + val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, config: Optional[FullTrainingConfig] = None, model_path: Optional[data.PathLike] = None, train_workers: Optional[int] = None, @@ -59,8 +62,19 @@ def train( trainer = build_trainer(config, targets=model.targets) + audio_loader = build_audio_loader(config=config.preprocess.audio) + + labeller = build_clip_labeler( + model.targets, + min_freq=model.preprocessor.min_freq, + max_freq=model.preprocessor.max_freq, + config=config.train.labels, + ) + train_dataloader = build_train_loader( - train_examples, + train_annotations, + audio_loader=audio_loader, + labeller=labeller, preprocessor=model.preprocessor, config=config.train, num_workers=train_workers, @@ -68,12 +82,14 @@ def train( val_dataloader = ( build_val_loader( - val_examples, + val_annotations, + audio_loader=audio_loader, + labeller=labeller, preprocessor=model.preprocessor, config=config.train, num_workers=val_workers, ) - if val_examples is not None + if val_annotations is not None else None ) @@ -153,19 +169,23 @@ def build_trainer( def build_train_loader( - train_examples: Sequence[data.PathLike], + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, + labeller: ClipLabeller, preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, num_workers: Optional[int] = None, ) -> DataLoader: config = config or TrainingConfig() - - logger.info("Building training data loader...") train_dataset = build_train_dataset( - train_examples, + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, preprocessor=preprocessor, config=config, ) + + logger.info("Building training data loader...") loader_conf = config.dataloaders.train logger.opt(lazy=True).debug( "Training data loader config: \n{config}", @@ -182,16 +202,20 @@ def build_train_loader( def build_val_loader( - val_examples: Sequence[data.PathLike], + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, + labeller: ClipLabeller, preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, num_workers: Optional[int] = None, ): + logger.info("Building validation data loader...") config = config or TrainingConfig() - logger.info("Building validation data loader...") val_dataset = build_val_dataset( - val_examples, + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, preprocessor=preprocessor, config=config, ) @@ -203,7 +227,7 @@ def build_val_loader( num_workers = num_workers or loader_conf.num_workers return DataLoader( val_dataset, - batch_size=loader_conf.batch_size, + batch_size=1, shuffle=loader_conf.shuffle, num_workers=num_workers, collate_fn=_collate_fn, @@ -232,52 +256,60 @@ def _collate_fn(batch: List[TrainExample]) -> TrainExample: def build_train_dataset( - examples: Sequence[data.PathLike], + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, + labeller: ClipLabeller, preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, -) -> LabeledDataset: +) -> TrainingDataset: logger.info("Building training dataset...") config = config or TrainingConfig() clipper = build_clipper( - preprocessor=preprocessor, config=config.cliping, random=True, ) - random_example_source = RandomExampleSource( - list(examples), - clipper=clipper, + random_example_source = RandomAudioSource( + clip_annotations, + audio_loader=audio_loader, ) - if config.augmentations.enabled and config.augmentations.steps: - augmentations = build_augmentations( - preprocessor, + if config.augmentations.enabled: + audio_augmentation, spectrogram_augmentation = build_augmentations( + samplerate=preprocessor.input_samplerate, config=config.augmentations, - example_source=random_example_source, + audio_source=random_example_source, ) else: logger.debug("No augmentations configured for training dataset.") - augmentations = None + audio_augmentation = None + spectrogram_augmentation = None - return LabeledDataset( - examples, + return TrainingDataset( + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, clipper=clipper, - augmentation=augmentations, + preprocessor=preprocessor, + audio_augmentation=audio_augmentation, + spectrogram_augmentation=spectrogram_augmentation, ) def build_val_dataset( - examples: Sequence[data.PathLike], + clip_annotations: Sequence[data.ClipAnnotation], + audio_loader: AudioLoader, + labeller: ClipLabeller, preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, - train: bool = True, -) -> LabeledDataset: +) -> TrainingDataset: logger.info("Building validation dataset...") config = config or TrainingConfig() - clipper = build_clipper( + + return TrainingDataset( + clip_annotations, + audio_loader=audio_loader, + labeller=labeller, preprocessor=preprocessor, - config=config.cliping, - random=train, ) - return LabeledDataset(examples, clipper=clipper) diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py index 646f5d0..7edd401 100644 --- a/src/batdetect2/typing/train.py +++ b/src/batdetect2/typing/train.py @@ -49,7 +49,11 @@ spectrogram, applies all configured filtering, transformation, and encoding steps, and returns the final `Heatmaps` used for model training. """ -Augmentation = Callable[[PreprocessedExample], PreprocessedExample] + +Augmentation = Callable[ + [torch.Tensor, data.ClipAnnotation], + Tuple[torch.Tensor, data.ClipAnnotation], +] class TrainExample(NamedTuple): @@ -97,5 +101,6 @@ class LossProtocol(Protocol): class ClipperProtocol(Protocol): def __call__( - self, example: PreprocessedExample - ) -> Tuple[PreprocessedExample, float, float]: ... + self, + clip_annotation: data.ClipAnnotation, + ) -> data.ClipAnnotation: ... diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index b5348bb..344df17 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -6,7 +6,7 @@ from soundevent import data from batdetect2.train.augmentations import ( add_echo, - mix_examples, + mix_audio, ) from batdetect2.train.clips import select_subclip from batdetect2.train.preprocess import generate_train_example @@ -41,7 +41,7 @@ def test_mix_examples( labeller=sample_labeller, ) - mixed = mix_examples( + mixed = mix_audio( example1, example2, weight=0.3, @@ -86,7 +86,7 @@ def test_mix_examples_of_different_durations( labeller=sample_labeller, ) - mixed = mix_examples( + mixed = mix_audio( example1, example2, weight=0.3,