mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
367 lines
12 KiB
Python
367 lines
12 KiB
Python
"""Generate heatmap training targets for BatDetect2 models.
|
|
|
|
This module is responsible for creating the target labels used for training
|
|
BatDetect2 models. It converts sound event annotations for an audio clip into
|
|
the specific multi-channel heatmap formats required by the neural network.
|
|
|
|
It uses a pre-configured object adhering to the `TargetProtocol` (from
|
|
`batdetect2.targets`) which encapsulates all the logic for filtering
|
|
annotations, transforming tags, encoding class names, and mapping annotation
|
|
geometry (ROIs) to target positions and sizes. This module then focuses on
|
|
rendering this information onto the heatmap grids.
|
|
|
|
The pipeline generates three core outputs for a given spectrogram:
|
|
1. **Detection Heatmap**: Indicates presence/location of relevant sound events.
|
|
2. **Class Heatmap**: Indicates location and class identity for specifically
|
|
classified events.
|
|
3. **Size Heatmap**: Encodes the target dimensions (width, height) of events.
|
|
|
|
The primary function generated by this module is a `ClipLabeller` (defined in
|
|
`.types`), which takes a `ClipAnnotation` object and its corresponding
|
|
spectrogram and returns the calculated `Heatmaps` tuple. The main configurable
|
|
parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
|
defined in `LabelConfig`.
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import Iterable
|
|
from functools import partial
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import xarray as xr
|
|
from scipy.ndimage import gaussian_filter
|
|
from soundevent import arrays, data
|
|
|
|
from batdetect2.configs import BaseConfig, load_config
|
|
from batdetect2.targets.types import TargetProtocol
|
|
from batdetect2.train.types import (
|
|
ClipLabeller,
|
|
Heatmaps,
|
|
)
|
|
|
|
__all__ = [
|
|
"LabelConfig",
|
|
"build_clip_labeler",
|
|
"generate_clip_label",
|
|
"generate_heatmaps",
|
|
"load_label_config",
|
|
]
|
|
|
|
|
|
SIZE_DIMENSION = "dimension"
|
|
"""Dimension name for the size heatmap."""
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class LabelConfig(BaseConfig):
|
|
"""Configuration parameters for heatmap generation.
|
|
|
|
Attributes
|
|
----------
|
|
sigma : float, default=3.0
|
|
The standard deviation (in pixels/bins) of the Gaussian kernel applied
|
|
to smooth the detection and class heatmaps. Larger values create more
|
|
diffuse targets.
|
|
"""
|
|
|
|
sigma: float = 3.0
|
|
|
|
|
|
def build_clip_labeler(
|
|
targets: TargetProtocol,
|
|
config: LabelConfig,
|
|
) -> ClipLabeller:
|
|
"""Construct the final clip labelling function.
|
|
|
|
This factory function prepares the callable that will perform the
|
|
end-to-end heatmap generation for a given clip and spectrogram during
|
|
training data loading. It takes the fully configured `targets` object and
|
|
the `LabelConfig` and binds them to the `generate_clip_label` function.
|
|
|
|
Parameters
|
|
----------
|
|
targets : TargetProtocol
|
|
An initialized object conforming to the `TargetProtocol`, providing all
|
|
necessary methods for filtering, transforming, encoding, and ROI
|
|
mapping.
|
|
config : LabelConfig
|
|
Configuration object containing heatmap generation parameters.
|
|
|
|
Returns
|
|
-------
|
|
ClipLabeller
|
|
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
|
|
(spectrogram) and returns the generated `Heatmaps`.
|
|
"""
|
|
return partial(
|
|
generate_clip_label,
|
|
targets=targets,
|
|
config=config,
|
|
)
|
|
|
|
|
|
def generate_clip_label(
|
|
clip_annotation: data.ClipAnnotation,
|
|
spec: xr.DataArray,
|
|
targets: TargetProtocol,
|
|
config: LabelConfig,
|
|
) -> Heatmaps:
|
|
"""Generate training heatmaps for a single annotated clip.
|
|
|
|
This function orchestrates the target generation process for one clip:
|
|
1. Filters and transforms sound events using `targets.filter` and
|
|
`targets.transform`.
|
|
2. Passes the resulting processed annotations, along with the spectrogram,
|
|
the `targets` object, and the Gaussian `sigma` from `config`, to the
|
|
core `generate_heatmaps` function.
|
|
|
|
Parameters
|
|
----------
|
|
clip_annotation : data.ClipAnnotation
|
|
The complete annotation data for the audio clip, including the list
|
|
of `sound_events` to process.
|
|
spec : xr.DataArray
|
|
The spectrogram corresponding to the `clip_annotation`. Must have
|
|
'time' and 'frequency' dimensions/coordinates.
|
|
targets : TargetProtocol
|
|
The fully configured target definition object, providing methods for
|
|
filtering, transformation, encoding, and ROI mapping.
|
|
config : LabelConfig
|
|
Configuration object providing heatmap parameters (primarily `sigma`).
|
|
|
|
Returns
|
|
-------
|
|
Heatmaps
|
|
A NamedTuple containing the generated 'detection', 'classes', and 'size'
|
|
heatmaps for this clip.
|
|
"""
|
|
return generate_heatmaps(
|
|
(
|
|
targets.transform(sound_event_annotation)
|
|
for sound_event_annotation in clip_annotation.sound_events
|
|
if targets.filter(sound_event_annotation)
|
|
),
|
|
spec=spec,
|
|
targets=targets,
|
|
target_sigma=config.sigma,
|
|
)
|
|
|
|
|
|
def generate_heatmaps(
|
|
sound_events: Iterable[data.SoundEventAnnotation],
|
|
spec: xr.DataArray,
|
|
targets: TargetProtocol,
|
|
target_sigma: float = 3.0,
|
|
dtype=np.float32,
|
|
) -> Heatmaps:
|
|
"""Generate detection, class, and size heatmaps from sound events.
|
|
|
|
Creates heatmap representations from an iterable of sound event
|
|
annotations. This function relies on the provided `targets` object to get
|
|
the reference position, scaled size, and class encoding for each
|
|
annotation.
|
|
|
|
Process:
|
|
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
|
|
metadata.
|
|
2. Iterates through `sound_events`.
|
|
3. For each event:
|
|
a. Gets geometry. Skips if missing.
|
|
b. Gets reference position using `targets.get_position()`. Skips if out
|
|
of bounds.
|
|
c. Places a peak (1.0) on the detection heatmap at the position.
|
|
d. Gets scaled size using `targets.get_size()` and places it on the
|
|
size heatmap.
|
|
e. Encodes class using `targets.encode()` and places a peak (1.0) on
|
|
the corresponding class heatmap layer if a specific class is
|
|
returned.
|
|
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
|
|
heatmaps.
|
|
5. Normalizes detection and class heatmaps to range [0, 1].
|
|
|
|
Parameters
|
|
----------
|
|
sound_events : Iterable[data.SoundEventAnnotation]
|
|
An iterable of sound event annotations to render onto the heatmaps.
|
|
spec : xr.DataArray
|
|
The spectrogram array corresponding to the time/frequency range of
|
|
the annotations. Used for shape and coordinate information. Must have
|
|
'time' and 'frequency' dimensions/coordinates.
|
|
targets : TargetProtocol
|
|
The fully configured target definition object. Used to access class
|
|
names, dimension names, and the methods `get_position`, `get_size`,
|
|
`encode`.
|
|
target_sigma : float, default=3.0
|
|
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
|
smooth the detection and class heatmaps.
|
|
dtype : type, default=np.float32
|
|
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
|
|
|
Returns
|
|
-------
|
|
Heatmaps
|
|
A NamedTuple containing the 'detection', 'classes', and 'size'
|
|
xarray DataArrays, ready for use in model training.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If the input spectrogram `spec` does not have both 'time' and
|
|
'frequency' dimensions, or if `targets.class_names` is empty.
|
|
"""
|
|
shape = dict(zip(spec.dims, spec.shape))
|
|
|
|
if "time" not in shape or "frequency" not in shape:
|
|
raise ValueError(
|
|
"Spectrogram must have time and frequency dimensions."
|
|
)
|
|
|
|
# Initialize heatmaps
|
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
|
class_heatmap = xr.DataArray(
|
|
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
|
|
dims=["category", *spec.dims],
|
|
coords={
|
|
"category": [*targets.class_names],
|
|
**spec.coords,
|
|
},
|
|
)
|
|
size_heatmap = xr.DataArray(
|
|
data=np.zeros((2, *spec.shape), dtype=dtype),
|
|
dims=[SIZE_DIMENSION, *spec.dims],
|
|
coords={
|
|
SIZE_DIMENSION: targets.dimension_names,
|
|
**spec.coords,
|
|
},
|
|
)
|
|
|
|
for sound_event_annotation in sound_events:
|
|
geom = sound_event_annotation.sound_event.geometry
|
|
if geom is None:
|
|
logger.debug(
|
|
"Skipping annotation %s: missing geometry.",
|
|
sound_event_annotation.uuid,
|
|
)
|
|
continue
|
|
|
|
# Get the position of the sound event
|
|
time, frequency = targets.get_position(sound_event_annotation)
|
|
|
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
|
try:
|
|
detection_heatmap = arrays.set_value_at_pos(
|
|
detection_heatmap,
|
|
1.0,
|
|
time=time,
|
|
frequency=frequency,
|
|
)
|
|
except KeyError:
|
|
# Skip the sound event if the position is outside the spectrogram
|
|
logger.debug(
|
|
"Skipping annotation %s: position outside spectrogram. "
|
|
"Pos: %s",
|
|
sound_event_annotation.uuid,
|
|
(time, frequency),
|
|
)
|
|
continue
|
|
|
|
size = targets.get_size(sound_event_annotation)
|
|
|
|
size_heatmap = arrays.set_value_at_pos(
|
|
size_heatmap,
|
|
size,
|
|
time=time,
|
|
frequency=frequency,
|
|
)
|
|
|
|
# Get the class name of the sound event
|
|
try:
|
|
class_name = targets.encode(sound_event_annotation)
|
|
except ValueError as e:
|
|
logger.warning(
|
|
"Skipping annotation %s: Unexpected error while encoding "
|
|
"class name %s",
|
|
sound_event_annotation.uuid,
|
|
e,
|
|
)
|
|
continue
|
|
|
|
if class_name is None:
|
|
# If the label is None skip the sound event
|
|
continue
|
|
|
|
try:
|
|
class_heatmap = arrays.set_value_at_pos(
|
|
class_heatmap,
|
|
1.0,
|
|
time=time,
|
|
frequency=frequency,
|
|
category=class_name,
|
|
)
|
|
except KeyError:
|
|
# Skip the sound event if the position is outside the spectrogram
|
|
logger.debug(
|
|
"Skipping annotation %s for class heatmap: "
|
|
"position outside spectrogram. Pos: %s",
|
|
sound_event_annotation.uuid,
|
|
(class_name, time, frequency),
|
|
)
|
|
continue
|
|
|
|
# Apply gaussian filters
|
|
detection_heatmap = xr.apply_ufunc(
|
|
gaussian_filter,
|
|
detection_heatmap,
|
|
target_sigma,
|
|
)
|
|
|
|
class_heatmap = class_heatmap.groupby("category").map(
|
|
gaussian_filter, # type: ignore
|
|
args=(target_sigma,),
|
|
)
|
|
|
|
# Normalize heatmaps
|
|
detection_heatmap = (
|
|
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
|
|
).fillna(0.0)
|
|
|
|
class_heatmap = (
|
|
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
|
).fillna(0.0)
|
|
|
|
return Heatmaps(
|
|
detection=detection_heatmap,
|
|
classes=class_heatmap,
|
|
size=size_heatmap,
|
|
)
|
|
|
|
|
|
def load_label_config(
|
|
path: data.PathLike, field: Optional[str] = None
|
|
) -> LabelConfig:
|
|
"""Load the heatmap label generation configuration from a file.
|
|
|
|
Parameters
|
|
----------
|
|
path : data.PathLike
|
|
Path to the configuration file (e.g., YAML or JSON).
|
|
field : str, optional
|
|
If the label configuration is nested under a specific key in the
|
|
file, specify the key here. Defaults to None.
|
|
|
|
Returns
|
|
-------
|
|
LabelConfig
|
|
The loaded and validated label configuration object.
|
|
|
|
Raises
|
|
------
|
|
FileNotFoundError
|
|
If the config file path does not exist.
|
|
pydantic.ValidationError
|
|
If the config file structure does not match the LabelConfig schema.
|
|
"""
|
|
return load_config(path, schema=LabelConfig, field=field)
|