Make sure preprocessing is batchable

This commit is contained in:
mbsantiago 2025-08-27 23:58:38 +01:00
parent 0b5ac96fe8
commit 34ef9e92a1
17 changed files with 446 additions and 373 deletions

View File

@ -26,19 +26,32 @@ def create_ax(
def plot_spectrogram(
spec: Union[torch.Tensor, np.ndarray],
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
min_freq: Optional[float] = None,
max_freq: Optional[float] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
cmap="gray",
) -> axes.Axes:
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
ax = create_ax(ax=ax, figsize=figsize)
if start_time is None:
start_time = 0
if end_time is None:
end_time = spec.shape[-1]
if min_freq is None:
min_freq = 0
if max_freq is None:
max_freq = spec.shape[-2]
ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),

View File

@ -2,6 +2,7 @@
from typing import List, Optional
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
@ -20,13 +21,15 @@ from batdetect2.postprocess.nms import (
)
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing import ModelOutput, PreprocessorProtocol, TargetProtocol
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
Detections,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"DEFAULT_CLASSIFICATION_THRESHOLD",
@ -128,7 +131,6 @@ def load_postprocess_config(
def build_postprocessor(
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: Optional[PostprocessConfig] = None,
) -> PostprocessorProtocol:
@ -139,29 +141,52 @@ def build_postprocessor(
lambda: config.to_yaml_string(),
)
return Postprocessor(
targets=targets,
preprocessor=preprocessor,
config=config,
samplerate=preprocessor.output_samplerate,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
top_k_per_sec=config.top_k_per_sec,
detection_threshold=config.detection_threshold,
)
class Postprocessor(PostprocessorProtocol):
class Postprocessor(torch.nn.Module, PostprocessorProtocol):
"""Standard implementation of the postprocessing pipeline."""
targets: TargetProtocol
preprocessor: PreprocessorProtocol
def __init__(
self,
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: PostprocessConfig,
samplerate: float,
min_freq: float,
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
):
"""Initialize the Postprocessor."""
self.targets = targets
self.preprocessor = preprocessor
self.config = config
super().__init__()
self.samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
def forward(self, output: ModelOutput) -> List[Detections]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
return [
map_detection_to_clip(
detection,
start_time=0,
end_time=duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection in detections
]
def get_detections(
self,
@ -169,13 +194,13 @@ class Postprocessor(PostprocessorProtocol):
clips: Optional[List[data.Clip]] = None,
) -> List[Detections]:
width = output.detection_probs.shape[-1]
duration = width / self.preprocessor.output_samplerate
max_detections = int(self.config.top_k_per_sec * duration)
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.config.detection_threshold,
threshold=self.detection_threshold,
)
if clips is None:
@ -186,96 +211,116 @@ class Postprocessor(PostprocessorProtocol):
detection,
start_time=clip.start_time,
end_time=clip.end_time,
min_freq=self.preprocessor.min_freq,
max_freq=self.preprocessor.max_freq,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, clip in zip(detections, clips)
]
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes raw model output through remapping, NMS, detection, data
extraction, and geometry recovery via the configured
`targets.recover_roi`.
def get_raw_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Processes raw model output through remapping, NMS, detection, data
extraction, and geometry recovery via the configured
`targets.recover_roi`.
Returns
-------
List[List[RawPrediction]]
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detections = self.get_detections(output, clips)
return [
convert_detections_to_raw_predictions(
dataset,
targets=self.targets,
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[List[RawPrediction]]
List of lists (one inner list per input clip). Each inner list
contains `RawPrediction` objects for detections in that clip.
"""
detections = postprocessor.get_detections(output, clips)
return [
convert_detections_to_raw_predictions(
dataset,
targets=targets,
)
for dataset in detections
]
def get_sound_event_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
clips,
targets=targets,
postprocessor=postprocessor,
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
)
for dataset in detections
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_sound_event_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[BatDetect2Prediction]]:
raw_predictions = self.get_raw_predictions(output, clips)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
def get_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = self.get_raw_predictions(output, clips)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=self.targets,
classification_threshold=self.config.classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = get_raw_predictions(
output,
clips,
targets=targets,
postprocessor=postprocessor,
)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -139,7 +139,21 @@ class FrequencyClip(torch.nn.Module):
self.high_index = high_index
def forward(self, spec: torch.Tensor) -> torch.Tensor:
return spec[self.low_index : self.high_index]
low_index = self.low_index
if low_index is None:
low_index = 0
if self.high_index is None:
length = spec.shape[-2] - low_index
else:
length = self.high_index - low_index
return torch.narrow(
spec,
dim=-2,
start=low_index,
length=length,
)
class PcenConfig(BaseConfig):
@ -256,16 +270,22 @@ class ResizeSpec(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
return (
torch.nn.functional.interpolate(
spec.unsqueeze(0).unsqueeze(0),
size=(self.height, target_length),
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
original_ndim = spec.ndim
while spec.ndim < 4:
spec = spec.unsqueeze(0)
resized = torch.nn.functional.interpolate(
spec,
size=(self.height, target_length),
mode="bilinear",
)
while resized.ndim != original_ndim:
resized = resized.squeeze(0)
return resized
class PeakNormalizeConfig(BaseConfig):
name: Literal["peak_normalize"] = "peak_normalize"

View File

@ -2,6 +2,7 @@ from batdetect2.train.augmentations import (
AugmentationsConfig,
EchoAugmentationConfig,
FrequencyMaskAugmentationConfig,
RandomExampleSource,
TimeMaskAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
@ -23,7 +24,6 @@ from batdetect2.train.config import (
)
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
list_preprocessed_files,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config

View File

@ -1,6 +1,7 @@
"""Applies data augmentation techniques to BatDetect2 training examples."""
import warnings
from collections.abc import Sequence
from typing import Annotated, Callable, List, Literal, Optional, Tuple, Union
import numpy as np
@ -10,8 +11,12 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import Augmentation, PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
from batdetect2.utils.arrays import adjust_width
__all__ = [
@ -39,21 +44,6 @@ ExampleSource = Callable[[], PreprocessedExample]
"""Type alias for a function that returns a training example"""
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
def mix_examples(
example: PreprocessedExample,
other: PreprocessedExample,
@ -149,7 +139,12 @@ def add_echo(
audio = example.audio
delay_steps = int(preprocessor.input_samplerate * delay)
audio_delay = adjust_width(audio[delay_steps:], audio.shape[-1])
slices = [slice(None)] * audio.ndim
slices[-1] = slice(None, -delay_steps)
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
delay_steps, dims=-1
)
audio = audio + weight * audio_delay
spectrogram = preprocessor(audio)
@ -184,7 +179,7 @@ class VolumeAugmentationConfig(BaseConfig):
class ScaleVolume(torch.nn.Module):
def __init__(self, min_scaling: float, max_scaling: float):
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
super().__init__()
self.min_scaling = min_scaling
self.max_scaling = max_scaling
@ -228,32 +223,22 @@ def warp_spectrogram(
example: PreprocessedExample, factor: float
) -> PreprocessedExample:
"""Apply time warping by resampling the time axis."""
target_shape = example.spectrogram.shape
width = example.spectrogram.shape[-1]
height = example.spectrogram.shape[-2]
target_shape = [height, width]
new_width = int(target_shape[-1] * factor)
spectrogram = (
torch.nn.functional.interpolate(
adjust_width(example.spectrogram, new_width)
.unsqueeze(0)
.unsqueeze(0),
size=target_shape,
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)
spectrogram = torch.nn.functional.interpolate(
adjust_width(example.spectrogram, new_width).unsqueeze(0),
size=target_shape,
mode="bilinear",
).squeeze(0)
detection = (
torch.nn.functional.interpolate(
adjust_width(example.detection_heatmap, new_width)
.unsqueeze(0)
.unsqueeze(0),
size=target_shape,
mode="nearest",
)
.squeeze(0)
.squeeze(0)
)
detection = torch.nn.functional.interpolate(
adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
size=target_shape,
mode="nearest",
).squeeze(0)
classification = torch.nn.functional.interpolate(
adjust_width(example.class_heatmap, new_width).unsqueeze(1),
@ -284,10 +269,16 @@ class TimeMaskAugmentationConfig(BaseConfig):
class MaskTime(torch.nn.Module):
def __init__(self, max_perc: float = 0.05, max_masks: int = 3) -> None:
def __init__(
self,
max_perc: float = 0.05,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
@ -306,20 +297,28 @@ class MaskTime(torch.nn.Module):
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_time(example, masks)
return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_time(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply time masking to the spectrogram."""
for start, end in masks:
example.spectrogram[:, start:end] = example.spectrogram.mean()
example.class_heatmap[:, :, start:end] = 0
example.size_heatmap[:, :, start:end] = 0
example.detection_heatmap[:, start:end] = 0
slices = [slice(None)] * example.spectrogram.ndim
slices[-1] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
@ -335,13 +334,20 @@ class FrequencyMaskAugmentationConfig(BaseConfig):
probability: float = 0.2
max_perc: float = 0.10
max_masks: int = 3
mask_heatmaps: bool = False
class MaskFrequency(torch.nn.Module):
def __init__(self, max_perc: float = 0.10, max_masks: int = 3) -> None:
def __init__(
self,
max_perc: float = 0.10,
max_masks: int = 3,
mask_heatmaps: bool = False,
) -> None:
super().__init__()
self.max_perc = max_perc
self.max_masks = max_masks
self.mask_heatmaps = mask_heatmaps
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
num_masks = np.random.randint(1, self.max_masks + 1)
@ -360,19 +366,26 @@ class MaskFrequency(torch.nn.Module):
masks = [
(start, start + size) for start, size in zip(mask_start, mask_size)
]
return mask_frequency(example, masks)
return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
def mask_frequency(
example: PreprocessedExample,
masks: List[Tuple[int, int]],
mask_heatmaps: bool = False,
) -> PreprocessedExample:
"""Apply frequency masking to the spectrogram."""
for start, end in masks:
example.spectrogram[start:end, :] = example.spectrogram.mean()
example.class_heatmap[:, start:end, :] = 0
example.size_heatmap[:, start:end, :] = 0
example.detection_heatmap[start:end, :] = 0
slices = [slice(None)] * example.spectrogram.ndim
slices[-2] = slice(start, end)
example.spectrogram[tuple(slices)] = 0
if not mask_heatmaps:
continue
example.class_heatmap[tuple(slices)] = 0
example.size_heatmap[tuple(slices)] = 0
example.detection_heatmap[tuple(slices)] = 0
return PreprocessedExample(
audio=example.audio,
@ -383,6 +396,50 @@ def mask_frequency(
)
class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
class MixAudio(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
min_weight: float = 0.3,
max_weight: float = 0.7,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
AugmentationConfig = Annotated[
Union[
MixAugmentationConfig,
@ -445,35 +502,6 @@ class MaybeApply(torch.nn.Module):
return self.augmentation(example)
class AudioMixer(torch.nn.Module):
"""Callable class for MixUp augmentation, handling example fetching."""
def __init__(
self,
min_weight: float,
max_weight: float,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
):
"""Initialize the AudioMixer."""
super().__init__()
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
"""Fetch another example and perform mixup."""
other = self.example_source()
weight = np.random.uniform(self.min_weight, self.max_weight)
return mix_examples(
example,
other,
self.preprocessor,
weight=weight,
)
def build_augmentation_from_config(
config: AugmentationConfig,
preprocessor: PreprocessorProtocol,
@ -489,7 +517,7 @@ def build_augmentation_from_config(
)
return None
return AudioMixer(
return MixAudio(
example_source=example_source,
preprocessor=preprocessor,
min_weight=config.min_weight,
@ -585,3 +613,25 @@ def load_augmentation_config(
) -> AugmentationsConfig:
"""Load the augmentations configuration from a file."""
return load_config(path, schema=AugmentationsConfig, field=field)
class RandomExampleSource:
def __init__(
self,
filenames: Sequence[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example
@classmethod
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol):
filenames = list_preprocessed_files(path)
return cls(filenames, clipper=clipper)

View File

@ -14,7 +14,9 @@ from batdetect2.evaluate.match import (
MatchConfig,
match_sound_events_and_raw_predictions,
)
from batdetect2.models import Model
from batdetect2.plotting.evaluation import plot_example_gallery
from batdetect2.postprocess import get_sound_event_predictions
from batdetect2.train.dataset import LabeledDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import (
@ -22,7 +24,6 @@ from batdetect2.typing import (
MatchEvaluation,
MetricsProtocol,
ModelOutput,
PostprocessorProtocol,
TargetProtocol,
TrainExample,
)
@ -127,8 +128,7 @@ class ValidationMetrics(Callback):
batch,
outputs,
dataset=self.get_dataset(trainer),
postprocessor=pl_module.model.postprocessor,
targets=pl_module.model.targets,
model=pl_module.model,
)
)
@ -137,15 +137,14 @@ def _get_batch_clips_and_predictions(
batch: TrainExample,
outputs: ModelOutput,
dataset: LabeledDataset,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
model: Model,
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
clip_annotations = [
_get_subclip(
dataset.get_clip_annotation(example_id),
start_time=start_time.item(),
end_time=end_time.item(),
targets=targets,
targets=model.targets,
)
for example_id, start_time, end_time in zip(
batch.idx,
@ -156,9 +155,11 @@ def _get_batch_clips_and_predictions(
clips = [clip_annotation.clip for clip_annotation in clip_annotations]
raw_predictions = postprocessor.get_sound_event_predictions(
raw_predictions = get_sound_event_predictions(
outputs,
clips,
targets=model.targets,
postprocessor=model.postprocessor
)
return [

View File

@ -8,7 +8,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.typing import ClipperProtocol
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.train import PreprocessedExample
from batdetect2.utils.arrays import adjust_width
from batdetect2.utils.arrays import adjust_width, slice_tensor
DEFAULT_TRAIN_CLIP_DURATION = 0.512
DEFAULT_MAX_EMPTY_CLIP = 0.1
@ -90,7 +90,12 @@ def select_subclip(
audio_start = int(np.floor(start * input_samplerate))
audio = adjust_width(
example.audio[audio_start : audio_start + audio_width],
slice_tensor(
example.audio,
start=audio_start,
end=audio_start + audio_width,
dim=-1,
),
audio_width,
value=fill_value,
)
@ -100,19 +105,39 @@ def select_subclip(
return PreprocessedExample(
audio=audio,
spectrogram=adjust_width(
example.spectrogram[:, spec_start : spec_start + spec_width],
slice_tensor(
example.spectrogram,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
class_heatmap=adjust_width(
example.class_heatmap[:, :, spec_start : spec_start + spec_width],
slice_tensor(
example.class_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
detection_heatmap=adjust_width(
example.detection_heatmap[:, spec_start : spec_start + spec_width],
slice_tensor(
example.detection_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
size_heatmap=adjust_width(
example.size_heatmap[:, :, spec_start : spec_start + spec_width],
slice_tensor(
example.size_heatmap,
start=spec_start,
end=spec_start + spec_width,
dim=-1,
),
spec_width,
),
)

View File

@ -44,8 +44,8 @@ class PLTrainerConfig(BaseConfig):
class DataLoaderConfig(BaseConfig):
batch_size: int
shuffle: bool
batch_size: int = 8
shuffle: bool = False
num_workers: int = 0

View File

@ -1,5 +1,4 @@
from pathlib import Path
from typing import List, Optional, Sequence, Tuple
from typing import Optional, Sequence, Tuple
import numpy as np
import torch
@ -7,6 +6,10 @@ from soundevent import data
from torch.utils.data import Dataset
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.preprocess import (
list_preprocessed_files,
load_preprocessed_example,
)
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.train import PreprocessedExample
@ -38,8 +41,8 @@ class LabeledDataset(Dataset):
example = self.augmentation(example)
return TrainExample(
spec=example.spectrogram.unsqueeze(0),
detection_heatmap=example.detection_heatmap.unsqueeze(0),
spec=example.spectrogram,
detection_heatmap=example.detection_heatmap,
class_heatmap=example.class_heatmap,
size_heatmap=example.size_heatmap,
idx=torch.tensor(idx),
@ -73,37 +76,3 @@ class LabeledDataset(Dataset):
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
return item["clip_annotation"].tolist()
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
class RandomExampleSource:
def __init__(
self,
filenames: List[data.PathLike],
clipper: ClipperProtocol,
):
self.filenames = filenames
self.clipper = clipper
def __call__(self) -> PreprocessedExample:
index = int(np.random.randint(len(self.filenames)))
filename = self.filenames[index]
example = load_preprocessed_example(filename)
example, _, _ = self.clipper(example)
return example

View File

@ -41,7 +41,6 @@ from batdetect2.typing import (
__all__ = [
"LabelConfig",
"build_clip_labeler",
"generate_clip_label",
"generate_heatmaps",
"load_label_config",
]
@ -99,21 +98,26 @@ def build_clip_labeler(
lambda: config.to_yaml_string(),
)
return partial(
generate_clip_label,
generate_heatmaps,
targets=targets,
config=config,
min_freq=min_freq,
max_freq=max_freq,
target_sigma=config.sigma,
)
def generate_clip_label(
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
config: LabelConfig,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip.
@ -150,57 +154,14 @@ def generate_clip_label(
num=len(clip_annotation.sound_events),
)
sound_events = []
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps(
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
spec=spec,
targets=targets,
target_sigma=config.sigma,
min_freq=min_freq,
max_freq=max_freq,
)
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
dtype=torch.float32,
) -> Heatmaps:
if not spec.ndim == 2:
raise ValueError(
"Expecting a 2-dimensional tensor of shape (H, W), "
"H is the height of the spectrogram "
"(frequency bins), and W is the width of the spectrogram "
f"(temporal bins). Instead got: {spec.shape}"
)
height, width = spec.shape
height = spec.shape[-2]
width = spec.shape[-1]
num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names)
clip = clip_annotation.clip
# Initialize heatmaps
detection_heatmap = torch.zeros([height, width], dtype=dtype)
detection_heatmap = torch.zeros([1, height, width], dtype=dtype)
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
@ -214,6 +175,16 @@ def generate_heatmaps(
times = times.to(spec.device)
for sound_event_annotation in clip_annotation.sound_events:
if not targets.filter(sound_event_annotation):
logger.debug(
"Sound event {sound_event} did not pass the filter. Tags: {tags}",
sound_event=sound_event_annotation,
tags=sound_event_annotation.tags,
)
continue
sound_event_annotation = targets.transform(sound_event_annotation)
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
@ -245,7 +216,10 @@ def generate_heatmaps(
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob)
detection_heatmap[0] = torch.maximum(
detection_heatmap[0],
gaussian_blob,
)
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
# Get the class name of the sound event

View File

@ -34,7 +34,7 @@ class TrainingModule(L.LightningModule):
return self.model(spec)
def training_step(self, batch: TrainExample):
outputs = self.model(batch.spec)
outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/train", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/train", losses.total, logger=True)
@ -47,7 +47,7 @@ class TrainingModule(L.LightningModule):
batch: TrainExample,
batch_idx: int,
) -> ModelOutput:
outputs = self.model(batch.spec)
outputs = self.model.detector(batch.spec)
losses = self.loss(outputs, batch)
self.log("total_loss/val", losses.total, prog_bar=True, logger=True)
self.log("detection_loss/val", losses.total, logger=True)

View File

@ -2,7 +2,7 @@
import os
from pathlib import Path
from typing import Callable, Optional, Sequence, TypedDict
from typing import Callable, List, Optional, Sequence, TypedDict
import numpy as np
import torch
@ -28,6 +28,8 @@ __all__ = [
"preprocess_dataset",
"TrainPreprocessConfig",
"load_train_preprocessing_config",
"save_preprocessed_example",
"load_preprocessed_example",
]
FilenameFn = Callable[[data.ClipAnnotation], str]
@ -94,8 +96,10 @@ def generate_train_example(
labeller: ClipLabeller,
) -> PreprocessedExample:
"""Generate a complete training example for one annotation."""
wave = torch.tensor(audio_loader.load_clip(clip_annotation.clip))
spectrogram = preprocessor(wave)
wave = torch.tensor(
audio_loader.load_clip(clip_annotation.clip)
).unsqueeze(0)
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
heatmaps = labeller(clip_annotation, spectrogram)
return PreprocessedExample(
audio=wave,
@ -145,7 +149,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
labeller=self.labeller,
)
save_example_to_file(example, clip_annotation, path)
save_preprocessed_example(example, clip_annotation, path)
return idx
@ -153,7 +157,7 @@ class PreprocessingDataset(torch.utils.data.Dataset):
return len(self.clips)
def save_example_to_file(
def save_preprocessed_example(
example: PreprocessedExample,
clip_annotation: data.ClipAnnotation,
path: data.PathLike,
@ -169,6 +173,23 @@ def save_example_to_file(
)
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}"

View File

@ -15,13 +15,15 @@ from batdetect2.evaluate.metrics import (
DetectionAveragePrecision,
)
from batdetect2.models import build_model
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.augmentations import (
RandomExampleSource,
build_augmentations,
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import (
LabeledDataset,
RandomExampleSource,
)
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger

View File

@ -95,69 +95,10 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[Detections]: ...
def get_detections(
self,
output: ModelOutput,
clips: Optional[List[data.Clip]] = None,
) -> List[Detections]: ...
def get_raw_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch.
Processes the raw model output for a batch through remapping, NMS,
detection, data extraction, and geometry recovery to produce a list of
`RawPrediction` objects for each corresponding input clip. This provides
a simplified, intermediate representation before final tag decoding.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[List[RawPrediction]]
A list of lists (one inner list per input clip, in order). Each
inner list contains the `RawPrediction` objects extracted for the
corresponding input clip.
"""
...
def get_sound_event_predictions(
self, output: ModelOutput, clips: List[data.Clip]
) -> List[List[BatDetect2Prediction]]: ...
def get_predictions(
self,
output: ModelOutput,
clips: List[data.Clip],
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output for a batch and corresponding clips, applies the
entire postprocessing chain, and returns the final, interpretable
predictions as a list of `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
The raw output from the neural network model for a batch.
clips : List[data.Clip]
A list of `soundevent.data.Clip` objects corresponding to the batch
items, providing context. Must match the batch size of `output`.
Returns
-------
List[data.ClipPrediction]
A list containing one `ClipPrediction` object for each input clip
(in the same order), populated with `SoundEventPrediction` objects
representing the final detections with decoded tags and geometry.
"""
...

View File

@ -12,8 +12,8 @@ that components responsible for these tasks can be interacted with consistently
throughout BatDetect2.
"""
from collections.abc import Callable
from typing import List, Optional, Protocol
from collections.abc import Callable, Iterable
from typing import List, Optional, Protocol, Tuple
import numpy as np
from soundevent import data

View File

@ -1,3 +1,5 @@
from typing import Optional
import numpy as np
import torch
import xarray as xr
@ -80,3 +82,14 @@ def adjust_width(
for index in range(dims)
]
return tensor[tuple(slices)]
def slice_tensor(
tensor: torch.Tensor,
start: Optional[int] = None,
end: Optional[int] = None,
dim: int = -1,
) -> torch.Tensor:
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start, end)
return tensor[tuple(slices)]

View File

@ -38,7 +38,6 @@ def build_from_config(
max_freq=preprocessor.max_freq,
)
postprocessor = build_postprocessor(
targets,
preprocessor=preprocessor,
config=postprocessing_config,
)