mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fixing imports after restructuring
This commit is contained in:
parent
dcae411ccb
commit
7c89e82579
@ -10,13 +10,13 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models import (
|
||||
BackboneConfig,
|
||||
BBoxHead,
|
||||
ClassifierHead,
|
||||
ModelConfig,
|
||||
ModelOutput,
|
||||
build_architecture,
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.post_process import (
|
||||
from batdetect2.postprocess import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
@ -37,7 +37,7 @@ __all__ = [
|
||||
class ModuleConfig(BaseConfig):
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
architecture: ModelConfig = Field(default_factory=ModelConfig)
|
||||
architecture: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
@ -58,7 +58,7 @@ class DetectorModel(L.LightningModule):
|
||||
self.config = config or ModuleConfig()
|
||||
self.save_hyperparameters()
|
||||
|
||||
self.backbone = build_architecture(self.config.architecture)
|
||||
self.backbone = build_model_backbone(self.config.architecture)
|
||||
|
||||
self.classifier = ClassifierHead(
|
||||
num_classes=len(self.config.targets.classes),
|
||||
|
@ -1,4 +1,30 @@
|
||||
from typing import Callable, Optional, Union
|
||||
"""Applies data augmentation techniques to BatDetect2 training examples.
|
||||
|
||||
This module provides various functions and configurable components for applying
|
||||
data augmentation to the training examples (`xr.Dataset` containing audio,
|
||||
spectrogram, and target heatmaps) generated by the
|
||||
`batdetect2.train.preprocess` module.
|
||||
|
||||
Data augmentation artificially increases the diversity of the training data by
|
||||
applying random transformations, which generally helps improve the robustness
|
||||
and generalization performance of trained models.
|
||||
|
||||
Augmentations included:
|
||||
- Time-based: `select_subclip`, `warp_spectrogram`, `mask_time`.
|
||||
- Amplitude/Noise-based: `mix_examples`, `add_echo`, `scale_volume`,
|
||||
`mask_frequency`.
|
||||
|
||||
Some augmentations modify the audio waveform and require recomputing the
|
||||
spectrogram using the `PreprocessorProtocol`, while others operate directly
|
||||
on the spectrogram or target heatmaps. The entire augmentation pipeline can be
|
||||
configured using the `AugmentationsConfig` class, specifying a sequence of
|
||||
augmentation steps, each with its own parameters and application probability.
|
||||
The `build_augmentations` function constructs the final augmentation callable
|
||||
from this configuration.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Annotated, Callable, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
@ -6,15 +32,14 @@ from pydantic import Field
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||
from batdetect2.preprocess.arrays import adjust_width
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.train.types import Augmentation
|
||||
from batdetect2.utils.arrays import adjust_width
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"load_agumentation_config",
|
||||
"load_augmentation_config",
|
||||
"build_augmentations",
|
||||
"select_subclip",
|
||||
"mix_examples",
|
||||
"add_echo",
|
||||
@ -22,13 +47,21 @@ __all__ = [
|
||||
"warp_spectrogram",
|
||||
"mask_time",
|
||||
"mask_frequency",
|
||||
"augment_example",
|
||||
"MixAugmentationConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"AugmentationConfig",
|
||||
"ExampleSource",
|
||||
]
|
||||
|
||||
ExampleSource = Callable[[], xr.Dataset]
|
||||
"""Type alias for a function that returns a training example (`xr.Dataset`).
|
||||
|
||||
class BaseAugmentationConfig(BaseConfig):
|
||||
enable: bool = True
|
||||
probability: float = 0.2
|
||||
Used by the `mix_examples` augmentation to fetch another example to mix with.
|
||||
"""
|
||||
|
||||
|
||||
def select_subclip(
|
||||
@ -38,7 +71,47 @@ def select_subclip(
|
||||
width: Optional[int] = None,
|
||||
random: bool = False,
|
||||
) -> xr.Dataset:
|
||||
"""Select a random subclip from a clip."""
|
||||
"""Extract a sub-clip (time segment) from a training example dataset.
|
||||
|
||||
Selects a portion of the 'time' dimension from all relevant DataArrays
|
||||
(`audio`, `spectrogram`, `detection`, `class`, `size`) within the example
|
||||
Dataset. The segment can be defined by a fixed start time and
|
||||
duration/width, or a random start time can be chosen.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
The input training example containing 'audio', 'spectrogram', and
|
||||
target heatmaps, all with compatible 'time' (or 'audio_time')
|
||||
coordinates.
|
||||
start_time : float, optional
|
||||
Desired start time (seconds) of the subclip. If None and `random` is
|
||||
False, starts from the beginning of the example. If None and `random`
|
||||
is True, a random start time is chosen.
|
||||
duration : float, optional
|
||||
Desired duration (seconds) of the subclip. Either `duration` or `width`
|
||||
must be provided.
|
||||
width : int, optional
|
||||
Desired width (number of time bins) of the subclip's
|
||||
spectrogram/heatmaps. Either `duration` or `width` must be provided. If
|
||||
both are given, `duration` takes precedence.
|
||||
random : bool, default=False
|
||||
If True and `start_time` is None, selects a random start time ensuring
|
||||
the subclip fits within the original example's duration.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
A new dataset containing only the selected time segment. Coordinates
|
||||
are adjusted accordingly. Returns the original example if the requested
|
||||
subclip cannot be extracted (e.g., duration too long).
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If neither `duration` nor `width` is provided, or if time coordinates
|
||||
are missing or invalid.
|
||||
"""
|
||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||
|
||||
@ -77,22 +150,62 @@ def select_subclip(
|
||||
)
|
||||
|
||||
|
||||
class MixAugmentationConfig(BaseAugmentationConfig):
|
||||
class MixAugmentationConfig(BaseConfig):
|
||||
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||
|
||||
augmentation_type: Literal["mix_audio"] = "mix_audio"
|
||||
|
||||
probability: float = 0.2
|
||||
"""Probability of applying this augmentation to an example."""
|
||||
|
||||
min_weight: float = 0.3
|
||||
"""Minimum mixing weight (lambda) applied to the primary example."""
|
||||
|
||||
max_weight: float = 0.7
|
||||
"""Maximum mixing weight (lambda) applied to the primary example."""
|
||||
|
||||
|
||||
def mix_examples(
|
||||
example: xr.Dataset,
|
||||
other: xr.Dataset,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
weight: Optional[float] = None,
|
||||
min_weight: float = 0.3,
|
||||
max_weight: float = 0.7,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Combine two audio clips."""
|
||||
config = config or PreprocessingConfig()
|
||||
"""Combine two training examples using MixUp augmentation.
|
||||
|
||||
Performs a weighted linear combination of the audio waveforms from two
|
||||
examples (`example` and `other`). The spectrogram is then *recomputed*
|
||||
from the mixed audio using the provided `preprocessor`. Target heatmaps
|
||||
(detection, class) are combined by taking the element-wise maximum. Target
|
||||
size heatmaps are combined by element-wise addition.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
The primary training example.
|
||||
other : xr.Dataset
|
||||
The second training example to mix with `example`.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The preprocessor used to recompute the spectrogram from mixed audio.
|
||||
Ensures consistency with original preprocessing.
|
||||
weight : float, optional
|
||||
The mixing weight (lambda) applied to the primary `example`. The weight
|
||||
for `other` will be `(1 - weight)`. If None, a random weight is chosen
|
||||
uniformly between `min_weight` and `max_weight`.
|
||||
min_weight : float, default=0.3
|
||||
Minimum value for the random weight lambda.
|
||||
max_weight : float, default=0.7
|
||||
Maximum value for the random weight lambda. Must be >= `min_weight`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
A new dataset representing the mixed example, with combined audio,
|
||||
recomputed spectrogram, and combined target heatmaps. Attributes from
|
||||
the primary `example` are preserved.
|
||||
"""
|
||||
if weight is None:
|
||||
weight = np.random.uniform(min_weight, max_weight)
|
||||
|
||||
@ -101,9 +214,8 @@ def mix_examples(
|
||||
|
||||
combined = weight * audio1 + (1 - weight) * audio2
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
combined.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
spectrogram = preprocessor.compute_spectrogram(
|
||||
combined.rename({"audio_time": "time"})
|
||||
).data
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
@ -147,7 +259,14 @@ def mix_examples(
|
||||
)
|
||||
|
||||
|
||||
class EchoAugmentationConfig(BaseAugmentationConfig):
|
||||
class EchoAugmentationConfig(BaseConfig):
|
||||
"""Configuration for adding synthetic echo/reverb."""
|
||||
|
||||
augmentation_type: Literal["add_echo"] = "add_echo"
|
||||
|
||||
probability: float = 0.2
|
||||
"""Probability of applying this augmentation."""
|
||||
|
||||
max_delay: float = 0.005
|
||||
min_weight: float = 0.0
|
||||
max_weight: float = 1.0
|
||||
@ -155,15 +274,45 @@ class EchoAugmentationConfig(BaseAugmentationConfig):
|
||||
|
||||
def add_echo(
|
||||
example: xr.Dataset,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
delay: Optional[float] = None,
|
||||
weight: Optional[float] = None,
|
||||
min_weight: float = 0.1,
|
||||
max_weight: float = 1.0,
|
||||
max_delay: float = 0.005,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Add a delay to the audio."""
|
||||
config = config or PreprocessingConfig()
|
||||
"""Add a synthetic echo to the audio waveform.
|
||||
|
||||
Creates an echo by adding a delayed and attenuated version of the original
|
||||
audio waveform back onto itself. The spectrogram is then recomputed from
|
||||
the modified audio. Target heatmaps remain unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
The input training example containing 'audio', 'spectrogram', etc.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Preprocessor used to recompute the spectrogram from the modified audio.
|
||||
delay : float, optional
|
||||
The delay time in seconds for the echo. If None, chosen randomly
|
||||
between 0 and `max_delay`.
|
||||
weight : float, optional
|
||||
The relative amplitude (weight) of the echo compared to the original.
|
||||
If None, chosen randomly between `min_weight` and `max_weight`.
|
||||
min_weight : float, default=0.1
|
||||
Minimum value for the random echo weight.
|
||||
max_weight : float, default=1.0
|
||||
Maximum value for the random echo weight. Must be >= `min_weight`.
|
||||
max_delay : float, default=0.005
|
||||
Maximum value for the random echo delay in seconds. Must be >= 0.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
A new dataset with the echo added to the 'audio' variable and the
|
||||
'spectrogram' variable recomputed. Other variables (targets, attrs)
|
||||
are copied from the original example.
|
||||
"""
|
||||
|
||||
if delay is None:
|
||||
delay = np.random.uniform(0, max_delay)
|
||||
@ -176,9 +325,8 @@ def add_echo(
|
||||
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||
audio = audio + weight * audio_delay
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
spectrogram = preprocessor.compute_spectrogram(
|
||||
audio.rename({"audio_time": "time"}),
|
||||
config=config.spectrogram,
|
||||
).data
|
||||
|
||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||
@ -202,7 +350,11 @@ def add_echo(
|
||||
)
|
||||
|
||||
|
||||
class VolumeAugmentationConfig(BaseAugmentationConfig):
|
||||
class VolumeAugmentationConfig(BaseConfig):
|
||||
"""Configuration for random volume scaling of the spectrogram."""
|
||||
|
||||
augmentation_type: Literal["scale_volume"] = "scale_volume"
|
||||
probability: float = 0.2
|
||||
min_scaling: float = 0.0
|
||||
max_scaling: float = 2.0
|
||||
|
||||
@ -213,14 +365,44 @@ def scale_volume(
|
||||
max_scaling: float = 2,
|
||||
min_scaling: float = 0,
|
||||
) -> xr.Dataset:
|
||||
"""Scale the volume of a spectrogram."""
|
||||
"""Scale the amplitude of the spectrogram by a random factor.
|
||||
|
||||
Multiplies the entire spectrogram DataArray by a scaling factor chosen
|
||||
uniformly between `min_scaling` and `max_scaling`. This simulates changes
|
||||
in recording volume or distance to the sound source directly in the
|
||||
spectrogram domain. Audio and target heatmaps are unchanged.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
The input training example containing 'spectrogram'.
|
||||
factor : float, optional
|
||||
The scaling factor to apply. If None, chosen randomly between
|
||||
`min_scaling` and `max_scaling`.
|
||||
min_scaling : float, default=0.0
|
||||
Minimum value for the random scaling factor. Must be non-negative.
|
||||
max_scaling : float, default=2.0
|
||||
Maximum value for the random scaling factor. Must be >= `min_scaling`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
A new dataset with the 'spectrogram' variable scaled.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `min_scaling` > `max_scaling` or if `min_scaling` is negative.
|
||||
"""
|
||||
if factor is None:
|
||||
factor = np.random.uniform(min_scaling, max_scaling)
|
||||
|
||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||
|
||||
|
||||
class WarpAugmentationConfig(BaseAugmentationConfig):
|
||||
class WarpAugmentationConfig(BaseConfig):
|
||||
augmentation_type: Literal["warp"] = "warp"
|
||||
probability: float = 0.2
|
||||
delta: float = 0.04
|
||||
|
||||
|
||||
@ -229,11 +411,39 @@ def warp_spectrogram(
|
||||
factor: Optional[float] = None,
|
||||
delta: float = 0.04,
|
||||
) -> xr.Dataset:
|
||||
"""Warp a spectrogram."""
|
||||
"""Apply time warping by resampling the time axis.
|
||||
|
||||
Stretches or compresses the time axis of the spectrogram and all target
|
||||
heatmaps by a random `factor` (chosen between `1-delta` and `1+delta`).
|
||||
Uses linear interpolation for the spectrogram and nearest neighbor for the
|
||||
discrete-like target heatmaps. Updates the 'time' coordinate accordingly.
|
||||
The audio waveform is not modified.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
Input training example with 'spectrogram', 'detection', 'class',
|
||||
'size', and 'time' coordinates.
|
||||
factor : float, optional
|
||||
The warping factor. If None, chosen randomly between `1-delta` and
|
||||
`1+delta`. Values > 1 stretch time, < 1 compress time.
|
||||
delta : float, default=0.04
|
||||
Controls the range `[1-delta, 1+delta]` for the random warp factor.
|
||||
Must be >= 0 and < 1.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
A new dataset with time-warped spectrogram and target heatmaps, and
|
||||
updated 'time' coordinates.
|
||||
"""
|
||||
if factor is None:
|
||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||
|
||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||
start_time, end_time = arrays.get_dim_range(
|
||||
example, # type: ignore
|
||||
"time",
|
||||
)
|
||||
duration = end_time - start_time
|
||||
|
||||
new_time = np.linspace(
|
||||
@ -296,6 +506,39 @@ def mask_axis(
|
||||
end: float,
|
||||
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
|
||||
) -> xr.DataArray:
|
||||
"""Mask values along a specified dimension.
|
||||
|
||||
Sets values in the DataArray to `mask_value` where the coordinate along
|
||||
`dim` falls within the range [`start`, `end`). Values outside this range
|
||||
are kept. Used as a helper for time/frequency masking.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
array : xr.DataArray
|
||||
The input DataArray (e.g., spectrogram).
|
||||
dim : str
|
||||
The name of the dimension along which to mask
|
||||
(e.g., "time", "frequency").
|
||||
start : float
|
||||
The starting coordinate value for the mask range.
|
||||
end : float
|
||||
The ending coordinate value for the mask range (exclusive, typically).
|
||||
Values >= start and < end will be masked.
|
||||
mask_value : float or Callable[[xr.DataArray], float], default=np.mean
|
||||
The value to use for masking. Can be a fixed float (e.g., 0.0) or a
|
||||
callable (like `np.mean`, `np.median`) that computes the value from
|
||||
the input `array`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
The DataArray with the specified range along `dim` masked.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `dim` is not found in the array's dimensions or coordinates.
|
||||
"""
|
||||
if dim not in array.dims:
|
||||
raise ValueError(f"Axis {dim} not found in array")
|
||||
|
||||
@ -308,7 +551,9 @@ def mask_axis(
|
||||
return array.where(condition, other=mask_value)
|
||||
|
||||
|
||||
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
class TimeMaskAugmentationConfig(BaseConfig):
|
||||
augmentation_type: Literal["mask_time"] = "mask_time"
|
||||
probability: float = 0.2
|
||||
max_perc: float = 0.05
|
||||
max_masks: int = 3
|
||||
|
||||
@ -318,9 +563,32 @@ def mask_time(
|
||||
max_perc: float = 0.05,
|
||||
max_mask: int = 3,
|
||||
) -> xr.Dataset:
|
||||
"""Mask a random section of the time axis."""
|
||||
"""Apply random time masking (SpecAugment) to the spectrogram.
|
||||
|
||||
Randomly selects a number of time intervals (up to `max_masks`) and masks
|
||||
(sets to the mean value) the spectrogram within those intervals. The width
|
||||
of each mask is chosen randomly up to `max_perc` of the total duration.
|
||||
Only the 'spectrogram' variable is modified.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
Input training example containing 'spectrogram' and 'time' coordinate.
|
||||
max_perc : float, default=0.05
|
||||
Maximum width of a single mask as a fraction of total duration.
|
||||
max_masks : int, default=3
|
||||
Maximum number of time masks to apply.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
Dataset with time masking applied to the 'spectrogram'.
|
||||
"""
|
||||
num_masks = np.random.randint(1, max_mask + 1)
|
||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
||||
start_time, end_time = arrays.get_dim_range(
|
||||
example, # type: ignore
|
||||
"time",
|
||||
)
|
||||
|
||||
spectrogram = example["spectrogram"]
|
||||
for _ in range(num_masks):
|
||||
@ -332,7 +600,9 @@ def mask_time(
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||
augmentation_type: Literal["mask_freq"] = "mask_freq"
|
||||
probability: float = 0.2
|
||||
max_perc: float = 0.10
|
||||
max_masks: int = 3
|
||||
|
||||
@ -342,9 +612,38 @@ def mask_frequency(
|
||||
max_perc: float = 0.10,
|
||||
max_masks: int = 3,
|
||||
) -> xr.Dataset:
|
||||
"""Mask a random section of the frequency axis."""
|
||||
"""Apply random frequency masking (SpecAugment) to the spectrogram.
|
||||
|
||||
Randomly selects a number of frequency intervals (up to `max_masks`) and
|
||||
masks (sets to the mean value) the spectrogram within those intervals. The
|
||||
height of each mask is chosen randomly up to `max_perc` of the total
|
||||
frequency range. Only the 'spectrogram' variable is modified.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
Input training example containing 'spectrogram' and 'frequency'
|
||||
coordinate.
|
||||
max_perc : float, default=0.10
|
||||
Maximum height of a single mask as a fraction of total frequency range.
|
||||
max_masks : int, default=3
|
||||
Maximum number of frequency masks to apply.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
Dataset with frequency masking applied to the 'spectrogram'.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If frequency coordinate info is missing or invalid.
|
||||
"""
|
||||
num_masks = np.random.randint(1, max_masks + 1)
|
||||
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
|
||||
min_freq, max_freq = arrays.get_dim_range(
|
||||
example, # type: ignore
|
||||
"frequency",
|
||||
)
|
||||
|
||||
spectrogram = example["spectrogram"]
|
||||
for _ in range(num_masks):
|
||||
@ -356,88 +655,326 @@ def mask_frequency(
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
AugmentationConfig = Annotated[
|
||||
Union[
|
||||
MixAugmentationConfig,
|
||||
EchoAugmentationConfig,
|
||||
VolumeAugmentationConfig,
|
||||
WarpAugmentationConfig,
|
||||
FrequencyMaskAugmentationConfig,
|
||||
TimeMaskAugmentationConfig,
|
||||
],
|
||||
Field(discriminator="augmentation_type"),
|
||||
]
|
||||
"""Type alias for the discriminated union of individual augmentation config."""
|
||||
|
||||
|
||||
class AugmentationsConfig(BaseConfig):
|
||||
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
||||
echo: EchoAugmentationConfig = Field(
|
||||
default_factory=EchoAugmentationConfig
|
||||
)
|
||||
volume: VolumeAugmentationConfig = Field(
|
||||
default_factory=VolumeAugmentationConfig
|
||||
)
|
||||
warp: WarpAugmentationConfig = Field(
|
||||
default_factory=WarpAugmentationConfig
|
||||
)
|
||||
time_mask: TimeMaskAugmentationConfig = Field(
|
||||
default_factory=TimeMaskAugmentationConfig
|
||||
)
|
||||
frequency_mask: FrequencyMaskAugmentationConfig = Field(
|
||||
default_factory=FrequencyMaskAugmentationConfig
|
||||
"""Configuration for a sequence of data augmentations.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
steps : List[AugmentationConfig]
|
||||
An ordered list of configuration objects, each defining a single
|
||||
augmentation step (e.g., MixAugmentationConfig,
|
||||
TimeMaskAugmentationConfig). Each step's configuration must include an
|
||||
`augmentation_type` field and a `probability` field, along with
|
||||
type-specific parameters. The augmentations will be applied
|
||||
(probabilistically) in the sequence defined by this list.
|
||||
"""
|
||||
|
||||
steps: List[AugmentationConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class MaybeApply:
|
||||
"""Applies an augmentation function probabilistically."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
augmentation: Augmentation,
|
||||
probability: float = 0.2,
|
||||
):
|
||||
"""Initialize the wrapper.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset])
|
||||
The augmentation function to potentially apply.
|
||||
probability : float, default=0.5
|
||||
The probability (0.0 to 1.0) of applying the augmentation.
|
||||
"""
|
||||
self.augmentation = augmentation
|
||||
self.probability = probability
|
||||
|
||||
def __call__(self, example: xr.Dataset) -> xr.Dataset:
|
||||
"""Apply the wrapped augmentation with configured probability.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
The input training example.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
The potentially augmented training example.
|
||||
"""
|
||||
if np.random.random() > self.probability:
|
||||
return example
|
||||
|
||||
return self.augmentation(example)
|
||||
|
||||
|
||||
class AudioMixer:
|
||||
"""Callable class for MixUp augmentation, handling example fetching.
|
||||
|
||||
Wraps the `mix_examples` logic, providing the necessary `example_source`
|
||||
to fetch a second example dynamically when called.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
min_weight : float
|
||||
Minimum mixing weight (lambda) for the primary example.
|
||||
max_weight : float
|
||||
Maximum mixing weight (lambda) for the primary example.
|
||||
example_source : ExampleSource (Callable[[], xr.Dataset])
|
||||
A function that, when called, returns another random training example
|
||||
dataset (`xr.Dataset`) to be mixed with the input example.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The preprocessor needed to recompute the spectrogram after mixing
|
||||
audio.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
min_weight: float,
|
||||
max_weight: float,
|
||||
example_source: ExampleSource,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
):
|
||||
"""Initialize the AudioMixer."""
|
||||
self.min_weight = min_weight
|
||||
self.example_source = example_source
|
||||
self.max_weight = max_weight
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def __call__(self, example: xr.Dataset) -> xr.Dataset:
|
||||
"""Fetch another example and perform mixup.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
example : xr.Dataset
|
||||
The primary input training example.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
The mixed training example. Returns the original example if
|
||||
fetching the second example fails.
|
||||
"""
|
||||
other = self.example_source()
|
||||
return mix_examples(
|
||||
example,
|
||||
other,
|
||||
self.preprocessor,
|
||||
min_weight=self.min_weight,
|
||||
max_weight=self.max_weight,
|
||||
)
|
||||
|
||||
|
||||
def build_augmentation_from_config(
|
||||
config: AugmentationConfig,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
example_source: Optional[ExampleSource] = None,
|
||||
) -> Augmentation:
|
||||
"""Factory function to build a single augmentation from its config.
|
||||
|
||||
Takes a configuration object for one augmentation step (which includes the
|
||||
`augmentation_type` discriminator) and returns the corresponding functional
|
||||
augmentation callable (e.g., a `partial` function or a callable class like
|
||||
`AudioMixer`).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : AugmentationConfig
|
||||
Configuration object for a single augmentation (e.g., instance of
|
||||
`MixAugmentationConfig`, `EchoAugmentationConfig`, etc.).
|
||||
preprocessor : PreprocessorProtocol
|
||||
The preprocessor object, required by augmentations that modify audio
|
||||
and need to recompute the spectrogram (e.g., mixup, echo).
|
||||
example_source : ExampleSource, optional
|
||||
A callable that provides other training examples. Required only if
|
||||
the configuration includes `MixAugmentationConfig` (`augmentation_type`
|
||||
is "mix_audio").
|
||||
|
||||
Returns
|
||||
-------
|
||||
Augmentation (Callable[[xr.Dataset], xr.Dataset])
|
||||
A callable function that takes a training example (`xr.Dataset`) and
|
||||
returns a potentially augmented version.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If `config.augmentation_type` is "mix_audio" but `example_source` is
|
||||
None.
|
||||
NotImplementedError
|
||||
If `config.augmentation_type` does not match any known augmentation
|
||||
type.
|
||||
"""
|
||||
if config.augmentation_type == "mix_audio":
|
||||
if example_source is None:
|
||||
raise ValueError(
|
||||
"Mix audio augmentation ('mix_audio') requires an "
|
||||
"'example_source' callable to be provided."
|
||||
)
|
||||
|
||||
return AudioMixer(
|
||||
example_source=example_source,
|
||||
preprocessor=preprocessor,
|
||||
min_weight=config.min_weight,
|
||||
max_weight=config.max_weight,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "add_echo":
|
||||
return partial(
|
||||
add_echo,
|
||||
preprocessor=preprocessor,
|
||||
max_delay=config.max_delay,
|
||||
min_weight=config.min_weight,
|
||||
max_weight=config.max_weight,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "scale_volume":
|
||||
return partial(
|
||||
scale_volume,
|
||||
max_scaling=config.max_scaling,
|
||||
min_scaling=config.min_scaling,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "warp":
|
||||
return partial(
|
||||
warp_spectrogram,
|
||||
delta=config.delta,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "mask_time":
|
||||
return partial(
|
||||
mask_time,
|
||||
max_perc=config.max_perc,
|
||||
max_mask=config.max_masks,
|
||||
)
|
||||
|
||||
if config.augmentation_type == "mask_freq":
|
||||
return partial(
|
||||
mask_frequency,
|
||||
max_perc=config.max_perc,
|
||||
max_masks=config.max_masks,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
"Invalid or unimplemented augmentation type: "
|
||||
f"{config.augmentation_type}"
|
||||
)
|
||||
|
||||
|
||||
def load_agumentation_config(
|
||||
def build_augmentations(
|
||||
config: AugmentationsConfig,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
example_source: Optional[ExampleSource] = None,
|
||||
) -> Augmentation:
|
||||
"""Build a composite augmentation pipeline function from configuration.
|
||||
|
||||
Takes the overall `AugmentationsConfig` (containing a list of individual
|
||||
augmentation steps), builds the callable function for each step using
|
||||
`build_augmentation_from_config`, wraps each function with `MaybeApply`
|
||||
to handle its application probability, and returns a single
|
||||
`Augmentation` function that applies the entire sequence.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : AugmentationsConfig
|
||||
The configuration object detailing the sequence of augmentation steps.
|
||||
preprocessor : PreprocessorProtocol
|
||||
The preprocessor object, needed for audio-modifying augmentations.
|
||||
example_source : ExampleSource, optional
|
||||
A callable providing other examples, required if 'mix_audio' is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Augmentation (Callable[[xr.Dataset], xr.Dataset])
|
||||
A single callable function that takes a training example (`xr.Dataset`)
|
||||
and applies the configured sequence of augmentations probabilistically,
|
||||
returning the augmented example. Returns the original example if
|
||||
`config.steps` is empty.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If 'mix_audio' is configured but `example_source` is not provided.
|
||||
NotImplementedError
|
||||
If an unknown `augmentation_type` is encountered in `config.steps`.
|
||||
"""
|
||||
augmentations = []
|
||||
|
||||
for step_config in config.steps:
|
||||
augmentation = build_augmentation_from_config(
|
||||
step_config,
|
||||
preprocessor=preprocessor,
|
||||
example_source=example_source,
|
||||
)
|
||||
augmentations.append(
|
||||
MaybeApply(
|
||||
augmentation=augmentation,
|
||||
probability=step_config.probability,
|
||||
)
|
||||
)
|
||||
|
||||
return partial(_apply_augmentations, augmentations=augmentations)
|
||||
|
||||
|
||||
def load_augmentation_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> AugmentationsConfig:
|
||||
"""Load the augmentations configuration from a file.
|
||||
|
||||
Reads a configuration file (YAML) and validates it against the
|
||||
`AugmentationsConfig` schema, potentially extracting data from a nested
|
||||
field.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : PathLike
|
||||
Path to the configuration file.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
augmentations configuration (e.g., "training.augmentations"). If None,
|
||||
the entire file content is used.
|
||||
|
||||
Returns
|
||||
-------
|
||||
AugmentationsConfig
|
||||
The loaded and validated augmentations configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
yaml.YAMLError
|
||||
If the file content is not valid YAML.
|
||||
pydantic.ValidationError
|
||||
If the loaded config data does not conform to `AugmentationsConfig`.
|
||||
KeyError, TypeError
|
||||
If `field` specifies an invalid path.
|
||||
"""
|
||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||
|
||||
|
||||
def should_apply(config: BaseAugmentationConfig) -> bool:
|
||||
if not config.enable:
|
||||
return False
|
||||
|
||||
return np.random.uniform() < config.probability
|
||||
|
||||
|
||||
def augment_example(
|
||||
def _apply_augmentations(
|
||||
example: xr.Dataset,
|
||||
config: AugmentationsConfig,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
others: Optional[Callable[[], xr.Dataset]] = None,
|
||||
) -> xr.Dataset:
|
||||
if should_apply(config.mix) and (others is not None):
|
||||
other = others()
|
||||
example = mix_examples(
|
||||
example,
|
||||
other,
|
||||
min_weight=config.mix.min_weight,
|
||||
max_weight=config.mix.max_weight,
|
||||
config=preprocessing_config,
|
||||
)
|
||||
|
||||
if should_apply(config.echo):
|
||||
example = add_echo(
|
||||
example,
|
||||
max_delay=config.echo.max_delay,
|
||||
min_weight=config.echo.min_weight,
|
||||
max_weight=config.echo.max_weight,
|
||||
config=preprocessing_config,
|
||||
)
|
||||
|
||||
if should_apply(config.volume):
|
||||
example = scale_volume(
|
||||
example,
|
||||
max_scaling=config.volume.max_scaling,
|
||||
min_scaling=config.volume.min_scaling,
|
||||
)
|
||||
|
||||
if should_apply(config.warp):
|
||||
example = warp_spectrogram(
|
||||
example,
|
||||
delta=config.warp.delta,
|
||||
)
|
||||
|
||||
if should_apply(config.time_mask):
|
||||
example = mask_time(
|
||||
example,
|
||||
max_perc=config.time_mask.max_perc,
|
||||
max_mask=config.time_mask.max_masks,
|
||||
)
|
||||
|
||||
if should_apply(config.frequency_mask):
|
||||
example = mask_frequency(
|
||||
example,
|
||||
max_perc=config.frequency_mask.max_perc,
|
||||
max_masks=config.frequency_mask.max_masks,
|
||||
)
|
||||
|
||||
augmentations: List[Augmentation],
|
||||
):
|
||||
"""Apply a sequence of augmentation functions to an example."""
|
||||
for augmentation in augmentations:
|
||||
example = augmentation(example)
|
||||
return example
|
||||
|
@ -3,14 +3,15 @@ from lightning.pytorch.callbacks import Callback
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.post_process import postprocess_model_outputs
|
||||
from batdetect2.postprocess import PostprocessorProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.types import ModelOutput
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self):
|
||||
def __init__(self, postprocessor: PostprocessorProtocol):
|
||||
super().__init__()
|
||||
self.postprocessor = postprocessor
|
||||
self.predictions = []
|
||||
|
||||
def on_validation_epoch_start(
|
||||
@ -36,20 +37,20 @@ class ValidationMetrics(Callback):
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||
|
||||
clip_prediction = postprocess_model_outputs(
|
||||
outputs,
|
||||
clips=[clip_annotation.clip],
|
||||
classes=self.class_names,
|
||||
decoder=self.decoder,
|
||||
config=self.config.postprocessing,
|
||||
)[0]
|
||||
|
||||
matches = match_predictions_and_annotations(
|
||||
clip_annotation,
|
||||
clip_prediction,
|
||||
)
|
||||
|
||||
self.validation_predictions.extend(matches)
|
||||
return super().on_validation_batch_end(
|
||||
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
)
|
||||
# clip_prediction = postprocess_model_outputs(
|
||||
# outputs,
|
||||
# clips=[clip_annotation.clip],
|
||||
# classes=self.class_names,
|
||||
# decoder=self.decoder,
|
||||
# config=self.config.postprocessing,
|
||||
# )[0]
|
||||
#
|
||||
# matches = match_predictions_and_annotations(
|
||||
# clip_annotation,
|
||||
# clip_prediction,
|
||||
# )
|
||||
#
|
||||
# self.validation_predictions.extend(matches)
|
||||
# return super().on_validation_batch_end(
|
||||
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||
# )
|
||||
|
@ -10,13 +10,13 @@ from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess.tensors import adjust_width
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
augment_example,
|
||||
select_subclip,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessingConfig
|
||||
from batdetect2.train.preprocess import PreprocessorProtocol
|
||||
from batdetect2.utils.tensors import adjust_width
|
||||
|
||||
__all__ = [
|
||||
"TrainExample",
|
||||
@ -51,15 +51,15 @@ class DatasetConfig(BaseConfig):
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
filenames: Sequence[PathLike],
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
self.preprocessor = preprocessor
|
||||
self.filenames = filenames
|
||||
self.subclip = subclip
|
||||
self.augmentation = augmentation
|
||||
self.preprocessing = preprocessing or PreprocessingConfig()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
@ -79,7 +79,7 @@ class LabeledDataset(Dataset):
|
||||
dataset = augment_example(
|
||||
dataset,
|
||||
self.augmentation,
|
||||
preprocessing_config=self.preprocessing,
|
||||
preprocessor=self.preprocessor,
|
||||
others=self.get_random_example,
|
||||
)
|
||||
|
||||
@ -95,16 +95,16 @@ class LabeledDataset(Dataset):
|
||||
def from_directory(
|
||||
cls,
|
||||
directory: PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
extension: str = ".nc",
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
return cls(
|
||||
get_preprocessed_files(directory, extension),
|
||||
preprocessor=preprocessor,
|
||||
filenames=get_preprocessed_files(directory, extension),
|
||||
subclip=subclip,
|
||||
augmentation=augmentation,
|
||||
preprocessing=preprocessing,
|
||||
)
|
||||
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
|
@ -1,48 +1,47 @@
|
||||
"""Generate heatmap training targets for BatDetect2 models.
|
||||
|
||||
This module represents the final step in the `batdetect2.targets` pipeline,
|
||||
converting processed sound event annotations from an audio clip into the
|
||||
specific heatmap formats required for training the BatDetect2 neural network.
|
||||
This module is responsible for creating the target labels used for training
|
||||
BatDetect2 models. It converts sound event annotations for an audio clip into
|
||||
the specific multi-channel heatmap formats required by the neural network.
|
||||
|
||||
It integrates the filtering, transformation, and class encoding logic defined
|
||||
in the preceding configuration steps (`filtering`, `transform`, `classes`)
|
||||
and applies them to generate three core outputs for a given spectrogram:
|
||||
It uses a pre-configured object adhering to the `TargetProtocol` (from
|
||||
`batdetect2.targets`) which encapsulates all the logic for filtering
|
||||
annotations, transforming tags, encoding class names, and mapping annotation
|
||||
geometry (ROIs) to target positions and sizes. This module then focuses on
|
||||
rendering this information onto the heatmap grids.
|
||||
|
||||
1. **Detection Heatmap**: Indicates the presence and location of relevant
|
||||
sound events.
|
||||
2. **Class Heatmap**: Indicates the location and predicted class label for
|
||||
events that match a specific target class.
|
||||
3. **Size Heatmap**: Encodes the dimensions (width/time duration,
|
||||
height/frequency bandwidth) of the detected sound events at their
|
||||
reference locations.
|
||||
The pipeline generates three core outputs for a given spectrogram:
|
||||
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
|
||||
2. **Class Heatmap**: Indicates location and class identity for specifically
|
||||
classified events.
|
||||
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
|
||||
|
||||
The primary function generated by this module is a `ClipLabeller`, which takes
|
||||
a `ClipAnnotation` object and its corresponding spectrogram (`xr.DataArray`)
|
||||
and returns the calculated `Heatmaps`. Configuration options allow tuning of
|
||||
the heatmap generation process (e.g., Gaussian smoothing sigma, reference point
|
||||
within bounding boxes).
|
||||
The primary function generated by this module is a `ClipLabeller` (defined in
|
||||
`.types`), which takes a `ClipAnnotation` object and its corresponding
|
||||
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
|
||||
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
||||
defined in `LabelConfig`.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Callable, List, NamedTuple, Optional
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import arrays, data, geometry
|
||||
from soundevent.geometry.operations import Positions
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.classes import SoundEventEncoder
|
||||
from batdetect2.targets.filtering import SoundEventFilter
|
||||
from batdetect2.targets.transform import SoundEventTransformation
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.types import (
|
||||
ClipLabeller,
|
||||
Heatmaps,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LabelConfig",
|
||||
"Heatmaps",
|
||||
"ClipLabeller",
|
||||
"build_clip_labeler",
|
||||
"generate_clip_label",
|
||||
"generate_heatmaps",
|
||||
@ -50,109 +49,45 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
SIZE_DIMENSION = "dimension"
|
||||
"""Dimension name for the size heatmap."""
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Heatmaps(NamedTuple):
|
||||
"""Structure holding the generated heatmap targets.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : xr.DataArray
|
||||
Heatmap indicating the probability of sound event presence. Typically
|
||||
smoothed with a Gaussian kernel centered on event reference points.
|
||||
Shape matches the input spectrogram. Values normalized [0, 1].
|
||||
classes : xr.DataArray
|
||||
Heatmap indicating the probability of specific class presence. Has an
|
||||
additional 'category' dimension corresponding to the target class
|
||||
names. Each category slice is typically smoothed with a Gaussian
|
||||
kernel. Values normalized [0, 1] per category.
|
||||
size : xr.DataArray
|
||||
Heatmap encoding the size (width, height) of detected events. Has an
|
||||
additional 'dimension' coordinate ('width', 'height'). Values represent
|
||||
scaled dimensions placed at the event reference points.
|
||||
"""
|
||||
|
||||
detection: xr.DataArray
|
||||
classes: xr.DataArray
|
||||
size: xr.DataArray
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
||||
"""Type alias for the final clip labelling function.
|
||||
|
||||
This function takes the complete annotations for a clip and the corresponding
|
||||
spectrogram, applies all configured filtering, transformation, and encoding
|
||||
steps, and returns the final `Heatmaps` used for model training.
|
||||
"""
|
||||
|
||||
|
||||
class LabelConfig(BaseConfig):
|
||||
"""Configuration parameters for heatmap generation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
Specifies the reference point within each sound event's geometry
|
||||
(bounding box) that is used to place the 'peak' or value on the
|
||||
heatmaps. Options include 'center', 'bottom-left', 'top-right', etc.
|
||||
See `soundevent.geometry.operations.Positions` for valid options.
|
||||
sigma : float, default=3.0
|
||||
The standard deviation (in pixels/bins) of the Gaussian kernel applied
|
||||
to smooth the detection and class heatmaps. Larger values create more
|
||||
diffuse targets.
|
||||
time_scale : float, default=1000.0
|
||||
A scaling factor applied to the time duration (width) of sound event
|
||||
bounding boxes before storing them in the 'width' dimension of the
|
||||
`size` heatmap. The appropriate value depends on how the model input/
|
||||
output resolution relates to physical time units (e.g., if time axis
|
||||
is milliseconds, scale might be 1.0; if seconds, maybe 1000.0). Needs
|
||||
to match model expectations.
|
||||
frequency_scale : float, default=1/859.375
|
||||
A scaling factor applied to the frequency bandwidth (height) of sound
|
||||
event bounding boxes before storing them in the 'height' dimension of
|
||||
the `size` heatmap. The appropriate value depends on the relationship
|
||||
between the frequency axis resolution (e.g., kHz or Hz per bin) and
|
||||
the desired output units/scale for the model. Needs to match model
|
||||
expectations. (The default suggests input frequency might be in Hz
|
||||
and output is scaled relative to some reference).
|
||||
"""
|
||||
|
||||
position: Positions = "bottom-left"
|
||||
sigma: float = 3.0
|
||||
time_scale: float = 1000.0
|
||||
frequency_scale: float = 1 / 859.375
|
||||
|
||||
|
||||
def build_clip_labeler(
|
||||
filter_fn: SoundEventFilter,
|
||||
transform_fn: SoundEventTransformation,
|
||||
encoder_fn: SoundEventEncoder,
|
||||
class_names: List[str],
|
||||
targets: TargetProtocol,
|
||||
config: LabelConfig,
|
||||
) -> ClipLabeller:
|
||||
"""Construct the clip labelling function.
|
||||
"""Construct the final clip labelling function.
|
||||
|
||||
This function takes the pre-built components from the previous target
|
||||
definition steps (filtering, transformation, encoding) and the label
|
||||
configuration, then returns a single callable (`ClipLabeller`) that
|
||||
performs the end-to-end heatmap generation for a given clip and
|
||||
spectrogram.
|
||||
This factory function prepares the callable that will perform the
|
||||
end-to-end heatmap generation for a given clip and spectrogram during
|
||||
training data loading. It takes the fully configured `targets` object and
|
||||
the `LabelConfig` and binds them to the `generate_clip_label` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter_fn : SoundEventFilter
|
||||
Function to filter irrelevant sound event annotations.
|
||||
transform_fn : SoundEventTransformation
|
||||
Function to transform tags of sound event annotations.
|
||||
encoder_fn : SoundEventEncoder
|
||||
Function to encode a sound event annotation into a class name.
|
||||
class_names : List[str]
|
||||
Ordered list of unique target class names for the classification
|
||||
heatmap.
|
||||
targets : TargetProtocol
|
||||
An initialized object conforming to the `TargetProtocol`, providing all
|
||||
necessary methods for filtering, transforming, encoding, and ROI
|
||||
mapping.
|
||||
config : LabelConfig
|
||||
Configuration object containing heatmap generation parameters (sigma,
|
||||
etc.).
|
||||
Configuration object containing heatmap generation parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -162,10 +97,7 @@ def build_clip_labeler(
|
||||
"""
|
||||
return partial(
|
||||
generate_clip_label,
|
||||
filter_fn=filter_fn,
|
||||
transform_fn=transform_fn,
|
||||
encoder_fn=encoder_fn,
|
||||
class_names=class_names,
|
||||
targets=targets,
|
||||
config=config,
|
||||
)
|
||||
|
||||
@ -173,123 +105,97 @@ def build_clip_labeler(
|
||||
def generate_clip_label(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: xr.DataArray,
|
||||
filter_fn: SoundEventFilter,
|
||||
transform_fn: SoundEventTransformation,
|
||||
encoder_fn: SoundEventEncoder,
|
||||
class_names: List[str],
|
||||
targets: TargetProtocol,
|
||||
config: LabelConfig,
|
||||
) -> Heatmaps:
|
||||
"""Generate heatmaps for a single clip by applying all processing steps.
|
||||
"""Generate training heatmaps for a single annotated clip.
|
||||
|
||||
This function orchestrates the process for one clip:
|
||||
1. Filters the sound events using `filter_fn`.
|
||||
2. Transforms the tags of filtered events using `transform_fn`.
|
||||
3. Passes the processed annotations and other parameters to
|
||||
`generate_heatmaps` to create the final target heatmaps.
|
||||
This function orchestrates the target generation process for one clip:
|
||||
1. Filters and transforms sound events using `targets.filter` and
|
||||
`targets.transform`.
|
||||
2. Passes the resulting processed annotations, along with the spectrogram,
|
||||
the `targets` object, and the Gaussian `sigma` from `config`, to the
|
||||
core `generate_heatmaps` function.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The complete annotation data for the audio clip.
|
||||
The complete annotation data for the audio clip, including the list
|
||||
of `sound_events` to process.
|
||||
spec : xr.DataArray
|
||||
The spectrogram corresponding to the `clip_annotation`.
|
||||
filter_fn : SoundEventFilter
|
||||
Function to filter sound event annotations.
|
||||
transform_fn : SoundEventTransformation
|
||||
Function to transform tags of sound event annotations.
|
||||
encoder_fn : SoundEventEncoder
|
||||
Function to encode a sound event annotation into a class name.
|
||||
class_names : List[str]
|
||||
Ordered list of unique target class names.
|
||||
The spectrogram corresponding to the `clip_annotation`. Must have
|
||||
'time' and 'frequency' dimensions/coordinates.
|
||||
targets : TargetProtocol
|
||||
The fully configured target definition object, providing methods for
|
||||
filtering, transformation, encoding, and ROI mapping.
|
||||
config : LabelConfig
|
||||
Configuration object containing heatmap generation parameters.
|
||||
Configuration object providing heatmap parameters (primarily `sigma`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Heatmaps
|
||||
The generated detection, classes, and size heatmaps for the clip.
|
||||
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
||||
heatmaps for this clip.
|
||||
"""
|
||||
return generate_heatmaps(
|
||||
(
|
||||
transform_fn(sound_event_annotation)
|
||||
targets.transform(sound_event_annotation)
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if filter_fn(sound_event_annotation)
|
||||
if targets.filter(sound_event_annotation)
|
||||
),
|
||||
spec=spec,
|
||||
class_names=class_names,
|
||||
encoder=encoder_fn,
|
||||
targets=targets,
|
||||
target_sigma=config.sigma,
|
||||
position=config.position,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
class_names: List[str],
|
||||
encoder: SoundEventEncoder,
|
||||
targets: TargetProtocol,
|
||||
target_sigma: float = 3.0,
|
||||
position: Positions = "bottom-left",
|
||||
time_scale: float = 1000.0,
|
||||
frequency_scale: float = 1 / 859.375,
|
||||
dtype=np.float32,
|
||||
) -> Heatmaps:
|
||||
"""Generate detection, class, and size heatmaps from sound events.
|
||||
|
||||
Processes an iterable of sound event annotations (assumed to be already
|
||||
filtered and transformed) and creates heatmap representations suitable
|
||||
for training models like BatDetect2.
|
||||
Creates heatmap representations from an iterable of sound event
|
||||
annotations. This function relies on the provided `targets` object to get
|
||||
the reference position, scaled size, and class encoding for each
|
||||
annotation.
|
||||
|
||||
The process involves:
|
||||
1. Initializing empty heatmaps based on the spectrogram shape.
|
||||
2. Iterating through each sound event.
|
||||
3. For each event, finding its reference point and placing a '1.0'
|
||||
on the detection heatmap at that point.
|
||||
4. Calculating the scaled bounding box size and placing it on the size
|
||||
heatmap at the reference point.
|
||||
5. Encoding the event to get its class name and placing a '1.0' on the
|
||||
corresponding class heatmap slice at the reference point (if
|
||||
classified).
|
||||
6. Applying Gaussian smoothing to detection and class heatmaps.
|
||||
7. Normalizing detection and class heatmaps to the range [0, 1].
|
||||
Process:
|
||||
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
|
||||
metadata.
|
||||
2. Iterates through `sound_events`.
|
||||
3. For each event:
|
||||
a. Gets geometry. Skips if missing.
|
||||
b. Gets reference position using `targets.get_position()`. Skips if out
|
||||
of bounds.
|
||||
c. Places a peak (1.0) on the detection heatmap at the position.
|
||||
d. Gets scaled size using `targets.get_size()` and places it on the
|
||||
size heatmap.
|
||||
e. Encodes class using `targets.encode()` and places a peak (1.0) on
|
||||
the corresponding class heatmap layer if a specific class is
|
||||
returned.
|
||||
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
|
||||
heatmaps.
|
||||
5. Normalizes detection and class heatmaps to range [0, 1].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_events : Iterable[data.SoundEventAnnotation]
|
||||
An iterable of sound event annotations to include in the heatmaps.
|
||||
These should ideally be the result of prior filtering and tag
|
||||
transformation steps.
|
||||
An iterable of sound event annotations to render onto the heatmaps.
|
||||
spec : xr.DataArray
|
||||
The spectrogram array corresponding to the time/frequency range of
|
||||
the annotations. Used for shape and coordinate information. Must have
|
||||
'time' and 'frequency' dimensions.
|
||||
class_names : List[str]
|
||||
An ordered list of unique class names. The class heatmap will have
|
||||
a channel ('category' dimension) for each name in this list. Must not
|
||||
be empty.
|
||||
encoder : SoundEventEncoder
|
||||
A function that takes a SoundEventAnnotation and returns the
|
||||
corresponding class name (str) or None if it doesn't belong to a
|
||||
specific class (e.g., it falls into the generic 'Bat' category).
|
||||
'time' and 'frequency' dimensions/coordinates.
|
||||
targets : TargetProtocol
|
||||
The fully configured target definition object. Used to access class
|
||||
names, dimension names, and the methods `get_position`, `get_size`,
|
||||
`encode`.
|
||||
target_sigma : float, default=3.0
|
||||
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
||||
smooth the detection and class heatmaps after initial point placement.
|
||||
position : Positions, default="bottom-left"
|
||||
The reference point within each annotation's geometry bounding box
|
||||
used to place the signal on the heatmaps (e.g., "center",
|
||||
"bottom-left"). See `soundevent.geometry.operations.Positions`.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor applied to the time duration (width in seconds) of
|
||||
annotations when storing them in the size heatmap. The resulting
|
||||
value's unit depends on this scale (e.g., 1000.0 might convert seconds
|
||||
to ms).
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor applied to the frequency bandwidth (height in Hz or kHz)
|
||||
of annotations when storing them in the size heatmap. The resulting
|
||||
value's unit depends on this scale and the input unit. (Default
|
||||
scaling relative to ~860 Hz).
|
||||
smooth the detection and class heatmaps.
|
||||
dtype : type, default=np.float32
|
||||
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
||||
|
||||
@ -303,28 +209,10 @@ def generate_heatmaps(
|
||||
------
|
||||
ValueError
|
||||
If the input spectrogram `spec` does not have both 'time' and
|
||||
'frequency' dimensions, or if `class_names` is empty.
|
||||
|
||||
Notes
|
||||
-----
|
||||
* This function expects `sound_events` to be already filtered and
|
||||
transformed.
|
||||
* It includes error handling to skip individual annotations that cause
|
||||
issues (e.g., missing geometry, out-of-bounds coordinates, encoder
|
||||
errors) allowing the rest of the clip to be processed. Warnings or
|
||||
errors are logged.
|
||||
* The `time_scale` and `frequency_scale` parameters are crucial and must be
|
||||
set according to the expectations of the specific BatDetect2 model
|
||||
architecture being trained. Consult model documentation for required
|
||||
units/scales.
|
||||
* Gaussian filtering and normalization are applied only to detection and
|
||||
class heatmaps, not the size heatmap.
|
||||
'frequency' dimensions, or if `targets.class_names` is empty.
|
||||
"""
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
|
||||
if len(class_names) == 0:
|
||||
raise ValueError("No class names provided.")
|
||||
|
||||
if "time" not in shape or "frequency" not in shape:
|
||||
raise ValueError(
|
||||
"Spectrogram must have time and frequency dimensions."
|
||||
@ -333,18 +221,18 @@ def generate_heatmaps(
|
||||
# Initialize heatmaps
|
||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||
class_heatmap = xr.DataArray(
|
||||
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
||||
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
|
||||
dims=["category", *spec.dims],
|
||||
coords={
|
||||
"category": [*class_names],
|
||||
"category": [*targets.class_names],
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
size_heatmap = xr.DataArray(
|
||||
data=np.zeros((2, *spec.shape), dtype=dtype),
|
||||
dims=["dimension", *spec.dims],
|
||||
dims=[SIZE_DIMENSION, *spec.dims],
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
SIZE_DIMENSION: targets.dimension_names,
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
@ -359,7 +247,7 @@ def generate_heatmaps(
|
||||
continue
|
||||
|
||||
# Get the position of the sound event
|
||||
time, frequency = geometry.get_geometry_point(geom, position=position)
|
||||
time, frequency = targets.get_position(sound_event_annotation)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
try:
|
||||
@ -379,17 +267,7 @@ def generate_heatmaps(
|
||||
)
|
||||
continue
|
||||
|
||||
# Set the size of the sound event at the position in the size heatmap
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
geom
|
||||
)
|
||||
|
||||
size = np.array(
|
||||
[
|
||||
(end_time - start_time) * time_scale,
|
||||
(high_freq - low_freq) * frequency_scale,
|
||||
]
|
||||
)
|
||||
size = targets.get_size(sound_event_annotation)
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
@ -400,7 +278,7 @@ def generate_heatmaps(
|
||||
|
||||
# Get the class name of the sound event
|
||||
try:
|
||||
class_name = encoder(sound_event_annotation)
|
||||
class_name = targets.encode(sound_event_annotation)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Skipping annotation %s: Unexpected error while encoding "
|
||||
@ -414,19 +292,6 @@ def generate_heatmaps(
|
||||
# If the label is None skip the sound event
|
||||
continue
|
||||
|
||||
if class_name not in class_names:
|
||||
# If the label is not in the class names skip the sound event
|
||||
logger.warning(
|
||||
(
|
||||
"Skipping annotation %s for class heatmap: "
|
||||
"class name '%s' not in class names. Class names: %s"
|
||||
),
|
||||
sound_event_annotation.uuid,
|
||||
class_name,
|
||||
class_names,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
@ -453,7 +318,7 @@ def generate_heatmaps(
|
||||
)
|
||||
|
||||
class_heatmap = class_heatmap.groupby("category").map(
|
||||
gaussian_filter,
|
||||
gaussian_filter, # type: ignore
|
||||
args=(target_sigma,),
|
||||
)
|
||||
|
||||
|
@ -1,98 +1,101 @@
|
||||
"""Module for preprocessing data for training."""
|
||||
"""Preprocesses datasets for BatDetect2 model training.
|
||||
|
||||
This module provides functions to take a collection of annotated audio clips
|
||||
(`soundevent.data.ClipAnnotation`) and process them into the final format
|
||||
required for training a BatDetect2 model. This typically involves:
|
||||
|
||||
1. Loading the relevant audio segment for each annotation using a configured
|
||||
`PreprocessorProtocol`.
|
||||
2. Generating the corresponding input spectrogram using the
|
||||
`PreprocessorProtocol`.
|
||||
3. Generating the target heatmaps (detection, classification, size) using a
|
||||
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
|
||||
4. Packaging the input spectrogram, target heatmaps, and potentially the
|
||||
processed audio waveform into an `xarray.Dataset`.
|
||||
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
|
||||
in an output directory.
|
||||
|
||||
This offline preprocessing is often preferred for large datasets as it avoids
|
||||
computationally intensive steps during the actual training loop. The module
|
||||
includes utilities for parallel processing using `multiprocessing`.
|
||||
"""
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
from typing import Callable, Optional, Sequence
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
compute_spectrogram,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.targets import (
|
||||
LabelConfig,
|
||||
TargetConfig,
|
||||
build_sound_event_filter,
|
||||
build_target_encoder,
|
||||
generate_heatmaps,
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"preprocess_single_annotation",
|
||||
"generate_train_example",
|
||||
"TrainPreprocessingConfig",
|
||||
]
|
||||
|
||||
|
||||
class TrainPreprocessingConfig(BaseConfig):
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
"""Type alias for a function that generates an output filename."""
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a training example."""
|
||||
config = TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing_config or PreprocessingConfig(),
|
||||
target=target_config or TargetConfig(),
|
||||
labels=label_config or LabelConfig(),
|
||||
)
|
||||
"""Generate a complete training example for one annotation.
|
||||
|
||||
wave = load_clip_audio(
|
||||
clip_annotation.clip,
|
||||
config=config.preprocessing.audio,
|
||||
)
|
||||
This function takes a single `ClipAnnotation`, applies the configured
|
||||
preprocessing (`PreprocessorProtocol`) to get the processed waveform and
|
||||
input spectrogram, applies the configured target generation
|
||||
(`ClipLabeller`) to get the target heatmaps, and packages them all into a
|
||||
single `xr.Dataset`.
|
||||
|
||||
spectrogram = compute_spectrogram(
|
||||
wave,
|
||||
config=config.preprocessing.spectrogram,
|
||||
)
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The annotated clip to process. Contains the reference to the `Clip`
|
||||
(audio segment) and the associated `SoundEventAnnotation` objects.
|
||||
preprocessor : PreprocessorProtocol
|
||||
An initialized preprocessor object responsible for loading/processing
|
||||
audio and computing the input spectrogram.
|
||||
labeller : ClipLabeller
|
||||
An initialized clip labeller function responsible for generating the
|
||||
target heatmaps (detection, class, size) from the `clip_annotation`
|
||||
and the computed spectrogram.
|
||||
|
||||
filter_fn = build_sound_event_filter(
|
||||
include=config.target.include,
|
||||
exclude=config.target.exclude,
|
||||
)
|
||||
Returns
|
||||
-------
|
||||
xr.Dataset
|
||||
An xarray Dataset containing the following data variables:
|
||||
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
|
||||
- `spectrogram`: The computed input spectrogram
|
||||
(dims: 'time', 'frequency').
|
||||
- `detection`: The target detection heatmap
|
||||
(dims: 'time', 'frequency').
|
||||
- `class`: The target class heatmap
|
||||
(dims: 'category', 'time', 'frequency').
|
||||
- `size`: The target size heatmap
|
||||
(dims: 'dimension', 'time', 'frequency').
|
||||
The Dataset also includes metadata in its attributes.
|
||||
|
||||
selected_events = [
|
||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||
]
|
||||
Notes
|
||||
-----
|
||||
- The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
|
||||
within the output Dataset to avoid coordinate conflicts with the
|
||||
spectrogram's 'time' dimension when stored together.
|
||||
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
||||
Dataset's attributes for provenance.
|
||||
"""
|
||||
wave = preprocessor.load_clip_audio(clip_annotation.clip)
|
||||
|
||||
encoder = build_target_encoder(
|
||||
config.target.classes,
|
||||
replacement_rules=config.target.replace,
|
||||
)
|
||||
class_names = get_class_names(config.target.classes)
|
||||
spectrogram = preprocessor.compute_spectrogram(wave)
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
selected_events,
|
||||
spectrogram,
|
||||
class_names,
|
||||
encoder,
|
||||
target_sigma=config.labels.heatmaps.sigma,
|
||||
position=config.labels.heatmaps.position,
|
||||
time_scale=config.labels.heatmaps.time_scale,
|
||||
frequency_scale=config.labels.heatmaps.frequency_scale,
|
||||
)
|
||||
heatmaps = labeller(clip_annotation, spectrogram)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
{
|
||||
@ -102,15 +105,14 @@ def generate_train_example(
|
||||
# as the waveform.
|
||||
"audio": wave.rename({"time": "audio_time"}),
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
"size": size_heatmap,
|
||||
"detection": heatmaps.detection,
|
||||
"class": heatmaps.classes,
|
||||
"size": heatmaps.size,
|
||||
}
|
||||
)
|
||||
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
config=config.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(
|
||||
exclude_none=True,
|
||||
exclude_defaults=True,
|
||||
@ -119,10 +121,21 @@ def generate_train_example(
|
||||
)
|
||||
|
||||
|
||||
def save_to_file(
|
||||
def _save_xr_dataset_to_file(
|
||||
dataset: xr.Dataset,
|
||||
path: PathLike,
|
||||
path: data.PathLike,
|
||||
) -> None:
|
||||
"""Save an xarray Dataset to a NetCDF file with compression.
|
||||
|
||||
Internal helper function used by `preprocess_single_annotation`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : xr.Dataset
|
||||
The training example dataset to save.
|
||||
path : PathLike
|
||||
The output file path (e.g., 'output/uuid.nc').
|
||||
"""
|
||||
dataset.to_netcdf(
|
||||
path,
|
||||
encoding={
|
||||
@ -135,20 +148,60 @@ def save_to_file(
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
"""Generate a default output filename based on the annotation UUID."""
|
||||
return f"{clip_annotation.uuid}.nc"
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
output_dir: data.PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
"""Preprocess a sequence of ClipAnnotations and save results to disk.
|
||||
|
||||
Generates the full training example (spectrogram, heatmaps, etc.) for each
|
||||
`ClipAnnotation` in the input sequence using the provided `preprocessor`
|
||||
and `labeller`. Saves each example as a separate NetCDF file in the
|
||||
`output_dir`. Utilizes multiprocessing for potentially faster processing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotations : Sequence[data.ClipAnnotation]
|
||||
A sequence (e.g., list) of the clip annotations to preprocess.
|
||||
output_dir : PathLike
|
||||
Path to the directory where the processed NetCDF files will be saved.
|
||||
Will be created if it doesn't exist.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Initialized preprocessor object to generate spectrograms.
|
||||
labeller : ClipLabeller
|
||||
Initialized labeller function to generate target heatmaps.
|
||||
filename_fn : FilenameFn, optional
|
||||
Function to generate the output filename (without extension) for each
|
||||
`ClipAnnotation`. Defaults to using the annotation UUID via
|
||||
`_get_filename`.
|
||||
replace : bool, default=False
|
||||
If True, existing files in `output_dir` with the same generated name
|
||||
will be overwritten. If False (default), existing files are skipped.
|
||||
max_workers : int, optional
|
||||
Maximum number of worker processes to use for parallel processing.
|
||||
If None (default), uses the number of CPUs available (`os.cpu_count()`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
None
|
||||
This function does not return anything; its side effect is creating
|
||||
files in the `output_dir`.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If processing fails for any individual annotation when using
|
||||
multiprocessing. The original exception will be attached as the cause.
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
@ -163,9 +216,8 @@ def preprocess_annotations(
|
||||
output_dir=output_dir,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
@ -176,13 +228,34 @@ def preprocess_annotations(
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
output_dir: data.PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
) -> None:
|
||||
"""Process a single ClipAnnotation and save the result to a file.
|
||||
|
||||
Internal function designed to be called by `preprocess_annotations`, often
|
||||
in parallel worker processes. It generates the training example using
|
||||
`generate_train_example` and saves it using `save_to_file`. Handles
|
||||
file existence checks based on the `replace` flag.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The single annotation to process.
|
||||
output_dir : Path
|
||||
The directory where the output NetCDF file should be saved.
|
||||
preprocessor : PreprocessorProtocol
|
||||
Initialized preprocessor object.
|
||||
labeller : ClipLabeller
|
||||
Initialized labeller function.
|
||||
filename_fn : FilenameFn, default=_get_filename
|
||||
Function to determine the output filename.
|
||||
replace : bool, default=False
|
||||
Whether to overwrite existing output files.
|
||||
"""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filename = filename_fn(clip_annotation)
|
||||
@ -197,13 +270,12 @@ def preprocess_single_annotation(
|
||||
try:
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||
) from error
|
||||
|
||||
save_to_file(sample, path)
|
||||
_save_xr_dataset_to_file(sample, path)
|
||||
|
48
batdetect2/train/types.py
Normal file
48
batdetect2/train/types.py
Normal file
@ -0,0 +1,48 @@
|
||||
from typing import Callable, NamedTuple
|
||||
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"Heatmaps",
|
||||
"ClipLabeller",
|
||||
"Augmentation",
|
||||
]
|
||||
|
||||
|
||||
class Heatmaps(NamedTuple):
|
||||
"""Structure holding the generated heatmap targets.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : xr.DataArray
|
||||
Heatmap indicating the probability of sound event presence. Typically
|
||||
smoothed with a Gaussian kernel centered on event reference points.
|
||||
Shape matches the input spectrogram. Values normalized [0, 1].
|
||||
classes : xr.DataArray
|
||||
Heatmap indicating the probability of specific class presence. Has an
|
||||
additional 'category' dimension corresponding to the target class
|
||||
names. Each category slice is typically smoothed with a Gaussian
|
||||
kernel. Values normalized [0, 1] per category.
|
||||
size : xr.DataArray
|
||||
Heatmap encoding the size (width, height) of detected events. Has an
|
||||
additional 'dimension' coordinate ('width', 'height'). Values represent
|
||||
scaled dimensions placed at the event reference points.
|
||||
"""
|
||||
|
||||
detection: xr.DataArray
|
||||
classes: xr.DataArray
|
||||
size: xr.DataArray
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
||||
"""Type alias for the final clip labelling function.
|
||||
|
||||
This function takes the complete annotations for a clip and the corresponding
|
||||
spectrogram, applies all configured filtering, transformation, and encoding
|
||||
steps, and returns the final `Heatmaps` used for model training.
|
||||
"""
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
|
@ -1,14 +1,10 @@
|
||||
"""Types used in the code base."""
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional
|
||||
|
||||
from typing import Any, List, NamedTuple, Optional, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
|
||||
try:
|
||||
from typing import Protocol
|
||||
except ImportError:
|
||||
|
@ -1,34 +0,0 @@
|
||||
import torch
|
||||
from hypothesis import given
|
||||
from hypothesis import strategies as st
|
||||
|
||||
from batdetect2.models import ModelConfig, ModelType, build_architecture
|
||||
|
||||
|
||||
@given(
|
||||
input_width=st.integers(min_value=50, max_value=1500),
|
||||
input_height=st.integers(min_value=1, max_value=16),
|
||||
model_type=st.sampled_from(ModelType),
|
||||
)
|
||||
def test_model_can_process_spectrograms_of_any_width(
|
||||
input_width,
|
||||
input_height,
|
||||
model_type,
|
||||
):
|
||||
# Input height must be divisible by 8
|
||||
input_height = 8 * input_height
|
||||
|
||||
input = torch.rand([1, 1, input_height, input_width])
|
||||
|
||||
model = build_architecture(
|
||||
config=ModelConfig(
|
||||
name=model_type, # type: ignore
|
||||
input_height=input_height,
|
||||
),
|
||||
)
|
||||
|
||||
output = model(input)
|
||||
assert output.shape[0] == 1
|
||||
assert output.shape[1] == model.out_channels
|
||||
assert output.shape[2] == input_height
|
||||
assert output.shape[3] == input_width
|
@ -4,7 +4,7 @@ import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.targets import generate_heatmaps
|
||||
from batdetect2.train.labels import generate_heatmaps
|
||||
|
||||
recording = data.Recording(
|
||||
samplerate=256_000,
|
||||
|
Loading…
Reference in New Issue
Block a user