Add extensive documentation for the labels module

This commit is contained in:
mbsantiago 2025-04-15 19:25:58 +01:00
parent 62471664fa
commit 0778663a2c

View File

@ -1,47 +1,330 @@
"""Generate heatmap training targets for BatDetect2 models.
This module represents the final step in the `batdetect2.targets` pipeline,
converting processed sound event annotations from an audio clip into the
specific heatmap formats required for training the BatDetect2 neural network.
It integrates the filtering, transformation, and class encoding logic defined
in the preceding configuration steps (`filtering`, `transform`, `classes`)
and applies them to generate three core outputs for a given spectrogram:
1. **Detection Heatmap**: Indicates the presence and location of relevant
sound events.
2. **Class Heatmap**: Indicates the location and predicted class label for
events that match a specific target class.
3. **Size Heatmap**: Encodes the dimensions (width/time duration,
height/frequency bandwidth) of the detected sound events at their
reference locations.
The primary function generated by this module is a `ClipLabeller`, which takes
a `ClipAnnotation` object and its corresponding spectrogram (`xr.DataArray`)
and returns the calculated `Heatmaps`. Configuration options allow tuning of
the heatmap generation process (e.g., Gaussian smoothing sigma, reference point
within bounding boxes).
"""
import logging
from collections.abc import Iterable
from typing import Callable, List, Optional, Sequence, Tuple
from functools import partial
from typing import Callable, List, NamedTuple, Optional
import numpy as np
import xarray as xr
from pydantic import Field
from scipy.ndimage import gaussian_filter
from soundevent import arrays, data, geometry
from soundevent.geometry.operations import Positions
from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.classes import SoundEventEncoder
from batdetect2.targets.filtering import SoundEventFilter
from batdetect2.targets.transform import SoundEventTransformation
__all__ = [
"HeatmapsConfig",
"LabelConfig",
"Heatmaps",
"ClipLabeller",
"build_clip_labeler",
"generate_clip_label",
"generate_heatmaps",
"load_label_config",
]
class HeatmapsConfig(BaseConfig):
logger = logging.getLogger(__name__)
class Heatmaps(NamedTuple):
"""Structure holding the generated heatmap targets.
Attributes
----------
detection : xr.DataArray
Heatmap indicating the probability of sound event presence. Typically
smoothed with a Gaussian kernel centered on event reference points.
Shape matches the input spectrogram. Values normalized [0, 1].
classes : xr.DataArray
Heatmap indicating the probability of specific class presence. Has an
additional 'category' dimension corresponding to the target class
names. Each category slice is typically smoothed with a Gaussian
kernel. Values normalized [0, 1] per category.
size : xr.DataArray
Heatmap encoding the size (width, height) of detected events. Has an
additional 'dimension' coordinate ('width', 'height'). Values represent
scaled dimensions placed at the event reference points.
"""
detection: xr.DataArray
classes: xr.DataArray
size: xr.DataArray
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
"""Type alias for the final clip labelling function.
This function takes the complete annotations for a clip and the corresponding
spectrogram, applies all configured filtering, transformation, and encoding
steps, and returns the final `Heatmaps` used for model training.
"""
class LabelConfig(BaseConfig):
"""Configuration parameters for heatmap generation.
Attributes
----------
position : Positions, default="bottom-left"
Specifies the reference point within each sound event's geometry
(bounding box) that is used to place the 'peak' or value on the
heatmaps. Options include 'center', 'bottom-left', 'top-right', etc.
See `soundevent.geometry.operations.Positions` for valid options.
sigma : float, default=3.0
The standard deviation (in pixels/bins) of the Gaussian kernel applied
to smooth the detection and class heatmaps. Larger values create more
diffuse targets.
time_scale : float, default=1000.0
A scaling factor applied to the time duration (width) of sound event
bounding boxes before storing them in the 'width' dimension of the
`size` heatmap. The appropriate value depends on how the model input/
output resolution relates to physical time units (e.g., if time axis
is milliseconds, scale might be 1.0; if seconds, maybe 1000.0). Needs
to match model expectations.
frequency_scale : float, default=1/859.375
A scaling factor applied to the frequency bandwidth (height) of sound
event bounding boxes before storing them in the 'height' dimension of
the `size` heatmap. The appropriate value depends on the relationship
between the frequency axis resolution (e.g., kHz or Hz per bin) and
the desired output units/scale for the model. Needs to match model
expectations. (The default suggests input frequency might be in Hz
and output is scaled relative to some reference).
"""
position: Positions = "bottom-left"
sigma: float = 3.0
time_scale: float = 1000.0
frequency_scale: float = 1 / 859.375
class LabelConfig(BaseConfig):
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
def build_clip_labeler(
filter_fn: SoundEventFilter,
transform_fn: SoundEventTransformation,
encoder_fn: SoundEventEncoder,
class_names: List[str],
config: LabelConfig,
) -> ClipLabeller:
"""Construct the clip labelling function.
This function takes the pre-built components from the previous target
definition steps (filtering, transformation, encoding) and the label
configuration, then returns a single callable (`ClipLabeller`) that
performs the end-to-end heatmap generation for a given clip and
spectrogram.
Parameters
----------
filter_fn : SoundEventFilter
Function to filter irrelevant sound event annotations.
transform_fn : SoundEventTransformation
Function to transform tags of sound event annotations.
encoder_fn : SoundEventEncoder
Function to encode a sound event annotation into a class name.
class_names : List[str]
Ordered list of unique target class names for the classification
heatmap.
config : LabelConfig
Configuration object containing heatmap generation parameters (sigma,
etc.).
Returns
-------
ClipLabeller
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
return partial(
generate_clip_label,
filter_fn=filter_fn,
transform_fn=transform_fn,
encoder_fn=encoder_fn,
class_names=class_names,
config=config,
)
def generate_clip_label(
clip_annotation: data.ClipAnnotation,
spec: xr.DataArray,
filter_fn: SoundEventFilter,
transform_fn: SoundEventTransformation,
encoder_fn: SoundEventEncoder,
class_names: List[str],
config: LabelConfig,
) -> Heatmaps:
"""Generate heatmaps for a single clip by applying all processing steps.
This function orchestrates the process for one clip:
1. Filters the sound events using `filter_fn`.
2. Transforms the tags of filtered events using `transform_fn`.
3. Passes the processed annotations and other parameters to
`generate_heatmaps` to create the final target heatmaps.
Parameters
----------
clip_annotation : data.ClipAnnotation
The complete annotation data for the audio clip.
spec : xr.DataArray
The spectrogram corresponding to the `clip_annotation`.
filter_fn : SoundEventFilter
Function to filter sound event annotations.
transform_fn : SoundEventTransformation
Function to transform tags of sound event annotations.
encoder_fn : SoundEventEncoder
Function to encode a sound event annotation into a class name.
class_names : List[str]
Ordered list of unique target class names.
config : LabelConfig
Configuration object containing heatmap generation parameters.
Returns
-------
Heatmaps
The generated detection, classes, and size heatmaps for the clip.
"""
return generate_heatmaps(
(
transform_fn(sound_event_annotation)
for sound_event_annotation in clip_annotation.sound_events
if filter_fn(sound_event_annotation)
),
spec=spec,
class_names=class_names,
encoder=encoder_fn,
target_sigma=config.sigma,
position=config.position,
time_scale=config.time_scale,
frequency_scale=config.frequency_scale,
)
def generate_heatmaps(
sound_events: Sequence[data.SoundEventAnnotation],
sound_events: Iterable[data.SoundEventAnnotation],
spec: xr.DataArray,
class_names: List[str],
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
encoder: SoundEventEncoder,
target_sigma: float = 3.0,
position: Positions = "bottom-left",
time_scale: float = 1000.0,
frequency_scale: float = 1 / 859.375,
dtype=np.float32,
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
) -> Heatmaps:
"""Generate detection, class, and size heatmaps from sound events.
Processes an iterable of sound event annotations (assumed to be already
filtered and transformed) and creates heatmap representations suitable
for training models like BatDetect2.
The process involves:
1. Initializing empty heatmaps based on the spectrogram shape.
2. Iterating through each sound event.
3. For each event, finding its reference point and placing a '1.0'
on the detection heatmap at that point.
4. Calculating the scaled bounding box size and placing it on the size
heatmap at the reference point.
5. Encoding the event to get its class name and placing a '1.0' on the
corresponding class heatmap slice at the reference point (if
classified).
6. Applying Gaussian smoothing to detection and class heatmaps.
7. Normalizing detection and class heatmaps to the range [0, 1].
Parameters
----------
sound_events : Iterable[data.SoundEventAnnotation]
An iterable of sound event annotations to include in the heatmaps.
These should ideally be the result of prior filtering and tag
transformation steps.
spec : xr.DataArray
The spectrogram array corresponding to the time/frequency range of
the annotations. Used for shape and coordinate information. Must have
'time' and 'frequency' dimensions.
class_names : List[str]
An ordered list of unique class names. The class heatmap will have
a channel ('category' dimension) for each name in this list. Must not
be empty.
encoder : SoundEventEncoder
A function that takes a SoundEventAnnotation and returns the
corresponding class name (str) or None if it doesn't belong to a
specific class (e.g., it falls into the generic 'Bat' category).
target_sigma : float, default=3.0
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
smooth the detection and class heatmaps after initial point placement.
position : Positions, default="bottom-left"
The reference point within each annotation's geometry bounding box
used to place the signal on the heatmaps (e.g., "center",
"bottom-left"). See `soundevent.geometry.operations.Positions`.
time_scale : float, default=1000.0
Scaling factor applied to the time duration (width in seconds) of
annotations when storing them in the size heatmap. The resulting
value's unit depends on this scale (e.g., 1000.0 might convert seconds
to ms).
frequency_scale : float, default=1/859.375
Scaling factor applied to the frequency bandwidth (height in Hz or kHz)
of annotations when storing them in the size heatmap. The resulting
value's unit depends on this scale and the input unit. (Default
scaling relative to ~860 Hz).
dtype : type, default=np.float32
The data type for the generated heatmap arrays (e.g., `np.float32`).
Returns
-------
Heatmaps
A NamedTuple containing the 'detection', 'classes', and 'size'
xarray DataArrays, ready for use in model training.
Raises
------
ValueError
If the input spectrogram `spec` does not have both 'time' and
'frequency' dimensions, or if `class_names` is empty.
Notes
-----
* This function expects `sound_events` to be already filtered and
transformed.
* It includes error handling to skip individual annotations that cause
issues (e.g., missing geometry, out-of-bounds coordinates, encoder
errors) allowing the rest of the clip to be processed. Warnings or
errors are logged.
* The `time_scale` and `frequency_scale` parameters are crucial and must be
set according to the expectations of the specific BatDetect2 model
architecture being trained. Consult model documentation for required
units/scales.
* Gaussian filtering and normalization are applied only to detection and
class heatmaps, not the size heatmap.
"""
shape = dict(zip(spec.dims, spec.shape))
if len(class_names) == 0:
raise ValueError("No class names provided.")
if "time" not in shape or "frequency" not in shape:
raise ValueError(
"Spectrogram must have time and frequency dimensions."
@ -69,6 +352,10 @@ def generate_heatmaps(
for sound_event_annotation in sound_events:
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
"Skipping annotation %s: missing geometry.",
sound_event_annotation.uuid,
)
continue
# Get the position of the sound event
@ -84,6 +371,12 @@ def generate_heatmaps(
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s: position outside spectrogram. "
"Pos: %s",
sound_event_annotation.uuid,
(time, frequency),
)
continue
# Set the size of the sound event at the position in the size heatmap
@ -106,19 +399,51 @@ def generate_heatmaps(
)
# Get the class name of the sound event
class_name = encoder(sound_event_annotation.tags)
try:
class_name = encoder(sound_event_annotation)
except ValueError as e:
logger.warning(
"Skipping annotation %s: Unexpected error while encoding "
"class name %s",
sound_event_annotation.uuid,
e,
)
continue
if class_name is None:
# If the label is None skip the sound event
continue
class_heatmap = arrays.set_value_at_pos(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=class_name,
)
if class_name not in class_names:
# If the label is not in the class names skip the sound event
logger.warning(
(
"Skipping annotation %s for class heatmap: "
"class name '%s' not in class names. Class names: %s"
),
sound_event_annotation.uuid,
class_name,
class_names,
)
continue
try:
class_heatmap = arrays.set_value_at_pos(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=class_name,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s for class heatmap: "
"position outside spectrogram. Pos: %s",
sound_event_annotation.uuid,
(class_name, time, frequency),
)
continue
# Apply gaussian filters
detection_heatmap = xr.apply_ufunc(
@ -128,7 +453,7 @@ def generate_heatmaps(
)
class_heatmap = class_heatmap.groupby("category").map(
gaussian_filter, # type: ignore
gaussian_filter,
args=(target_sigma,),
)
@ -141,10 +466,36 @@ def generate_heatmaps(
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
return detection_heatmap, class_heatmap, size_heatmap
return Heatmaps(
detection=detection_heatmap,
classes=class_heatmap,
size=size_heatmap,
)
def load_label_config(
path: data.PathLike, field: Optional[str] = None
) -> LabelConfig:
"""Load the heatmap label generation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML or JSON).
field : str, optional
If the label configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
LabelConfig
The loaded and validated label configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the LabelConfig schema.
"""
return load_config(path, schema=LabelConfig, field=field)