mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add extensive documentation for the labels module
This commit is contained in:
parent
62471664fa
commit
0778663a2c
@ -1,47 +1,330 @@
|
||||
"""Generate heatmap training targets for BatDetect2 models.
|
||||
|
||||
This module represents the final step in the `batdetect2.targets` pipeline,
|
||||
converting processed sound event annotations from an audio clip into the
|
||||
specific heatmap formats required for training the BatDetect2 neural network.
|
||||
|
||||
It integrates the filtering, transformation, and class encoding logic defined
|
||||
in the preceding configuration steps (`filtering`, `transform`, `classes`)
|
||||
and applies them to generate three core outputs for a given spectrogram:
|
||||
|
||||
1. **Detection Heatmap**: Indicates the presence and location of relevant
|
||||
sound events.
|
||||
2. **Class Heatmap**: Indicates the location and predicted class label for
|
||||
events that match a specific target class.
|
||||
3. **Size Heatmap**: Encodes the dimensions (width/time duration,
|
||||
height/frequency bandwidth) of the detected sound events at their
|
||||
reference locations.
|
||||
|
||||
The primary function generated by this module is a `ClipLabeller`, which takes
|
||||
a `ClipAnnotation` object and its corresponding spectrogram (`xr.DataArray`)
|
||||
and returns the calculated `Heatmaps`. Configuration options allow tuning of
|
||||
the heatmap generation process (e.g., Gaussian smoothing sigma, reference point
|
||||
within bounding boxes).
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, List, Optional, Sequence, Tuple
|
||||
from functools import partial
|
||||
from typing import Callable, List, NamedTuple, Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import arrays, data, geometry
|
||||
from soundevent.geometry.operations import Positions
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.targets.classes import SoundEventEncoder
|
||||
from batdetect2.targets.filtering import SoundEventFilter
|
||||
from batdetect2.targets.transform import SoundEventTransformation
|
||||
|
||||
__all__ = [
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"Heatmaps",
|
||||
"ClipLabeller",
|
||||
"build_clip_labeler",
|
||||
"generate_clip_label",
|
||||
"generate_heatmaps",
|
||||
"load_label_config",
|
||||
]
|
||||
|
||||
|
||||
class HeatmapsConfig(BaseConfig):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Heatmaps(NamedTuple):
|
||||
"""Structure holding the generated heatmap targets.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
detection : xr.DataArray
|
||||
Heatmap indicating the probability of sound event presence. Typically
|
||||
smoothed with a Gaussian kernel centered on event reference points.
|
||||
Shape matches the input spectrogram. Values normalized [0, 1].
|
||||
classes : xr.DataArray
|
||||
Heatmap indicating the probability of specific class presence. Has an
|
||||
additional 'category' dimension corresponding to the target class
|
||||
names. Each category slice is typically smoothed with a Gaussian
|
||||
kernel. Values normalized [0, 1] per category.
|
||||
size : xr.DataArray
|
||||
Heatmap encoding the size (width, height) of detected events. Has an
|
||||
additional 'dimension' coordinate ('width', 'height'). Values represent
|
||||
scaled dimensions placed at the event reference points.
|
||||
"""
|
||||
|
||||
detection: xr.DataArray
|
||||
classes: xr.DataArray
|
||||
size: xr.DataArray
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
||||
"""Type alias for the final clip labelling function.
|
||||
|
||||
This function takes the complete annotations for a clip and the corresponding
|
||||
spectrogram, applies all configured filtering, transformation, and encoding
|
||||
steps, and returns the final `Heatmaps` used for model training.
|
||||
"""
|
||||
|
||||
|
||||
class LabelConfig(BaseConfig):
|
||||
"""Configuration parameters for heatmap generation.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
position : Positions, default="bottom-left"
|
||||
Specifies the reference point within each sound event's geometry
|
||||
(bounding box) that is used to place the 'peak' or value on the
|
||||
heatmaps. Options include 'center', 'bottom-left', 'top-right', etc.
|
||||
See `soundevent.geometry.operations.Positions` for valid options.
|
||||
sigma : float, default=3.0
|
||||
The standard deviation (in pixels/bins) of the Gaussian kernel applied
|
||||
to smooth the detection and class heatmaps. Larger values create more
|
||||
diffuse targets.
|
||||
time_scale : float, default=1000.0
|
||||
A scaling factor applied to the time duration (width) of sound event
|
||||
bounding boxes before storing them in the 'width' dimension of the
|
||||
`size` heatmap. The appropriate value depends on how the model input/
|
||||
output resolution relates to physical time units (e.g., if time axis
|
||||
is milliseconds, scale might be 1.0; if seconds, maybe 1000.0). Needs
|
||||
to match model expectations.
|
||||
frequency_scale : float, default=1/859.375
|
||||
A scaling factor applied to the frequency bandwidth (height) of sound
|
||||
event bounding boxes before storing them in the 'height' dimension of
|
||||
the `size` heatmap. The appropriate value depends on the relationship
|
||||
between the frequency axis resolution (e.g., kHz or Hz per bin) and
|
||||
the desired output units/scale for the model. Needs to match model
|
||||
expectations. (The default suggests input frequency might be in Hz
|
||||
and output is scaled relative to some reference).
|
||||
"""
|
||||
|
||||
position: Positions = "bottom-left"
|
||||
sigma: float = 3.0
|
||||
time_scale: float = 1000.0
|
||||
frequency_scale: float = 1 / 859.375
|
||||
|
||||
|
||||
class LabelConfig(BaseConfig):
|
||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||
def build_clip_labeler(
|
||||
filter_fn: SoundEventFilter,
|
||||
transform_fn: SoundEventTransformation,
|
||||
encoder_fn: SoundEventEncoder,
|
||||
class_names: List[str],
|
||||
config: LabelConfig,
|
||||
) -> ClipLabeller:
|
||||
"""Construct the clip labelling function.
|
||||
|
||||
This function takes the pre-built components from the previous target
|
||||
definition steps (filtering, transformation, encoding) and the label
|
||||
configuration, then returns a single callable (`ClipLabeller`) that
|
||||
performs the end-to-end heatmap generation for a given clip and
|
||||
spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter_fn : SoundEventFilter
|
||||
Function to filter irrelevant sound event annotations.
|
||||
transform_fn : SoundEventTransformation
|
||||
Function to transform tags of sound event annotations.
|
||||
encoder_fn : SoundEventEncoder
|
||||
Function to encode a sound event annotation into a class name.
|
||||
class_names : List[str]
|
||||
Ordered list of unique target class names for the classification
|
||||
heatmap.
|
||||
config : LabelConfig
|
||||
Configuration object containing heatmap generation parameters (sigma,
|
||||
etc.).
|
||||
|
||||
Returns
|
||||
-------
|
||||
ClipLabeller
|
||||
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
|
||||
(spectrogram) and returns the generated `Heatmaps`.
|
||||
"""
|
||||
return partial(
|
||||
generate_clip_label,
|
||||
filter_fn=filter_fn,
|
||||
transform_fn=transform_fn,
|
||||
encoder_fn=encoder_fn,
|
||||
class_names=class_names,
|
||||
config=config,
|
||||
)
|
||||
|
||||
|
||||
def generate_clip_label(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: xr.DataArray,
|
||||
filter_fn: SoundEventFilter,
|
||||
transform_fn: SoundEventTransformation,
|
||||
encoder_fn: SoundEventEncoder,
|
||||
class_names: List[str],
|
||||
config: LabelConfig,
|
||||
) -> Heatmaps:
|
||||
"""Generate heatmaps for a single clip by applying all processing steps.
|
||||
|
||||
This function orchestrates the process for one clip:
|
||||
1. Filters the sound events using `filter_fn`.
|
||||
2. Transforms the tags of filtered events using `transform_fn`.
|
||||
3. Passes the processed annotations and other parameters to
|
||||
`generate_heatmaps` to create the final target heatmaps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip_annotation : data.ClipAnnotation
|
||||
The complete annotation data for the audio clip.
|
||||
spec : xr.DataArray
|
||||
The spectrogram corresponding to the `clip_annotation`.
|
||||
filter_fn : SoundEventFilter
|
||||
Function to filter sound event annotations.
|
||||
transform_fn : SoundEventTransformation
|
||||
Function to transform tags of sound event annotations.
|
||||
encoder_fn : SoundEventEncoder
|
||||
Function to encode a sound event annotation into a class name.
|
||||
class_names : List[str]
|
||||
Ordered list of unique target class names.
|
||||
config : LabelConfig
|
||||
Configuration object containing heatmap generation parameters.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Heatmaps
|
||||
The generated detection, classes, and size heatmaps for the clip.
|
||||
"""
|
||||
return generate_heatmaps(
|
||||
(
|
||||
transform_fn(sound_event_annotation)
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if filter_fn(sound_event_annotation)
|
||||
),
|
||||
spec=spec,
|
||||
class_names=class_names,
|
||||
encoder=encoder_fn,
|
||||
target_sigma=config.sigma,
|
||||
position=config.position,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
class_names: List[str],
|
||||
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
|
||||
encoder: SoundEventEncoder,
|
||||
target_sigma: float = 3.0,
|
||||
position: Positions = "bottom-left",
|
||||
time_scale: float = 1000.0,
|
||||
frequency_scale: float = 1 / 859.375,
|
||||
dtype=np.float32,
|
||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||
) -> Heatmaps:
|
||||
"""Generate detection, class, and size heatmaps from sound events.
|
||||
|
||||
Processes an iterable of sound event annotations (assumed to be already
|
||||
filtered and transformed) and creates heatmap representations suitable
|
||||
for training models like BatDetect2.
|
||||
|
||||
The process involves:
|
||||
1. Initializing empty heatmaps based on the spectrogram shape.
|
||||
2. Iterating through each sound event.
|
||||
3. For each event, finding its reference point and placing a '1.0'
|
||||
on the detection heatmap at that point.
|
||||
4. Calculating the scaled bounding box size and placing it on the size
|
||||
heatmap at the reference point.
|
||||
5. Encoding the event to get its class name and placing a '1.0' on the
|
||||
corresponding class heatmap slice at the reference point (if
|
||||
classified).
|
||||
6. Applying Gaussian smoothing to detection and class heatmaps.
|
||||
7. Normalizing detection and class heatmaps to the range [0, 1].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_events : Iterable[data.SoundEventAnnotation]
|
||||
An iterable of sound event annotations to include in the heatmaps.
|
||||
These should ideally be the result of prior filtering and tag
|
||||
transformation steps.
|
||||
spec : xr.DataArray
|
||||
The spectrogram array corresponding to the time/frequency range of
|
||||
the annotations. Used for shape and coordinate information. Must have
|
||||
'time' and 'frequency' dimensions.
|
||||
class_names : List[str]
|
||||
An ordered list of unique class names. The class heatmap will have
|
||||
a channel ('category' dimension) for each name in this list. Must not
|
||||
be empty.
|
||||
encoder : SoundEventEncoder
|
||||
A function that takes a SoundEventAnnotation and returns the
|
||||
corresponding class name (str) or None if it doesn't belong to a
|
||||
specific class (e.g., it falls into the generic 'Bat' category).
|
||||
target_sigma : float, default=3.0
|
||||
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
||||
smooth the detection and class heatmaps after initial point placement.
|
||||
position : Positions, default="bottom-left"
|
||||
The reference point within each annotation's geometry bounding box
|
||||
used to place the signal on the heatmaps (e.g., "center",
|
||||
"bottom-left"). See `soundevent.geometry.operations.Positions`.
|
||||
time_scale : float, default=1000.0
|
||||
Scaling factor applied to the time duration (width in seconds) of
|
||||
annotations when storing them in the size heatmap. The resulting
|
||||
value's unit depends on this scale (e.g., 1000.0 might convert seconds
|
||||
to ms).
|
||||
frequency_scale : float, default=1/859.375
|
||||
Scaling factor applied to the frequency bandwidth (height in Hz or kHz)
|
||||
of annotations when storing them in the size heatmap. The resulting
|
||||
value's unit depends on this scale and the input unit. (Default
|
||||
scaling relative to ~860 Hz).
|
||||
dtype : type, default=np.float32
|
||||
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Heatmaps
|
||||
A NamedTuple containing the 'detection', 'classes', and 'size'
|
||||
xarray DataArrays, ready for use in model training.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input spectrogram `spec` does not have both 'time' and
|
||||
'frequency' dimensions, or if `class_names` is empty.
|
||||
|
||||
Notes
|
||||
-----
|
||||
* This function expects `sound_events` to be already filtered and
|
||||
transformed.
|
||||
* It includes error handling to skip individual annotations that cause
|
||||
issues (e.g., missing geometry, out-of-bounds coordinates, encoder
|
||||
errors) allowing the rest of the clip to be processed. Warnings or
|
||||
errors are logged.
|
||||
* The `time_scale` and `frequency_scale` parameters are crucial and must be
|
||||
set according to the expectations of the specific BatDetect2 model
|
||||
architecture being trained. Consult model documentation for required
|
||||
units/scales.
|
||||
* Gaussian filtering and normalization are applied only to detection and
|
||||
class heatmaps, not the size heatmap.
|
||||
"""
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
|
||||
if len(class_names) == 0:
|
||||
raise ValueError("No class names provided.")
|
||||
|
||||
if "time" not in shape or "frequency" not in shape:
|
||||
raise ValueError(
|
||||
"Spectrogram must have time and frequency dimensions."
|
||||
@ -69,6 +352,10 @@ def generate_heatmaps(
|
||||
for sound_event_annotation in sound_events:
|
||||
geom = sound_event_annotation.sound_event.geometry
|
||||
if geom is None:
|
||||
logger.debug(
|
||||
"Skipping annotation %s: missing geometry.",
|
||||
sound_event_annotation.uuid,
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the position of the sound event
|
||||
@ -84,6 +371,12 @@ def generate_heatmaps(
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
logger.debug(
|
||||
"Skipping annotation %s: position outside spectrogram. "
|
||||
"Pos: %s",
|
||||
sound_event_annotation.uuid,
|
||||
(time, frequency),
|
||||
)
|
||||
continue
|
||||
|
||||
# Set the size of the sound event at the position in the size heatmap
|
||||
@ -106,19 +399,51 @@ def generate_heatmaps(
|
||||
)
|
||||
|
||||
# Get the class name of the sound event
|
||||
class_name = encoder(sound_event_annotation.tags)
|
||||
try:
|
||||
class_name = encoder(sound_event_annotation)
|
||||
except ValueError as e:
|
||||
logger.warning(
|
||||
"Skipping annotation %s: Unexpected error while encoding "
|
||||
"class name %s",
|
||||
sound_event_annotation.uuid,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
if class_name is None:
|
||||
# If the label is None skip the sound event
|
||||
continue
|
||||
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
category=class_name,
|
||||
)
|
||||
if class_name not in class_names:
|
||||
# If the label is not in the class names skip the sound event
|
||||
logger.warning(
|
||||
(
|
||||
"Skipping annotation %s for class heatmap: "
|
||||
"class name '%s' not in class names. Class names: %s"
|
||||
),
|
||||
sound_event_annotation.uuid,
|
||||
class_name,
|
||||
class_names,
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
category=class_name,
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
logger.debug(
|
||||
"Skipping annotation %s for class heatmap: "
|
||||
"position outside spectrogram. Pos: %s",
|
||||
sound_event_annotation.uuid,
|
||||
(class_name, time, frequency),
|
||||
)
|
||||
continue
|
||||
|
||||
# Apply gaussian filters
|
||||
detection_heatmap = xr.apply_ufunc(
|
||||
@ -128,7 +453,7 @@ def generate_heatmaps(
|
||||
)
|
||||
|
||||
class_heatmap = class_heatmap.groupby("category").map(
|
||||
gaussian_filter, # type: ignore
|
||||
gaussian_filter,
|
||||
args=(target_sigma,),
|
||||
)
|
||||
|
||||
@ -141,10 +466,36 @@ def generate_heatmaps(
|
||||
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
||||
).fillna(0.0)
|
||||
|
||||
return detection_heatmap, class_heatmap, size_heatmap
|
||||
return Heatmaps(
|
||||
detection=detection_heatmap,
|
||||
classes=class_heatmap,
|
||||
size=size_heatmap,
|
||||
)
|
||||
|
||||
|
||||
def load_label_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> LabelConfig:
|
||||
"""Load the heatmap label generation configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML or JSON).
|
||||
field : str, optional
|
||||
If the label configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LabelConfig
|
||||
The loaded and validated label configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the LabelConfig schema.
|
||||
"""
|
||||
return load_config(path, schema=LabelConfig, field=field)
|
||||
|
Loading…
Reference in New Issue
Block a user