batdetect2/batdetect2/train/labels.py
2025-04-22 00:36:34 +01:00

367 lines
12 KiB
Python

"""Generate heatmap training targets for BatDetect2 models.
This module is responsible for creating the target labels used for training
BatDetect2 models. It converts sound event annotations for an audio clip into
the specific multi-channel heatmap formats required by the neural network.
It uses a pre-configured object adhering to the `TargetProtocol` (from
`batdetect2.targets`) which encapsulates all the logic for filtering
annotations, transforming tags, encoding class names, and mapping annotation
geometry (ROIs) to target positions and sizes. This module then focuses on
rendering this information onto the heatmap grids.
The pipeline generates three core outputs for a given spectrogram:
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
2. **Class Heatmap**: Indicates location and class identity for specifically
classified events.
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
The primary function generated by this module is a `ClipLabeller` (defined in
`.types`), which takes a `ClipAnnotation` object and its corresponding
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`.
"""
import logging
from collections.abc import Iterable
from functools import partial
from typing import Optional
import numpy as np
import xarray as xr
from scipy.ndimage import gaussian_filter
from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.types import (
ClipLabeller,
Heatmaps,
)
__all__ = [
"LabelConfig",
"build_clip_labeler",
"generate_clip_label",
"generate_heatmaps",
"load_label_config",
]
SIZE_DIMENSION = "dimension"
"""Dimension name for the size heatmap."""
logger = logging.getLogger(__name__)
class LabelConfig(BaseConfig):
"""Configuration parameters for heatmap generation.
Attributes
----------
sigma : float, default=3.0
The standard deviation (in pixels/bins) of the Gaussian kernel applied
to smooth the detection and class heatmaps. Larger values create more
diffuse targets.
"""
sigma: float = 3.0
def build_clip_labeler(
targets: TargetProtocol,
config: LabelConfig,
) -> ClipLabeller:
"""Construct the final clip labelling function.
This factory function prepares the callable that will perform the
end-to-end heatmap generation for a given clip and spectrogram during
training data loading. It takes the fully configured `targets` object and
the `LabelConfig` and binds them to the `generate_clip_label` function.
Parameters
----------
targets : TargetProtocol
An initialized object conforming to the `TargetProtocol`, providing all
necessary methods for filtering, transforming, encoding, and ROI
mapping.
config : LabelConfig
Configuration object containing heatmap generation parameters.
Returns
-------
ClipLabeller
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
return partial(
generate_clip_label,
targets=targets,
config=config,
)
def generate_clip_label(
clip_annotation: data.ClipAnnotation,
spec: xr.DataArray,
targets: TargetProtocol,
config: LabelConfig,
) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip.
This function orchestrates the target generation process for one clip:
1. Filters and transforms sound events using `targets.filter` and
`targets.transform`.
2. Passes the resulting processed annotations, along with the spectrogram,
the `targets` object, and the Gaussian `sigma` from `config`, to the
core `generate_heatmaps` function.
Parameters
----------
clip_annotation : data.ClipAnnotation
The complete annotation data for the audio clip, including the list
of `sound_events` to process.
spec : xr.DataArray
The spectrogram corresponding to the `clip_annotation`. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object, providing methods for
filtering, transformation, encoding, and ROI mapping.
config : LabelConfig
Configuration object providing heatmap parameters (primarily `sigma`).
Returns
-------
Heatmaps
A NamedTuple containing the generated 'detection', 'classes', and 'size'
heatmaps for this clip.
"""
return generate_heatmaps(
(
targets.transform(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
),
spec=spec,
targets=targets,
target_sigma=config.sigma,
)
def generate_heatmaps(
sound_events: Iterable[data.SoundEventAnnotation],
spec: xr.DataArray,
targets: TargetProtocol,
target_sigma: float = 3.0,
dtype=np.float32,
) -> Heatmaps:
"""Generate detection, class, and size heatmaps from sound events.
Creates heatmap representations from an iterable of sound event
annotations. This function relies on the provided `targets` object to get
the reference position, scaled size, and class encoding for each
annotation.
Process:
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
metadata.
2. Iterates through `sound_events`.
3. For each event:
a. Gets geometry. Skips if missing.
b. Gets reference position using `targets.get_position()`. Skips if out
of bounds.
c. Places a peak (1.0) on the detection heatmap at the position.
d. Gets scaled size using `targets.get_size()` and places it on the
size heatmap.
e. Encodes class using `targets.encode()` and places a peak (1.0) on
the corresponding class heatmap layer if a specific class is
returned.
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
heatmaps.
5. Normalizes detection and class heatmaps to range [0, 1].
Parameters
----------
sound_events : Iterable[data.SoundEventAnnotation]
An iterable of sound event annotations to render onto the heatmaps.
spec : xr.DataArray
The spectrogram array corresponding to the time/frequency range of
the annotations. Used for shape and coordinate information. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object. Used to access class
names, dimension names, and the methods `get_position`, `get_size`,
`encode`.
target_sigma : float, default=3.0
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
smooth the detection and class heatmaps.
dtype : type, default=np.float32
The data type for the generated heatmap arrays (e.g., `np.float32`).
Returns
-------
Heatmaps
A NamedTuple containing the 'detection', 'classes', and 'size'
xarray DataArrays, ready for use in model training.
Raises
------
ValueError
If the input spectrogram `spec` does not have both 'time' and
'frequency' dimensions, or if `targets.class_names` is empty.
"""
shape = dict(zip(spec.dims, spec.shape))
if "time" not in shape or "frequency" not in shape:
raise ValueError(
"Spectrogram must have time and frequency dimensions."
)
# Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray(
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
dims=["category", *spec.dims],
coords={
"category": [*targets.class_names],
**spec.coords,
},
)
size_heatmap = xr.DataArray(
data=np.zeros((2, *spec.shape), dtype=dtype),
dims=[SIZE_DIMENSION, *spec.dims],
coords={
SIZE_DIMENSION: targets.dimension_names,
**spec.coords,
},
)
for sound_event_annotation in sound_events:
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
"Skipping annotation %s: missing geometry.",
sound_event_annotation.uuid,
)
continue
# Get the position of the sound event
time, frequency = targets.get_position(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap
try:
detection_heatmap = arrays.set_value_at_pos(
detection_heatmap,
1.0,
time=time,
frequency=frequency,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s: position outside spectrogram. "
"Pos: %s",
sound_event_annotation.uuid,
(time, frequency),
)
continue
size = targets.get_size(sound_event_annotation)
size_heatmap = arrays.set_value_at_pos(
size_heatmap,
size,
time=time,
frequency=frequency,
)
# Get the class name of the sound event
try:
class_name = targets.encode(sound_event_annotation)
except ValueError as e:
logger.warning(
"Skipping annotation %s: Unexpected error while encoding "
"class name %s",
sound_event_annotation.uuid,
e,
)
continue
if class_name is None:
# If the label is None skip the sound event
continue
try:
class_heatmap = arrays.set_value_at_pos(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=class_name,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s for class heatmap: "
"position outside spectrogram. Pos: %s",
sound_event_annotation.uuid,
(class_name, time, frequency),
)
continue
# Apply gaussian filters
detection_heatmap = xr.apply_ufunc(
gaussian_filter,
detection_heatmap,
target_sigma,
)
class_heatmap = class_heatmap.groupby("category").map(
gaussian_filter, # type: ignore
args=(target_sigma,),
)
# Normalize heatmaps
detection_heatmap = (
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
class_heatmap = (
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
return Heatmaps(
detection=detection_heatmap,
classes=class_heatmap,
size=size_heatmap,
)
def load_label_config(
path: data.PathLike, field: Optional[str] = None
) -> LabelConfig:
"""Load the heatmap label generation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML or JSON).
field : str, optional
If the label configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
LabelConfig
The loaded and validated label configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the LabelConfig schema.
"""
return load_config(path, schema=LabelConfig, field=field)