Fixing imports after restructuring

This commit is contained in:
mbsantiago 2025-04-22 00:36:34 +01:00
parent dcae411ccb
commit 7c89e82579
10 changed files with 991 additions and 506 deletions

View File

@ -10,13 +10,13 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models import ( from batdetect2.models import (
BackboneConfig,
BBoxHead, BBoxHead,
ClassifierHead, ClassifierHead,
ModelConfig,
ModelOutput, ModelOutput,
build_architecture, build_backbone,
) )
from batdetect2.post_process import ( from batdetect2.postprocess import (
PostprocessConfig, PostprocessConfig,
postprocess_model_outputs, postprocess_model_outputs,
) )
@ -37,7 +37,7 @@ __all__ = [
class ModuleConfig(BaseConfig): class ModuleConfig(BaseConfig):
train: TrainingConfig = Field(default_factory=TrainingConfig) train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig) targets: TargetConfig = Field(default_factory=TargetConfig)
architecture: ModelConfig = Field(default_factory=ModelConfig) architecture: BackboneConfig = Field(default_factory=BackboneConfig)
preprocessing: PreprocessingConfig = Field( preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
@ -58,7 +58,7 @@ class DetectorModel(L.LightningModule):
self.config = config or ModuleConfig() self.config = config or ModuleConfig()
self.save_hyperparameters() self.save_hyperparameters()
self.backbone = build_architecture(self.config.architecture) self.backbone = build_model_backbone(self.config.architecture)
self.classifier = ClassifierHead( self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes), num_classes=len(self.config.targets.classes),

View File

@ -1,4 +1,30 @@
from typing import Callable, Optional, Union """Applies data augmentation techniques to BatDetect2 training examples.
This module provides various functions and configurable components for applying
data augmentation to the training examples (`xr.Dataset` containing audio,
spectrogram, and target heatmaps) generated by the
`batdetect2.train.preprocess` module.
Data augmentation artificially increases the diversity of the training data by
applying random transformations, which generally helps improve the robustness
and generalization performance of trained models.
Augmentations included:
- Time-based: `select_subclip`, `warp_spectrogram`, `mask_time`.
- Amplitude/Noise-based: `mix_examples`, `add_echo`, `scale_volume`,
`mask_frequency`.
Some augmentations modify the audio waveform and require recomputing the
spectrogram using the `PreprocessorProtocol`, while others operate directly
on the spectrogram or target heatmaps. The entire augmentation pipeline can be
configured using the `AugmentationsConfig` class, specifying a sequence of
augmentation steps, each with its own parameters and application probability.
The `build_augmentations` function constructs the final augmentation callable
from this configuration.
"""
from functools import partial
from typing import Annotated, Callable, List, Literal, Optional, Union
import numpy as np import numpy as np
import xarray as xr import xarray as xr
@ -6,15 +32,14 @@ from pydantic import Field
from soundevent import arrays, data from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.preprocess.arrays import adjust_width from batdetect2.train.types import Augmentation
from batdetect2.utils.arrays import adjust_width
Augmentation = Callable[[xr.Dataset], xr.Dataset]
__all__ = [ __all__ = [
"AugmentationsConfig", "AugmentationsConfig",
"load_agumentation_config", "load_augmentation_config",
"build_augmentations",
"select_subclip", "select_subclip",
"mix_examples", "mix_examples",
"add_echo", "add_echo",
@ -22,13 +47,21 @@ __all__ = [
"warp_spectrogram", "warp_spectrogram",
"mask_time", "mask_time",
"mask_frequency", "mask_frequency",
"augment_example", "MixAugmentationConfig",
"EchoAugmentationConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"TimeMaskAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"AugmentationConfig",
"ExampleSource",
] ]
ExampleSource = Callable[[], xr.Dataset]
"""Type alias for a function that returns a training example (`xr.Dataset`).
class BaseAugmentationConfig(BaseConfig): Used by the `mix_examples` augmentation to fetch another example to mix with.
enable: bool = True """
probability: float = 0.2
def select_subclip( def select_subclip(
@ -38,7 +71,47 @@ def select_subclip(
width: Optional[int] = None, width: Optional[int] = None,
random: bool = False, random: bool = False,
) -> xr.Dataset: ) -> xr.Dataset:
"""Select a random subclip from a clip.""" """Extract a sub-clip (time segment) from a training example dataset.
Selects a portion of the 'time' dimension from all relevant DataArrays
(`audio`, `spectrogram`, `detection`, `class`, `size`) within the example
Dataset. The segment can be defined by a fixed start time and
duration/width, or a random start time can be chosen.
Parameters
----------
example : xr.Dataset
The input training example containing 'audio', 'spectrogram', and
target heatmaps, all with compatible 'time' (or 'audio_time')
coordinates.
start_time : float, optional
Desired start time (seconds) of the subclip. If None and `random` is
False, starts from the beginning of the example. If None and `random`
is True, a random start time is chosen.
duration : float, optional
Desired duration (seconds) of the subclip. Either `duration` or `width`
must be provided.
width : int, optional
Desired width (number of time bins) of the subclip's
spectrogram/heatmaps. Either `duration` or `width` must be provided. If
both are given, `duration` takes precedence.
random : bool, default=False
If True and `start_time` is None, selects a random start time ensuring
the subclip fits within the original example's duration.
Returns
-------
xr.Dataset
A new dataset containing only the selected time segment. Coordinates
are adjusted accordingly. Returns the original example if the requested
subclip cannot be extracted (e.g., duration too long).
Raises
------
ValueError
If neither `duration` nor `width` is provided, or if time coordinates
are missing or invalid.
"""
step = arrays.get_dim_step(example, "time") # type: ignore step = arrays.get_dim_step(example, "time") # type: ignore
start, stop = arrays.get_dim_range(example, "time") # type: ignore start, stop = arrays.get_dim_range(example, "time") # type: ignore
@ -77,22 +150,62 @@ def select_subclip(
) )
class MixAugmentationConfig(BaseAugmentationConfig): class MixAugmentationConfig(BaseConfig):
"""Configuration for MixUp augmentation (mixing two examples)."""
augmentation_type: Literal["mix_audio"] = "mix_audio"
probability: float = 0.2
"""Probability of applying this augmentation to an example."""
min_weight: float = 0.3 min_weight: float = 0.3
"""Minimum mixing weight (lambda) applied to the primary example."""
max_weight: float = 0.7 max_weight: float = 0.7
"""Maximum mixing weight (lambda) applied to the primary example."""
def mix_examples( def mix_examples(
example: xr.Dataset, example: xr.Dataset,
other: xr.Dataset, other: xr.Dataset,
preprocessor: PreprocessorProtocol,
weight: Optional[float] = None, weight: Optional[float] = None,
min_weight: float = 0.3, min_weight: float = 0.3,
max_weight: float = 0.7, max_weight: float = 0.7,
config: Optional[PreprocessingConfig] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Combine two audio clips.""" """Combine two training examples using MixUp augmentation.
config = config or PreprocessingConfig()
Performs a weighted linear combination of the audio waveforms from two
examples (`example` and `other`). The spectrogram is then *recomputed*
from the mixed audio using the provided `preprocessor`. Target heatmaps
(detection, class) are combined by taking the element-wise maximum. Target
size heatmaps are combined by element-wise addition.
Parameters
----------
example : xr.Dataset
The primary training example.
other : xr.Dataset
The second training example to mix with `example`.
preprocessor : PreprocessorProtocol
The preprocessor used to recompute the spectrogram from mixed audio.
Ensures consistency with original preprocessing.
weight : float, optional
The mixing weight (lambda) applied to the primary `example`. The weight
for `other` will be `(1 - weight)`. If None, a random weight is chosen
uniformly between `min_weight` and `max_weight`.
min_weight : float, default=0.3
Minimum value for the random weight lambda.
max_weight : float, default=0.7
Maximum value for the random weight lambda. Must be >= `min_weight`.
Returns
-------
xr.Dataset
A new dataset representing the mixed example, with combined audio,
recomputed spectrogram, and combined target heatmaps. Attributes from
the primary `example` are preserved.
"""
if weight is None: if weight is None:
weight = np.random.uniform(min_weight, max_weight) weight = np.random.uniform(min_weight, max_weight)
@ -101,9 +214,8 @@ def mix_examples(
combined = weight * audio1 + (1 - weight) * audio2 combined = weight * audio1 + (1 - weight) * audio2
spectrogram = compute_spectrogram( spectrogram = preprocessor.compute_spectrogram(
combined.rename({"audio_time": "time"}), combined.rename({"audio_time": "time"})
config=config.spectrogram,
).data ).data
# NOTE: The subclip's spectrogram might be slightly longer than the # NOTE: The subclip's spectrogram might be slightly longer than the
@ -147,7 +259,14 @@ def mix_examples(
) )
class EchoAugmentationConfig(BaseAugmentationConfig): class EchoAugmentationConfig(BaseConfig):
"""Configuration for adding synthetic echo/reverb."""
augmentation_type: Literal["add_echo"] = "add_echo"
probability: float = 0.2
"""Probability of applying this augmentation."""
max_delay: float = 0.005 max_delay: float = 0.005
min_weight: float = 0.0 min_weight: float = 0.0
max_weight: float = 1.0 max_weight: float = 1.0
@ -155,15 +274,45 @@ class EchoAugmentationConfig(BaseAugmentationConfig):
def add_echo( def add_echo(
example: xr.Dataset, example: xr.Dataset,
preprocessor: PreprocessorProtocol,
delay: Optional[float] = None, delay: Optional[float] = None,
weight: Optional[float] = None, weight: Optional[float] = None,
min_weight: float = 0.1, min_weight: float = 0.1,
max_weight: float = 1.0, max_weight: float = 1.0,
max_delay: float = 0.005, max_delay: float = 0.005,
config: Optional[PreprocessingConfig] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Add a delay to the audio.""" """Add a synthetic echo to the audio waveform.
config = config or PreprocessingConfig()
Creates an echo by adding a delayed and attenuated version of the original
audio waveform back onto itself. The spectrogram is then recomputed from
the modified audio. Target heatmaps remain unchanged.
Parameters
----------
example : xr.Dataset
The input training example containing 'audio', 'spectrogram', etc.
preprocessor : PreprocessorProtocol
Preprocessor used to recompute the spectrogram from the modified audio.
delay : float, optional
The delay time in seconds for the echo. If None, chosen randomly
between 0 and `max_delay`.
weight : float, optional
The relative amplitude (weight) of the echo compared to the original.
If None, chosen randomly between `min_weight` and `max_weight`.
min_weight : float, default=0.1
Minimum value for the random echo weight.
max_weight : float, default=1.0
Maximum value for the random echo weight. Must be >= `min_weight`.
max_delay : float, default=0.005
Maximum value for the random echo delay in seconds. Must be >= 0.
Returns
-------
xr.Dataset
A new dataset with the echo added to the 'audio' variable and the
'spectrogram' variable recomputed. Other variables (targets, attrs)
are copied from the original example.
"""
if delay is None: if delay is None:
delay = np.random.uniform(0, max_delay) delay = np.random.uniform(0, max_delay)
@ -176,9 +325,8 @@ def add_echo(
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0) audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
audio = audio + weight * audio_delay audio = audio + weight * audio_delay
spectrogram = compute_spectrogram( spectrogram = preprocessor.compute_spectrogram(
audio.rename({"audio_time": "time"}), audio.rename({"audio_time": "time"}),
config=config.spectrogram,
).data ).data
# NOTE: The subclip's spectrogram might be slightly longer than the # NOTE: The subclip's spectrogram might be slightly longer than the
@ -202,7 +350,11 @@ def add_echo(
) )
class VolumeAugmentationConfig(BaseAugmentationConfig): class VolumeAugmentationConfig(BaseConfig):
"""Configuration for random volume scaling of the spectrogram."""
augmentation_type: Literal["scale_volume"] = "scale_volume"
probability: float = 0.2
min_scaling: float = 0.0 min_scaling: float = 0.0
max_scaling: float = 2.0 max_scaling: float = 2.0
@ -213,14 +365,44 @@ def scale_volume(
max_scaling: float = 2, max_scaling: float = 2,
min_scaling: float = 0, min_scaling: float = 0,
) -> xr.Dataset: ) -> xr.Dataset:
"""Scale the volume of a spectrogram.""" """Scale the amplitude of the spectrogram by a random factor.
Multiplies the entire spectrogram DataArray by a scaling factor chosen
uniformly between `min_scaling` and `max_scaling`. This simulates changes
in recording volume or distance to the sound source directly in the
spectrogram domain. Audio and target heatmaps are unchanged.
Parameters
----------
example : xr.Dataset
The input training example containing 'spectrogram'.
factor : float, optional
The scaling factor to apply. If None, chosen randomly between
`min_scaling` and `max_scaling`.
min_scaling : float, default=0.0
Minimum value for the random scaling factor. Must be non-negative.
max_scaling : float, default=2.0
Maximum value for the random scaling factor. Must be >= `min_scaling`.
Returns
-------
xr.Dataset
A new dataset with the 'spectrogram' variable scaled.
Raises
------
ValueError
If `min_scaling` > `max_scaling` or if `min_scaling` is negative.
"""
if factor is None: if factor is None:
factor = np.random.uniform(min_scaling, max_scaling) factor = np.random.uniform(min_scaling, max_scaling)
return example.assign(spectrogram=example["spectrogram"] * factor) return example.assign(spectrogram=example["spectrogram"] * factor)
class WarpAugmentationConfig(BaseAugmentationConfig): class WarpAugmentationConfig(BaseConfig):
augmentation_type: Literal["warp"] = "warp"
probability: float = 0.2
delta: float = 0.04 delta: float = 0.04
@ -229,11 +411,39 @@ def warp_spectrogram(
factor: Optional[float] = None, factor: Optional[float] = None,
delta: float = 0.04, delta: float = 0.04,
) -> xr.Dataset: ) -> xr.Dataset:
"""Warp a spectrogram.""" """Apply time warping by resampling the time axis.
Stretches or compresses the time axis of the spectrogram and all target
heatmaps by a random `factor` (chosen between `1-delta` and `1+delta`).
Uses linear interpolation for the spectrogram and nearest neighbor for the
discrete-like target heatmaps. Updates the 'time' coordinate accordingly.
The audio waveform is not modified.
Parameters
----------
example : xr.Dataset
Input training example with 'spectrogram', 'detection', 'class',
'size', and 'time' coordinates.
factor : float, optional
The warping factor. If None, chosen randomly between `1-delta` and
`1+delta`. Values > 1 stretch time, < 1 compress time.
delta : float, default=0.04
Controls the range `[1-delta, 1+delta]` for the random warp factor.
Must be >= 0 and < 1.
Returns
-------
xr.Dataset
A new dataset with time-warped spectrogram and target heatmaps, and
updated 'time' coordinates.
"""
if factor is None: if factor is None:
factor = np.random.uniform(1 - delta, 1 + delta) factor = np.random.uniform(1 - delta, 1 + delta)
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore start_time, end_time = arrays.get_dim_range(
example, # type: ignore
"time",
)
duration = end_time - start_time duration = end_time - start_time
new_time = np.linspace( new_time = np.linspace(
@ -296,6 +506,39 @@ def mask_axis(
end: float, end: float,
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean, mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
) -> xr.DataArray: ) -> xr.DataArray:
"""Mask values along a specified dimension.
Sets values in the DataArray to `mask_value` where the coordinate along
`dim` falls within the range [`start`, `end`). Values outside this range
are kept. Used as a helper for time/frequency masking.
Parameters
----------
array : xr.DataArray
The input DataArray (e.g., spectrogram).
dim : str
The name of the dimension along which to mask
(e.g., "time", "frequency").
start : float
The starting coordinate value for the mask range.
end : float
The ending coordinate value for the mask range (exclusive, typically).
Values >= start and < end will be masked.
mask_value : float or Callable[[xr.DataArray], float], default=np.mean
The value to use for masking. Can be a fixed float (e.g., 0.0) or a
callable (like `np.mean`, `np.median`) that computes the value from
the input `array`.
Returns
-------
xr.DataArray
The DataArray with the specified range along `dim` masked.
Raises
------
ValueError
If `dim` is not found in the array's dimensions or coordinates.
"""
if dim not in array.dims: if dim not in array.dims:
raise ValueError(f"Axis {dim} not found in array") raise ValueError(f"Axis {dim} not found in array")
@ -308,7 +551,9 @@ def mask_axis(
return array.where(condition, other=mask_value) return array.where(condition, other=mask_value)
class TimeMaskAugmentationConfig(BaseAugmentationConfig): class TimeMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_time"] = "mask_time"
probability: float = 0.2
max_perc: float = 0.05 max_perc: float = 0.05
max_masks: int = 3 max_masks: int = 3
@ -318,9 +563,32 @@ def mask_time(
max_perc: float = 0.05, max_perc: float = 0.05,
max_mask: int = 3, max_mask: int = 3,
) -> xr.Dataset: ) -> xr.Dataset:
"""Mask a random section of the time axis.""" """Apply random time masking (SpecAugment) to the spectrogram.
Randomly selects a number of time intervals (up to `max_masks`) and masks
(sets to the mean value) the spectrogram within those intervals. The width
of each mask is chosen randomly up to `max_perc` of the total duration.
Only the 'spectrogram' variable is modified.
Parameters
----------
example : xr.Dataset
Input training example containing 'spectrogram' and 'time' coordinate.
max_perc : float, default=0.05
Maximum width of a single mask as a fraction of total duration.
max_masks : int, default=3
Maximum number of time masks to apply.
Returns
-------
xr.Dataset
Dataset with time masking applied to the 'spectrogram'.
"""
num_masks = np.random.randint(1, max_mask + 1) num_masks = np.random.randint(1, max_mask + 1)
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore start_time, end_time = arrays.get_dim_range(
example, # type: ignore
"time",
)
spectrogram = example["spectrogram"] spectrogram = example["spectrogram"]
for _ in range(num_masks): for _ in range(num_masks):
@ -332,7 +600,9 @@ def mask_time(
return example.assign(spectrogram=spectrogram) return example.assign(spectrogram=spectrogram)
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig): class FrequencyMaskAugmentationConfig(BaseConfig):
augmentation_type: Literal["mask_freq"] = "mask_freq"
probability: float = 0.2
max_perc: float = 0.10 max_perc: float = 0.10
max_masks: int = 3 max_masks: int = 3
@ -342,9 +612,38 @@ def mask_frequency(
max_perc: float = 0.10, max_perc: float = 0.10,
max_masks: int = 3, max_masks: int = 3,
) -> xr.Dataset: ) -> xr.Dataset:
"""Mask a random section of the frequency axis.""" """Apply random frequency masking (SpecAugment) to the spectrogram.
Randomly selects a number of frequency intervals (up to `max_masks`) and
masks (sets to the mean value) the spectrogram within those intervals. The
height of each mask is chosen randomly up to `max_perc` of the total
frequency range. Only the 'spectrogram' variable is modified.
Parameters
----------
example : xr.Dataset
Input training example containing 'spectrogram' and 'frequency'
coordinate.
max_perc : float, default=0.10
Maximum height of a single mask as a fraction of total frequency range.
max_masks : int, default=3
Maximum number of frequency masks to apply.
Returns
-------
xr.Dataset
Dataset with frequency masking applied to the 'spectrogram'.
Raises
------
ValueError
If frequency coordinate info is missing or invalid.
"""
num_masks = np.random.randint(1, max_masks + 1) num_masks = np.random.randint(1, max_masks + 1)
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore min_freq, max_freq = arrays.get_dim_range(
example, # type: ignore
"frequency",
)
spectrogram = example["spectrogram"] spectrogram = example["spectrogram"]
for _ in range(num_masks): for _ in range(num_masks):
@ -356,88 +655,326 @@ def mask_frequency(
return example.assign(spectrogram=spectrogram) return example.assign(spectrogram=spectrogram)
AugmentationConfig = Annotated[
Union[
MixAugmentationConfig,
EchoAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
],
Field(discriminator="augmentation_type"),
]
"""Type alias for the discriminated union of individual augmentation config."""
class AugmentationsConfig(BaseConfig): class AugmentationsConfig(BaseConfig):
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig) """Configuration for a sequence of data augmentations.
echo: EchoAugmentationConfig = Field(
default_factory=EchoAugmentationConfig Attributes
) ----------
volume: VolumeAugmentationConfig = Field( steps : List[AugmentationConfig]
default_factory=VolumeAugmentationConfig An ordered list of configuration objects, each defining a single
) augmentation step (e.g., MixAugmentationConfig,
warp: WarpAugmentationConfig = Field( TimeMaskAugmentationConfig). Each step's configuration must include an
default_factory=WarpAugmentationConfig `augmentation_type` field and a `probability` field, along with
) type-specific parameters. The augmentations will be applied
time_mask: TimeMaskAugmentationConfig = Field( (probabilistically) in the sequence defined by this list.
default_factory=TimeMaskAugmentationConfig """
)
frequency_mask: FrequencyMaskAugmentationConfig = Field( steps: List[AugmentationConfig] = Field(default_factory=list)
default_factory=FrequencyMaskAugmentationConfig
class MaybeApply:
"""Applies an augmentation function probabilistically."""
def __init__(
self,
augmentation: Augmentation,
probability: float = 0.2,
):
"""Initialize the wrapper.
Parameters
----------
augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset])
The augmentation function to potentially apply.
probability : float, default=0.5
The probability (0.0 to 1.0) of applying the augmentation.
"""
self.augmentation = augmentation
self.probability = probability
def __call__(self, example: xr.Dataset) -> xr.Dataset:
"""Apply the wrapped augmentation with configured probability.
Parameters
----------
example : xr.Dataset
The input training example.
Returns
-------
xr.Dataset
The potentially augmented training example.
"""
if np.random.random() > self.probability:
return example
return self.augmentation(example)
class AudioMixer:
"""Callable class for MixUp augmentation, handling example fetching.
Wraps the `mix_examples` logic, providing the necessary `example_source`
to fetch a second example dynamically when called.
Parameters
----------
min_weight : float
Minimum mixing weight (lambda) for the primary example.
max_weight : float
Maximum mixing weight (lambda) for the primary example.
example_source : ExampleSource (Callable[[], xr.Dataset])
A function that, when called, returns another random training example
dataset (`xr.Dataset`) to be mixed with the input example.
preprocessor : PreprocessorProtocol
The preprocessor needed to recompute the spectrogram after mixing
audio.
"""
def __init__(
self,
min_weight: float,
max_weight: float,
example_source: ExampleSource,
preprocessor: PreprocessorProtocol,
):
"""Initialize the AudioMixer."""
self.min_weight = min_weight
self.example_source = example_source
self.max_weight = max_weight
self.preprocessor = preprocessor
def __call__(self, example: xr.Dataset) -> xr.Dataset:
"""Fetch another example and perform mixup.
Parameters
----------
example : xr.Dataset
The primary input training example.
Returns
-------
xr.Dataset
The mixed training example. Returns the original example if
fetching the second example fails.
"""
other = self.example_source()
return mix_examples(
example,
other,
self.preprocessor,
min_weight=self.min_weight,
max_weight=self.max_weight,
)
def build_augmentation_from_config(
config: AugmentationConfig,
preprocessor: PreprocessorProtocol,
example_source: Optional[ExampleSource] = None,
) -> Augmentation:
"""Factory function to build a single augmentation from its config.
Takes a configuration object for one augmentation step (which includes the
`augmentation_type` discriminator) and returns the corresponding functional
augmentation callable (e.g., a `partial` function or a callable class like
`AudioMixer`).
Parameters
----------
config : AugmentationConfig
Configuration object for a single augmentation (e.g., instance of
`MixAugmentationConfig`, `EchoAugmentationConfig`, etc.).
preprocessor : PreprocessorProtocol
The preprocessor object, required by augmentations that modify audio
and need to recompute the spectrogram (e.g., mixup, echo).
example_source : ExampleSource, optional
A callable that provides other training examples. Required only if
the configuration includes `MixAugmentationConfig` (`augmentation_type`
is "mix_audio").
Returns
-------
Augmentation (Callable[[xr.Dataset], xr.Dataset])
A callable function that takes a training example (`xr.Dataset`) and
returns a potentially augmented version.
Raises
------
ValueError
If `config.augmentation_type` is "mix_audio" but `example_source` is
None.
NotImplementedError
If `config.augmentation_type` does not match any known augmentation
type.
"""
if config.augmentation_type == "mix_audio":
if example_source is None:
raise ValueError(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided."
)
return AudioMixer(
example_source=example_source,
preprocessor=preprocessor,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.augmentation_type == "add_echo":
return partial(
add_echo,
preprocessor=preprocessor,
max_delay=config.max_delay,
min_weight=config.min_weight,
max_weight=config.max_weight,
)
if config.augmentation_type == "scale_volume":
return partial(
scale_volume,
max_scaling=config.max_scaling,
min_scaling=config.min_scaling,
)
if config.augmentation_type == "warp":
return partial(
warp_spectrogram,
delta=config.delta,
)
if config.augmentation_type == "mask_time":
return partial(
mask_time,
max_perc=config.max_perc,
max_mask=config.max_masks,
)
if config.augmentation_type == "mask_freq":
return partial(
mask_frequency,
max_perc=config.max_perc,
max_masks=config.max_masks,
)
raise NotImplementedError(
"Invalid or unimplemented augmentation type: "
f"{config.augmentation_type}"
) )
def load_agumentation_config( def build_augmentations(
config: AugmentationsConfig,
preprocessor: PreprocessorProtocol,
example_source: Optional[ExampleSource] = None,
) -> Augmentation:
"""Build a composite augmentation pipeline function from configuration.
Takes the overall `AugmentationsConfig` (containing a list of individual
augmentation steps), builds the callable function for each step using
`build_augmentation_from_config`, wraps each function with `MaybeApply`
to handle its application probability, and returns a single
`Augmentation` function that applies the entire sequence.
Parameters
----------
config : AugmentationsConfig
The configuration object detailing the sequence of augmentation steps.
preprocessor : PreprocessorProtocol
The preprocessor object, needed for audio-modifying augmentations.
example_source : ExampleSource, optional
A callable providing other examples, required if 'mix_audio' is used.
Returns
-------
Augmentation (Callable[[xr.Dataset], xr.Dataset])
A single callable function that takes a training example (`xr.Dataset`)
and applies the configured sequence of augmentations probabilistically,
returning the augmented example. Returns the original example if
`config.steps` is empty.
Raises
------
ValueError
If 'mix_audio' is configured but `example_source` is not provided.
NotImplementedError
If an unknown `augmentation_type` is encountered in `config.steps`.
"""
augmentations = []
for step_config in config.steps:
augmentation = build_augmentation_from_config(
step_config,
preprocessor=preprocessor,
example_source=example_source,
)
augmentations.append(
MaybeApply(
augmentation=augmentation,
probability=step_config.probability,
)
)
return partial(_apply_augmentations, augmentations=augmentations)
def load_augmentation_config(
path: data.PathLike, field: Optional[str] = None path: data.PathLike, field: Optional[str] = None
) -> AugmentationsConfig: ) -> AugmentationsConfig:
"""Load the augmentations configuration from a file.
Reads a configuration file (YAML) and validates it against the
`AugmentationsConfig` schema, potentially extracting data from a nested
field.
Parameters
----------
path : PathLike
Path to the configuration file.
field : str, optional
Dot-separated path to a nested section within the file containing the
augmentations configuration (e.g., "training.augmentations"). If None,
the entire file content is used.
Returns
-------
AugmentationsConfig
The loaded and validated augmentations configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
yaml.YAMLError
If the file content is not valid YAML.
pydantic.ValidationError
If the loaded config data does not conform to `AugmentationsConfig`.
KeyError, TypeError
If `field` specifies an invalid path.
"""
return load_config(path, schema=AugmentationsConfig, field=field) return load_config(path, schema=AugmentationsConfig, field=field)
def should_apply(config: BaseAugmentationConfig) -> bool: def _apply_augmentations(
if not config.enable:
return False
return np.random.uniform() < config.probability
def augment_example(
example: xr.Dataset, example: xr.Dataset,
config: AugmentationsConfig, augmentations: List[Augmentation],
preprocessing_config: Optional[PreprocessingConfig] = None, ):
others: Optional[Callable[[], xr.Dataset]] = None, """Apply a sequence of augmentation functions to an example."""
) -> xr.Dataset: for augmentation in augmentations:
if should_apply(config.mix) and (others is not None): example = augmentation(example)
other = others()
example = mix_examples(
example,
other,
min_weight=config.mix.min_weight,
max_weight=config.mix.max_weight,
config=preprocessing_config,
)
if should_apply(config.echo):
example = add_echo(
example,
max_delay=config.echo.max_delay,
min_weight=config.echo.min_weight,
max_weight=config.echo.max_weight,
config=preprocessing_config,
)
if should_apply(config.volume):
example = scale_volume(
example,
max_scaling=config.volume.max_scaling,
min_scaling=config.volume.min_scaling,
)
if should_apply(config.warp):
example = warp_spectrogram(
example,
delta=config.warp.delta,
)
if should_apply(config.time_mask):
example = mask_time(
example,
max_perc=config.time_mask.max_perc,
max_mask=config.time_mask.max_masks,
)
if should_apply(config.frequency_mask):
example = mask_frequency(
example,
max_perc=config.frequency_mask.max_perc,
max_masks=config.frequency_mask.max_masks,
)
return example return example

View File

@ -3,14 +3,15 @@ from lightning.pytorch.callbacks import Callback
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate import match_predictions_and_annotations from batdetect2.evaluate import match_predictions_and_annotations
from batdetect2.post_process import postprocess_model_outputs from batdetect2.postprocess import PostprocessorProtocol
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.types import ModelOutput from batdetect2.types import ModelOutput
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self): def __init__(self, postprocessor: PostprocessorProtocol):
super().__init__() super().__init__()
self.postprocessor = postprocessor
self.predictions = [] self.predictions = []
def on_validation_epoch_start( def on_validation_epoch_start(
@ -36,20 +37,20 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, LabeledDataset) assert isinstance(dataset, LabeledDataset)
clip_annotation = dataset.get_clip_annotation(batch_idx) clip_annotation = dataset.get_clip_annotation(batch_idx)
clip_prediction = postprocess_model_outputs( # clip_prediction = postprocess_model_outputs(
outputs, # outputs,
clips=[clip_annotation.clip], # clips=[clip_annotation.clip],
classes=self.class_names, # classes=self.class_names,
decoder=self.decoder, # decoder=self.decoder,
config=self.config.postprocessing, # config=self.config.postprocessing,
)[0] # )[0]
#
matches = match_predictions_and_annotations( # matches = match_predictions_and_annotations(
clip_annotation, # clip_annotation,
clip_prediction, # clip_prediction,
) # )
#
self.validation_predictions.extend(matches) # self.validation_predictions.extend(matches)
return super().on_validation_batch_end( # return super().on_validation_batch_end(
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx # trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
) # )

View File

@ -10,13 +10,13 @@ from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.preprocess.tensors import adjust_width
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
AugmentationsConfig, AugmentationsConfig,
augment_example, augment_example,
select_subclip, select_subclip,
) )
from batdetect2.train.preprocess import PreprocessingConfig from batdetect2.train.preprocess import PreprocessorProtocol
from batdetect2.utils.tensors import adjust_width
__all__ = [ __all__ = [
"TrainExample", "TrainExample",
@ -51,15 +51,15 @@ class DatasetConfig(BaseConfig):
class LabeledDataset(Dataset): class LabeledDataset(Dataset):
def __init__( def __init__(
self, self,
preprocessor: PreprocessorProtocol,
filenames: Sequence[PathLike], filenames: Sequence[PathLike],
subclip: Optional[SubclipConfig] = None, subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None, augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
): ):
self.preprocessor = preprocessor
self.filenames = filenames self.filenames = filenames
self.subclip = subclip self.subclip = subclip
self.augmentation = augmentation self.augmentation = augmentation
self.preprocessing = preprocessing or PreprocessingConfig()
def __len__(self): def __len__(self):
return len(self.filenames) return len(self.filenames)
@ -79,7 +79,7 @@ class LabeledDataset(Dataset):
dataset = augment_example( dataset = augment_example(
dataset, dataset,
self.augmentation, self.augmentation,
preprocessing_config=self.preprocessing, preprocessor=self.preprocessor,
others=self.get_random_example, others=self.get_random_example,
) )
@ -95,16 +95,16 @@ class LabeledDataset(Dataset):
def from_directory( def from_directory(
cls, cls,
directory: PathLike, directory: PathLike,
preprocessor: PreprocessorProtocol,
extension: str = ".nc", extension: str = ".nc",
subclip: Optional[SubclipConfig] = None, subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None, augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
): ):
return cls( return cls(
get_preprocessed_files(directory, extension), preprocessor=preprocessor,
filenames=get_preprocessed_files(directory, extension),
subclip=subclip, subclip=subclip,
augmentation=augmentation, augmentation=augmentation,
preprocessing=preprocessing,
) )
def get_random_example(self) -> xr.Dataset: def get_random_example(self) -> xr.Dataset:

View File

@ -1,48 +1,47 @@
"""Generate heatmap training targets for BatDetect2 models. """Generate heatmap training targets for BatDetect2 models.
This module represents the final step in the `batdetect2.targets` pipeline, This module is responsible for creating the target labels used for training
converting processed sound event annotations from an audio clip into the BatDetect2 models. It converts sound event annotations for an audio clip into
specific heatmap formats required for training the BatDetect2 neural network. the specific multi-channel heatmap formats required by the neural network.
It integrates the filtering, transformation, and class encoding logic defined It uses a pre-configured object adhering to the `TargetProtocol` (from
in the preceding configuration steps (`filtering`, `transform`, `classes`) `batdetect2.targets`) which encapsulates all the logic for filtering
and applies them to generate three core outputs for a given spectrogram: annotations, transforming tags, encoding class names, and mapping annotation
geometry (ROIs) to target positions and sizes. This module then focuses on
rendering this information onto the heatmap grids.
1. **Detection Heatmap**: Indicates the presence and location of relevant The pipeline generates three core outputs for a given spectrogram:
sound events. 1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
2. **Class Heatmap**: Indicates the location and predicted class label for 2. **Class Heatmap**: Indicates location and class identity for specifically
events that match a specific target class. classified events.
3. **Size Heatmap**: Encodes the dimensions (width/time duration, 3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
height/frequency bandwidth) of the detected sound events at their
reference locations.
The primary function generated by this module is a `ClipLabeller`, which takes The primary function generated by this module is a `ClipLabeller` (defined in
a `ClipAnnotation` object and its corresponding spectrogram (`xr.DataArray`) `.types`), which takes a `ClipAnnotation` object and its corresponding
and returns the calculated `Heatmaps`. Configuration options allow tuning of spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
the heatmap generation process (e.g., Gaussian smoothing sigma, reference point parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
within bounding boxes). defined in `LabelConfig`.
""" """
import logging import logging
from collections.abc import Iterable from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Callable, List, NamedTuple, Optional from typing import Optional
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from scipy.ndimage import gaussian_filter from scipy.ndimage import gaussian_filter
from soundevent import arrays, data, geometry from soundevent import arrays, data
from soundevent.geometry.operations import Positions
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.classes import SoundEventEncoder from batdetect2.targets.types import TargetProtocol
from batdetect2.targets.filtering import SoundEventFilter from batdetect2.train.types import (
from batdetect2.targets.transform import SoundEventTransformation ClipLabeller,
Heatmaps,
)
__all__ = [ __all__ = [
"LabelConfig", "LabelConfig",
"Heatmaps",
"ClipLabeller",
"build_clip_labeler", "build_clip_labeler",
"generate_clip_label", "generate_clip_label",
"generate_heatmaps", "generate_heatmaps",
@ -50,109 +49,45 @@ __all__ = [
] ]
SIZE_DIMENSION = "dimension"
"""Dimension name for the size heatmap."""
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class Heatmaps(NamedTuple):
"""Structure holding the generated heatmap targets.
Attributes
----------
detection : xr.DataArray
Heatmap indicating the probability of sound event presence. Typically
smoothed with a Gaussian kernel centered on event reference points.
Shape matches the input spectrogram. Values normalized [0, 1].
classes : xr.DataArray
Heatmap indicating the probability of specific class presence. Has an
additional 'category' dimension corresponding to the target class
names. Each category slice is typically smoothed with a Gaussian
kernel. Values normalized [0, 1] per category.
size : xr.DataArray
Heatmap encoding the size (width, height) of detected events. Has an
additional 'dimension' coordinate ('width', 'height'). Values represent
scaled dimensions placed at the event reference points.
"""
detection: xr.DataArray
classes: xr.DataArray
size: xr.DataArray
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
"""Type alias for the final clip labelling function.
This function takes the complete annotations for a clip and the corresponding
spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training.
"""
class LabelConfig(BaseConfig): class LabelConfig(BaseConfig):
"""Configuration parameters for heatmap generation. """Configuration parameters for heatmap generation.
Attributes Attributes
---------- ----------
position : Positions, default="bottom-left"
Specifies the reference point within each sound event's geometry
(bounding box) that is used to place the 'peak' or value on the
heatmaps. Options include 'center', 'bottom-left', 'top-right', etc.
See `soundevent.geometry.operations.Positions` for valid options.
sigma : float, default=3.0 sigma : float, default=3.0
The standard deviation (in pixels/bins) of the Gaussian kernel applied The standard deviation (in pixels/bins) of the Gaussian kernel applied
to smooth the detection and class heatmaps. Larger values create more to smooth the detection and class heatmaps. Larger values create more
diffuse targets. diffuse targets.
time_scale : float, default=1000.0
A scaling factor applied to the time duration (width) of sound event
bounding boxes before storing them in the 'width' dimension of the
`size` heatmap. The appropriate value depends on how the model input/
output resolution relates to physical time units (e.g., if time axis
is milliseconds, scale might be 1.0; if seconds, maybe 1000.0). Needs
to match model expectations.
frequency_scale : float, default=1/859.375
A scaling factor applied to the frequency bandwidth (height) of sound
event bounding boxes before storing them in the 'height' dimension of
the `size` heatmap. The appropriate value depends on the relationship
between the frequency axis resolution (e.g., kHz or Hz per bin) and
the desired output units/scale for the model. Needs to match model
expectations. (The default suggests input frequency might be in Hz
and output is scaled relative to some reference).
""" """
position: Positions = "bottom-left"
sigma: float = 3.0 sigma: float = 3.0
time_scale: float = 1000.0
frequency_scale: float = 1 / 859.375
def build_clip_labeler( def build_clip_labeler(
filter_fn: SoundEventFilter, targets: TargetProtocol,
transform_fn: SoundEventTransformation,
encoder_fn: SoundEventEncoder,
class_names: List[str],
config: LabelConfig, config: LabelConfig,
) -> ClipLabeller: ) -> ClipLabeller:
"""Construct the clip labelling function. """Construct the final clip labelling function.
This function takes the pre-built components from the previous target This factory function prepares the callable that will perform the
definition steps (filtering, transformation, encoding) and the label end-to-end heatmap generation for a given clip and spectrogram during
configuration, then returns a single callable (`ClipLabeller`) that training data loading. It takes the fully configured `targets` object and
performs the end-to-end heatmap generation for a given clip and the `LabelConfig` and binds them to the `generate_clip_label` function.
spectrogram.
Parameters Parameters
---------- ----------
filter_fn : SoundEventFilter targets : TargetProtocol
Function to filter irrelevant sound event annotations. An initialized object conforming to the `TargetProtocol`, providing all
transform_fn : SoundEventTransformation necessary methods for filtering, transforming, encoding, and ROI
Function to transform tags of sound event annotations. mapping.
encoder_fn : SoundEventEncoder
Function to encode a sound event annotation into a class name.
class_names : List[str]
Ordered list of unique target class names for the classification
heatmap.
config : LabelConfig config : LabelConfig
Configuration object containing heatmap generation parameters (sigma, Configuration object containing heatmap generation parameters.
etc.).
Returns Returns
------- -------
@ -162,10 +97,7 @@ def build_clip_labeler(
""" """
return partial( return partial(
generate_clip_label, generate_clip_label,
filter_fn=filter_fn, targets=targets,
transform_fn=transform_fn,
encoder_fn=encoder_fn,
class_names=class_names,
config=config, config=config,
) )
@ -173,123 +105,97 @@ def build_clip_labeler(
def generate_clip_label( def generate_clip_label(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
spec: xr.DataArray, spec: xr.DataArray,
filter_fn: SoundEventFilter, targets: TargetProtocol,
transform_fn: SoundEventTransformation,
encoder_fn: SoundEventEncoder,
class_names: List[str],
config: LabelConfig, config: LabelConfig,
) -> Heatmaps: ) -> Heatmaps:
"""Generate heatmaps for a single clip by applying all processing steps. """Generate training heatmaps for a single annotated clip.
This function orchestrates the process for one clip: This function orchestrates the target generation process for one clip:
1. Filters the sound events using `filter_fn`. 1. Filters and transforms sound events using `targets.filter` and
2. Transforms the tags of filtered events using `transform_fn`. `targets.transform`.
3. Passes the processed annotations and other parameters to 2. Passes the resulting processed annotations, along with the spectrogram,
`generate_heatmaps` to create the final target heatmaps. the `targets` object, and the Gaussian `sigma` from `config`, to the
core `generate_heatmaps` function.
Parameters Parameters
---------- ----------
clip_annotation : data.ClipAnnotation clip_annotation : data.ClipAnnotation
The complete annotation data for the audio clip. The complete annotation data for the audio clip, including the list
of `sound_events` to process.
spec : xr.DataArray spec : xr.DataArray
The spectrogram corresponding to the `clip_annotation`. The spectrogram corresponding to the `clip_annotation`. Must have
filter_fn : SoundEventFilter 'time' and 'frequency' dimensions/coordinates.
Function to filter sound event annotations. targets : TargetProtocol
transform_fn : SoundEventTransformation The fully configured target definition object, providing methods for
Function to transform tags of sound event annotations. filtering, transformation, encoding, and ROI mapping.
encoder_fn : SoundEventEncoder
Function to encode a sound event annotation into a class name.
class_names : List[str]
Ordered list of unique target class names.
config : LabelConfig config : LabelConfig
Configuration object containing heatmap generation parameters. Configuration object providing heatmap parameters (primarily `sigma`).
Returns Returns
------- -------
Heatmaps Heatmaps
The generated detection, classes, and size heatmaps for the clip. A NamedTuple containing the generated 'detection', 'classes', and 'size'
heatmaps for this clip.
""" """
return generate_heatmaps( return generate_heatmaps(
( (
transform_fn(sound_event_annotation) targets.transform(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events for sound_event_annotation in clip_annotation.sound_events
if filter_fn(sound_event_annotation) if targets.filter(sound_event_annotation)
), ),
spec=spec, spec=spec,
class_names=class_names, targets=targets,
encoder=encoder_fn,
target_sigma=config.sigma, target_sigma=config.sigma,
position=config.position,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
) )
def generate_heatmaps( def generate_heatmaps(
sound_events: Iterable[data.SoundEventAnnotation], sound_events: Iterable[data.SoundEventAnnotation],
spec: xr.DataArray, spec: xr.DataArray,
class_names: List[str], targets: TargetProtocol,
encoder: SoundEventEncoder,
target_sigma: float = 3.0, target_sigma: float = 3.0,
position: Positions = "bottom-left",
time_scale: float = 1000.0,
frequency_scale: float = 1 / 859.375,
dtype=np.float32, dtype=np.float32,
) -> Heatmaps: ) -> Heatmaps:
"""Generate detection, class, and size heatmaps from sound events. """Generate detection, class, and size heatmaps from sound events.
Processes an iterable of sound event annotations (assumed to be already Creates heatmap representations from an iterable of sound event
filtered and transformed) and creates heatmap representations suitable annotations. This function relies on the provided `targets` object to get
for training models like BatDetect2. the reference position, scaled size, and class encoding for each
annotation.
The process involves: Process:
1. Initializing empty heatmaps based on the spectrogram shape. 1. Initializes empty heatmap arrays based on `spec` shape and `targets`
2. Iterating through each sound event. metadata.
3. For each event, finding its reference point and placing a '1.0' 2. Iterates through `sound_events`.
on the detection heatmap at that point. 3. For each event:
4. Calculating the scaled bounding box size and placing it on the size a. Gets geometry. Skips if missing.
heatmap at the reference point. b. Gets reference position using `targets.get_position()`. Skips if out
5. Encoding the event to get its class name and placing a '1.0' on the of bounds.
corresponding class heatmap slice at the reference point (if c. Places a peak (1.0) on the detection heatmap at the position.
classified). d. Gets scaled size using `targets.get_size()` and places it on the
6. Applying Gaussian smoothing to detection and class heatmaps. size heatmap.
7. Normalizing detection and class heatmaps to the range [0, 1]. e. Encodes class using `targets.encode()` and places a peak (1.0) on
the corresponding class heatmap layer if a specific class is
returned.
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
heatmaps.
5. Normalizes detection and class heatmaps to range [0, 1].
Parameters Parameters
---------- ----------
sound_events : Iterable[data.SoundEventAnnotation] sound_events : Iterable[data.SoundEventAnnotation]
An iterable of sound event annotations to include in the heatmaps. An iterable of sound event annotations to render onto the heatmaps.
These should ideally be the result of prior filtering and tag
transformation steps.
spec : xr.DataArray spec : xr.DataArray
The spectrogram array corresponding to the time/frequency range of The spectrogram array corresponding to the time/frequency range of
the annotations. Used for shape and coordinate information. Must have the annotations. Used for shape and coordinate information. Must have
'time' and 'frequency' dimensions. 'time' and 'frequency' dimensions/coordinates.
class_names : List[str] targets : TargetProtocol
An ordered list of unique class names. The class heatmap will have The fully configured target definition object. Used to access class
a channel ('category' dimension) for each name in this list. Must not names, dimension names, and the methods `get_position`, `get_size`,
be empty. `encode`.
encoder : SoundEventEncoder
A function that takes a SoundEventAnnotation and returns the
corresponding class name (str) or None if it doesn't belong to a
specific class (e.g., it falls into the generic 'Bat' category).
target_sigma : float, default=3.0 target_sigma : float, default=3.0
Standard deviation (in pixels/bins) of the Gaussian kernel applied to Standard deviation (in pixels/bins) of the Gaussian kernel applied to
smooth the detection and class heatmaps after initial point placement. smooth the detection and class heatmaps.
position : Positions, default="bottom-left"
The reference point within each annotation's geometry bounding box
used to place the signal on the heatmaps (e.g., "center",
"bottom-left"). See `soundevent.geometry.operations.Positions`.
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width in seconds) of
annotations when storing them in the size heatmap. The resulting
value's unit depends on this scale (e.g., 1000.0 might convert seconds
to ms).
frequency_scale : float, default=1/859.375
Scaling factor applied to the frequency bandwidth (height in Hz or kHz)
of annotations when storing them in the size heatmap. The resulting
value's unit depends on this scale and the input unit. (Default
scaling relative to ~860 Hz).
dtype : type, default=np.float32 dtype : type, default=np.float32
The data type for the generated heatmap arrays (e.g., `np.float32`). The data type for the generated heatmap arrays (e.g., `np.float32`).
@ -303,28 +209,10 @@ def generate_heatmaps(
------ ------
ValueError ValueError
If the input spectrogram `spec` does not have both 'time' and If the input spectrogram `spec` does not have both 'time' and
'frequency' dimensions, or if `class_names` is empty. 'frequency' dimensions, or if `targets.class_names` is empty.
Notes
-----
* This function expects `sound_events` to be already filtered and
transformed.
* It includes error handling to skip individual annotations that cause
issues (e.g., missing geometry, out-of-bounds coordinates, encoder
errors) allowing the rest of the clip to be processed. Warnings or
errors are logged.
* The `time_scale` and `frequency_scale` parameters are crucial and must be
set according to the expectations of the specific BatDetect2 model
architecture being trained. Consult model documentation for required
units/scales.
* Gaussian filtering and normalization are applied only to detection and
class heatmaps, not the size heatmap.
""" """
shape = dict(zip(spec.dims, spec.shape)) shape = dict(zip(spec.dims, spec.shape))
if len(class_names) == 0:
raise ValueError("No class names provided.")
if "time" not in shape or "frequency" not in shape: if "time" not in shape or "frequency" not in shape:
raise ValueError( raise ValueError(
"Spectrogram must have time and frequency dimensions." "Spectrogram must have time and frequency dimensions."
@ -333,18 +221,18 @@ def generate_heatmaps(
# Initialize heatmaps # Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype) detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray( class_heatmap = xr.DataArray(
data=np.zeros((len(class_names), *spec.shape), dtype=dtype), data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
dims=["category", *spec.dims], dims=["category", *spec.dims],
coords={ coords={
"category": [*class_names], "category": [*targets.class_names],
**spec.coords, **spec.coords,
}, },
) )
size_heatmap = xr.DataArray( size_heatmap = xr.DataArray(
data=np.zeros((2, *spec.shape), dtype=dtype), data=np.zeros((2, *spec.shape), dtype=dtype),
dims=["dimension", *spec.dims], dims=[SIZE_DIMENSION, *spec.dims],
coords={ coords={
"dimension": ["width", "height"], SIZE_DIMENSION: targets.dimension_names,
**spec.coords, **spec.coords,
}, },
) )
@ -359,7 +247,7 @@ def generate_heatmaps(
continue continue
# Get the position of the sound event # Get the position of the sound event
time, frequency = geometry.get_geometry_point(geom, position=position) time, frequency = targets.get_position(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap # Set 1.0 at the position of the sound event in the detection heatmap
try: try:
@ -379,17 +267,7 @@ def generate_heatmaps(
) )
continue continue
# Set the size of the sound event at the position in the size heatmap size = targets.get_size(sound_event_annotation)
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
geom
)
size = np.array(
[
(end_time - start_time) * time_scale,
(high_freq - low_freq) * frequency_scale,
]
)
size_heatmap = arrays.set_value_at_pos( size_heatmap = arrays.set_value_at_pos(
size_heatmap, size_heatmap,
@ -400,7 +278,7 @@ def generate_heatmaps(
# Get the class name of the sound event # Get the class name of the sound event
try: try:
class_name = encoder(sound_event_annotation) class_name = targets.encode(sound_event_annotation)
except ValueError as e: except ValueError as e:
logger.warning( logger.warning(
"Skipping annotation %s: Unexpected error while encoding " "Skipping annotation %s: Unexpected error while encoding "
@ -414,19 +292,6 @@ def generate_heatmaps(
# If the label is None skip the sound event # If the label is None skip the sound event
continue continue
if class_name not in class_names:
# If the label is not in the class names skip the sound event
logger.warning(
(
"Skipping annotation %s for class heatmap: "
"class name '%s' not in class names. Class names: %s"
),
sound_event_annotation.uuid,
class_name,
class_names,
)
continue
try: try:
class_heatmap = arrays.set_value_at_pos( class_heatmap = arrays.set_value_at_pos(
class_heatmap, class_heatmap,
@ -453,7 +318,7 @@ def generate_heatmaps(
) )
class_heatmap = class_heatmap.groupby("category").map( class_heatmap = class_heatmap.groupby("category").map(
gaussian_filter, gaussian_filter, # type: ignore
args=(target_sigma,), args=(target_sigma,),
) )

View File

@ -1,98 +1,101 @@
"""Module for preprocessing data for training.""" """Preprocesses datasets for BatDetect2 model training.
This module provides functions to take a collection of annotated audio clips
(`soundevent.data.ClipAnnotation`) and process them into the final format
required for training a BatDetect2 model. This typically involves:
1. Loading the relevant audio segment for each annotation using a configured
`PreprocessorProtocol`.
2. Generating the corresponding input spectrogram using the
`PreprocessorProtocol`.
3. Generating the target heatmaps (detection, classification, size) using a
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
4. Packaging the input spectrogram, target heatmaps, and potentially the
processed audio waveform into an `xarray.Dataset`.
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
in an output directory.
This offline preprocessing is often preferred for large datasets as it avoids
computationally intensive steps during the actual training loop. The module
includes utilities for parallel processing using `multiprocessing`.
"""
import os
from functools import partial from functools import partial
from multiprocessing import Pool from multiprocessing import Pool
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Sequence, Union from typing import Callable, Optional, Sequence
import xarray as xr import xarray as xr
from pydantic import Field
from soundevent import data from soundevent import data
from tqdm.auto import tqdm from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.preprocess import ( from batdetect2.train.types import ClipLabeller
PreprocessingConfig,
compute_spectrogram,
load_clip_audio,
)
from batdetect2.targets import (
LabelConfig,
TargetConfig,
build_sound_event_filter,
build_target_encoder,
generate_heatmaps,
get_class_names,
)
PathLike = Union[Path, str, os.PathLike]
FilenameFn = Callable[[data.ClipAnnotation], str]
__all__ = [ __all__ = [
"preprocess_annotations", "preprocess_annotations",
"preprocess_single_annotation", "preprocess_single_annotation",
"generate_train_example", "generate_train_example",
"TrainPreprocessingConfig",
] ]
FilenameFn = Callable[[data.ClipAnnotation], str]
class TrainPreprocessingConfig(BaseConfig): """Type alias for a function that generates an output filename."""
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
target: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def generate_train_example( def generate_train_example(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
preprocessing_config: Optional[PreprocessingConfig] = None, preprocessor: PreprocessorProtocol,
target_config: Optional[TargetConfig] = None, labeller: ClipLabeller,
label_config: Optional[LabelConfig] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Generate a training example.""" """Generate a complete training example for one annotation.
config = TrainPreprocessingConfig(
preprocessing=preprocessing_config or PreprocessingConfig(),
target=target_config or TargetConfig(),
labels=label_config or LabelConfig(),
)
wave = load_clip_audio( This function takes a single `ClipAnnotation`, applies the configured
clip_annotation.clip, preprocessing (`PreprocessorProtocol`) to get the processed waveform and
config=config.preprocessing.audio, input spectrogram, applies the configured target generation
) (`ClipLabeller`) to get the target heatmaps, and packages them all into a
single `xr.Dataset`.
spectrogram = compute_spectrogram( Parameters
wave, ----------
config=config.preprocessing.spectrogram, clip_annotation : data.ClipAnnotation
) The annotated clip to process. Contains the reference to the `Clip`
(audio segment) and the associated `SoundEventAnnotation` objects.
preprocessor : PreprocessorProtocol
An initialized preprocessor object responsible for loading/processing
audio and computing the input spectrogram.
labeller : ClipLabeller
An initialized clip labeller function responsible for generating the
target heatmaps (detection, class, size) from the `clip_annotation`
and the computed spectrogram.
filter_fn = build_sound_event_filter( Returns
include=config.target.include, -------
exclude=config.target.exclude, xr.Dataset
) An xarray Dataset containing the following data variables:
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
- `spectrogram`: The computed input spectrogram
(dims: 'time', 'frequency').
- `detection`: The target detection heatmap
(dims: 'time', 'frequency').
- `class`: The target class heatmap
(dims: 'category', 'time', 'frequency').
- `size`: The target size heatmap
(dims: 'dimension', 'time', 'frequency').
The Dataset also includes metadata in its attributes.
selected_events = [ Notes
event for event in clip_annotation.sound_events if filter_fn(event) -----
] - The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
within the output Dataset to avoid coordinate conflicts with the
spectrogram's 'time' dimension when stored together.
- The original `ClipAnnotation` metadata is stored as a JSON string in the
Dataset's attributes for provenance.
"""
wave = preprocessor.load_clip_audio(clip_annotation.clip)
encoder = build_target_encoder( spectrogram = preprocessor.compute_spectrogram(wave)
config.target.classes,
replacement_rules=config.target.replace,
)
class_names = get_class_names(config.target.classes)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps( heatmaps = labeller(clip_annotation, spectrogram)
selected_events,
spectrogram,
class_names,
encoder,
target_sigma=config.labels.heatmaps.sigma,
position=config.labels.heatmaps.position,
time_scale=config.labels.heatmaps.time_scale,
frequency_scale=config.labels.heatmaps.frequency_scale,
)
dataset = xr.Dataset( dataset = xr.Dataset(
{ {
@ -102,15 +105,14 @@ def generate_train_example(
# as the waveform. # as the waveform.
"audio": wave.rename({"time": "audio_time"}), "audio": wave.rename({"time": "audio_time"}),
"spectrogram": spectrogram, "spectrogram": spectrogram,
"detection": detection_heatmap, "detection": heatmaps.detection,
"class": class_heatmap, "class": heatmaps.classes,
"size": size_heatmap, "size": heatmaps.size,
} }
) )
return dataset.assign_attrs( return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}", title=f"Training example for {clip_annotation.uuid}",
config=config.model_dump_json(),
clip_annotation=clip_annotation.model_dump_json( clip_annotation=clip_annotation.model_dump_json(
exclude_none=True, exclude_none=True,
exclude_defaults=True, exclude_defaults=True,
@ -119,10 +121,21 @@ def generate_train_example(
) )
def save_to_file( def _save_xr_dataset_to_file(
dataset: xr.Dataset, dataset: xr.Dataset,
path: PathLike, path: data.PathLike,
) -> None: ) -> None:
"""Save an xarray Dataset to a NetCDF file with compression.
Internal helper function used by `preprocess_single_annotation`.
Parameters
----------
dataset : xr.Dataset
The training example dataset to save.
path : PathLike
The output file path (e.g., 'output/uuid.nc').
"""
dataset.to_netcdf( dataset.to_netcdf(
path, path,
encoding={ encoding={
@ -135,20 +148,60 @@ def save_to_file(
def _get_filename(clip_annotation: data.ClipAnnotation) -> str: def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}.nc" return f"{clip_annotation.uuid}.nc"
def preprocess_annotations( def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
output_dir: PathLike, output_dir: data.PathLike,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename, filename_fn: FilenameFn = _get_filename,
replace: bool = False, replace: bool = False,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
max_workers: Optional[int] = None, max_workers: Optional[int] = None,
) -> None: ) -> None:
"""Preprocess annotations and save to disk.""" """Preprocess a sequence of ClipAnnotations and save results to disk.
Generates the full training example (spectrogram, heatmaps, etc.) for each
`ClipAnnotation` in the input sequence using the provided `preprocessor`
and `labeller`. Saves each example as a separate NetCDF file in the
`output_dir`. Utilizes multiprocessing for potentially faster processing.
Parameters
----------
clip_annotations : Sequence[data.ClipAnnotation]
A sequence (e.g., list) of the clip annotations to preprocess.
output_dir : PathLike
Path to the directory where the processed NetCDF files will be saved.
Will be created if it doesn't exist.
preprocessor : PreprocessorProtocol
Initialized preprocessor object to generate spectrograms.
labeller : ClipLabeller
Initialized labeller function to generate target heatmaps.
filename_fn : FilenameFn, optional
Function to generate the output filename (without extension) for each
`ClipAnnotation`. Defaults to using the annotation UUID via
`_get_filename`.
replace : bool, default=False
If True, existing files in `output_dir` with the same generated name
will be overwritten. If False (default), existing files are skipped.
max_workers : int, optional
Maximum number of worker processes to use for parallel processing.
If None (default), uses the number of CPUs available (`os.cpu_count()`).
Returns
-------
None
This function does not return anything; its side effect is creating
files in the `output_dir`.
Raises
------
RuntimeError
If processing fails for any individual annotation when using
multiprocessing. The original exception will be attached as the cause.
"""
output_dir = Path(output_dir) output_dir = Path(output_dir)
if not output_dir.is_dir(): if not output_dir.is_dir():
@ -163,9 +216,8 @@ def preprocess_annotations(
output_dir=output_dir, output_dir=output_dir,
filename_fn=filename_fn, filename_fn=filename_fn,
replace=replace, replace=replace,
preprocessing_config=preprocessing_config, preprocessor=preprocessor,
target_config=target_config, labeller=labeller,
label_config=label_config,
), ),
clip_annotations, clip_annotations,
), ),
@ -176,13 +228,34 @@ def preprocess_annotations(
def preprocess_single_annotation( def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
output_dir: PathLike, output_dir: data.PathLike,
preprocessing_config: Optional[PreprocessingConfig] = None, preprocessor: PreprocessorProtocol,
target_config: Optional[TargetConfig] = None, labeller: ClipLabeller,
label_config: Optional[LabelConfig] = None,
filename_fn: FilenameFn = _get_filename, filename_fn: FilenameFn = _get_filename,
replace: bool = False, replace: bool = False,
) -> None: ) -> None:
"""Process a single ClipAnnotation and save the result to a file.
Internal function designed to be called by `preprocess_annotations`, often
in parallel worker processes. It generates the training example using
`generate_train_example` and saves it using `save_to_file`. Handles
file existence checks based on the `replace` flag.
Parameters
----------
clip_annotation : data.ClipAnnotation
The single annotation to process.
output_dir : Path
The directory where the output NetCDF file should be saved.
preprocessor : PreprocessorProtocol
Initialized preprocessor object.
labeller : ClipLabeller
Initialized labeller function.
filename_fn : FilenameFn, default=_get_filename
Function to determine the output filename.
replace : bool, default=False
Whether to overwrite existing output files.
"""
output_dir = Path(output_dir) output_dir = Path(output_dir)
filename = filename_fn(clip_annotation) filename = filename_fn(clip_annotation)
@ -197,13 +270,12 @@ def preprocess_single_annotation(
try: try:
sample = generate_train_example( sample = generate_train_example(
clip_annotation, clip_annotation,
preprocessing_config=preprocessing_config, preprocessor=preprocessor,
target_config=target_config, labeller=labeller,
label_config=label_config,
) )
except Exception as error: except Exception as error:
raise RuntimeError( raise RuntimeError(
f"Failed to process annotation: {clip_annotation.uuid}" f"Failed to process annotation: {clip_annotation.uuid}"
) from error ) from error
save_to_file(sample, path) _save_xr_dataset_to_file(sample, path)

48
batdetect2/train/types.py Normal file
View File

@ -0,0 +1,48 @@
from typing import Callable, NamedTuple
import xarray as xr
from soundevent import data
__all__ = [
"Heatmaps",
"ClipLabeller",
"Augmentation",
]
class Heatmaps(NamedTuple):
"""Structure holding the generated heatmap targets.
Attributes
----------
detection : xr.DataArray
Heatmap indicating the probability of sound event presence. Typically
smoothed with a Gaussian kernel centered on event reference points.
Shape matches the input spectrogram. Values normalized [0, 1].
classes : xr.DataArray
Heatmap indicating the probability of specific class presence. Has an
additional 'category' dimension corresponding to the target class
names. Each category slice is typically smoothed with a Gaussian
kernel. Values normalized [0, 1] per category.
size : xr.DataArray
Heatmap encoding the size (width, height) of detected events. Has an
additional 'dimension' coordinate ('width', 'height'). Values represent
scaled dimensions placed at the event reference points.
"""
detection: xr.DataArray
classes: xr.DataArray
size: xr.DataArray
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
"""Type alias for the final clip labelling function.
This function takes the complete annotations for a clip and the corresponding
spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training.
"""
Augmentation = Callable[[xr.Dataset], xr.Dataset]

View File

@ -1,14 +1,10 @@
"""Types used in the code base.""" """Types used in the code base."""
from typing import Any, List, NamedTuple, Optional from typing import Any, List, NamedTuple, Optional, TypedDict
import numpy as np import numpy as np
import torch import torch
from typing import TypedDict
try: try:
from typing import Protocol from typing import Protocol
except ImportError: except ImportError:

View File

@ -1,34 +0,0 @@
import torch
from hypothesis import given
from hypothesis import strategies as st
from batdetect2.models import ModelConfig, ModelType, build_architecture
@given(
input_width=st.integers(min_value=50, max_value=1500),
input_height=st.integers(min_value=1, max_value=16),
model_type=st.sampled_from(ModelType),
)
def test_model_can_process_spectrograms_of_any_width(
input_width,
input_height,
model_type,
):
# Input height must be divisible by 8
input_height = 8 * input_height
input = torch.rand([1, 1, input_height, input_width])
model = build_architecture(
config=ModelConfig(
name=model_type, # type: ignore
input_height=input_height,
),
)
output = model(input)
assert output.shape[0] == 1
assert output.shape[1] == model.out_channels
assert output.shape[2] == input_height
assert output.shape[3] == input_width

View File

@ -4,7 +4,7 @@ import numpy as np
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.targets import generate_heatmaps from batdetect2.train.labels import generate_heatmaps
recording = data.Recording( recording = data.Recording(
samplerate=256_000, samplerate=256_000,