Make sure preprocessing is batchable

This commit is contained in:
mbsantiago 2025-08-27 23:58:38 +01:00
parent 0b5ac96fe8
commit 34ef9e92a1
17 changed files with 446 additions and 373 deletions

View File

@ -26,19 +26,32 @@ def create_ax(
def plot_spectrogram( def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray], spec: Union[torch.Tensor, np.ndarray],
start_time: float, start_time: Optional[float] = None,
end_time: float, end_time: Optional[float] = None,
min_freq: float, min_freq: Optional[float] = None,
max_freq: float, max_freq: Optional[float] = None,
ax: Optional[axes.Axes] = None, ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
cmap="gray", cmap="gray",
) -> axes.Axes: ) -> axes.Axes:
if isinstance(spec, torch.Tensor): if isinstance(spec, torch.Tensor):
spec = spec.numpy() spec = spec.numpy()
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
if start_time is None:
start_time = 0
if end_time is None:
end_time = spec.shape[-1]
if min_freq is None:
min_freq = 0
if max_freq is None:
max_freq = spec.shape[-2]
ax.pcolormesh( ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True), np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True), np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),

View File

@ -2,6 +2,7 @@
from typing import List, Optional from typing import List, Optional
import torch
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -20,13 +21,15 @@ from batdetect2.postprocess.nms import (
) )
from batdetect2.postprocess.remapping import map_detection_to_clip from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
BatDetect2Prediction, BatDetect2Prediction,
Detections, Detections,
PostprocessorProtocol, PostprocessorProtocol,
RawPrediction, RawPrediction,
) )
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
@ -128,7 +131,6 @@ def load_postprocess_config(
def build_postprocessor( def build_postprocessor(
targets: TargetProtocol,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None, config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol: ) -> PostprocessorProtocol:
@ -139,29 +141,52 @@ def build_postprocessor(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return Postprocessor( return Postprocessor(
targets=targets, samplerate=preprocessor.output_samplerate,
preprocessor=preprocessor, min_freq=preprocessor.min_freq,
config=config, max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
) )
class Postprocessor(PostprocessorProtocol): class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline.""" """Standard implementation of the postprocessing pipeline."""
targets: TargetProtocol
preprocessor: PreprocessorProtocol
def __init__( def __init__(
self, self,
targets: TargetProtocol, samplerate: float,
preprocessor: PreprocessorProtocol, min_freq: float,
config: PostprocessConfig, max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
): ):
"""Initialize the Postprocessor.""" """Initialize the Postprocessor."""
self.targets = targets super().__init__()
self.preprocessor = preprocessor self.samplerate = samplerate
self.config = config self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[Detections]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]
def get_detections( def get_detections(
self, self,
@ -169,13 +194,13 @@ class Postprocessor(PostprocessorProtocol):
clips: Optional[List[data.Clip]] = None, clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ) -> List[Detections]:
width = output.detection_probs.shape[-1] width = output.detection_probs.shape[-1]
duration = width / self.preprocessor.output_samplerate duration = width / self.samplerate
max_detections = int(self.config.top_k_per_sec * duration) max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor( detections = extract_prediction_tensor(
output, output,
max_detections=max_detections, max_detections=max_detections,
threshold=self.config.detection_threshold, threshold=self.detection_threshold,
) )
if clips is None: if clips is None:
@ -186,96 +211,116 @@ class Postprocessor(PostprocessorProtocol):
detection, detection,
start_time=clip.start_time, start_time=clip.start_time,
end_time=clip.end_time, end_time=clip.end_time,
min_freq=self.preprocessor.min_freq, min_freq=self.min_freq,
max_freq=self.preprocessor.max_freq, max_freq=self.max_freq,
) )
for detection, clip in zip(detections, clips) for detection, clip in zip(detections, clips)
] ]
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data def get_raw_predictions(
extraction, and geometry recovery via the configured output: ModelOutput,
`targets.recover_roi`. clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Parameters Processes raw model output through remapping, NMS, detection, data
---------- extraction, and geometry recovery via the configured
output : ModelOutput `targets.recover_roi`.
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns Parameters
------- ----------
List[List[RawPrediction]] output : ModelOutput
List of lists (one inner list per input clip). Each inner list Raw output from the neural network model for a batch.
contains `RawPrediction` objects for detections in that clip. clips : List[data.Clip]
""" List of `soundevent.data.Clip` objects corresponding to the batch.
detections = self.get_detections(output, clips)
return [ Returns
convert_detections_to_raw_predictions( -------
dataset, List[List[RawPrediction]]
targets=self.targets, List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detections = postprocessor.get_detections(output, clips)
return [
convert_detections_to_raw_predictions(
dataset,
targets=targets,
)
for dataset in detections
]
def get_sound_event_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
clips,
targets=targets,
postprocessor=postprocessor,
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
) )
for dataset in detections for raw in predictions
] ]
for predictions, clip in zip(raw_predictions, clips)
]
def get_sound_event_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions( def get_predictions(
self, output: ModelOutput, clips: List[data.Clip] output: ModelOutput,
) -> List[data.ClipPrediction]: clips: List[data.Clip],
"""Perform the full postprocessing pipeline for a batch. targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects. decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters Parameters
---------- ----------
output : ModelOutput output : ModelOutput
Raw output from the neural network model for a batch. Raw output from the neural network model for a batch.
clips : List[data.Clip] clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch. List of `soundevent.data.Clip` objects corresponding to the batch.
Returns Returns
------- -------
List[data.ClipPrediction] List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip, List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects. populated with `SoundEventPrediction` objects.
""" """
raw_predictions = self.get_raw_predictions(output, clips) raw_predictions = get_raw_predictions(
return [ output,
convert_raw_predictions_to_clip_prediction( clips,
prediction, targets=targets,
clip, postprocessor=postprocessor,
targets=self.targets, )
classification_threshold=self.config.classification_threshold, return [
) convert_raw_predictions_to_clip_prediction(
for prediction, clip in zip(raw_predictions, clips) prediction,
] clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -139,7 +139,21 @@ class FrequencyClip(torch.nn.Module):
self.high_index = high_index self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec[self.low_index : self.high_index] low_index = self.low_index
if low_index is None:
low_index = 0
if self.high_index is None:
length = spec.shape[-2] - low_index
else:
length = self.high_index - low_index
return torch.narrow(
spec,
dim=-2,
start=low_index,
length=length,
)
class PcenConfig(BaseConfig): class PcenConfig(BaseConfig):
@ -256,16 +270,22 @@ class ResizeSpec(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1] current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length) target_length = int(self.time_factor * current_length)
return (
torch.nn.functional.interpolate( original_ndim = spec.ndim
spec.unsqueeze(0).unsqueeze(0), while spec.ndim < 4:
size=(self.height, target_length), spec = spec.unsqueeze(0)
mode="bilinear",
) resized = torch.nn.functional.interpolate(
.squeeze(0) spec,
.squeeze(0) size=(self.height, target_length),
mode="bilinear",
) )
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
class PeakNormalizeConfig(BaseConfig): class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize" name: Literal["peak_normalize"] = "peak_normalize"

View File

@ -2,6 +2,7 @@ from batdetect2.train.augmentations import (
AugmentationsConfig, AugmentationsConfig,
EchoAugmentationConfig, EchoAugmentationConfig,
FrequencyMaskAugmentationConfig, FrequencyMaskAugmentationConfig,
RandomExampleSource,
TimeMaskAugmentationConfig, TimeMaskAugmentationConfig,
VolumeAugmentationConfig, VolumeAugmentationConfig,
WarpAugmentationConfig, WarpAugmentationConfig,
@ -23,7 +24,6 @@ from batdetect2.train.config import (
) )
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource,
list_preprocessed_files, list_preprocessed_files,
) )
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config

View File

@ -1,6 +1,7 @@
"""Applies data augmentation techniques to BatDetect2 training examples.""" """Applies data augmentation techniques to BatDetect2 training examples."""
import warnings import warnings
from collections.abc import Sequence
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
import numpy as np import numpy as np
@ -10,8 +11,12 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import Augmentation, PreprocessorProtocol from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width
__all__ = [ __all__ = [
@ -39,21 +44,6 @@ ExampleSource = Callable[[], PreprocessedExample]
"""Type alias for a function that returns a training example""" """Type alias for a function that returns a training example"""
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
def mix_examples( def mix_examples(
example: PreprocessedExample, example: PreprocessedExample,
other: PreprocessedExample, other: PreprocessedExample,
@ -149,7 +139,12 @@ def add_echo(
audio = example.audio audio = example.audio
delay_steps = int(preprocessor.input_samplerate * delay) delay_steps = int(preprocessor.input_samplerate * delay)
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
slices = [slice(None)] * audio.ndim
slices[-1] = slice(None, -delay_steps)
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
delay_steps, dims=-1
)
audio = audio + weight * audio_delay audio = audio + weight * audio_delay
spectrogram = preprocessor(audio) spectrogram = preprocessor(audio)
@ -184,7 +179,7 @@ class VolumeAugmentationConfig(BaseConfig):
class ScaleVolume(torch.nn.Module): class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float, max_scaling: float): def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__() super().__init__()
self.min_scaling = min_scaling self.min_scaling = min_scaling
self.max_scaling = max_scaling self.max_scaling = max_scaling
@ -228,32 +223,22 @@ def warp_spectrogram(
example: PreprocessedExample, factor: float example: PreprocessedExample, factor: float
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Apply time warping by resampling the time axis.""" """Apply time warping by resampling the time axis."""
target_shape = example.spectrogram.shape width = example.spectrogram.shape[-1]
height = example.spectrogram.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor) new_width = int(target_shape[-1] * factor)
spectrogram = ( spectrogram = torch.nn.functional.interpolate(
torch.nn.functional.interpolate( adjust_width(example.spectrogram, new_width).unsqueeze(0),
adjust_width(example.spectrogram, new_width) size=target_shape,
.unsqueeze(0) mode="bilinear",
.unsqueeze(0), ).squeeze(0)
size=target_shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
detection = ( detection = torch.nn.functional.interpolate(
torch.nn.functional.interpolate( adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
adjust_width(example.detection_heatmap, new_width) size=target_shape,
.unsqueeze(0) mode="nearest",
.unsqueeze(0), ).squeeze(0)
size=target_shape,
mode="nearest",
)
.squeeze(0)
.squeeze(0)
)
classification = torch.nn.functional.interpolate( classification = torch.nn.functional.interpolate(
adjust_width(example.class_heatmap, new_width).unsqueeze(1), adjust_width(example.class_heatmap, new_width).unsqueeze(1),
@ -284,10 +269,16 @@ class TimeMaskAugmentationConfig(BaseConfig):
class MaskTime(torch.nn.Module): class MaskTime(torch.nn.Module):
def __init__(self, max_perc: float = 0.05, max_masks: int = 3) -> None: def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__() super().__init__()
self.max_perc = max_perc self.max_perc = max_perc
self.max_masks = max_masks self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample: def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1) num_masks = np.random.randint(1, self.max_masks + 1)
@ -306,20 +297,28 @@ class MaskTime(torch.nn.Module):
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size) (start, start + size) for start, size in zip(mask_start, mask_size)
] ]
return mask_time(example, masks) return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_time( def mask_time(
example: PreprocessedExample, example: PreprocessedExample,
masks: List[Tuple[int, int]], masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Apply time masking to the spectrogram.""" """Apply time masking to the spectrogram."""
for start, end in masks: for start, end in masks:
example.spectrogram[:, start:end] = example.spectrogram.mean() slices = [slice(None)] * example.spectrogram.ndim
example.class_heatmap[:, :, start:end] = 0 slices[-1] = slice(start, end)
example.size_heatmap[:, :, start:end] = 0
example.detection_heatmap[:, start:end] = 0 example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample( return PreprocessedExample(
audio=example.audio, audio=example.audio,
@ -335,13 +334,20 @@ class FrequencyMaskAugmentationConfig(BaseConfig):
probability: float = 0.2 probability: float = 0.2
max_perc: float = 0.10 max_perc: float = 0.10
max_masks: int = 3 max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module): class MaskFrequency(torch.nn.Module):
def __init__(self, max_perc: float = 0.10, max_masks: int = 3) -> None: def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__() super().__init__()
self.max_perc = max_perc self.max_perc = max_perc
self.max_masks = max_masks self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample: def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1) num_masks = np.random.randint(1, self.max_masks + 1)
@ -360,19 +366,26 @@ class MaskFrequency(torch.nn.Module):
masks = [ masks = [
(start, start + size) for start, size in zip(mask_start, mask_size) (start, start + size) for start, size in zip(mask_start, mask_size)
] ]
return mask_frequency(example, masks) return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_frequency( def mask_frequency(
example: PreprocessedExample, example: PreprocessedExample,
masks: List[Tuple[int, int]], masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Apply frequency masking to the spectrogram.""" """Apply frequency masking to the spectrogram."""
for start, end in masks: for start, end in masks:
example.spectrogram[start:end, :] = example.spectrogram.mean() slices = [slice(None)] * example.spectrogram.ndim
example.class_heatmap[:, start:end, :] = 0 slices[-2] = slice(start, end)
example.size_heatmap[:, start:end, :] = 0 example.spectrogram[tuple(slices)] = 0
example.detection_heatmap[start:end, :] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample( return PreprocessedExample(
audio=example.audio, audio=example.audio,
@ -383,6 +396,50 @@ def mask_frequency(
) )
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
class MixAudio(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
min_weight: float = 0.3,
max_weight: float = 0.7,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
AugmentationConfig = Annotated[ AugmentationConfig = Annotated[
Union[ Union[
MixAugmentationConfig, MixAugmentationConfig,
@ -445,35 +502,6 @@ class MaybeApply(torch.nn.Module):
return self.augmentation(example) return self.augmentation(example)
class AudioMixer(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
min_weight: float,
max_weight: float,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
def build_augmentation_from_config( def build_augmentation_from_config(
config: AugmentationConfig, config: AugmentationConfig,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
@ -489,7 +517,7 @@ def build_augmentation_from_config(
) )
return None return None
return AudioMixer( return MixAudio(
example_source=example_source, example_source=example_source,
preprocessor=preprocessor, preprocessor=preprocessor,
min_weight=config.min_weight, min_weight=config.min_weight,
@ -585,3 +613,25 @@ def load_augmentation_config(
) -> AugmentationsConfig: ) -> AugmentationsConfig:
"""Load the augmentations configuration from a file.""" """Load the augmentations configuration from a file."""
return load_config(path, schema=AugmentationsConfig, field=field) return load_config(path, schema=AugmentationsConfig, field=field)
class RandomExampleSource:
def __init__(
self,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example
@classmethod
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol):
filenames = list_preprocessed_files(path)
return cls(filenames, clipper=clipper)

View File

@ -14,7 +14,9 @@ from batdetect2.evaluate.match import (
MatchConfig, MatchConfig,
match_sound_events_and_raw_predictions, match_sound_events_and_raw_predictions,
) )
from batdetect2.models import Model
from batdetect2.plotting.evaluation import plot_example_gallery from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions
from batdetect2.train.dataset import LabeledDataset from batdetect2.train.dataset import LabeledDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import ( from batdetect2.typing import (
@ -22,7 +24,6 @@ from batdetect2.typing import (
MatchEvaluation, MatchEvaluation,
MetricsProtocol, MetricsProtocol,
ModelOutput, ModelOutput,
PostprocessorProtocol,
TargetProtocol, TargetProtocol,
TrainExample, TrainExample,
) )
@ -127,8 +128,7 @@ class ValidationMetrics(Callback):
batch, batch,
outputs, outputs,
dataset=self.get_dataset(trainer), dataset=self.get_dataset(trainer),
postprocessor=pl_module.model.postprocessor, model=pl_module.model,
targets=pl_module.model.targets,
) )
) )
@ -137,15 +137,14 @@ def _get_batch_clips_and_predictions(
batch: TrainExample, batch: TrainExample,
outputs: ModelOutput, outputs: ModelOutput,
dataset: LabeledDataset, dataset: LabeledDataset,
postprocessor: PostprocessorProtocol, model: Model,
targets: TargetProtocol,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]: ) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [ clip_annotations = [
_get_subclip( _get_subclip(
dataset.get_clip_annotation(example_id), dataset.get_clip_annotation(example_id),
start_time=start_time.item(), start_time=start_time.item(),
end_time=end_time.item(), end_time=end_time.item(),
targets=targets, targets=model.targets,
) )
for example_id, start_time, end_time in zip( for example_id, start_time, end_time in zip(
batch.idx, batch.idx,
@ -156,9 +155,11 @@ def _get_batch_clips_and_predictions(
clips = [clip_annotation.clip for clip_annotation in clip_annotations] clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions( raw_predictions = get_sound_event_predictions(
outputs, outputs,
clips, clips,
targets=model.targets,
postprocessor=model.postprocessor
) )
return [ return [

View File

@ -8,7 +8,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width from batdetect2.utils.arrays import adjust_width, slice_tensor
DEFAULT_TRAIN_CLIP_DURATION = 0.512 DEFAULT_TRAIN_CLIP_DURATION = 0.512
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -90,7 +90,12 @@ def select_subclip(
audio_start = int(np.floor(start * input_samplerate)) audio_start = int(np.floor(start * input_samplerate))
audio = adjust_width( audio = adjust_width(
example.audio[audio_start : audio_start + audio_width], slice_tensor(
example.audio,
start=audio_start,
end=audio_start + audio_width,
dim=-1,
),
audio_width, audio_width,
value=fill_value, value=fill_value,
) )
@ -100,19 +105,39 @@ def select_subclip(
return PreprocessedExample( return PreprocessedExample(
audio=audio, audio=audio,
spectrogram=adjust_width( spectrogram=adjust_width(
example.spectrogram[:, spec_start : spec_start + spec_width], slice_tensor(
example.spectrogram,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
class_heatmap=adjust_width( class_heatmap=adjust_width(
example.class_heatmap[:, :, spec_start : spec_start + spec_width], slice_tensor(
example.class_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
detection_heatmap=adjust_width( detection_heatmap=adjust_width(
example.detection_heatmap[:, spec_start : spec_start + spec_width], slice_tensor(
example.detection_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
size_heatmap=adjust_width( size_heatmap=adjust_width(
example.size_heatmap[:, :, spec_start : spec_start + spec_width], slice_tensor(
example.size_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width, spec_width,
), ),
) )

View File

@ -44,8 +44,8 @@ class PLTrainerConfig(BaseConfig):
class DataLoaderConfig(BaseConfig): class DataLoaderConfig(BaseConfig):
batch_size: int batch_size: int = 8
shuffle: bool shuffle: bool = False
num_workers: int = 0 num_workers: int = 0

View File

@ -1,5 +1,4 @@
from pathlib import Path from typing import Optional, Sequence, Tuple
from typing import List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
@ -7,6 +6,10 @@ from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation from batdetect2.train.augmentations import Augmentation
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import ClipperProtocol, TrainExample from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.train import PreprocessedExample from batdetect2.typing.train import PreprocessedExample
@ -38,8 +41,8 @@ class LabeledDataset(Dataset):
example = self.augmentation(example) example = self.augmentation(example)
return TrainExample( return TrainExample(
spec=example.spectrogram.unsqueeze(0), spec=example.spectrogram,
detection_heatmap=example.detection_heatmap.unsqueeze(0), detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap, class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap, size_heatmap=example.size_heatmap,
idx=torch.tensor(idx), idx=torch.tensor(idx),
@ -73,37 +76,3 @@ class LabeledDataset(Dataset):
def get_clip_annotation(self, idx) -> data.ClipAnnotation: def get_clip_annotation(self, idx) -> data.ClipAnnotation:
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+") item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
return item["clip_annotation"].tolist() return item["clip_annotation"].tolist()
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
class RandomExampleSource:
def __init__(
self,
filenames: List[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example

View File

@ -41,7 +41,6 @@ from batdetect2.typing import (
__all__ = [ __all__ = [
"LabelConfig", "LabelConfig",
"build_clip_labeler", "build_clip_labeler",
"generate_clip_label",
"generate_heatmaps", "generate_heatmaps",
"load_label_config", "load_label_config",
] ]
@ -99,21 +98,26 @@ def build_clip_labeler(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return partial( return partial(
generate_clip_label, generate_heatmaps,
targets=targets, targets=targets,
config=config,
min_freq=min_freq, min_freq=min_freq,
max_freq=max_freq, max_freq=max_freq,
target_sigma=config.sigma,
) )
def generate_clip_label( def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
spec: torch.Tensor, spec: torch.Tensor,
targets: TargetProtocol, targets: TargetProtocol,
config: LabelConfig,
min_freq: float, min_freq: float,
max_freq: float, max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps: ) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip. """Generate training heatmaps for a single annotated clip.
@ -150,57 +154,14 @@ def generate_clip_label(
num=len(clip_annotation.sound_events), num=len(clip_annotation.sound_events),
) )
sound_events = [] height = spec.shape[-2]
width = spec.shape[-1]
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps(
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
spec=spec,
targets=targets,
target_sigma=config.sigma,
min_freq=min_freq,
max_freq=max_freq,
)
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps:
if not spec.ndim == 2:
raise ValueError(
"Expecting a 2-dimensional tensor of shape (H, W), "
"H is the height of the spectrogram "
"(frequency bins), and W is the width of the spectrogram "
f"(temporal bins). Instead got: {spec.shape}"
)
height, width = spec.shape
num_classes = len(targets.class_names) num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names) num_dims = len(targets.dimension_names)
clip = clip_annotation.clip clip = clip_annotation.clip
# Initialize heatmaps # Initialize heatmaps
detection_heatmap = torch.zeros([height, width], dtype=dtype) detection_heatmap = torch.zeros([1, height, width], dtype=dtype)
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype) class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype) size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
@ -214,6 +175,16 @@ def generate_heatmaps(
times = times.to(spec.device) times = times.to(spec.device)
for sound_event_annotation in clip_annotation.sound_events: for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_event_annotation = targets.transform(sound_event_annotation)
geom = sound_event_annotation.sound_event.geometry geom = sound_event_annotation.sound_event.geometry
if geom is None: if geom is None:
logger.debug( logger.debug(
@ -245,7 +216,10 @@ def generate_heatmaps(
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2 distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2)) gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob) detection_heatmap[0] = torch.maximum(
detection_heatmap[0],
gaussian_blob,
)
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:]) size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
# Get the class name of the sound event # Get the class name of the sound event

View File

@ -34,7 +34,7 @@ class TrainingModule(L.LightningModule):
return self.model(spec) return self.model(spec)
def training_step(self, batch: TrainExample): def training_step(self, batch: TrainExample):
outputs = self.model(batch.spec) outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True) self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True) self.log("detection_loss/train", losses.total, logger=True)
@ -47,7 +47,7 @@ class TrainingModule(L.LightningModule):
batch: TrainExample, batch: TrainExample,
batch_idx: int, batch_idx: int,
) -> ModelOutput: ) -> ModelOutput:
outputs = self.model(batch.spec) outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch) losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True) self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True) self.log("detection_loss/val", losses.total, logger=True)

View File

@ -2,7 +2,7 @@
import os import os
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Sequence, TypedDict from typing import Callable, List, Optional, Sequence, TypedDict
import numpy as np import numpy as np
import torch import torch
@ -28,6 +28,8 @@ __all__ = [
"preprocess_dataset", "preprocess_dataset",
"TrainPreprocessConfig", "TrainPreprocessConfig",
"load_train_preprocessing_config", "load_train_preprocessing_config",
"save_preprocessed_example",
"load_preprocessed_example",
] ]
FilenameFn = Callable[[data.ClipAnnotation], str] FilenameFn = Callable[[data.ClipAnnotation], str]
@ -94,8 +96,10 @@ def generate_train_example(
labeller: ClipLabeller, labeller: ClipLabeller,
) -> PreprocessedExample: ) -> PreprocessedExample:
"""Generate a complete training example for one annotation.""" """Generate a complete training example for one annotation."""
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip)) wave = torch.tensor(
spectrogram = preprocessor(wave) audio_loader.load_clip(clip_annotation.clip)
).unsqueeze(0)
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
heatmaps = labeller(clip_annotation, spectrogram) heatmaps = labeller(clip_annotation, spectrogram)
return PreprocessedExample( return PreprocessedExample(
audio=wave, audio=wave,
@ -145,7 +149,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
labeller=self.labeller, labeller=self.labeller,
) )
save_example_to_file(example, clip_annotation, path) save_preprocessed_example(example, clip_annotation, path)
return idx return idx
@ -153,7 +157,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
return len(self.clips) return len(self.clips)
def save_example_to_file( def save_preprocessed_example(
example: PreprocessedExample, example: PreprocessedExample,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
path: data.PathLike, path: data.PathLike,
@ -169,6 +173,23 @@ def save_example_to_file(
) )
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
def _get_filename(clip_annotation: data.ClipAnnotation) -> str: def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID.""" """Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}" return f"{clip_annotation.uuid}"

View File

@ -15,13 +15,15 @@ from batdetect2.evaluate.metrics import (
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import build_model from batdetect2.models import build_model
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import (
RandomExampleSource,
build_augmentations,
)
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource,
) )
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger

View File

@ -95,69 +95,10 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.""" """Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[Detections]: ...
def get_detections( def get_detections(
self, self,
output: ModelOutput, output: ModelOutput,
clips: Optional[List[data.Clip]] = None, clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ... ) -> List[Detections]: ...
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes the raw model output for a batch through remapping, NMS,
detection, data extraction, and geometry recovery to produce a list of
`RawPrediction` objects for each corresponding input clip. This provides
a simplified, intermediate representation before final tag decoding.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[List[RawPrediction]]
A list of lists (one inner list per input clip, in order). Each
inner list contains the `RawPrediction` objects extracted for the
corresponding input clip.
"""
...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output for a batch and corresponding clips, applies the
entire postprocessing chain, and returns the final, interpretable
predictions as a list of `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[data.ClipPrediction]
A list containing one `ClipPrediction` object for each input clip
(in the same order), populated with `SoundEventPrediction` objects
representing the final detections with decoded tags and geometry.
"""
...

View File

@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2. throughout BatDetect2.
""" """
from collections.abc import Callable from collections.abc import Callable, Iterable
from typing import List, Optional, Protocol from typing import List, Optional, Protocol, Tuple
import numpy as np import numpy as np
from soundevent import data from soundevent import data

View File

@ -1,3 +1,5 @@
from typing import Optional
import numpy as np import numpy as np
import torch import torch
import xarray as xr import xarray as xr
@ -80,3 +82,14 @@ def adjust_width(
for index in range(dims) for index in range(dims)
] ]
return tensor[tuple(slices)] return tensor[tuple(slices)]
def slice_tensor(
tensor: torch.Tensor,
start: Optional[int] = None,
end: Optional[int] = None,
dim: int = -1,
) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start, end)
return tensor[tuple(slices)]

View File

@ -38,7 +38,6 @@ def build_from_config(
max_freq=preprocessor.max_freq, max_freq=preprocessor.max_freq,
) )
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
targets,
preprocessor=preprocessor, preprocessor=preprocessor,
config=postprocessing_config, config=postprocessing_config,
) )