From 34ef9e92a1b135914de5d16f42221747857b50d2 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 27 Aug 2025 23:58:38 +0100 Subject: [PATCH] Make sure preprocessing is batchable --- src/batdetect2/plotting/common.py | 21 +- src/batdetect2/postprocess/__init__.py | 241 ++++++++++++++--------- src/batdetect2/preprocess/spectrogram.py | 38 +++- src/batdetect2/train/__init__.py | 2 +- src/batdetect2/train/augmentations.py | 216 ++++++++++++-------- src/batdetect2/train/callbacks.py | 15 +- src/batdetect2/train/clips.py | 37 +++- src/batdetect2/train/config.py | 4 +- src/batdetect2/train/dataset.py | 45 +---- src/batdetect2/train/labels.py | 78 +++----- src/batdetect2/train/lightning.py | 4 +- src/batdetect2/train/preprocess.py | 31 ++- src/batdetect2/train/train.py | 6 +- src/batdetect2/typing/postprocess.py | 63 +----- src/batdetect2/typing/targets.py | 4 +- src/batdetect2/utils/arrays.py | 13 ++ tests/test_train/test_preprocessing.py | 1 - 17 files changed, 446 insertions(+), 373 deletions(-) diff --git a/src/batdetect2/plotting/common.py b/src/batdetect2/plotting/common.py index f9459a9..b0adf80 100644 --- a/src/batdetect2/plotting/common.py +++ b/src/batdetect2/plotting/common.py @@ -26,19 +26,32 @@ def create_ax( def plot_spectrogram( spec: Union[torch.Tensor, np.ndarray], - start_time: float, - end_time: float, - min_freq: float, - max_freq: float, + start_time: Optional[float] = None, + end_time: Optional[float] = None, + min_freq: Optional[float] = None, + max_freq: Optional[float] = None, ax: Optional[axes.Axes] = None, figsize: Optional[Tuple[int, int]] = None, cmap="gray", ) -> axes.Axes: + if isinstance(spec, torch.Tensor): spec = spec.numpy() ax = create_ax(ax=ax, figsize=figsize) + if start_time is None: + start_time = 0 + + if end_time is None: + end_time = spec.shape[-1] + + if min_freq is None: + min_freq = 0 + + if max_freq is None: + max_freq = spec.shape[-2] + ax.pcolormesh( np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True), np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True), diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index 07f0551..ba77c1a 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -2,6 +2,7 @@ from typing import List, Optional +import torch from loguru import logger from pydantic import Field from soundevent import data @@ -20,13 +21,15 @@ from batdetect2.postprocess.nms import ( ) from batdetect2.postprocess.remapping import map_detection_to_clip from batdetect2.preprocess import MAX_FREQ, MIN_FREQ -from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol +from batdetect2.typing import ModelOutput from batdetect2.typing.postprocess import ( BatDetect2Prediction, Detections, PostprocessorProtocol, RawPrediction, ) +from batdetect2.typing.preprocess import PreprocessorProtocol +from batdetect2.typing.targets import TargetProtocol __all__ = [ "DEFAULT_CLASSIFICATION_THRESHOLD", @@ -128,7 +131,6 @@ def load_postprocess_config( def build_postprocessor( - targets: TargetProtocol, preprocessor: PreprocessorProtocol, config: Optional[PostprocessConfig] = None, ) -> PostprocessorProtocol: @@ -139,29 +141,52 @@ def build_postprocessor( lambda: config.to_yaml_string(), ) return Postprocessor( - targets=targets, - preprocessor=preprocessor, - config=config, + samplerate=preprocessor.output_samplerate, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + top_k_per_sec=config.top_k_per_sec, + detection_threshold=config.detection_threshold, ) -class Postprocessor(PostprocessorProtocol): +class Postprocessor(torch.nn.Module, PostprocessorProtocol): """Standard implementation of the postprocessing pipeline.""" - targets: TargetProtocol - - preprocessor: PreprocessorProtocol - def __init__( self, - targets: TargetProtocol, - preprocessor: PreprocessorProtocol, - config: PostprocessConfig, + samplerate: float, + min_freq: float, + max_freq: float, + top_k_per_sec: int = 200, + detection_threshold: float = 0.01, ): """Initialize the Postprocessor.""" - self.targets = targets - self.preprocessor = preprocessor - self.config = config + super().__init__() + self.samplerate = samplerate + self.min_freq = min_freq + self.max_freq = max_freq + self.top_k_per_sec = top_k_per_sec + self.detection_threshold = detection_threshold + + def forward(self, output: ModelOutput) -> List[Detections]: + width = output.detection_probs.shape[-1] + duration = width / self.samplerate + max_detections = int(self.top_k_per_sec * duration) + detections = extract_prediction_tensor( + output, + max_detections=max_detections, + threshold=self.detection_threshold, + ) + return [ + map_detection_to_clip( + detection, + start_time=0, + end_time=duration, + min_freq=self.min_freq, + max_freq=self.max_freq, + ) + for detection in detections + ] def get_detections( self, @@ -169,13 +194,13 @@ class Postprocessor(PostprocessorProtocol): clips: Optional[List[data.Clip]] = None, ) -> List[Detections]: width = output.detection_probs.shape[-1] - duration = width / self.preprocessor.output_samplerate - max_detections = int(self.config.top_k_per_sec * duration) + duration = width / self.samplerate + max_detections = int(self.top_k_per_sec * duration) detections = extract_prediction_tensor( output, max_detections=max_detections, - threshold=self.config.detection_threshold, + threshold=self.detection_threshold, ) if clips is None: @@ -186,96 +211,116 @@ class Postprocessor(PostprocessorProtocol): detection, start_time=clip.start_time, end_time=clip.end_time, - min_freq=self.preprocessor.min_freq, - max_freq=self.preprocessor.max_freq, + min_freq=self.min_freq, + max_freq=self.max_freq, ) for detection, clip in zip(detections, clips) ] - def get_raw_predictions( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[List[RawPrediction]]: - """Extract intermediate RawPrediction objects for a batch. - Processes raw model output through remapping, NMS, detection, data - extraction, and geometry recovery via the configured - `targets.recover_roi`. +def get_raw_predictions( + output: ModelOutput, + clips: List[data.Clip], + targets: TargetProtocol, + postprocessor: PostprocessorProtocol, +) -> List[List[RawPrediction]]: + """Extract intermediate RawPrediction objects for a batch. - Parameters - ---------- - output : ModelOutput - Raw output from the neural network model for a batch. - clips : List[data.Clip] - List of `soundevent.data.Clip` objects corresponding to the batch. + Processes raw model output through remapping, NMS, detection, data + extraction, and geometry recovery via the configured + `targets.recover_roi`. - Returns - ------- - List[List[RawPrediction]] - List of lists (one inner list per input clip). Each inner list - contains `RawPrediction` objects for detections in that clip. - """ - detections = self.get_detections(output, clips) - return [ - convert_detections_to_raw_predictions( - dataset, - targets=self.targets, + Parameters + ---------- + output : ModelOutput + Raw output from the neural network model for a batch. + clips : List[data.Clip] + List of `soundevent.data.Clip` objects corresponding to the batch. + + Returns + ------- + List[List[RawPrediction]] + List of lists (one inner list per input clip). Each inner list + contains `RawPrediction` objects for detections in that clip. + """ + detections = postprocessor.get_detections(output, clips) + return [ + convert_detections_to_raw_predictions( + dataset, + targets=targets, + ) + for dataset in detections + ] + + +def get_sound_event_predictions( + output: ModelOutput, + clips: List[data.Clip], + targets: TargetProtocol, + postprocessor: PostprocessorProtocol, + classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, +) -> List[List[BatDetect2Prediction]]: + raw_predictions = get_raw_predictions( + output, + clips, + targets=targets, + postprocessor=postprocessor, + ) + return [ + [ + BatDetect2Prediction( + raw=raw, + sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( + raw, + recording=clip.recording, + targets=targets, + classification_threshold=classification_threshold, + ), ) - for dataset in detections + for raw in predictions ] + for predictions, clip in zip(raw_predictions, clips) + ] - def get_sound_event_predictions( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[List[BatDetect2Prediction]]: - raw_predictions = self.get_raw_predictions(output, clips) - return [ - [ - BatDetect2Prediction( - raw=raw, - sound_event_prediction=convert_raw_prediction_to_sound_event_prediction( - raw, - recording=clip.recording, - targets=self.targets, - classification_threshold=self.config.classification_threshold, - ), - ) - for raw in predictions - ] - for predictions, clip in zip(raw_predictions, clips) - ] - def get_predictions( - self, output: ModelOutput, clips: List[data.Clip] - ) -> List[data.ClipPrediction]: - """Perform the full postprocessing pipeline for a batch. +def get_predictions( + output: ModelOutput, + clips: List[data.Clip], + targets: TargetProtocol, + postprocessor: PostprocessorProtocol, + classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, +) -> List[data.ClipPrediction]: + """Perform the full postprocessing pipeline for a batch. - Takes raw model output and corresponding clips, applies the entire - configured chain (NMS, remapping, extraction, geometry recovery, class - decoding), producing final `soundevent.data.ClipPrediction` objects. + Takes raw model output and corresponding clips, applies the entire + configured chain (NMS, remapping, extraction, geometry recovery, class + decoding), producing final `soundevent.data.ClipPrediction` objects. - Parameters - ---------- - output : ModelOutput - Raw output from the neural network model for a batch. - clips : List[data.Clip] - List of `soundevent.data.Clip` objects corresponding to the batch. + Parameters + ---------- + output : ModelOutput + Raw output from the neural network model for a batch. + clips : List[data.Clip] + List of `soundevent.data.Clip` objects corresponding to the batch. - Returns - ------- - List[data.ClipPrediction] - List containing one `ClipPrediction` object for each input clip, - populated with `SoundEventPrediction` objects. - """ - raw_predictions = self.get_raw_predictions(output, clips) - return [ - convert_raw_predictions_to_clip_prediction( - prediction, - clip, - targets=self.targets, - classification_threshold=self.config.classification_threshold, - ) - for prediction, clip in zip(raw_predictions, clips) - ] + Returns + ------- + List[data.ClipPrediction] + List containing one `ClipPrediction` object for each input clip, + populated with `SoundEventPrediction` objects. + """ + raw_predictions = get_raw_predictions( + output, + clips, + targets=targets, + postprocessor=postprocessor, + ) + return [ + convert_raw_predictions_to_clip_prediction( + prediction, + clip, + targets=targets, + classification_threshold=classification_threshold, + ) + for prediction, clip in zip(raw_predictions, clips) + ] diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index dde3e13..b79ef81 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -139,7 +139,21 @@ class FrequencyClip(torch.nn.Module): self.high_index = high_index def forward(self, spec: torch.Tensor) -> torch.Tensor: - return spec[self.low_index : self.high_index] + low_index = self.low_index + if low_index is None: + low_index = 0 + + if self.high_index is None: + length = spec.shape[-2] - low_index + else: + length = self.high_index - low_index + + return torch.narrow( + spec, + dim=-2, + start=low_index, + length=length, + ) class PcenConfig(BaseConfig): @@ -256,16 +270,22 @@ class ResizeSpec(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: current_length = spec.shape[-1] target_length = int(self.time_factor * current_length) - return ( - torch.nn.functional.interpolate( - spec.unsqueeze(0).unsqueeze(0), - size=(self.height, target_length), - mode="bilinear", - ) - .squeeze(0) - .squeeze(0) + + original_ndim = spec.ndim + while spec.ndim < 4: + spec = spec.unsqueeze(0) + + resized = torch.nn.functional.interpolate( + spec, + size=(self.height, target_length), + mode="bilinear", ) + while resized.ndim != original_ndim: + resized = resized.squeeze(0) + + return resized + class PeakNormalizeConfig(BaseConfig): name: Literal["peak_normalize"] = "peak_normalize" diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index fd161e0..f2658e7 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -2,6 +2,7 @@ from batdetect2.train.augmentations import ( AugmentationsConfig, EchoAugmentationConfig, FrequencyMaskAugmentationConfig, + RandomExampleSource, TimeMaskAugmentationConfig, VolumeAugmentationConfig, WarpAugmentationConfig, @@ -23,7 +24,6 @@ from batdetect2.train.config import ( ) from batdetect2.train.dataset import ( LabeledDataset, - RandomExampleSource, list_preprocessed_files, ) from batdetect2.train.labels import build_clip_labeler, load_label_config diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 84499c0..b350e07 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -1,6 +1,7 @@ """Applies data augmentation techniques to BatDetect2 training examples.""" import warnings +from collections.abc import Sequence from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union import numpy as np @@ -10,8 +11,12 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config +from batdetect2.train.preprocess import ( + list_preprocessed_files, + load_preprocessed_example, +) from batdetect2.typing import Augmentation, PreprocessorProtocol -from batdetect2.typing.train import PreprocessedExample +from batdetect2.typing.train import ClipperProtocol, PreprocessedExample from batdetect2.utils.arrays import adjust_width __all__ = [ @@ -39,21 +44,6 @@ ExampleSource = Callable[[], PreprocessedExample] """Type alias for a function that returns a training example""" -class MixAugmentationConfig(BaseConfig): - """Configuration for MixUp augmentation (mixing two examples).""" - - augmentation_type: Literal["mix_audio"] = "mix_audio" - - probability: float = 0.2 - """Probability of applying this augmentation to an example.""" - - min_weight: float = 0.3 - """Minimum mixing weight (lambda) applied to the primary example.""" - - max_weight: float = 0.7 - """Maximum mixing weight (lambda) applied to the primary example.""" - - def mix_examples( example: PreprocessedExample, other: PreprocessedExample, @@ -149,7 +139,12 @@ def add_echo( audio = example.audio delay_steps = int(preprocessor.input_samplerate * delay) - audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1]) + + slices = [slice(None)] * audio.ndim + slices[-1] = slice(None, -delay_steps) + audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll( + delay_steps, dims=-1 + ) audio = audio + weight * audio_delay spectrogram = preprocessor(audio) @@ -184,7 +179,7 @@ class VolumeAugmentationConfig(BaseConfig): class ScaleVolume(torch.nn.Module): - def __init__(self, min_scaling: float, max_scaling: float): + def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0): super().__init__() self.min_scaling = min_scaling self.max_scaling = max_scaling @@ -228,32 +223,22 @@ def warp_spectrogram( example: PreprocessedExample, factor: float ) -> PreprocessedExample: """Apply time warping by resampling the time axis.""" - target_shape = example.spectrogram.shape + width = example.spectrogram.shape[-1] + height = example.spectrogram.shape[-2] + target_shape = [height, width] new_width = int(target_shape[-1] * factor) - spectrogram = ( - torch.nn.functional.interpolate( - adjust_width(example.spectrogram, new_width) - .unsqueeze(0) - .unsqueeze(0), - size=target_shape, - mode="bilinear", - ) - .squeeze(0) - .squeeze(0) - ) + spectrogram = torch.nn.functional.interpolate( + adjust_width(example.spectrogram, new_width).unsqueeze(0), + size=target_shape, + mode="bilinear", + ).squeeze(0) - detection = ( - torch.nn.functional.interpolate( - adjust_width(example.detection_heatmap, new_width) - .unsqueeze(0) - .unsqueeze(0), - size=target_shape, - mode="nearest", - ) - .squeeze(0) - .squeeze(0) - ) + detection = torch.nn.functional.interpolate( + adjust_width(example.detection_heatmap, new_width).unsqueeze(0), + size=target_shape, + mode="nearest", + ).squeeze(0) classification = torch.nn.functional.interpolate( adjust_width(example.class_heatmap, new_width).unsqueeze(1), @@ -284,10 +269,16 @@ class TimeMaskAugmentationConfig(BaseConfig): class MaskTime(torch.nn.Module): - def __init__(self, max_perc: float = 0.05, max_masks: int = 3) -> None: + def __init__( + self, + max_perc: float = 0.05, + max_masks: int = 3, + mask_heatmaps: bool = False, + ) -> None: super().__init__() self.max_perc = max_perc self.max_masks = max_masks + self.mask_heatmaps = mask_heatmaps def forward(self, example: PreprocessedExample) -> PreprocessedExample: num_masks = np.random.randint(1, self.max_masks + 1) @@ -306,20 +297,28 @@ class MaskTime(torch.nn.Module): masks = [ (start, start + size) for start, size in zip(mask_start, mask_size) ] - return mask_time(example, masks) + return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps) def mask_time( example: PreprocessedExample, masks: List[Tuple[int, int]], + mask_heatmaps: bool = False, ) -> PreprocessedExample: """Apply time masking to the spectrogram.""" for start, end in masks: - example.spectrogram[:, start:end] = example.spectrogram.mean() - example.class_heatmap[:, :, start:end] = 0 - example.size_heatmap[:, :, start:end] = 0 - example.detection_heatmap[:, start:end] = 0 + slices = [slice(None)] * example.spectrogram.ndim + slices[-1] = slice(start, end) + + example.spectrogram[tuple(slices)] = 0 + + if not mask_heatmaps: + continue + + example.class_heatmap[tuple(slices)] = 0 + example.size_heatmap[tuple(slices)] = 0 + example.detection_heatmap[tuple(slices)] = 0 return PreprocessedExample( audio=example.audio, @@ -335,13 +334,20 @@ class FrequencyMaskAugmentationConfig(BaseConfig): probability: float = 0.2 max_perc: float = 0.10 max_masks: int = 3 + mask_heatmaps: bool = False class MaskFrequency(torch.nn.Module): - def __init__(self, max_perc: float = 0.10, max_masks: int = 3) -> None: + def __init__( + self, + max_perc: float = 0.10, + max_masks: int = 3, + mask_heatmaps: bool = False, + ) -> None: super().__init__() self.max_perc = max_perc self.max_masks = max_masks + self.mask_heatmaps = mask_heatmaps def forward(self, example: PreprocessedExample) -> PreprocessedExample: num_masks = np.random.randint(1, self.max_masks + 1) @@ -360,19 +366,26 @@ class MaskFrequency(torch.nn.Module): masks = [ (start, start + size) for start, size in zip(mask_start, mask_size) ] - return mask_frequency(example, masks) + return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps) def mask_frequency( example: PreprocessedExample, masks: List[Tuple[int, int]], + mask_heatmaps: bool = False, ) -> PreprocessedExample: """Apply frequency masking to the spectrogram.""" for start, end in masks: - example.spectrogram[start:end, :] = example.spectrogram.mean() - example.class_heatmap[:, start:end, :] = 0 - example.size_heatmap[:, start:end, :] = 0 - example.detection_heatmap[start:end, :] = 0 + slices = [slice(None)] * example.spectrogram.ndim + slices[-2] = slice(start, end) + example.spectrogram[tuple(slices)] = 0 + + if not mask_heatmaps: + continue + + example.class_heatmap[tuple(slices)] = 0 + example.size_heatmap[tuple(slices)] = 0 + example.detection_heatmap[tuple(slices)] = 0 return PreprocessedExample( audio=example.audio, @@ -383,6 +396,50 @@ def mask_frequency( ) +class MixAugmentationConfig(BaseConfig): + """Configuration for MixUp augmentation (mixing two examples).""" + + augmentation_type: Literal["mix_audio"] = "mix_audio" + + probability: float = 0.2 + """Probability of applying this augmentation to an example.""" + + min_weight: float = 0.3 + """Minimum mixing weight (lambda) applied to the primary example.""" + + max_weight: float = 0.7 + """Maximum mixing weight (lambda) applied to the primary example.""" + + +class MixAudio(torch.nn.Module): + """Callable class for MixUp augmentation, handling example fetching.""" + + def __init__( + self, + example_source: ExampleSource, + preprocessor: PreprocessorProtocol, + min_weight: float = 0.3, + max_weight: float = 0.7, + ): + """Initialize the AudioMixer.""" + super().__init__() + self.min_weight = min_weight + self.example_source = example_source + self.max_weight = max_weight + self.preprocessor = preprocessor + + def __call__(self, example: PreprocessedExample) -> PreprocessedExample: + """Fetch another example and perform mixup.""" + other = self.example_source() + weight = np.random.uniform(self.min_weight, self.max_weight) + return mix_examples( + example, + other, + self.preprocessor, + weight=weight, + ) + + AugmentationConfig = Annotated[ Union[ MixAugmentationConfig, @@ -445,35 +502,6 @@ class MaybeApply(torch.nn.Module): return self.augmentation(example) -class AudioMixer(torch.nn.Module): - """Callable class for MixUp augmentation, handling example fetching.""" - - def __init__( - self, - min_weight: float, - max_weight: float, - example_source: ExampleSource, - preprocessor: PreprocessorProtocol, - ): - """Initialize the AudioMixer.""" - super().__init__() - self.min_weight = min_weight - self.example_source = example_source - self.max_weight = max_weight - self.preprocessor = preprocessor - - def __call__(self, example: PreprocessedExample) -> PreprocessedExample: - """Fetch another example and perform mixup.""" - other = self.example_source() - weight = np.random.uniform(self.min_weight, self.max_weight) - return mix_examples( - example, - other, - self.preprocessor, - weight=weight, - ) - - def build_augmentation_from_config( config: AugmentationConfig, preprocessor: PreprocessorProtocol, @@ -489,7 +517,7 @@ def build_augmentation_from_config( ) return None - return AudioMixer( + return MixAudio( example_source=example_source, preprocessor=preprocessor, min_weight=config.min_weight, @@ -585,3 +613,25 @@ def load_augmentation_config( ) -> AugmentationsConfig: """Load the augmentations configuration from a file.""" return load_config(path, schema=AugmentationsConfig, field=field) + + +class RandomExampleSource: + def __init__( + self, + filenames: Sequence[data.PathLike], + clipper: ClipperProtocol, + ): + self.filenames = filenames + self.clipper = clipper + + def __call__(self) -> PreprocessedExample: + index = int(np.random.randint(len(self.filenames))) + filename = self.filenames[index] + example = load_preprocessed_example(filename) + example, _, _ = self.clipper(example) + return example + + @classmethod + def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol): + filenames = list_preprocessed_files(path) + return cls(filenames, clipper=clipper) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index c7862c8..4184440 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -14,7 +14,9 @@ from batdetect2.evaluate.match import ( MatchConfig, match_sound_events_and_raw_predictions, ) +from batdetect2.models import Model from batdetect2.plotting.evaluation import plot_example_gallery +from batdetect2.postprocess import get_sound_event_predictions from batdetect2.train.dataset import LabeledDataset from batdetect2.train.lightning import TrainingModule from batdetect2.typing import ( @@ -22,7 +24,6 @@ from batdetect2.typing import ( MatchEvaluation, MetricsProtocol, ModelOutput, - PostprocessorProtocol, TargetProtocol, TrainExample, ) @@ -127,8 +128,7 @@ class ValidationMetrics(Callback): batch, outputs, dataset=self.get_dataset(trainer), - postprocessor=pl_module.model.postprocessor, - targets=pl_module.model.targets, + model=pl_module.model, ) ) @@ -137,15 +137,14 @@ def _get_batch_clips_and_predictions( batch: TrainExample, outputs: ModelOutput, dataset: LabeledDataset, - postprocessor: PostprocessorProtocol, - targets: TargetProtocol, + model: Model, ) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]: clip_annotations = [ _get_subclip( dataset.get_clip_annotation(example_id), start_time=start_time.item(), end_time=end_time.item(), - targets=targets, + targets=model.targets, ) for example_id, start_time, end_time in zip( batch.idx, @@ -156,9 +155,11 @@ def _get_batch_clips_and_predictions( clips = [clip_annotation.clip for clip_annotation in clip_annotations] - raw_predictions = postprocessor.get_sound_event_predictions( + raw_predictions = get_sound_event_predictions( outputs, clips, + targets=model.targets, + postprocessor=model.postprocessor ) return [ diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 67b0ad5..9ed9147 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -8,7 +8,7 @@ from batdetect2.configs import BaseConfig from batdetect2.typing import ClipperProtocol from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.train import PreprocessedExample -from batdetect2.utils.arrays import adjust_width +from batdetect2.utils.arrays import adjust_width, slice_tensor DEFAULT_TRAIN_CLIP_DURATION = 0.512 DEFAULT_MAX_EMPTY_CLIP = 0.1 @@ -90,7 +90,12 @@ def select_subclip( audio_start = int(np.floor(start * input_samplerate)) audio = adjust_width( - example.audio[audio_start : audio_start + audio_width], + slice_tensor( + example.audio, + start=audio_start, + end=audio_start + audio_width, + dim=-1, + ), audio_width, value=fill_value, ) @@ -100,19 +105,39 @@ def select_subclip( return PreprocessedExample( audio=audio, spectrogram=adjust_width( - example.spectrogram[:, spec_start : spec_start + spec_width], + slice_tensor( + example.spectrogram, + start=spec_start, + end=spec_start + spec_width, + dim=-1, + ), spec_width, ), class_heatmap=adjust_width( - example.class_heatmap[:, :, spec_start : spec_start + spec_width], + slice_tensor( + example.class_heatmap, + start=spec_start, + end=spec_start + spec_width, + dim=-1, + ), spec_width, ), detection_heatmap=adjust_width( - example.detection_heatmap[:, spec_start : spec_start + spec_width], + slice_tensor( + example.detection_heatmap, + start=spec_start, + end=spec_start + spec_width, + dim=-1, + ), spec_width, ), size_heatmap=adjust_width( - example.size_heatmap[:, :, spec_start : spec_start + spec_width], + slice_tensor( + example.size_heatmap, + start=spec_start, + end=spec_start + spec_width, + dim=-1, + ), spec_width, ), ) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index 5e77dc5..1e82cc2 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -44,8 +44,8 @@ class PLTrainerConfig(BaseConfig): class DataLoaderConfig(BaseConfig): - batch_size: int - shuffle: bool + batch_size: int = 8 + shuffle: bool = False num_workers: int = 0 diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index ae41bd0..da2aa55 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -1,5 +1,4 @@ -from pathlib import Path -from typing import List, Optional, Sequence, Tuple +from typing import Optional, Sequence, Tuple import numpy as np import torch @@ -7,6 +6,10 @@ from soundevent import data from torch.utils.data import Dataset from batdetect2.train.augmentations import Augmentation +from batdetect2.train.preprocess import ( + list_preprocessed_files, + load_preprocessed_example, +) from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing.train import PreprocessedExample @@ -38,8 +41,8 @@ class LabeledDataset(Dataset): example = self.augmentation(example) return TrainExample( - spec=example.spectrogram.unsqueeze(0), - detection_heatmap=example.detection_heatmap.unsqueeze(0), + spec=example.spectrogram, + detection_heatmap=example.detection_heatmap, class_heatmap=example.class_heatmap, size_heatmap=example.size_heatmap, idx=torch.tensor(idx), @@ -73,37 +76,3 @@ class LabeledDataset(Dataset): def get_clip_annotation(self, idx) -> data.ClipAnnotation: item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+") return item["clip_annotation"].tolist() - - -def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample: - item = np.load(path, mmap_mode="r+") - return PreprocessedExample( - audio=torch.tensor(item["audio"]), - spectrogram=torch.tensor(item["spectrogram"]), - size_heatmap=torch.tensor(item["size_heatmap"]), - detection_heatmap=torch.tensor(item["detection_heatmap"]), - class_heatmap=torch.tensor(item["class_heatmap"]), - ) - - -def list_preprocessed_files( - directory: data.PathLike, extension: str = ".npz" -) -> List[Path]: - return list(Path(directory).glob(f"*{extension}")) - - -class RandomExampleSource: - def __init__( - self, - filenames: List[data.PathLike], - clipper: ClipperProtocol, - ): - self.filenames = filenames - self.clipper = clipper - - def __call__(self) -> PreprocessedExample: - index = int(np.random.randint(len(self.filenames))) - filename = self.filenames[index] - example = load_preprocessed_example(filename) - example, _, _ = self.clipper(example) - return example diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 176ee81..dd42110 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -41,7 +41,6 @@ from batdetect2.typing import ( __all__ = [ "LabelConfig", "build_clip_labeler", - "generate_clip_label", "generate_heatmaps", "load_label_config", ] @@ -99,21 +98,26 @@ def build_clip_labeler( lambda: config.to_yaml_string(), ) return partial( - generate_clip_label, + generate_heatmaps, targets=targets, - config=config, min_freq=min_freq, max_freq=max_freq, + target_sigma=config.sigma, ) -def generate_clip_label( +def map_to_pixels(x, size, min_val, max_val) -> int: + return int(np.interp(x, [min_val, max_val], [0, size])) + + +def generate_heatmaps( clip_annotation: data.ClipAnnotation, spec: torch.Tensor, targets: TargetProtocol, - config: LabelConfig, min_freq: float, max_freq: float, + target_sigma: float = 3.0, + dtype=torch.float32, ) -> Heatmaps: """Generate training heatmaps for a single annotated clip. @@ -150,57 +154,14 @@ def generate_clip_label( num=len(clip_annotation.sound_events), ) - sound_events = [] - - for sound_event_annotation in clip_annotation.sound_events: - if not targets.filter(sound_event_annotation): - logger.debug( - "Sound event {sound_event} did not pass the filter. Tags: {tags}", - sound_event=sound_event_annotation, - tags=sound_event_annotation.tags, - ) - continue - - sound_events.append(targets.transform(sound_event_annotation)) - - return generate_heatmaps( - clip_annotation.model_copy(update=dict(sound_events=sound_events)), - spec=spec, - targets=targets, - target_sigma=config.sigma, - min_freq=min_freq, - max_freq=max_freq, - ) - - -def map_to_pixels(x, size, min_val, max_val) -> int: - return int(np.interp(x, [min_val, max_val], [0, size])) - - -def generate_heatmaps( - clip_annotation: data.ClipAnnotation, - spec: torch.Tensor, - targets: TargetProtocol, - min_freq: float, - max_freq: float, - target_sigma: float = 3.0, - dtype=torch.float32, -) -> Heatmaps: - if not spec.ndim == 2: - raise ValueError( - "Expecting a 2-dimensional tensor of shape (H, W), " - "H is the height of the spectrogram " - "(frequency bins), and W is the width of the spectrogram " - f"(temporal bins). Instead got: {spec.shape}" - ) - - height, width = spec.shape + height = spec.shape[-2] + width = spec.shape[-1] num_classes = len(targets.class_names) num_dims = len(targets.dimension_names) clip = clip_annotation.clip # Initialize heatmaps - detection_heatmap = torch.zeros([height, width], dtype=dtype) + detection_heatmap = torch.zeros([1, height, width], dtype=dtype) class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype) size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype) @@ -214,6 +175,16 @@ def generate_heatmaps( times = times.to(spec.device) for sound_event_annotation in clip_annotation.sound_events: + if not targets.filter(sound_event_annotation): + logger.debug( + "Sound event {sound_event} did not pass the filter. Tags: {tags}", + sound_event=sound_event_annotation, + tags=sound_event_annotation.tags, + ) + continue + + sound_event_annotation = targets.transform(sound_event_annotation) + geom = sound_event_annotation.sound_event.geometry if geom is None: logger.debug( @@ -245,7 +216,10 @@ def generate_heatmaps( distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2 gaussian_blob = torch.exp(-distance / (2 * target_sigma**2)) - detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob) + detection_heatmap[0] = torch.maximum( + detection_heatmap[0], + gaussian_blob, + ) size_heatmap[:, freq_index, time_index] = torch.tensor(size[:]) # Get the class name of the sound event diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 1e42651..6e33ed9 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -34,7 +34,7 @@ class TrainingModule(L.LightningModule): return self.model(spec) def training_step(self, batch: TrainExample): - outputs = self.model(batch.spec) + outputs = self.model.detector(batch.spec) losses = self.loss(outputs, batch) self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("detection_loss/train", losses.total, logger=True) @@ -47,7 +47,7 @@ class TrainingModule(L.LightningModule): batch: TrainExample, batch_idx: int, ) -> ModelOutput: - outputs = self.model(batch.spec) + outputs = self.model.detector(batch.spec) losses = self.loss(outputs, batch) self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("detection_loss/val", losses.total, logger=True) diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index 1a21a77..9d6a660 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -2,7 +2,7 @@ import os from pathlib import Path -from typing import Callable, Optional, Sequence, TypedDict +from typing import Callable, List, Optional, Sequence, TypedDict import numpy as np import torch @@ -28,6 +28,8 @@ __all__ = [ "preprocess_dataset", "TrainPreprocessConfig", "load_train_preprocessing_config", + "save_preprocessed_example", + "load_preprocessed_example", ] FilenameFn = Callable[[data.ClipAnnotation], str] @@ -94,8 +96,10 @@ def generate_train_example( labeller: ClipLabeller, ) -> PreprocessedExample: """Generate a complete training example for one annotation.""" - wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip)) - spectrogram = preprocessor(wave) + wave = torch.tensor( + audio_loader.load_clip(clip_annotation.clip) + ).unsqueeze(0) + spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0) heatmaps = labeller(clip_annotation, spectrogram) return PreprocessedExample( audio=wave, @@ -145,7 +149,7 @@ class PreprocessingDataset(torch.utils.data.Dataset): labeller=self.labeller, ) - save_example_to_file(example, clip_annotation, path) + save_preprocessed_example(example, clip_annotation, path) return idx @@ -153,7 +157,7 @@ class PreprocessingDataset(torch.utils.data.Dataset): return len(self.clips) -def save_example_to_file( +def save_preprocessed_example( example: PreprocessedExample, clip_annotation: data.ClipAnnotation, path: data.PathLike, @@ -169,6 +173,23 @@ def save_example_to_file( ) +def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample: + item = np.load(path, mmap_mode="r+") + return PreprocessedExample( + audio=torch.tensor(item["audio"]), + spectrogram=torch.tensor(item["spectrogram"]), + size_heatmap=torch.tensor(item["size_heatmap"]), + detection_heatmap=torch.tensor(item["detection_heatmap"]), + class_heatmap=torch.tensor(item["class_heatmap"]), + ) + + +def list_preprocessed_files( + directory: data.PathLike, extension: str = ".npz" +) -> List[Path]: + return list(Path(directory).glob(f"*{extension}")) + + def _get_filename(clip_annotation: data.ClipAnnotation) -> str: """Generate a default output filename based on the annotation UUID.""" return f"{clip_annotation.uuid}" diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 1cb899a..376c08c 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -15,13 +15,15 @@ from batdetect2.evaluate.metrics import ( DetectionAveragePrecision, ) from batdetect2.models import build_model -from batdetect2.train.augmentations import build_augmentations +from batdetect2.train.augmentations import ( + RandomExampleSource, + build_augmentations, +) from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.clips import build_clipper from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.dataset import ( LabeledDataset, - RandomExampleSource, ) from batdetect2.train.lightning import TrainingModule from batdetect2.train.logging import build_logger diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index dbf23da..e876c3d 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -95,69 +95,10 @@ class BatDetect2Prediction: class PostprocessorProtocol(Protocol): """Protocol defining the interface for the full postprocessing pipeline.""" + def __call__(self, output: ModelOutput) -> List[Detections]: ... + def get_detections( self, output: ModelOutput, clips: Optional[List[data.Clip]] = None, ) -> List[Detections]: ... - - def get_raw_predictions( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[List[RawPrediction]]: - """Extract intermediate RawPrediction objects for a batch. - - Processes the raw model output for a batch through remapping, NMS, - detection, data extraction, and geometry recovery to produce a list of - `RawPrediction` objects for each corresponding input clip. This provides - a simplified, intermediate representation before final tag decoding. - - Parameters - ---------- - output : ModelOutput - The raw output from the neural network model for a batch. - clips : List[data.Clip] - A list of `soundevent.data.Clip` objects corresponding to the batch - items, providing context. Must match the batch size of `output`. - - Returns - ------- - List[List[RawPrediction]] - A list of lists (one inner list per input clip, in order). Each - inner list contains the `RawPrediction` objects extracted for the - corresponding input clip. - """ - ... - - def get_sound_event_predictions( - self, output: ModelOutput, clips: List[data.Clip] - ) -> List[List[BatDetect2Prediction]]: ... - - def get_predictions( - self, - output: ModelOutput, - clips: List[data.Clip], - ) -> List[data.ClipPrediction]: - """Perform the full postprocessing pipeline for a batch. - - Takes raw model output for a batch and corresponding clips, applies the - entire postprocessing chain, and returns the final, interpretable - predictions as a list of `soundevent.data.ClipPrediction` objects. - - Parameters - ---------- - output : ModelOutput - The raw output from the neural network model for a batch. - clips : List[data.Clip] - A list of `soundevent.data.Clip` objects corresponding to the batch - items, providing context. Must match the batch size of `output`. - - Returns - ------- - List[data.ClipPrediction] - A list containing one `ClipPrediction` object for each input clip - (in the same order), populated with `SoundEventPrediction` objects - representing the final detections with decoded tags and geometry. - """ - ... diff --git a/src/batdetect2/typing/targets.py b/src/batdetect2/typing/targets.py index 2846a0e..db74baf 100644 --- a/src/batdetect2/typing/targets.py +++ b/src/batdetect2/typing/targets.py @@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently throughout BatDetect2. """ -from collections.abc import Callable -from typing import List, Optional, Protocol +from collections.abc import Callable, Iterable +from typing import List, Optional, Protocol, Tuple import numpy as np from soundevent import data diff --git a/src/batdetect2/utils/arrays.py b/src/batdetect2/utils/arrays.py index 7a46dd7..c3204c5 100644 --- a/src/batdetect2/utils/arrays.py +++ b/src/batdetect2/utils/arrays.py @@ -1,3 +1,5 @@ +from typing import Optional + import numpy as np import torch import xarray as xr @@ -80,3 +82,14 @@ def adjust_width( for index in range(dims) ] return tensor[tuple(slices)] + + +def slice_tensor( + tensor: torch.Tensor, + start: Optional[int] = None, + end: Optional[int] = None, + dim: int = -1, +) -> torch.Tensor: + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(start, end) + return tensor[tuple(slices)] diff --git a/tests/test_train/test_preprocessing.py b/tests/test_train/test_preprocessing.py index fc30501..0660705 100644 --- a/tests/test_train/test_preprocessing.py +++ b/tests/test_train/test_preprocessing.py @@ -38,7 +38,6 @@ def build_from_config( max_freq=preprocessor.max_freq, ) postprocessor = build_postprocessor( - targets, preprocessor=preprocessor, config=postprocessing_config, )