mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Fixing imports after restructuring
This commit is contained in:
parent
dcae411ccb
commit
7c89e82579
@ -10,13 +10,13 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models import (
|
from batdetect2.models import (
|
||||||
|
BackboneConfig,
|
||||||
BBoxHead,
|
BBoxHead,
|
||||||
ClassifierHead,
|
ClassifierHead,
|
||||||
ModelConfig,
|
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
build_architecture,
|
build_backbone,
|
||||||
)
|
)
|
||||||
from batdetect2.post_process import (
|
from batdetect2.postprocess import (
|
||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
postprocess_model_outputs,
|
postprocess_model_outputs,
|
||||||
)
|
)
|
||||||
@ -37,7 +37,7 @@ __all__ = [
|
|||||||
class ModuleConfig(BaseConfig):
|
class ModuleConfig(BaseConfig):
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
architecture: ModelConfig = Field(default_factory=ModelConfig)
|
architecture: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
preprocessing: PreprocessingConfig = Field(
|
preprocessing: PreprocessingConfig = Field(
|
||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
@ -58,7 +58,7 @@ class DetectorModel(L.LightningModule):
|
|||||||
self.config = config or ModuleConfig()
|
self.config = config or ModuleConfig()
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
self.backbone = build_architecture(self.config.architecture)
|
self.backbone = build_model_backbone(self.config.architecture)
|
||||||
|
|
||||||
self.classifier = ClassifierHead(
|
self.classifier = ClassifierHead(
|
||||||
num_classes=len(self.config.targets.classes),
|
num_classes=len(self.config.targets.classes),
|
||||||
|
@ -1,4 +1,30 @@
|
|||||||
from typing import Callable, Optional, Union
|
"""Applies data augmentation techniques to BatDetect2 training examples.
|
||||||
|
|
||||||
|
This module provides various functions and configurable components for applying
|
||||||
|
data augmentation to the training examples (`xr.Dataset` containing audio,
|
||||||
|
spectrogram, and target heatmaps) generated by the
|
||||||
|
`batdetect2.train.preprocess` module.
|
||||||
|
|
||||||
|
Data augmentation artificially increases the diversity of the training data by
|
||||||
|
applying random transformations, which generally helps improve the robustness
|
||||||
|
and generalization performance of trained models.
|
||||||
|
|
||||||
|
Augmentations included:
|
||||||
|
- Time-based: `select_subclip`, `warp_spectrogram`, `mask_time`.
|
||||||
|
- Amplitude/Noise-based: `mix_examples`, `add_echo`, `scale_volume`,
|
||||||
|
`mask_frequency`.
|
||||||
|
|
||||||
|
Some augmentations modify the audio waveform and require recomputing the
|
||||||
|
spectrogram using the `PreprocessorProtocol`, while others operate directly
|
||||||
|
on the spectrogram or target heatmaps. The entire augmentation pipeline can be
|
||||||
|
configured using the `AugmentationsConfig` class, specifying a sequence of
|
||||||
|
augmentation steps, each with its own parameters and application probability.
|
||||||
|
The `build_augmentations` function constructs the final augmentation callable
|
||||||
|
from this configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from functools import partial
|
||||||
|
from typing import Annotated, Callable, List, Literal, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
@ -6,15 +32,14 @@ from pydantic import Field
|
|||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
from batdetect2.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.preprocess.arrays import adjust_width
|
from batdetect2.train.types import Augmentation
|
||||||
|
from batdetect2.utils.arrays import adjust_width
|
||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
"load_agumentation_config",
|
"load_augmentation_config",
|
||||||
|
"build_augmentations",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
"mix_examples",
|
"mix_examples",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
@ -22,13 +47,21 @@ __all__ = [
|
|||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"augment_example",
|
"MixAugmentationConfig",
|
||||||
|
"EchoAugmentationConfig",
|
||||||
|
"VolumeAugmentationConfig",
|
||||||
|
"WarpAugmentationConfig",
|
||||||
|
"TimeMaskAugmentationConfig",
|
||||||
|
"FrequencyMaskAugmentationConfig",
|
||||||
|
"AugmentationConfig",
|
||||||
|
"ExampleSource",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
ExampleSource = Callable[[], xr.Dataset]
|
||||||
|
"""Type alias for a function that returns a training example (`xr.Dataset`).
|
||||||
|
|
||||||
class BaseAugmentationConfig(BaseConfig):
|
Used by the `mix_examples` augmentation to fetch another example to mix with.
|
||||||
enable: bool = True
|
"""
|
||||||
probability: float = 0.2
|
|
||||||
|
|
||||||
|
|
||||||
def select_subclip(
|
def select_subclip(
|
||||||
@ -38,7 +71,47 @@ def select_subclip(
|
|||||||
width: Optional[int] = None,
|
width: Optional[int] = None,
|
||||||
random: bool = False,
|
random: bool = False,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Select a random subclip from a clip."""
|
"""Extract a sub-clip (time segment) from a training example dataset.
|
||||||
|
|
||||||
|
Selects a portion of the 'time' dimension from all relevant DataArrays
|
||||||
|
(`audio`, `spectrogram`, `detection`, `class`, `size`) within the example
|
||||||
|
Dataset. The segment can be defined by a fixed start time and
|
||||||
|
duration/width, or a random start time can be chosen.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
The input training example containing 'audio', 'spectrogram', and
|
||||||
|
target heatmaps, all with compatible 'time' (or 'audio_time')
|
||||||
|
coordinates.
|
||||||
|
start_time : float, optional
|
||||||
|
Desired start time (seconds) of the subclip. If None and `random` is
|
||||||
|
False, starts from the beginning of the example. If None and `random`
|
||||||
|
is True, a random start time is chosen.
|
||||||
|
duration : float, optional
|
||||||
|
Desired duration (seconds) of the subclip. Either `duration` or `width`
|
||||||
|
must be provided.
|
||||||
|
width : int, optional
|
||||||
|
Desired width (number of time bins) of the subclip's
|
||||||
|
spectrogram/heatmaps. Either `duration` or `width` must be provided. If
|
||||||
|
both are given, `duration` takes precedence.
|
||||||
|
random : bool, default=False
|
||||||
|
If True and `start_time` is None, selects a random start time ensuring
|
||||||
|
the subclip fits within the original example's duration.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
A new dataset containing only the selected time segment. Coordinates
|
||||||
|
are adjusted accordingly. Returns the original example if the requested
|
||||||
|
subclip cannot be extracted (e.g., duration too long).
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If neither `duration` nor `width` is provided, or if time coordinates
|
||||||
|
are missing or invalid.
|
||||||
|
"""
|
||||||
step = arrays.get_dim_step(example, "time") # type: ignore
|
step = arrays.get_dim_step(example, "time") # type: ignore
|
||||||
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
start, stop = arrays.get_dim_range(example, "time") # type: ignore
|
||||||
|
|
||||||
@ -77,22 +150,62 @@ def select_subclip(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MixAugmentationConfig(BaseAugmentationConfig):
|
class MixAugmentationConfig(BaseConfig):
|
||||||
|
"""Configuration for MixUp augmentation (mixing two examples)."""
|
||||||
|
|
||||||
|
augmentation_type: Literal["mix_audio"] = "mix_audio"
|
||||||
|
|
||||||
|
probability: float = 0.2
|
||||||
|
"""Probability of applying this augmentation to an example."""
|
||||||
|
|
||||||
min_weight: float = 0.3
|
min_weight: float = 0.3
|
||||||
|
"""Minimum mixing weight (lambda) applied to the primary example."""
|
||||||
|
|
||||||
max_weight: float = 0.7
|
max_weight: float = 0.7
|
||||||
|
"""Maximum mixing weight (lambda) applied to the primary example."""
|
||||||
|
|
||||||
|
|
||||||
def mix_examples(
|
def mix_examples(
|
||||||
example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
other: xr.Dataset,
|
other: xr.Dataset,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
weight: Optional[float] = None,
|
weight: Optional[float] = None,
|
||||||
min_weight: float = 0.3,
|
min_weight: float = 0.3,
|
||||||
max_weight: float = 0.7,
|
max_weight: float = 0.7,
|
||||||
config: Optional[PreprocessingConfig] = None,
|
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Combine two audio clips."""
|
"""Combine two training examples using MixUp augmentation.
|
||||||
config = config or PreprocessingConfig()
|
|
||||||
|
|
||||||
|
Performs a weighted linear combination of the audio waveforms from two
|
||||||
|
examples (`example` and `other`). The spectrogram is then *recomputed*
|
||||||
|
from the mixed audio using the provided `preprocessor`. Target heatmaps
|
||||||
|
(detection, class) are combined by taking the element-wise maximum. Target
|
||||||
|
size heatmaps are combined by element-wise addition.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
The primary training example.
|
||||||
|
other : xr.Dataset
|
||||||
|
The second training example to mix with `example`.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
The preprocessor used to recompute the spectrogram from mixed audio.
|
||||||
|
Ensures consistency with original preprocessing.
|
||||||
|
weight : float, optional
|
||||||
|
The mixing weight (lambda) applied to the primary `example`. The weight
|
||||||
|
for `other` will be `(1 - weight)`. If None, a random weight is chosen
|
||||||
|
uniformly between `min_weight` and `max_weight`.
|
||||||
|
min_weight : float, default=0.3
|
||||||
|
Minimum value for the random weight lambda.
|
||||||
|
max_weight : float, default=0.7
|
||||||
|
Maximum value for the random weight lambda. Must be >= `min_weight`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
A new dataset representing the mixed example, with combined audio,
|
||||||
|
recomputed spectrogram, and combined target heatmaps. Attributes from
|
||||||
|
the primary `example` are preserved.
|
||||||
|
"""
|
||||||
if weight is None:
|
if weight is None:
|
||||||
weight = np.random.uniform(min_weight, max_weight)
|
weight = np.random.uniform(min_weight, max_weight)
|
||||||
|
|
||||||
@ -101,9 +214,8 @@ def mix_examples(
|
|||||||
|
|
||||||
combined = weight * audio1 + (1 - weight) * audio2
|
combined = weight * audio1 + (1 - weight) * audio2
|
||||||
|
|
||||||
spectrogram = compute_spectrogram(
|
spectrogram = preprocessor.compute_spectrogram(
|
||||||
combined.rename({"audio_time": "time"}),
|
combined.rename({"audio_time": "time"})
|
||||||
config=config.spectrogram,
|
|
||||||
).data
|
).data
|
||||||
|
|
||||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
@ -147,7 +259,14 @@ def mix_examples(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class EchoAugmentationConfig(BaseAugmentationConfig):
|
class EchoAugmentationConfig(BaseConfig):
|
||||||
|
"""Configuration for adding synthetic echo/reverb."""
|
||||||
|
|
||||||
|
augmentation_type: Literal["add_echo"] = "add_echo"
|
||||||
|
|
||||||
|
probability: float = 0.2
|
||||||
|
"""Probability of applying this augmentation."""
|
||||||
|
|
||||||
max_delay: float = 0.005
|
max_delay: float = 0.005
|
||||||
min_weight: float = 0.0
|
min_weight: float = 0.0
|
||||||
max_weight: float = 1.0
|
max_weight: float = 1.0
|
||||||
@ -155,15 +274,45 @@ class EchoAugmentationConfig(BaseAugmentationConfig):
|
|||||||
|
|
||||||
def add_echo(
|
def add_echo(
|
||||||
example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
delay: Optional[float] = None,
|
delay: Optional[float] = None,
|
||||||
weight: Optional[float] = None,
|
weight: Optional[float] = None,
|
||||||
min_weight: float = 0.1,
|
min_weight: float = 0.1,
|
||||||
max_weight: float = 1.0,
|
max_weight: float = 1.0,
|
||||||
max_delay: float = 0.005,
|
max_delay: float = 0.005,
|
||||||
config: Optional[PreprocessingConfig] = None,
|
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Add a delay to the audio."""
|
"""Add a synthetic echo to the audio waveform.
|
||||||
config = config or PreprocessingConfig()
|
|
||||||
|
Creates an echo by adding a delayed and attenuated version of the original
|
||||||
|
audio waveform back onto itself. The spectrogram is then recomputed from
|
||||||
|
the modified audio. Target heatmaps remain unchanged.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
The input training example containing 'audio', 'spectrogram', etc.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
Preprocessor used to recompute the spectrogram from the modified audio.
|
||||||
|
delay : float, optional
|
||||||
|
The delay time in seconds for the echo. If None, chosen randomly
|
||||||
|
between 0 and `max_delay`.
|
||||||
|
weight : float, optional
|
||||||
|
The relative amplitude (weight) of the echo compared to the original.
|
||||||
|
If None, chosen randomly between `min_weight` and `max_weight`.
|
||||||
|
min_weight : float, default=0.1
|
||||||
|
Minimum value for the random echo weight.
|
||||||
|
max_weight : float, default=1.0
|
||||||
|
Maximum value for the random echo weight. Must be >= `min_weight`.
|
||||||
|
max_delay : float, default=0.005
|
||||||
|
Maximum value for the random echo delay in seconds. Must be >= 0.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
A new dataset with the echo added to the 'audio' variable and the
|
||||||
|
'spectrogram' variable recomputed. Other variables (targets, attrs)
|
||||||
|
are copied from the original example.
|
||||||
|
"""
|
||||||
|
|
||||||
if delay is None:
|
if delay is None:
|
||||||
delay = np.random.uniform(0, max_delay)
|
delay = np.random.uniform(0, max_delay)
|
||||||
@ -176,9 +325,8 @@ def add_echo(
|
|||||||
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
audio_delay = audio.shift(audio_time=int(delay / step), fill_value=0)
|
||||||
audio = audio + weight * audio_delay
|
audio = audio + weight * audio_delay
|
||||||
|
|
||||||
spectrogram = compute_spectrogram(
|
spectrogram = preprocessor.compute_spectrogram(
|
||||||
audio.rename({"audio_time": "time"}),
|
audio.rename({"audio_time": "time"}),
|
||||||
config=config.spectrogram,
|
|
||||||
).data
|
).data
|
||||||
|
|
||||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
# NOTE: The subclip's spectrogram might be slightly longer than the
|
||||||
@ -202,7 +350,11 @@ def add_echo(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VolumeAugmentationConfig(BaseAugmentationConfig):
|
class VolumeAugmentationConfig(BaseConfig):
|
||||||
|
"""Configuration for random volume scaling of the spectrogram."""
|
||||||
|
|
||||||
|
augmentation_type: Literal["scale_volume"] = "scale_volume"
|
||||||
|
probability: float = 0.2
|
||||||
min_scaling: float = 0.0
|
min_scaling: float = 0.0
|
||||||
max_scaling: float = 2.0
|
max_scaling: float = 2.0
|
||||||
|
|
||||||
@ -213,14 +365,44 @@ def scale_volume(
|
|||||||
max_scaling: float = 2,
|
max_scaling: float = 2,
|
||||||
min_scaling: float = 0,
|
min_scaling: float = 0,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Scale the volume of a spectrogram."""
|
"""Scale the amplitude of the spectrogram by a random factor.
|
||||||
|
|
||||||
|
Multiplies the entire spectrogram DataArray by a scaling factor chosen
|
||||||
|
uniformly between `min_scaling` and `max_scaling`. This simulates changes
|
||||||
|
in recording volume or distance to the sound source directly in the
|
||||||
|
spectrogram domain. Audio and target heatmaps are unchanged.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
The input training example containing 'spectrogram'.
|
||||||
|
factor : float, optional
|
||||||
|
The scaling factor to apply. If None, chosen randomly between
|
||||||
|
`min_scaling` and `max_scaling`.
|
||||||
|
min_scaling : float, default=0.0
|
||||||
|
Minimum value for the random scaling factor. Must be non-negative.
|
||||||
|
max_scaling : float, default=2.0
|
||||||
|
Maximum value for the random scaling factor. Must be >= `min_scaling`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
A new dataset with the 'spectrogram' variable scaled.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `min_scaling` > `max_scaling` or if `min_scaling` is negative.
|
||||||
|
"""
|
||||||
if factor is None:
|
if factor is None:
|
||||||
factor = np.random.uniform(min_scaling, max_scaling)
|
factor = np.random.uniform(min_scaling, max_scaling)
|
||||||
|
|
||||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||||
|
|
||||||
|
|
||||||
class WarpAugmentationConfig(BaseAugmentationConfig):
|
class WarpAugmentationConfig(BaseConfig):
|
||||||
|
augmentation_type: Literal["warp"] = "warp"
|
||||||
|
probability: float = 0.2
|
||||||
delta: float = 0.04
|
delta: float = 0.04
|
||||||
|
|
||||||
|
|
||||||
@ -229,11 +411,39 @@ def warp_spectrogram(
|
|||||||
factor: Optional[float] = None,
|
factor: Optional[float] = None,
|
||||||
delta: float = 0.04,
|
delta: float = 0.04,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Warp a spectrogram."""
|
"""Apply time warping by resampling the time axis.
|
||||||
|
|
||||||
|
Stretches or compresses the time axis of the spectrogram and all target
|
||||||
|
heatmaps by a random `factor` (chosen between `1-delta` and `1+delta`).
|
||||||
|
Uses linear interpolation for the spectrogram and nearest neighbor for the
|
||||||
|
discrete-like target heatmaps. Updates the 'time' coordinate accordingly.
|
||||||
|
The audio waveform is not modified.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
Input training example with 'spectrogram', 'detection', 'class',
|
||||||
|
'size', and 'time' coordinates.
|
||||||
|
factor : float, optional
|
||||||
|
The warping factor. If None, chosen randomly between `1-delta` and
|
||||||
|
`1+delta`. Values > 1 stretch time, < 1 compress time.
|
||||||
|
delta : float, default=0.04
|
||||||
|
Controls the range `[1-delta, 1+delta]` for the random warp factor.
|
||||||
|
Must be >= 0 and < 1.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
A new dataset with time-warped spectrogram and target heatmaps, and
|
||||||
|
updated 'time' coordinates.
|
||||||
|
"""
|
||||||
if factor is None:
|
if factor is None:
|
||||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||||
|
|
||||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
start_time, end_time = arrays.get_dim_range(
|
||||||
|
example, # type: ignore
|
||||||
|
"time",
|
||||||
|
)
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
|
|
||||||
new_time = np.linspace(
|
new_time = np.linspace(
|
||||||
@ -296,6 +506,39 @@ def mask_axis(
|
|||||||
end: float,
|
end: float,
|
||||||
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
|
mask_value: Union[float, Callable[[xr.DataArray], float]] = np.mean,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
"""Mask values along a specified dimension.
|
||||||
|
|
||||||
|
Sets values in the DataArray to `mask_value` where the coordinate along
|
||||||
|
`dim` falls within the range [`start`, `end`). Values outside this range
|
||||||
|
are kept. Used as a helper for time/frequency masking.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
array : xr.DataArray
|
||||||
|
The input DataArray (e.g., spectrogram).
|
||||||
|
dim : str
|
||||||
|
The name of the dimension along which to mask
|
||||||
|
(e.g., "time", "frequency").
|
||||||
|
start : float
|
||||||
|
The starting coordinate value for the mask range.
|
||||||
|
end : float
|
||||||
|
The ending coordinate value for the mask range (exclusive, typically).
|
||||||
|
Values >= start and < end will be masked.
|
||||||
|
mask_value : float or Callable[[xr.DataArray], float], default=np.mean
|
||||||
|
The value to use for masking. Can be a fixed float (e.g., 0.0) or a
|
||||||
|
callable (like `np.mean`, `np.median`) that computes the value from
|
||||||
|
the input `array`.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
The DataArray with the specified range along `dim` masked.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `dim` is not found in the array's dimensions or coordinates.
|
||||||
|
"""
|
||||||
if dim not in array.dims:
|
if dim not in array.dims:
|
||||||
raise ValueError(f"Axis {dim} not found in array")
|
raise ValueError(f"Axis {dim} not found in array")
|
||||||
|
|
||||||
@ -308,7 +551,9 @@ def mask_axis(
|
|||||||
return array.where(condition, other=mask_value)
|
return array.where(condition, other=mask_value)
|
||||||
|
|
||||||
|
|
||||||
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
|
class TimeMaskAugmentationConfig(BaseConfig):
|
||||||
|
augmentation_type: Literal["mask_time"] = "mask_time"
|
||||||
|
probability: float = 0.2
|
||||||
max_perc: float = 0.05
|
max_perc: float = 0.05
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
|
|
||||||
@ -318,9 +563,32 @@ def mask_time(
|
|||||||
max_perc: float = 0.05,
|
max_perc: float = 0.05,
|
||||||
max_mask: int = 3,
|
max_mask: int = 3,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Mask a random section of the time axis."""
|
"""Apply random time masking (SpecAugment) to the spectrogram.
|
||||||
|
|
||||||
|
Randomly selects a number of time intervals (up to `max_masks`) and masks
|
||||||
|
(sets to the mean value) the spectrogram within those intervals. The width
|
||||||
|
of each mask is chosen randomly up to `max_perc` of the total duration.
|
||||||
|
Only the 'spectrogram' variable is modified.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
Input training example containing 'spectrogram' and 'time' coordinate.
|
||||||
|
max_perc : float, default=0.05
|
||||||
|
Maximum width of a single mask as a fraction of total duration.
|
||||||
|
max_masks : int, default=3
|
||||||
|
Maximum number of time masks to apply.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
Dataset with time masking applied to the 'spectrogram'.
|
||||||
|
"""
|
||||||
num_masks = np.random.randint(1, max_mask + 1)
|
num_masks = np.random.randint(1, max_mask + 1)
|
||||||
start_time, end_time = arrays.get_dim_range(example, "time") # type: ignore
|
start_time, end_time = arrays.get_dim_range(
|
||||||
|
example, # type: ignore
|
||||||
|
"time",
|
||||||
|
)
|
||||||
|
|
||||||
spectrogram = example["spectrogram"]
|
spectrogram = example["spectrogram"]
|
||||||
for _ in range(num_masks):
|
for _ in range(num_masks):
|
||||||
@ -332,7 +600,9 @@ def mask_time(
|
|||||||
return example.assign(spectrogram=spectrogram)
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
|
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||||
|
augmentation_type: Literal["mask_freq"] = "mask_freq"
|
||||||
|
probability: float = 0.2
|
||||||
max_perc: float = 0.10
|
max_perc: float = 0.10
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
|
|
||||||
@ -342,9 +612,38 @@ def mask_frequency(
|
|||||||
max_perc: float = 0.10,
|
max_perc: float = 0.10,
|
||||||
max_masks: int = 3,
|
max_masks: int = 3,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Mask a random section of the frequency axis."""
|
"""Apply random frequency masking (SpecAugment) to the spectrogram.
|
||||||
|
|
||||||
|
Randomly selects a number of frequency intervals (up to `max_masks`) and
|
||||||
|
masks (sets to the mean value) the spectrogram within those intervals. The
|
||||||
|
height of each mask is chosen randomly up to `max_perc` of the total
|
||||||
|
frequency range. Only the 'spectrogram' variable is modified.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
Input training example containing 'spectrogram' and 'frequency'
|
||||||
|
coordinate.
|
||||||
|
max_perc : float, default=0.10
|
||||||
|
Maximum height of a single mask as a fraction of total frequency range.
|
||||||
|
max_masks : int, default=3
|
||||||
|
Maximum number of frequency masks to apply.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
Dataset with frequency masking applied to the 'spectrogram'.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If frequency coordinate info is missing or invalid.
|
||||||
|
"""
|
||||||
num_masks = np.random.randint(1, max_masks + 1)
|
num_masks = np.random.randint(1, max_masks + 1)
|
||||||
min_freq, max_freq = arrays.get_dim_range(example, "frequency") # type: ignore
|
min_freq, max_freq = arrays.get_dim_range(
|
||||||
|
example, # type: ignore
|
||||||
|
"frequency",
|
||||||
|
)
|
||||||
|
|
||||||
spectrogram = example["spectrogram"]
|
spectrogram = example["spectrogram"]
|
||||||
for _ in range(num_masks):
|
for _ in range(num_masks):
|
||||||
@ -356,88 +655,326 @@ def mask_frequency(
|
|||||||
return example.assign(spectrogram=spectrogram)
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
|
AugmentationConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
MixAugmentationConfig,
|
||||||
|
EchoAugmentationConfig,
|
||||||
|
VolumeAugmentationConfig,
|
||||||
|
WarpAugmentationConfig,
|
||||||
|
FrequencyMaskAugmentationConfig,
|
||||||
|
TimeMaskAugmentationConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="augmentation_type"),
|
||||||
|
]
|
||||||
|
"""Type alias for the discriminated union of individual augmentation config."""
|
||||||
|
|
||||||
|
|
||||||
class AugmentationsConfig(BaseConfig):
|
class AugmentationsConfig(BaseConfig):
|
||||||
mix: MixAugmentationConfig = Field(default_factory=MixAugmentationConfig)
|
"""Configuration for a sequence of data augmentations.
|
||||||
echo: EchoAugmentationConfig = Field(
|
|
||||||
default_factory=EchoAugmentationConfig
|
Attributes
|
||||||
)
|
----------
|
||||||
volume: VolumeAugmentationConfig = Field(
|
steps : List[AugmentationConfig]
|
||||||
default_factory=VolumeAugmentationConfig
|
An ordered list of configuration objects, each defining a single
|
||||||
)
|
augmentation step (e.g., MixAugmentationConfig,
|
||||||
warp: WarpAugmentationConfig = Field(
|
TimeMaskAugmentationConfig). Each step's configuration must include an
|
||||||
default_factory=WarpAugmentationConfig
|
`augmentation_type` field and a `probability` field, along with
|
||||||
)
|
type-specific parameters. The augmentations will be applied
|
||||||
time_mask: TimeMaskAugmentationConfig = Field(
|
(probabilistically) in the sequence defined by this list.
|
||||||
default_factory=TimeMaskAugmentationConfig
|
"""
|
||||||
)
|
|
||||||
frequency_mask: FrequencyMaskAugmentationConfig = Field(
|
steps: List[AugmentationConfig] = Field(default_factory=list)
|
||||||
default_factory=FrequencyMaskAugmentationConfig
|
|
||||||
|
|
||||||
|
class MaybeApply:
|
||||||
|
"""Applies an augmentation function probabilistically."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
augmentation: Augmentation,
|
||||||
|
probability: float = 0.2,
|
||||||
|
):
|
||||||
|
"""Initialize the wrapper.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset])
|
||||||
|
The augmentation function to potentially apply.
|
||||||
|
probability : float, default=0.5
|
||||||
|
The probability (0.0 to 1.0) of applying the augmentation.
|
||||||
|
"""
|
||||||
|
self.augmentation = augmentation
|
||||||
|
self.probability = probability
|
||||||
|
|
||||||
|
def __call__(self, example: xr.Dataset) -> xr.Dataset:
|
||||||
|
"""Apply the wrapped augmentation with configured probability.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
The input training example.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
The potentially augmented training example.
|
||||||
|
"""
|
||||||
|
if np.random.random() > self.probability:
|
||||||
|
return example
|
||||||
|
|
||||||
|
return self.augmentation(example)
|
||||||
|
|
||||||
|
|
||||||
|
class AudioMixer:
|
||||||
|
"""Callable class for MixUp augmentation, handling example fetching.
|
||||||
|
|
||||||
|
Wraps the `mix_examples` logic, providing the necessary `example_source`
|
||||||
|
to fetch a second example dynamically when called.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
min_weight : float
|
||||||
|
Minimum mixing weight (lambda) for the primary example.
|
||||||
|
max_weight : float
|
||||||
|
Maximum mixing weight (lambda) for the primary example.
|
||||||
|
example_source : ExampleSource (Callable[[], xr.Dataset])
|
||||||
|
A function that, when called, returns another random training example
|
||||||
|
dataset (`xr.Dataset`) to be mixed with the input example.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
The preprocessor needed to recompute the spectrogram after mixing
|
||||||
|
audio.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_weight: float,
|
||||||
|
max_weight: float,
|
||||||
|
example_source: ExampleSource,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
):
|
||||||
|
"""Initialize the AudioMixer."""
|
||||||
|
self.min_weight = min_weight
|
||||||
|
self.example_source = example_source
|
||||||
|
self.max_weight = max_weight
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
|
||||||
|
def __call__(self, example: xr.Dataset) -> xr.Dataset:
|
||||||
|
"""Fetch another example and perform mixup.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
example : xr.Dataset
|
||||||
|
The primary input training example.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.Dataset
|
||||||
|
The mixed training example. Returns the original example if
|
||||||
|
fetching the second example fails.
|
||||||
|
"""
|
||||||
|
other = self.example_source()
|
||||||
|
return mix_examples(
|
||||||
|
example,
|
||||||
|
other,
|
||||||
|
self.preprocessor,
|
||||||
|
min_weight=self.min_weight,
|
||||||
|
max_weight=self.max_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_augmentation_from_config(
|
||||||
|
config: AugmentationConfig,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
example_source: Optional[ExampleSource] = None,
|
||||||
|
) -> Augmentation:
|
||||||
|
"""Factory function to build a single augmentation from its config.
|
||||||
|
|
||||||
|
Takes a configuration object for one augmentation step (which includes the
|
||||||
|
`augmentation_type` discriminator) and returns the corresponding functional
|
||||||
|
augmentation callable (e.g., a `partial` function or a callable class like
|
||||||
|
`AudioMixer`).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : AugmentationConfig
|
||||||
|
Configuration object for a single augmentation (e.g., instance of
|
||||||
|
`MixAugmentationConfig`, `EchoAugmentationConfig`, etc.).
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
The preprocessor object, required by augmentations that modify audio
|
||||||
|
and need to recompute the spectrogram (e.g., mixup, echo).
|
||||||
|
example_source : ExampleSource, optional
|
||||||
|
A callable that provides other training examples. Required only if
|
||||||
|
the configuration includes `MixAugmentationConfig` (`augmentation_type`
|
||||||
|
is "mix_audio").
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Augmentation (Callable[[xr.Dataset], xr.Dataset])
|
||||||
|
A callable function that takes a training example (`xr.Dataset`) and
|
||||||
|
returns a potentially augmented version.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If `config.augmentation_type` is "mix_audio" but `example_source` is
|
||||||
|
None.
|
||||||
|
NotImplementedError
|
||||||
|
If `config.augmentation_type` does not match any known augmentation
|
||||||
|
type.
|
||||||
|
"""
|
||||||
|
if config.augmentation_type == "mix_audio":
|
||||||
|
if example_source is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Mix audio augmentation ('mix_audio') requires an "
|
||||||
|
"'example_source' callable to be provided."
|
||||||
|
)
|
||||||
|
|
||||||
|
return AudioMixer(
|
||||||
|
example_source=example_source,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
min_weight=config.min_weight,
|
||||||
|
max_weight=config.max_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.augmentation_type == "add_echo":
|
||||||
|
return partial(
|
||||||
|
add_echo,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
max_delay=config.max_delay,
|
||||||
|
min_weight=config.min_weight,
|
||||||
|
max_weight=config.max_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.augmentation_type == "scale_volume":
|
||||||
|
return partial(
|
||||||
|
scale_volume,
|
||||||
|
max_scaling=config.max_scaling,
|
||||||
|
min_scaling=config.min_scaling,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.augmentation_type == "warp":
|
||||||
|
return partial(
|
||||||
|
warp_spectrogram,
|
||||||
|
delta=config.delta,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.augmentation_type == "mask_time":
|
||||||
|
return partial(
|
||||||
|
mask_time,
|
||||||
|
max_perc=config.max_perc,
|
||||||
|
max_mask=config.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.augmentation_type == "mask_freq":
|
||||||
|
return partial(
|
||||||
|
mask_frequency,
|
||||||
|
max_perc=config.max_perc,
|
||||||
|
max_masks=config.max_masks,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Invalid or unimplemented augmentation type: "
|
||||||
|
f"{config.augmentation_type}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_agumentation_config(
|
def build_augmentations(
|
||||||
|
config: AugmentationsConfig,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
example_source: Optional[ExampleSource] = None,
|
||||||
|
) -> Augmentation:
|
||||||
|
"""Build a composite augmentation pipeline function from configuration.
|
||||||
|
|
||||||
|
Takes the overall `AugmentationsConfig` (containing a list of individual
|
||||||
|
augmentation steps), builds the callable function for each step using
|
||||||
|
`build_augmentation_from_config`, wraps each function with `MaybeApply`
|
||||||
|
to handle its application probability, and returns a single
|
||||||
|
`Augmentation` function that applies the entire sequence.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
config : AugmentationsConfig
|
||||||
|
The configuration object detailing the sequence of augmentation steps.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
The preprocessor object, needed for audio-modifying augmentations.
|
||||||
|
example_source : ExampleSource, optional
|
||||||
|
A callable providing other examples, required if 'mix_audio' is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Augmentation (Callable[[xr.Dataset], xr.Dataset])
|
||||||
|
A single callable function that takes a training example (`xr.Dataset`)
|
||||||
|
and applies the configured sequence of augmentations probabilistically,
|
||||||
|
returning the augmented example. Returns the original example if
|
||||||
|
`config.steps` is empty.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If 'mix_audio' is configured but `example_source` is not provided.
|
||||||
|
NotImplementedError
|
||||||
|
If an unknown `augmentation_type` is encountered in `config.steps`.
|
||||||
|
"""
|
||||||
|
augmentations = []
|
||||||
|
|
||||||
|
for step_config in config.steps:
|
||||||
|
augmentation = build_augmentation_from_config(
|
||||||
|
step_config,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
example_source=example_source,
|
||||||
|
)
|
||||||
|
augmentations.append(
|
||||||
|
MaybeApply(
|
||||||
|
augmentation=augmentation,
|
||||||
|
probability=step_config.probability,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return partial(_apply_augmentations, augmentations=augmentations)
|
||||||
|
|
||||||
|
|
||||||
|
def load_augmentation_config(
|
||||||
path: data.PathLike, field: Optional[str] = None
|
path: data.PathLike, field: Optional[str] = None
|
||||||
) -> AugmentationsConfig:
|
) -> AugmentationsConfig:
|
||||||
|
"""Load the augmentations configuration from a file.
|
||||||
|
|
||||||
|
Reads a configuration file (YAML) and validates it against the
|
||||||
|
`AugmentationsConfig` schema, potentially extracting data from a nested
|
||||||
|
field.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to the configuration file.
|
||||||
|
field : str, optional
|
||||||
|
Dot-separated path to a nested section within the file containing the
|
||||||
|
augmentations configuration (e.g., "training.augmentations"). If None,
|
||||||
|
the entire file content is used.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
AugmentationsConfig
|
||||||
|
The loaded and validated augmentations configuration object.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file path does not exist.
|
||||||
|
yaml.YAMLError
|
||||||
|
If the file content is not valid YAML.
|
||||||
|
pydantic.ValidationError
|
||||||
|
If the loaded config data does not conform to `AugmentationsConfig`.
|
||||||
|
KeyError, TypeError
|
||||||
|
If `field` specifies an invalid path.
|
||||||
|
"""
|
||||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def should_apply(config: BaseAugmentationConfig) -> bool:
|
def _apply_augmentations(
|
||||||
if not config.enable:
|
|
||||||
return False
|
|
||||||
|
|
||||||
return np.random.uniform() < config.probability
|
|
||||||
|
|
||||||
|
|
||||||
def augment_example(
|
|
||||||
example: xr.Dataset,
|
example: xr.Dataset,
|
||||||
config: AugmentationsConfig,
|
augmentations: List[Augmentation],
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
):
|
||||||
others: Optional[Callable[[], xr.Dataset]] = None,
|
"""Apply a sequence of augmentation functions to an example."""
|
||||||
) -> xr.Dataset:
|
for augmentation in augmentations:
|
||||||
if should_apply(config.mix) and (others is not None):
|
example = augmentation(example)
|
||||||
other = others()
|
|
||||||
example = mix_examples(
|
|
||||||
example,
|
|
||||||
other,
|
|
||||||
min_weight=config.mix.min_weight,
|
|
||||||
max_weight=config.mix.max_weight,
|
|
||||||
config=preprocessing_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_apply(config.echo):
|
|
||||||
example = add_echo(
|
|
||||||
example,
|
|
||||||
max_delay=config.echo.max_delay,
|
|
||||||
min_weight=config.echo.min_weight,
|
|
||||||
max_weight=config.echo.max_weight,
|
|
||||||
config=preprocessing_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_apply(config.volume):
|
|
||||||
example = scale_volume(
|
|
||||||
example,
|
|
||||||
max_scaling=config.volume.max_scaling,
|
|
||||||
min_scaling=config.volume.min_scaling,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_apply(config.warp):
|
|
||||||
example = warp_spectrogram(
|
|
||||||
example,
|
|
||||||
delta=config.warp.delta,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_apply(config.time_mask):
|
|
||||||
example = mask_time(
|
|
||||||
example,
|
|
||||||
max_perc=config.time_mask.max_perc,
|
|
||||||
max_mask=config.time_mask.max_masks,
|
|
||||||
)
|
|
||||||
|
|
||||||
if should_apply(config.frequency_mask):
|
|
||||||
example = mask_frequency(
|
|
||||||
example,
|
|
||||||
max_perc=config.frequency_mask.max_perc,
|
|
||||||
max_masks=config.frequency_mask.max_masks,
|
|
||||||
)
|
|
||||||
|
|
||||||
return example
|
return example
|
||||||
|
@ -3,14 +3,15 @@ from lightning.pytorch.callbacks import Callback
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate import match_predictions_and_annotations
|
from batdetect2.evaluate import match_predictions_and_annotations
|
||||||
from batdetect2.post_process import postprocess_model_outputs
|
from batdetect2.postprocess import PostprocessorProtocol
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.types import ModelOutput
|
from batdetect2.types import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self):
|
def __init__(self, postprocessor: PostprocessorProtocol):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
self.postprocessor = postprocessor
|
||||||
self.predictions = []
|
self.predictions = []
|
||||||
|
|
||||||
def on_validation_epoch_start(
|
def on_validation_epoch_start(
|
||||||
@ -36,20 +37,20 @@ class ValidationMetrics(Callback):
|
|||||||
assert isinstance(dataset, LabeledDataset)
|
assert isinstance(dataset, LabeledDataset)
|
||||||
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
clip_annotation = dataset.get_clip_annotation(batch_idx)
|
||||||
|
|
||||||
clip_prediction = postprocess_model_outputs(
|
# clip_prediction = postprocess_model_outputs(
|
||||||
outputs,
|
# outputs,
|
||||||
clips=[clip_annotation.clip],
|
# clips=[clip_annotation.clip],
|
||||||
classes=self.class_names,
|
# classes=self.class_names,
|
||||||
decoder=self.decoder,
|
# decoder=self.decoder,
|
||||||
config=self.config.postprocessing,
|
# config=self.config.postprocessing,
|
||||||
)[0]
|
# )[0]
|
||||||
|
#
|
||||||
matches = match_predictions_and_annotations(
|
# matches = match_predictions_and_annotations(
|
||||||
clip_annotation,
|
# clip_annotation,
|
||||||
clip_prediction,
|
# clip_prediction,
|
||||||
)
|
# )
|
||||||
|
#
|
||||||
self.validation_predictions.extend(matches)
|
# self.validation_predictions.extend(matches)
|
||||||
return super().on_validation_batch_end(
|
# return super().on_validation_batch_end(
|
||||||
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
# trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
|
||||||
)
|
# )
|
||||||
|
@ -10,13 +10,13 @@ from soundevent import data
|
|||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess.tensors import adjust_width
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
augment_example,
|
augment_example,
|
||||||
select_subclip,
|
select_subclip,
|
||||||
)
|
)
|
||||||
from batdetect2.train.preprocess import PreprocessingConfig
|
from batdetect2.train.preprocess import PreprocessorProtocol
|
||||||
|
from batdetect2.utils.tensors import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
@ -51,15 +51,15 @@ class DatasetConfig(BaseConfig):
|
|||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
filenames: Sequence[PathLike],
|
filenames: Sequence[PathLike],
|
||||||
subclip: Optional[SubclipConfig] = None,
|
subclip: Optional[SubclipConfig] = None,
|
||||||
augmentation: Optional[AugmentationsConfig] = None,
|
augmentation: Optional[AugmentationsConfig] = None,
|
||||||
preprocessing: Optional[PreprocessingConfig] = None,
|
|
||||||
):
|
):
|
||||||
|
self.preprocessor = preprocessor
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.subclip = subclip
|
self.subclip = subclip
|
||||||
self.augmentation = augmentation
|
self.augmentation = augmentation
|
||||||
self.preprocessing = preprocessing or PreprocessingConfig()
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.filenames)
|
return len(self.filenames)
|
||||||
@ -79,7 +79,7 @@ class LabeledDataset(Dataset):
|
|||||||
dataset = augment_example(
|
dataset = augment_example(
|
||||||
dataset,
|
dataset,
|
||||||
self.augmentation,
|
self.augmentation,
|
||||||
preprocessing_config=self.preprocessing,
|
preprocessor=self.preprocessor,
|
||||||
others=self.get_random_example,
|
others=self.get_random_example,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -95,16 +95,16 @@ class LabeledDataset(Dataset):
|
|||||||
def from_directory(
|
def from_directory(
|
||||||
cls,
|
cls,
|
||||||
directory: PathLike,
|
directory: PathLike,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
extension: str = ".nc",
|
extension: str = ".nc",
|
||||||
subclip: Optional[SubclipConfig] = None,
|
subclip: Optional[SubclipConfig] = None,
|
||||||
augmentation: Optional[AugmentationsConfig] = None,
|
augmentation: Optional[AugmentationsConfig] = None,
|
||||||
preprocessing: Optional[PreprocessingConfig] = None,
|
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
get_preprocessed_files(directory, extension),
|
preprocessor=preprocessor,
|
||||||
|
filenames=get_preprocessed_files(directory, extension),
|
||||||
subclip=subclip,
|
subclip=subclip,
|
||||||
augmentation=augmentation,
|
augmentation=augmentation,
|
||||||
preprocessing=preprocessing,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_random_example(self) -> xr.Dataset:
|
def get_random_example(self) -> xr.Dataset:
|
||||||
|
@ -1,48 +1,47 @@
|
|||||||
"""Generate heatmap training targets for BatDetect2 models.
|
"""Generate heatmap training targets for BatDetect2 models.
|
||||||
|
|
||||||
This module represents the final step in the `batdetect2.targets` pipeline,
|
This module is responsible for creating the target labels used for training
|
||||||
converting processed sound event annotations from an audio clip into the
|
BatDetect2 models. It converts sound event annotations for an audio clip into
|
||||||
specific heatmap formats required for training the BatDetect2 neural network.
|
the specific multi-channel heatmap formats required by the neural network.
|
||||||
|
|
||||||
It integrates the filtering, transformation, and class encoding logic defined
|
It uses a pre-configured object adhering to the `TargetProtocol` (from
|
||||||
in the preceding configuration steps (`filtering`, `transform`, `classes`)
|
`batdetect2.targets`) which encapsulates all the logic for filtering
|
||||||
and applies them to generate three core outputs for a given spectrogram:
|
annotations, transforming tags, encoding class names, and mapping annotation
|
||||||
|
geometry (ROIs) to target positions and sizes. This module then focuses on
|
||||||
|
rendering this information onto the heatmap grids.
|
||||||
|
|
||||||
1. **Detection Heatmap**: Indicates the presence and location of relevant
|
The pipeline generates three core outputs for a given spectrogram:
|
||||||
sound events.
|
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
|
||||||
2. **Class Heatmap**: Indicates the location and predicted class label for
|
2. **Class Heatmap**: Indicates location and class identity for specifically
|
||||||
events that match a specific target class.
|
classified events.
|
||||||
3. **Size Heatmap**: Encodes the dimensions (width/time duration,
|
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
|
||||||
height/frequency bandwidth) of the detected sound events at their
|
|
||||||
reference locations.
|
|
||||||
|
|
||||||
The primary function generated by this module is a `ClipLabeller`, which takes
|
The primary function generated by this module is a `ClipLabeller` (defined in
|
||||||
a `ClipAnnotation` object and its corresponding spectrogram (`xr.DataArray`)
|
`.types`), which takes a `ClipAnnotation` object and its corresponding
|
||||||
and returns the calculated `Heatmaps`. Configuration options allow tuning of
|
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
|
||||||
the heatmap generation process (e.g., Gaussian smoothing sigma, reference point
|
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
||||||
within bounding boxes).
|
defined in `LabelConfig`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable, List, NamedTuple, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import arrays, data, geometry
|
from soundevent import arrays, data
|
||||||
from soundevent.geometry.operations import Positions
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.targets.classes import SoundEventEncoder
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.targets.filtering import SoundEventFilter
|
from batdetect2.train.types import (
|
||||||
from batdetect2.targets.transform import SoundEventTransformation
|
ClipLabeller,
|
||||||
|
Heatmaps,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
"Heatmaps",
|
|
||||||
"ClipLabeller",
|
|
||||||
"build_clip_labeler",
|
"build_clip_labeler",
|
||||||
"generate_clip_label",
|
"generate_clip_label",
|
||||||
"generate_heatmaps",
|
"generate_heatmaps",
|
||||||
@ -50,109 +49,45 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
SIZE_DIMENSION = "dimension"
|
||||||
|
"""Dimension name for the size heatmap."""
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Heatmaps(NamedTuple):
|
|
||||||
"""Structure holding the generated heatmap targets.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
detection : xr.DataArray
|
|
||||||
Heatmap indicating the probability of sound event presence. Typically
|
|
||||||
smoothed with a Gaussian kernel centered on event reference points.
|
|
||||||
Shape matches the input spectrogram. Values normalized [0, 1].
|
|
||||||
classes : xr.DataArray
|
|
||||||
Heatmap indicating the probability of specific class presence. Has an
|
|
||||||
additional 'category' dimension corresponding to the target class
|
|
||||||
names. Each category slice is typically smoothed with a Gaussian
|
|
||||||
kernel. Values normalized [0, 1] per category.
|
|
||||||
size : xr.DataArray
|
|
||||||
Heatmap encoding the size (width, height) of detected events. Has an
|
|
||||||
additional 'dimension' coordinate ('width', 'height'). Values represent
|
|
||||||
scaled dimensions placed at the event reference points.
|
|
||||||
"""
|
|
||||||
|
|
||||||
detection: xr.DataArray
|
|
||||||
classes: xr.DataArray
|
|
||||||
size: xr.DataArray
|
|
||||||
|
|
||||||
|
|
||||||
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
|
||||||
"""Type alias for the final clip labelling function.
|
|
||||||
|
|
||||||
This function takes the complete annotations for a clip and the corresponding
|
|
||||||
spectrogram, applies all configured filtering, transformation, and encoding
|
|
||||||
steps, and returns the final `Heatmaps` used for model training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class LabelConfig(BaseConfig):
|
class LabelConfig(BaseConfig):
|
||||||
"""Configuration parameters for heatmap generation.
|
"""Configuration parameters for heatmap generation.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
position : Positions, default="bottom-left"
|
|
||||||
Specifies the reference point within each sound event's geometry
|
|
||||||
(bounding box) that is used to place the 'peak' or value on the
|
|
||||||
heatmaps. Options include 'center', 'bottom-left', 'top-right', etc.
|
|
||||||
See `soundevent.geometry.operations.Positions` for valid options.
|
|
||||||
sigma : float, default=3.0
|
sigma : float, default=3.0
|
||||||
The standard deviation (in pixels/bins) of the Gaussian kernel applied
|
The standard deviation (in pixels/bins) of the Gaussian kernel applied
|
||||||
to smooth the detection and class heatmaps. Larger values create more
|
to smooth the detection and class heatmaps. Larger values create more
|
||||||
diffuse targets.
|
diffuse targets.
|
||||||
time_scale : float, default=1000.0
|
|
||||||
A scaling factor applied to the time duration (width) of sound event
|
|
||||||
bounding boxes before storing them in the 'width' dimension of the
|
|
||||||
`size` heatmap. The appropriate value depends on how the model input/
|
|
||||||
output resolution relates to physical time units (e.g., if time axis
|
|
||||||
is milliseconds, scale might be 1.0; if seconds, maybe 1000.0). Needs
|
|
||||||
to match model expectations.
|
|
||||||
frequency_scale : float, default=1/859.375
|
|
||||||
A scaling factor applied to the frequency bandwidth (height) of sound
|
|
||||||
event bounding boxes before storing them in the 'height' dimension of
|
|
||||||
the `size` heatmap. The appropriate value depends on the relationship
|
|
||||||
between the frequency axis resolution (e.g., kHz or Hz per bin) and
|
|
||||||
the desired output units/scale for the model. Needs to match model
|
|
||||||
expectations. (The default suggests input frequency might be in Hz
|
|
||||||
and output is scaled relative to some reference).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
position: Positions = "bottom-left"
|
|
||||||
sigma: float = 3.0
|
sigma: float = 3.0
|
||||||
time_scale: float = 1000.0
|
|
||||||
frequency_scale: float = 1 / 859.375
|
|
||||||
|
|
||||||
|
|
||||||
def build_clip_labeler(
|
def build_clip_labeler(
|
||||||
filter_fn: SoundEventFilter,
|
targets: TargetProtocol,
|
||||||
transform_fn: SoundEventTransformation,
|
|
||||||
encoder_fn: SoundEventEncoder,
|
|
||||||
class_names: List[str],
|
|
||||||
config: LabelConfig,
|
config: LabelConfig,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
"""Construct the clip labelling function.
|
"""Construct the final clip labelling function.
|
||||||
|
|
||||||
This function takes the pre-built components from the previous target
|
This factory function prepares the callable that will perform the
|
||||||
definition steps (filtering, transformation, encoding) and the label
|
end-to-end heatmap generation for a given clip and spectrogram during
|
||||||
configuration, then returns a single callable (`ClipLabeller`) that
|
training data loading. It takes the fully configured `targets` object and
|
||||||
performs the end-to-end heatmap generation for a given clip and
|
the `LabelConfig` and binds them to the `generate_clip_label` function.
|
||||||
spectrogram.
|
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
filter_fn : SoundEventFilter
|
targets : TargetProtocol
|
||||||
Function to filter irrelevant sound event annotations.
|
An initialized object conforming to the `TargetProtocol`, providing all
|
||||||
transform_fn : SoundEventTransformation
|
necessary methods for filtering, transforming, encoding, and ROI
|
||||||
Function to transform tags of sound event annotations.
|
mapping.
|
||||||
encoder_fn : SoundEventEncoder
|
|
||||||
Function to encode a sound event annotation into a class name.
|
|
||||||
class_names : List[str]
|
|
||||||
Ordered list of unique target class names for the classification
|
|
||||||
heatmap.
|
|
||||||
config : LabelConfig
|
config : LabelConfig
|
||||||
Configuration object containing heatmap generation parameters (sigma,
|
Configuration object containing heatmap generation parameters.
|
||||||
etc.).
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -162,10 +97,7 @@ def build_clip_labeler(
|
|||||||
"""
|
"""
|
||||||
return partial(
|
return partial(
|
||||||
generate_clip_label,
|
generate_clip_label,
|
||||||
filter_fn=filter_fn,
|
targets=targets,
|
||||||
transform_fn=transform_fn,
|
|
||||||
encoder_fn=encoder_fn,
|
|
||||||
class_names=class_names,
|
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -173,123 +105,97 @@ def build_clip_labeler(
|
|||||||
def generate_clip_label(
|
def generate_clip_label(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
filter_fn: SoundEventFilter,
|
targets: TargetProtocol,
|
||||||
transform_fn: SoundEventTransformation,
|
|
||||||
encoder_fn: SoundEventEncoder,
|
|
||||||
class_names: List[str],
|
|
||||||
config: LabelConfig,
|
config: LabelConfig,
|
||||||
) -> Heatmaps:
|
) -> Heatmaps:
|
||||||
"""Generate heatmaps for a single clip by applying all processing steps.
|
"""Generate training heatmaps for a single annotated clip.
|
||||||
|
|
||||||
This function orchestrates the process for one clip:
|
This function orchestrates the target generation process for one clip:
|
||||||
1. Filters the sound events using `filter_fn`.
|
1. Filters and transforms sound events using `targets.filter` and
|
||||||
2. Transforms the tags of filtered events using `transform_fn`.
|
`targets.transform`.
|
||||||
3. Passes the processed annotations and other parameters to
|
2. Passes the resulting processed annotations, along with the spectrogram,
|
||||||
`generate_heatmaps` to create the final target heatmaps.
|
the `targets` object, and the Gaussian `sigma` from `config`, to the
|
||||||
|
core `generate_heatmaps` function.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
clip_annotation : data.ClipAnnotation
|
clip_annotation : data.ClipAnnotation
|
||||||
The complete annotation data for the audio clip.
|
The complete annotation data for the audio clip, including the list
|
||||||
|
of `sound_events` to process.
|
||||||
spec : xr.DataArray
|
spec : xr.DataArray
|
||||||
The spectrogram corresponding to the `clip_annotation`.
|
The spectrogram corresponding to the `clip_annotation`. Must have
|
||||||
filter_fn : SoundEventFilter
|
'time' and 'frequency' dimensions/coordinates.
|
||||||
Function to filter sound event annotations.
|
targets : TargetProtocol
|
||||||
transform_fn : SoundEventTransformation
|
The fully configured target definition object, providing methods for
|
||||||
Function to transform tags of sound event annotations.
|
filtering, transformation, encoding, and ROI mapping.
|
||||||
encoder_fn : SoundEventEncoder
|
|
||||||
Function to encode a sound event annotation into a class name.
|
|
||||||
class_names : List[str]
|
|
||||||
Ordered list of unique target class names.
|
|
||||||
config : LabelConfig
|
config : LabelConfig
|
||||||
Configuration object containing heatmap generation parameters.
|
Configuration object providing heatmap parameters (primarily `sigma`).
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Heatmaps
|
Heatmaps
|
||||||
The generated detection, classes, and size heatmaps for the clip.
|
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
||||||
|
heatmaps for this clip.
|
||||||
"""
|
"""
|
||||||
return generate_heatmaps(
|
return generate_heatmaps(
|
||||||
(
|
(
|
||||||
transform_fn(sound_event_annotation)
|
targets.transform(sound_event_annotation)
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
if filter_fn(sound_event_annotation)
|
if targets.filter(sound_event_annotation)
|
||||||
),
|
),
|
||||||
spec=spec,
|
spec=spec,
|
||||||
class_names=class_names,
|
targets=targets,
|
||||||
encoder=encoder_fn,
|
|
||||||
target_sigma=config.sigma,
|
target_sigma=config.sigma,
|
||||||
position=config.position,
|
|
||||||
time_scale=config.time_scale,
|
|
||||||
frequency_scale=config.frequency_scale,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
sound_events: Iterable[data.SoundEventAnnotation],
|
sound_events: Iterable[data.SoundEventAnnotation],
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
class_names: List[str],
|
targets: TargetProtocol,
|
||||||
encoder: SoundEventEncoder,
|
|
||||||
target_sigma: float = 3.0,
|
target_sigma: float = 3.0,
|
||||||
position: Positions = "bottom-left",
|
|
||||||
time_scale: float = 1000.0,
|
|
||||||
frequency_scale: float = 1 / 859.375,
|
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> Heatmaps:
|
) -> Heatmaps:
|
||||||
"""Generate detection, class, and size heatmaps from sound events.
|
"""Generate detection, class, and size heatmaps from sound events.
|
||||||
|
|
||||||
Processes an iterable of sound event annotations (assumed to be already
|
Creates heatmap representations from an iterable of sound event
|
||||||
filtered and transformed) and creates heatmap representations suitable
|
annotations. This function relies on the provided `targets` object to get
|
||||||
for training models like BatDetect2.
|
the reference position, scaled size, and class encoding for each
|
||||||
|
annotation.
|
||||||
|
|
||||||
The process involves:
|
Process:
|
||||||
1. Initializing empty heatmaps based on the spectrogram shape.
|
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
|
||||||
2. Iterating through each sound event.
|
metadata.
|
||||||
3. For each event, finding its reference point and placing a '1.0'
|
2. Iterates through `sound_events`.
|
||||||
on the detection heatmap at that point.
|
3. For each event:
|
||||||
4. Calculating the scaled bounding box size and placing it on the size
|
a. Gets geometry. Skips if missing.
|
||||||
heatmap at the reference point.
|
b. Gets reference position using `targets.get_position()`. Skips if out
|
||||||
5. Encoding the event to get its class name and placing a '1.0' on the
|
of bounds.
|
||||||
corresponding class heatmap slice at the reference point (if
|
c. Places a peak (1.0) on the detection heatmap at the position.
|
||||||
classified).
|
d. Gets scaled size using `targets.get_size()` and places it on the
|
||||||
6. Applying Gaussian smoothing to detection and class heatmaps.
|
size heatmap.
|
||||||
7. Normalizing detection and class heatmaps to the range [0, 1].
|
e. Encodes class using `targets.encode()` and places a peak (1.0) on
|
||||||
|
the corresponding class heatmap layer if a specific class is
|
||||||
|
returned.
|
||||||
|
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
|
||||||
|
heatmaps.
|
||||||
|
5. Normalizes detection and class heatmaps to range [0, 1].
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
sound_events : Iterable[data.SoundEventAnnotation]
|
sound_events : Iterable[data.SoundEventAnnotation]
|
||||||
An iterable of sound event annotations to include in the heatmaps.
|
An iterable of sound event annotations to render onto the heatmaps.
|
||||||
These should ideally be the result of prior filtering and tag
|
|
||||||
transformation steps.
|
|
||||||
spec : xr.DataArray
|
spec : xr.DataArray
|
||||||
The spectrogram array corresponding to the time/frequency range of
|
The spectrogram array corresponding to the time/frequency range of
|
||||||
the annotations. Used for shape and coordinate information. Must have
|
the annotations. Used for shape and coordinate information. Must have
|
||||||
'time' and 'frequency' dimensions.
|
'time' and 'frequency' dimensions/coordinates.
|
||||||
class_names : List[str]
|
targets : TargetProtocol
|
||||||
An ordered list of unique class names. The class heatmap will have
|
The fully configured target definition object. Used to access class
|
||||||
a channel ('category' dimension) for each name in this list. Must not
|
names, dimension names, and the methods `get_position`, `get_size`,
|
||||||
be empty.
|
`encode`.
|
||||||
encoder : SoundEventEncoder
|
|
||||||
A function that takes a SoundEventAnnotation and returns the
|
|
||||||
corresponding class name (str) or None if it doesn't belong to a
|
|
||||||
specific class (e.g., it falls into the generic 'Bat' category).
|
|
||||||
target_sigma : float, default=3.0
|
target_sigma : float, default=3.0
|
||||||
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
||||||
smooth the detection and class heatmaps after initial point placement.
|
smooth the detection and class heatmaps.
|
||||||
position : Positions, default="bottom-left"
|
|
||||||
The reference point within each annotation's geometry bounding box
|
|
||||||
used to place the signal on the heatmaps (e.g., "center",
|
|
||||||
"bottom-left"). See `soundevent.geometry.operations.Positions`.
|
|
||||||
time_scale : float, default=1000.0
|
|
||||||
Scaling factor applied to the time duration (width in seconds) of
|
|
||||||
annotations when storing them in the size heatmap. The resulting
|
|
||||||
value's unit depends on this scale (e.g., 1000.0 might convert seconds
|
|
||||||
to ms).
|
|
||||||
frequency_scale : float, default=1/859.375
|
|
||||||
Scaling factor applied to the frequency bandwidth (height in Hz or kHz)
|
|
||||||
of annotations when storing them in the size heatmap. The resulting
|
|
||||||
value's unit depends on this scale and the input unit. (Default
|
|
||||||
scaling relative to ~860 Hz).
|
|
||||||
dtype : type, default=np.float32
|
dtype : type, default=np.float32
|
||||||
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
||||||
|
|
||||||
@ -303,28 +209,10 @@ def generate_heatmaps(
|
|||||||
------
|
------
|
||||||
ValueError
|
ValueError
|
||||||
If the input spectrogram `spec` does not have both 'time' and
|
If the input spectrogram `spec` does not have both 'time' and
|
||||||
'frequency' dimensions, or if `class_names` is empty.
|
'frequency' dimensions, or if `targets.class_names` is empty.
|
||||||
|
|
||||||
Notes
|
|
||||||
-----
|
|
||||||
* This function expects `sound_events` to be already filtered and
|
|
||||||
transformed.
|
|
||||||
* It includes error handling to skip individual annotations that cause
|
|
||||||
issues (e.g., missing geometry, out-of-bounds coordinates, encoder
|
|
||||||
errors) allowing the rest of the clip to be processed. Warnings or
|
|
||||||
errors are logged.
|
|
||||||
* The `time_scale` and `frequency_scale` parameters are crucial and must be
|
|
||||||
set according to the expectations of the specific BatDetect2 model
|
|
||||||
architecture being trained. Consult model documentation for required
|
|
||||||
units/scales.
|
|
||||||
* Gaussian filtering and normalization are applied only to detection and
|
|
||||||
class heatmaps, not the size heatmap.
|
|
||||||
"""
|
"""
|
||||||
shape = dict(zip(spec.dims, spec.shape))
|
shape = dict(zip(spec.dims, spec.shape))
|
||||||
|
|
||||||
if len(class_names) == 0:
|
|
||||||
raise ValueError("No class names provided.")
|
|
||||||
|
|
||||||
if "time" not in shape or "frequency" not in shape:
|
if "time" not in shape or "frequency" not in shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Spectrogram must have time and frequency dimensions."
|
"Spectrogram must have time and frequency dimensions."
|
||||||
@ -333,18 +221,18 @@ def generate_heatmaps(
|
|||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||||
class_heatmap = xr.DataArray(
|
class_heatmap = xr.DataArray(
|
||||||
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
|
||||||
dims=["category", *spec.dims],
|
dims=["category", *spec.dims],
|
||||||
coords={
|
coords={
|
||||||
"category": [*class_names],
|
"category": [*targets.class_names],
|
||||||
**spec.coords,
|
**spec.coords,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
size_heatmap = xr.DataArray(
|
size_heatmap = xr.DataArray(
|
||||||
data=np.zeros((2, *spec.shape), dtype=dtype),
|
data=np.zeros((2, *spec.shape), dtype=dtype),
|
||||||
dims=["dimension", *spec.dims],
|
dims=[SIZE_DIMENSION, *spec.dims],
|
||||||
coords={
|
coords={
|
||||||
"dimension": ["width", "height"],
|
SIZE_DIMENSION: targets.dimension_names,
|
||||||
**spec.coords,
|
**spec.coords,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@ -359,7 +247,7 @@ def generate_heatmaps(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Get the position of the sound event
|
# Get the position of the sound event
|
||||||
time, frequency = geometry.get_geometry_point(geom, position=position)
|
time, frequency = targets.get_position(sound_event_annotation)
|
||||||
|
|
||||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||||
try:
|
try:
|
||||||
@ -379,17 +267,7 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Set the size of the sound event at the position in the size heatmap
|
size = targets.get_size(sound_event_annotation)
|
||||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
|
||||||
geom
|
|
||||||
)
|
|
||||||
|
|
||||||
size = np.array(
|
|
||||||
[
|
|
||||||
(end_time - start_time) * time_scale,
|
|
||||||
(high_freq - low_freq) * frequency_scale,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
size_heatmap = arrays.set_value_at_pos(
|
size_heatmap = arrays.set_value_at_pos(
|
||||||
size_heatmap,
|
size_heatmap,
|
||||||
@ -400,7 +278,7 @@ def generate_heatmaps(
|
|||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
try:
|
try:
|
||||||
class_name = encoder(sound_event_annotation)
|
class_name = targets.encode(sound_event_annotation)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Skipping annotation %s: Unexpected error while encoding "
|
"Skipping annotation %s: Unexpected error while encoding "
|
||||||
@ -414,19 +292,6 @@ def generate_heatmaps(
|
|||||||
# If the label is None skip the sound event
|
# If the label is None skip the sound event
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if class_name not in class_names:
|
|
||||||
# If the label is not in the class names skip the sound event
|
|
||||||
logger.warning(
|
|
||||||
(
|
|
||||||
"Skipping annotation %s for class heatmap: "
|
|
||||||
"class name '%s' not in class names. Class names: %s"
|
|
||||||
),
|
|
||||||
sound_event_annotation.uuid,
|
|
||||||
class_name,
|
|
||||||
class_names,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
class_heatmap = arrays.set_value_at_pos(
|
class_heatmap = arrays.set_value_at_pos(
|
||||||
class_heatmap,
|
class_heatmap,
|
||||||
@ -453,7 +318,7 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
|
|
||||||
class_heatmap = class_heatmap.groupby("category").map(
|
class_heatmap = class_heatmap.groupby("category").map(
|
||||||
gaussian_filter,
|
gaussian_filter, # type: ignore
|
||||||
args=(target_sigma,),
|
args=(target_sigma,),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1,98 +1,101 @@
|
|||||||
"""Module for preprocessing data for training."""
|
"""Preprocesses datasets for BatDetect2 model training.
|
||||||
|
|
||||||
|
This module provides functions to take a collection of annotated audio clips
|
||||||
|
(`soundevent.data.ClipAnnotation`) and process them into the final format
|
||||||
|
required for training a BatDetect2 model. This typically involves:
|
||||||
|
|
||||||
|
1. Loading the relevant audio segment for each annotation using a configured
|
||||||
|
`PreprocessorProtocol`.
|
||||||
|
2. Generating the corresponding input spectrogram using the
|
||||||
|
`PreprocessorProtocol`.
|
||||||
|
3. Generating the target heatmaps (detection, classification, size) using a
|
||||||
|
configured `ClipLabeller` (which encapsulates the `TargetProtocol` logic).
|
||||||
|
4. Packaging the input spectrogram, target heatmaps, and potentially the
|
||||||
|
processed audio waveform into an `xarray.Dataset`.
|
||||||
|
5. Saving each processed `xarray.Dataset` to a separate file (typically NetCDF)
|
||||||
|
in an output directory.
|
||||||
|
|
||||||
|
This offline preprocessing is often preferred for large datasets as it avoids
|
||||||
|
computationally intensive steps during the actual training loop. The module
|
||||||
|
includes utilities for parallel processing using `multiprocessing`.
|
||||||
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from multiprocessing import Pool
|
from multiprocessing import Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Sequence, Union
|
from typing import Callable, Optional, Sequence
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.train.types import ClipLabeller
|
||||||
PreprocessingConfig,
|
|
||||||
compute_spectrogram,
|
|
||||||
load_clip_audio,
|
|
||||||
)
|
|
||||||
from batdetect2.targets import (
|
|
||||||
LabelConfig,
|
|
||||||
TargetConfig,
|
|
||||||
build_sound_event_filter,
|
|
||||||
build_target_encoder,
|
|
||||||
generate_heatmaps,
|
|
||||||
get_class_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
|
||||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
"preprocess_single_annotation",
|
"preprocess_single_annotation",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"TrainPreprocessingConfig",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||||
class TrainPreprocessingConfig(BaseConfig):
|
"""Type alias for a function that generates an output filename."""
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
preprocessor: PreprocessorProtocol,
|
||||||
target_config: Optional[TargetConfig] = None,
|
labeller: ClipLabeller,
|
||||||
label_config: Optional[LabelConfig] = None,
|
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Generate a training example."""
|
"""Generate a complete training example for one annotation.
|
||||||
config = TrainPreprocessingConfig(
|
|
||||||
preprocessing=preprocessing_config or PreprocessingConfig(),
|
|
||||||
target=target_config or TargetConfig(),
|
|
||||||
labels=label_config or LabelConfig(),
|
|
||||||
)
|
|
||||||
|
|
||||||
wave = load_clip_audio(
|
This function takes a single `ClipAnnotation`, applies the configured
|
||||||
clip_annotation.clip,
|
preprocessing (`PreprocessorProtocol`) to get the processed waveform and
|
||||||
config=config.preprocessing.audio,
|
input spectrogram, applies the configured target generation
|
||||||
)
|
(`ClipLabeller`) to get the target heatmaps, and packages them all into a
|
||||||
|
single `xr.Dataset`.
|
||||||
|
|
||||||
spectrogram = compute_spectrogram(
|
Parameters
|
||||||
wave,
|
----------
|
||||||
config=config.preprocessing.spectrogram,
|
clip_annotation : data.ClipAnnotation
|
||||||
)
|
The annotated clip to process. Contains the reference to the `Clip`
|
||||||
|
(audio segment) and the associated `SoundEventAnnotation` objects.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
An initialized preprocessor object responsible for loading/processing
|
||||||
|
audio and computing the input spectrogram.
|
||||||
|
labeller : ClipLabeller
|
||||||
|
An initialized clip labeller function responsible for generating the
|
||||||
|
target heatmaps (detection, class, size) from the `clip_annotation`
|
||||||
|
and the computed spectrogram.
|
||||||
|
|
||||||
filter_fn = build_sound_event_filter(
|
Returns
|
||||||
include=config.target.include,
|
-------
|
||||||
exclude=config.target.exclude,
|
xr.Dataset
|
||||||
)
|
An xarray Dataset containing the following data variables:
|
||||||
|
- `audio`: The preprocessed audio waveform (dims: 'audio_time').
|
||||||
|
- `spectrogram`: The computed input spectrogram
|
||||||
|
(dims: 'time', 'frequency').
|
||||||
|
- `detection`: The target detection heatmap
|
||||||
|
(dims: 'time', 'frequency').
|
||||||
|
- `class`: The target class heatmap
|
||||||
|
(dims: 'category', 'time', 'frequency').
|
||||||
|
- `size`: The target size heatmap
|
||||||
|
(dims: 'dimension', 'time', 'frequency').
|
||||||
|
The Dataset also includes metadata in its attributes.
|
||||||
|
|
||||||
selected_events = [
|
Notes
|
||||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
-----
|
||||||
]
|
- The 'time' dimension of the 'audio' DataArray is renamed to 'audio_time'
|
||||||
|
within the output Dataset to avoid coordinate conflicts with the
|
||||||
|
spectrogram's 'time' dimension when stored together.
|
||||||
|
- The original `ClipAnnotation` metadata is stored as a JSON string in the
|
||||||
|
Dataset's attributes for provenance.
|
||||||
|
"""
|
||||||
|
wave = preprocessor.load_clip_audio(clip_annotation.clip)
|
||||||
|
|
||||||
encoder = build_target_encoder(
|
spectrogram = preprocessor.compute_spectrogram(wave)
|
||||||
config.target.classes,
|
|
||||||
replacement_rules=config.target.replace,
|
|
||||||
)
|
|
||||||
class_names = get_class_names(config.target.classes)
|
|
||||||
|
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
heatmaps = labeller(clip_annotation, spectrogram)
|
||||||
selected_events,
|
|
||||||
spectrogram,
|
|
||||||
class_names,
|
|
||||||
encoder,
|
|
||||||
target_sigma=config.labels.heatmaps.sigma,
|
|
||||||
position=config.labels.heatmaps.position,
|
|
||||||
time_scale=config.labels.heatmaps.time_scale,
|
|
||||||
frequency_scale=config.labels.heatmaps.frequency_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
dataset = xr.Dataset(
|
||||||
{
|
{
|
||||||
@ -102,15 +105,14 @@ def generate_train_example(
|
|||||||
# as the waveform.
|
# as the waveform.
|
||||||
"audio": wave.rename({"time": "audio_time"}),
|
"audio": wave.rename({"time": "audio_time"}),
|
||||||
"spectrogram": spectrogram,
|
"spectrogram": spectrogram,
|
||||||
"detection": detection_heatmap,
|
"detection": heatmaps.detection,
|
||||||
"class": class_heatmap,
|
"class": heatmaps.classes,
|
||||||
"size": size_heatmap,
|
"size": heatmaps.size,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset.assign_attrs(
|
return dataset.assign_attrs(
|
||||||
title=f"Training example for {clip_annotation.uuid}",
|
title=f"Training example for {clip_annotation.uuid}",
|
||||||
config=config.model_dump_json(),
|
|
||||||
clip_annotation=clip_annotation.model_dump_json(
|
clip_annotation=clip_annotation.model_dump_json(
|
||||||
exclude_none=True,
|
exclude_none=True,
|
||||||
exclude_defaults=True,
|
exclude_defaults=True,
|
||||||
@ -119,10 +121,21 @@ def generate_train_example(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def save_to_file(
|
def _save_xr_dataset_to_file(
|
||||||
dataset: xr.Dataset,
|
dataset: xr.Dataset,
|
||||||
path: PathLike,
|
path: data.PathLike,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Save an xarray Dataset to a NetCDF file with compression.
|
||||||
|
|
||||||
|
Internal helper function used by `preprocess_single_annotation`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
dataset : xr.Dataset
|
||||||
|
The training example dataset to save.
|
||||||
|
path : PathLike
|
||||||
|
The output file path (e.g., 'output/uuid.nc').
|
||||||
|
"""
|
||||||
dataset.to_netcdf(
|
dataset.to_netcdf(
|
||||||
path,
|
path,
|
||||||
encoding={
|
encoding={
|
||||||
@ -135,20 +148,60 @@ def save_to_file(
|
|||||||
|
|
||||||
|
|
||||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||||
|
"""Generate a default output filename based on the annotation UUID."""
|
||||||
return f"{clip_annotation.uuid}.nc"
|
return f"{clip_annotation.uuid}.nc"
|
||||||
|
|
||||||
|
|
||||||
def preprocess_annotations(
|
def preprocess_annotations(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
output_dir: PathLike,
|
output_dir: data.PathLike,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
labeller: ClipLabeller,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
|
||||||
target_config: Optional[TargetConfig] = None,
|
|
||||||
label_config: Optional[LabelConfig] = None,
|
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Preprocess annotations and save to disk."""
|
"""Preprocess a sequence of ClipAnnotations and save results to disk.
|
||||||
|
|
||||||
|
Generates the full training example (spectrogram, heatmaps, etc.) for each
|
||||||
|
`ClipAnnotation` in the input sequence using the provided `preprocessor`
|
||||||
|
and `labeller`. Saves each example as a separate NetCDF file in the
|
||||||
|
`output_dir`. Utilizes multiprocessing for potentially faster processing.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip_annotations : Sequence[data.ClipAnnotation]
|
||||||
|
A sequence (e.g., list) of the clip annotations to preprocess.
|
||||||
|
output_dir : PathLike
|
||||||
|
Path to the directory where the processed NetCDF files will be saved.
|
||||||
|
Will be created if it doesn't exist.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
Initialized preprocessor object to generate spectrograms.
|
||||||
|
labeller : ClipLabeller
|
||||||
|
Initialized labeller function to generate target heatmaps.
|
||||||
|
filename_fn : FilenameFn, optional
|
||||||
|
Function to generate the output filename (without extension) for each
|
||||||
|
`ClipAnnotation`. Defaults to using the annotation UUID via
|
||||||
|
`_get_filename`.
|
||||||
|
replace : bool, default=False
|
||||||
|
If True, existing files in `output_dir` with the same generated name
|
||||||
|
will be overwritten. If False (default), existing files are skipped.
|
||||||
|
max_workers : int, optional
|
||||||
|
Maximum number of worker processes to use for parallel processing.
|
||||||
|
If None (default), uses the number of CPUs available (`os.cpu_count()`).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
None
|
||||||
|
This function does not return anything; its side effect is creating
|
||||||
|
files in the `output_dir`.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
RuntimeError
|
||||||
|
If processing fails for any individual annotation when using
|
||||||
|
multiprocessing. The original exception will be attached as the cause.
|
||||||
|
"""
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
@ -163,9 +216,8 @@ def preprocess_annotations(
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
filename_fn=filename_fn,
|
filename_fn=filename_fn,
|
||||||
replace=replace,
|
replace=replace,
|
||||||
preprocessing_config=preprocessing_config,
|
preprocessor=preprocessor,
|
||||||
target_config=target_config,
|
labeller=labeller,
|
||||||
label_config=label_config,
|
|
||||||
),
|
),
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
),
|
),
|
||||||
@ -176,13 +228,34 @@ def preprocess_annotations(
|
|||||||
|
|
||||||
def preprocess_single_annotation(
|
def preprocess_single_annotation(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
output_dir: PathLike,
|
output_dir: data.PathLike,
|
||||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
preprocessor: PreprocessorProtocol,
|
||||||
target_config: Optional[TargetConfig] = None,
|
labeller: ClipLabeller,
|
||||||
label_config: Optional[LabelConfig] = None,
|
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""Process a single ClipAnnotation and save the result to a file.
|
||||||
|
|
||||||
|
Internal function designed to be called by `preprocess_annotations`, often
|
||||||
|
in parallel worker processes. It generates the training example using
|
||||||
|
`generate_train_example` and saves it using `save_to_file`. Handles
|
||||||
|
file existence checks based on the `replace` flag.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip_annotation : data.ClipAnnotation
|
||||||
|
The single annotation to process.
|
||||||
|
output_dir : Path
|
||||||
|
The directory where the output NetCDF file should be saved.
|
||||||
|
preprocessor : PreprocessorProtocol
|
||||||
|
Initialized preprocessor object.
|
||||||
|
labeller : ClipLabeller
|
||||||
|
Initialized labeller function.
|
||||||
|
filename_fn : FilenameFn, default=_get_filename
|
||||||
|
Function to determine the output filename.
|
||||||
|
replace : bool, default=False
|
||||||
|
Whether to overwrite existing output files.
|
||||||
|
"""
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
filename = filename_fn(clip_annotation)
|
filename = filename_fn(clip_annotation)
|
||||||
@ -197,13 +270,12 @@ def preprocess_single_annotation(
|
|||||||
try:
|
try:
|
||||||
sample = generate_train_example(
|
sample = generate_train_example(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
preprocessing_config=preprocessing_config,
|
preprocessor=preprocessor,
|
||||||
target_config=target_config,
|
labeller=labeller,
|
||||||
label_config=label_config,
|
|
||||||
)
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
save_to_file(sample, path)
|
_save_xr_dataset_to_file(sample, path)
|
||||||
|
48
batdetect2/train/types.py
Normal file
48
batdetect2/train/types.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
from typing import Callable, NamedTuple
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Heatmaps",
|
||||||
|
"ClipLabeller",
|
||||||
|
"Augmentation",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Heatmaps(NamedTuple):
|
||||||
|
"""Structure holding the generated heatmap targets.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
detection : xr.DataArray
|
||||||
|
Heatmap indicating the probability of sound event presence. Typically
|
||||||
|
smoothed with a Gaussian kernel centered on event reference points.
|
||||||
|
Shape matches the input spectrogram. Values normalized [0, 1].
|
||||||
|
classes : xr.DataArray
|
||||||
|
Heatmap indicating the probability of specific class presence. Has an
|
||||||
|
additional 'category' dimension corresponding to the target class
|
||||||
|
names. Each category slice is typically smoothed with a Gaussian
|
||||||
|
kernel. Values normalized [0, 1] per category.
|
||||||
|
size : xr.DataArray
|
||||||
|
Heatmap encoding the size (width, height) of detected events. Has an
|
||||||
|
additional 'dimension' coordinate ('width', 'height'). Values represent
|
||||||
|
scaled dimensions placed at the event reference points.
|
||||||
|
"""
|
||||||
|
|
||||||
|
detection: xr.DataArray
|
||||||
|
classes: xr.DataArray
|
||||||
|
size: xr.DataArray
|
||||||
|
|
||||||
|
|
||||||
|
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
||||||
|
"""Type alias for the final clip labelling function.
|
||||||
|
|
||||||
|
This function takes the complete annotations for a clip and the corresponding
|
||||||
|
spectrogram, applies all configured filtering, transformation, and encoding
|
||||||
|
steps, and returns the final `Heatmaps` used for model training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
|
|
@ -1,14 +1,10 @@
|
|||||||
"""Types used in the code base."""
|
"""Types used in the code base."""
|
||||||
|
|
||||||
from typing import Any, List, NamedTuple, Optional
|
from typing import Any, List, NamedTuple, Optional, TypedDict
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from typing import TypedDict
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -1,34 +0,0 @@
|
|||||||
import torch
|
|
||||||
from hypothesis import given
|
|
||||||
from hypothesis import strategies as st
|
|
||||||
|
|
||||||
from batdetect2.models import ModelConfig, ModelType, build_architecture
|
|
||||||
|
|
||||||
|
|
||||||
@given(
|
|
||||||
input_width=st.integers(min_value=50, max_value=1500),
|
|
||||||
input_height=st.integers(min_value=1, max_value=16),
|
|
||||||
model_type=st.sampled_from(ModelType),
|
|
||||||
)
|
|
||||||
def test_model_can_process_spectrograms_of_any_width(
|
|
||||||
input_width,
|
|
||||||
input_height,
|
|
||||||
model_type,
|
|
||||||
):
|
|
||||||
# Input height must be divisible by 8
|
|
||||||
input_height = 8 * input_height
|
|
||||||
|
|
||||||
input = torch.rand([1, 1, input_height, input_width])
|
|
||||||
|
|
||||||
model = build_architecture(
|
|
||||||
config=ModelConfig(
|
|
||||||
name=model_type, # type: ignore
|
|
||||||
input_height=input_height,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
output = model(input)
|
|
||||||
assert output.shape[0] == 1
|
|
||||||
assert output.shape[1] == model.out_channels
|
|
||||||
assert output.shape[2] == input_height
|
|
||||||
assert output.shape[3] == input_width
|
|
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.targets import generate_heatmaps
|
from batdetect2.train.labels import generate_heatmaps
|
||||||
|
|
||||||
recording = data.Recording(
|
recording = data.Recording(
|
||||||
samplerate=256_000,
|
samplerate=256_000,
|
||||||
|
Loading…
Reference in New Issue
Block a user