diff --git a/batdetect2/targets/labels.py b/batdetect2/targets/labels.py index 9f339b7..0c4bd1b 100644 --- a/batdetect2/targets/labels.py +++ b/batdetect2/targets/labels.py @@ -1,47 +1,330 @@ +"""Generate heatmap training targets for BatDetect2 models. + +This module represents the final step in the `batdetect2.targets` pipeline, +converting processed sound event annotations from an audio clip into the +specific heatmap formats required for training the BatDetect2 neural network. + +It integrates the filtering, transformation, and class encoding logic defined +in the preceding configuration steps (`filtering`, `transform`, `classes`) +and applies them to generate three core outputs for a given spectrogram: + +1. **Detection Heatmap**: Indicates the presence and location of relevant + sound events. +2. **Class Heatmap**: Indicates the location and predicted class label for + events that match a specific target class. +3. **Size Heatmap**: Encodes the dimensions (width/time duration, + height/frequency bandwidth) of the detected sound events at their + reference locations. + +The primary function generated by this module is a `ClipLabeller`, which takes +a `ClipAnnotation` object and its corresponding spectrogram (`xr.DataArray`) +and returns the calculated `Heatmaps`. Configuration options allow tuning of +the heatmap generation process (e.g., Gaussian smoothing sigma, reference point +within bounding boxes). +""" + +import logging from collections.abc import Iterable -from typing import Callable, List, Optional, Sequence, Tuple +from functools import partial +from typing import Callable, List, NamedTuple, Optional import numpy as np import xarray as xr -from pydantic import Field from scipy.ndimage import gaussian_filter from soundevent import arrays, data, geometry from soundevent.geometry.operations import Positions from batdetect2.configs import BaseConfig, load_config +from batdetect2.targets.classes import SoundEventEncoder +from batdetect2.targets.filtering import SoundEventFilter +from batdetect2.targets.transform import SoundEventTransformation __all__ = [ - "HeatmapsConfig", "LabelConfig", + "Heatmaps", + "ClipLabeller", + "build_clip_labeler", + "generate_clip_label", "generate_heatmaps", "load_label_config", ] -class HeatmapsConfig(BaseConfig): +logger = logging.getLogger(__name__) + + +class Heatmaps(NamedTuple): + """Structure holding the generated heatmap targets. + + Attributes + ---------- + detection : xr.DataArray + Heatmap indicating the probability of sound event presence. Typically + smoothed with a Gaussian kernel centered on event reference points. + Shape matches the input spectrogram. Values normalized [0, 1]. + classes : xr.DataArray + Heatmap indicating the probability of specific class presence. Has an + additional 'category' dimension corresponding to the target class + names. Each category slice is typically smoothed with a Gaussian + kernel. Values normalized [0, 1] per category. + size : xr.DataArray + Heatmap encoding the size (width, height) of detected events. Has an + additional 'dimension' coordinate ('width', 'height'). Values represent + scaled dimensions placed at the event reference points. + """ + + detection: xr.DataArray + classes: xr.DataArray + size: xr.DataArray + + +ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps] +"""Type alias for the final clip labelling function. + +This function takes the complete annotations for a clip and the corresponding +spectrogram, applies all configured filtering, transformation, and encoding +steps, and returns the final `Heatmaps` used for model training. +""" + + +class LabelConfig(BaseConfig): + """Configuration parameters for heatmap generation. + + Attributes + ---------- + position : Positions, default="bottom-left" + Specifies the reference point within each sound event's geometry + (bounding box) that is used to place the 'peak' or value on the + heatmaps. Options include 'center', 'bottom-left', 'top-right', etc. + See `soundevent.geometry.operations.Positions` for valid options. + sigma : float, default=3.0 + The standard deviation (in pixels/bins) of the Gaussian kernel applied + to smooth the detection and class heatmaps. Larger values create more + diffuse targets. + time_scale : float, default=1000.0 + A scaling factor applied to the time duration (width) of sound event + bounding boxes before storing them in the 'width' dimension of the + `size` heatmap. The appropriate value depends on how the model input/ + output resolution relates to physical time units (e.g., if time axis + is milliseconds, scale might be 1.0; if seconds, maybe 1000.0). Needs + to match model expectations. + frequency_scale : float, default=1/859.375 + A scaling factor applied to the frequency bandwidth (height) of sound + event bounding boxes before storing them in the 'height' dimension of + the `size` heatmap. The appropriate value depends on the relationship + between the frequency axis resolution (e.g., kHz or Hz per bin) and + the desired output units/scale for the model. Needs to match model + expectations. (The default suggests input frequency might be in Hz + and output is scaled relative to some reference). + """ + position: Positions = "bottom-left" sigma: float = 3.0 time_scale: float = 1000.0 frequency_scale: float = 1 / 859.375 -class LabelConfig(BaseConfig): - heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig) +def build_clip_labeler( + filter_fn: SoundEventFilter, + transform_fn: SoundEventTransformation, + encoder_fn: SoundEventEncoder, + class_names: List[str], + config: LabelConfig, +) -> ClipLabeller: + """Construct the clip labelling function. + + This function takes the pre-built components from the previous target + definition steps (filtering, transformation, encoding) and the label + configuration, then returns a single callable (`ClipLabeller`) that + performs the end-to-end heatmap generation for a given clip and + spectrogram. + + Parameters + ---------- + filter_fn : SoundEventFilter + Function to filter irrelevant sound event annotations. + transform_fn : SoundEventTransformation + Function to transform tags of sound event annotations. + encoder_fn : SoundEventEncoder + Function to encode a sound event annotation into a class name. + class_names : List[str] + Ordered list of unique target class names for the classification + heatmap. + config : LabelConfig + Configuration object containing heatmap generation parameters (sigma, + etc.). + + Returns + ------- + ClipLabeller + A function that accepts a `data.ClipAnnotation` and `xr.DataArray` + (spectrogram) and returns the generated `Heatmaps`. + """ + return partial( + generate_clip_label, + filter_fn=filter_fn, + transform_fn=transform_fn, + encoder_fn=encoder_fn, + class_names=class_names, + config=config, + ) + + +def generate_clip_label( + clip_annotation: data.ClipAnnotation, + spec: xr.DataArray, + filter_fn: SoundEventFilter, + transform_fn: SoundEventTransformation, + encoder_fn: SoundEventEncoder, + class_names: List[str], + config: LabelConfig, +) -> Heatmaps: + """Generate heatmaps for a single clip by applying all processing steps. + + This function orchestrates the process for one clip: + 1. Filters the sound events using `filter_fn`. + 2. Transforms the tags of filtered events using `transform_fn`. + 3. Passes the processed annotations and other parameters to + `generate_heatmaps` to create the final target heatmaps. + + Parameters + ---------- + clip_annotation : data.ClipAnnotation + The complete annotation data for the audio clip. + spec : xr.DataArray + The spectrogram corresponding to the `clip_annotation`. + filter_fn : SoundEventFilter + Function to filter sound event annotations. + transform_fn : SoundEventTransformation + Function to transform tags of sound event annotations. + encoder_fn : SoundEventEncoder + Function to encode a sound event annotation into a class name. + class_names : List[str] + Ordered list of unique target class names. + config : LabelConfig + Configuration object containing heatmap generation parameters. + + Returns + ------- + Heatmaps + The generated detection, classes, and size heatmaps for the clip. + """ + return generate_heatmaps( + ( + transform_fn(sound_event_annotation) + for sound_event_annotation in clip_annotation.sound_events + if filter_fn(sound_event_annotation) + ), + spec=spec, + class_names=class_names, + encoder=encoder_fn, + target_sigma=config.sigma, + position=config.position, + time_scale=config.time_scale, + frequency_scale=config.frequency_scale, + ) def generate_heatmaps( - sound_events: Sequence[data.SoundEventAnnotation], + sound_events: Iterable[data.SoundEventAnnotation], spec: xr.DataArray, class_names: List[str], - encoder: Callable[[Iterable[data.Tag]], Optional[str]], + encoder: SoundEventEncoder, target_sigma: float = 3.0, position: Positions = "bottom-left", time_scale: float = 1000.0, frequency_scale: float = 1 / 859.375, dtype=np.float32, -) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]: +) -> Heatmaps: + """Generate detection, class, and size heatmaps from sound events. + + Processes an iterable of sound event annotations (assumed to be already + filtered and transformed) and creates heatmap representations suitable + for training models like BatDetect2. + + The process involves: + 1. Initializing empty heatmaps based on the spectrogram shape. + 2. Iterating through each sound event. + 3. For each event, finding its reference point and placing a '1.0' + on the detection heatmap at that point. + 4. Calculating the scaled bounding box size and placing it on the size + heatmap at the reference point. + 5. Encoding the event to get its class name and placing a '1.0' on the + corresponding class heatmap slice at the reference point (if + classified). + 6. Applying Gaussian smoothing to detection and class heatmaps. + 7. Normalizing detection and class heatmaps to the range [0, 1]. + + Parameters + ---------- + sound_events : Iterable[data.SoundEventAnnotation] + An iterable of sound event annotations to include in the heatmaps. + These should ideally be the result of prior filtering and tag + transformation steps. + spec : xr.DataArray + The spectrogram array corresponding to the time/frequency range of + the annotations. Used for shape and coordinate information. Must have + 'time' and 'frequency' dimensions. + class_names : List[str] + An ordered list of unique class names. The class heatmap will have + a channel ('category' dimension) for each name in this list. Must not + be empty. + encoder : SoundEventEncoder + A function that takes a SoundEventAnnotation and returns the + corresponding class name (str) or None if it doesn't belong to a + specific class (e.g., it falls into the generic 'Bat' category). + target_sigma : float, default=3.0 + Standard deviation (in pixels/bins) of the Gaussian kernel applied to + smooth the detection and class heatmaps after initial point placement. + position : Positions, default="bottom-left" + The reference point within each annotation's geometry bounding box + used to place the signal on the heatmaps (e.g., "center", + "bottom-left"). See `soundevent.geometry.operations.Positions`. + time_scale : float, default=1000.0 + Scaling factor applied to the time duration (width in seconds) of + annotations when storing them in the size heatmap. The resulting + value's unit depends on this scale (e.g., 1000.0 might convert seconds + to ms). + frequency_scale : float, default=1/859.375 + Scaling factor applied to the frequency bandwidth (height in Hz or kHz) + of annotations when storing them in the size heatmap. The resulting + value's unit depends on this scale and the input unit. (Default + scaling relative to ~860 Hz). + dtype : type, default=np.float32 + The data type for the generated heatmap arrays (e.g., `np.float32`). + + Returns + ------- + Heatmaps + A NamedTuple containing the 'detection', 'classes', and 'size' + xarray DataArrays, ready for use in model training. + + Raises + ------ + ValueError + If the input spectrogram `spec` does not have both 'time' and + 'frequency' dimensions, or if `class_names` is empty. + + Notes + ----- + * This function expects `sound_events` to be already filtered and + transformed. + * It includes error handling to skip individual annotations that cause + issues (e.g., missing geometry, out-of-bounds coordinates, encoder + errors) allowing the rest of the clip to be processed. Warnings or + errors are logged. + * The `time_scale` and `frequency_scale` parameters are crucial and must be + set according to the expectations of the specific BatDetect2 model + architecture being trained. Consult model documentation for required + units/scales. + * Gaussian filtering and normalization are applied only to detection and + class heatmaps, not the size heatmap. + """ shape = dict(zip(spec.dims, spec.shape)) + if len(class_names) == 0: + raise ValueError("No class names provided.") + if "time" not in shape or "frequency" not in shape: raise ValueError( "Spectrogram must have time and frequency dimensions." @@ -69,6 +352,10 @@ def generate_heatmaps( for sound_event_annotation in sound_events: geom = sound_event_annotation.sound_event.geometry if geom is None: + logger.debug( + "Skipping annotation %s: missing geometry.", + sound_event_annotation.uuid, + ) continue # Get the position of the sound event @@ -84,6 +371,12 @@ def generate_heatmaps( ) except KeyError: # Skip the sound event if the position is outside the spectrogram + logger.debug( + "Skipping annotation %s: position outside spectrogram. " + "Pos: %s", + sound_event_annotation.uuid, + (time, frequency), + ) continue # Set the size of the sound event at the position in the size heatmap @@ -106,19 +399,51 @@ def generate_heatmaps( ) # Get the class name of the sound event - class_name = encoder(sound_event_annotation.tags) + try: + class_name = encoder(sound_event_annotation) + except ValueError as e: + logger.warning( + "Skipping annotation %s: Unexpected error while encoding " + "class name %s", + sound_event_annotation.uuid, + e, + ) + continue if class_name is None: # If the label is None skip the sound event continue - class_heatmap = arrays.set_value_at_pos( - class_heatmap, - 1.0, - time=time, - frequency=frequency, - category=class_name, - ) + if class_name not in class_names: + # If the label is not in the class names skip the sound event + logger.warning( + ( + "Skipping annotation %s for class heatmap: " + "class name '%s' not in class names. Class names: %s" + ), + sound_event_annotation.uuid, + class_name, + class_names, + ) + continue + + try: + class_heatmap = arrays.set_value_at_pos( + class_heatmap, + 1.0, + time=time, + frequency=frequency, + category=class_name, + ) + except KeyError: + # Skip the sound event if the position is outside the spectrogram + logger.debug( + "Skipping annotation %s for class heatmap: " + "position outside spectrogram. Pos: %s", + sound_event_annotation.uuid, + (class_name, time, frequency), + ) + continue # Apply gaussian filters detection_heatmap = xr.apply_ufunc( @@ -128,7 +453,7 @@ def generate_heatmaps( ) class_heatmap = class_heatmap.groupby("category").map( - gaussian_filter, # type: ignore + gaussian_filter, args=(target_sigma,), ) @@ -141,10 +466,36 @@ def generate_heatmaps( class_heatmap / class_heatmap.max(dim=["time", "frequency"]) ).fillna(0.0) - return detection_heatmap, class_heatmap, size_heatmap + return Heatmaps( + detection=detection_heatmap, + classes=class_heatmap, + size=size_heatmap, + ) def load_label_config( path: data.PathLike, field: Optional[str] = None ) -> LabelConfig: + """Load the heatmap label generation configuration from a file. + + Parameters + ---------- + path : data.PathLike + Path to the configuration file (e.g., YAML or JSON). + field : str, optional + If the label configuration is nested under a specific key in the + file, specify the key here. Defaults to None. + + Returns + ------- + LabelConfig + The loaded and validated label configuration object. + + Raises + ------ + FileNotFoundError + If the config file path does not exist. + pydantic.ValidationError + If the config file structure does not match the LabelConfig schema. + """ return load_config(path, schema=LabelConfig, field=field)