From c36ef3ecb5b4463c6107b7fbb5e1a041242e2ae8 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 25 Aug 2025 12:43:34 +0100 Subject: [PATCH] Labels to torch --- src/batdetect2/preprocess/spectrogram.py | 12 +- src/batdetect2/train/labels.py | 198 +++++++---------------- src/batdetect2/typing/train.py | 8 +- 3 files changed, 71 insertions(+), 147 deletions(-) diff --git a/src/batdetect2/preprocess/spectrogram.py b/src/batdetect2/preprocess/spectrogram.py index e416501..2fd0ad0 100644 --- a/src/batdetect2/preprocess/spectrogram.py +++ b/src/batdetect2/preprocess/spectrogram.py @@ -257,10 +257,14 @@ class ResizeSpec(torch.nn.Module): def forward(self, spec: torch.Tensor) -> torch.Tensor: current_length = spec.shape[-1] target_length = int(self.time_factor * current_length) - return torch.nn.functional.interpolate( - spec.unsqueeze(0).unsqueeze(0), - size=(self.height, target_length), - mode="bilinear", + return ( + torch.nn.functional.interpolate( + spec.unsqueeze(0).unsqueeze(0), + size=(self.height, target_length), + mode="bilinear", + ) + .squeeze(0) + .squeeze(0) ) diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 9865ee7..9a668db 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -23,15 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`) defined in `LabelConfig`. """ -from collections.abc import Iterable from functools import partial from typing import Optional import numpy as np -import xarray as xr +import torch from loguru import logger -from scipy.ndimage import gaussian_filter -from soundevent import arrays, data +from soundevent import data from batdetect2.configs import BaseConfig, load_config from batdetect2.typing import ( @@ -69,6 +67,8 @@ class LabelConfig(BaseConfig): def build_clip_labeler( targets: TargetProtocol, + min_freq: float, + max_freq: float, config: Optional[LabelConfig] = None, ) -> ClipLabeller: """Construct the final clip labelling function. @@ -102,14 +102,18 @@ def build_clip_labeler( generate_clip_label, targets=targets, config=config, + min_freq=min_freq, + max_freq=max_freq, ) def generate_clip_label( clip_annotation: data.ClipAnnotation, - spec: xr.DataArray, + spec: torch.Tensor, targets: TargetProtocol, config: LabelConfig, + min_freq: float, + max_freq: float, ) -> Heatmaps: """Generate training heatmaps for a single annotated clip. @@ -160,102 +164,53 @@ def generate_clip_label( sound_events.append(targets.transform(sound_event_annotation)) return generate_heatmaps( - sound_events, + clip_annotation.model_copy(update=dict(sound_events=sound_events)), spec=spec, targets=targets, target_sigma=config.sigma, + min_freq=min_freq, + max_freq=max_freq, ) +def map_to_pixels(x, size, min_val, max_val) -> int: + return int(np.floor(np.interp(x, [min_val, max_val], [0, size]))) + + def generate_heatmaps( - sound_events: Iterable[data.SoundEventAnnotation], - spec: xr.DataArray, + clip_annotation: data.ClipAnnotation, + spec: torch.Tensor, targets: TargetProtocol, + min_freq: float, + max_freq: float, target_sigma: float = 3.0, - dtype=np.float32, + dtype=torch.float32, ) -> Heatmaps: - """Generate detection, class, and size heatmaps from sound events. - - Creates heatmap representations from an iterable of sound event - annotations. This function relies on the provided `targets` object to get - the reference position, scaled size, and class encoding for each - annotation. - - Process: - 1. Initializes empty heatmap arrays based on `spec` shape and `targets` - metadata. - 2. Iterates through `sound_events`. - 3. For each event: - a. Gets geometry. Skips if missing. - b. Gets reference position using `targets.get_position()`. Skips if out - of bounds. - c. Places a peak (1.0) on the detection heatmap at the position. - d. Gets scaled size using `targets.get_size()` and places it on the - size heatmap. - e. Encodes class using `targets.encode()` and places a peak (1.0) on - the corresponding class heatmap layer if a specific class is - returned. - 4. Applies Gaussian smoothing (using `target_sigma`) to detection and class - heatmaps. - 5. Normalizes detection and class heatmaps to range [0, 1]. - - Parameters - ---------- - sound_events : Iterable[data.SoundEventAnnotation] - An iterable of sound event annotations to render onto the heatmaps. - spec : xr.DataArray - The spectrogram array corresponding to the time/frequency range of - the annotations. Used for shape and coordinate information. Must have - 'time' and 'frequency' dimensions/coordinates. - targets : TargetProtocol - The fully configured target definition object. Used to access class - names, dimension names, and the methods `get_position`, `get_size`, - `encode`. - target_sigma : float, default=3.0 - Standard deviation (in pixels/bins) of the Gaussian kernel applied to - smooth the detection and class heatmaps. - dtype : type, default=np.float32 - The data type for the generated heatmap arrays (e.g., `np.float32`). - - Returns - ------- - Heatmaps - A NamedTuple containing the 'detection', 'classes', and 'size' - xarray DataArrays, ready for use in model training. - - Raises - ------ - ValueError - If the input spectrogram `spec` does not have both 'time' and - 'frequency' dimensions, or if `targets.class_names` is empty. - """ - shape = dict(zip(spec.dims, spec.shape)) - - if "time" not in shape or "frequency" not in shape: + if not spec.ndim == 2: raise ValueError( - "Spectrogram must have time and frequency dimensions." + "Expecting a 2-dimensional tensor of shape (H, W), " + "H is the height of the spectrogram " + "(frequency bins), and W is the width of the spectrogram " + f"(temporal bins). Instead got: {spec.shape}" ) + height, width = spec.shape + num_classes = len(targets.class_names) + num_dims = len(targets.dimension_names) + clip = clip_annotation.clip + # Initialize heatmaps - detection_heatmap = xr.zeros_like(spec, dtype=dtype) - class_heatmap = xr.DataArray( - data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype), - dims=["category", *spec.dims], - coords={ - "category": [*targets.class_names], - **spec.coords, - }, - ) - size_heatmap = xr.DataArray( - data=np.zeros((2, *spec.shape), dtype=dtype), - dims=[SIZE_DIMENSION, *spec.dims], - coords={ - SIZE_DIMENSION: targets.dimension_names, - **spec.coords, - }, + detection_heatmap = torch.zeros([height, width], dtype=dtype) + class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype) + size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype) + + freqs, times = torch.meshgrid( + torch.arange(height, dtype=dtype), + torch.arange(width, dtype=dtype), + indexing="ij", ) - for sound_event_annotation in sound_events: + for sound_event_annotation in clip_annotation.sound_events: geom = sound_event_annotation.sound_event.geometry if geom is None: logger.debug( @@ -267,16 +222,15 @@ def generate_heatmaps( # Get the position of the sound event (time, frequency), size = targets.encode_roi(sound_event_annotation) - # Set 1.0 at the position of the sound event in the detection heatmap - try: - detection_heatmap = arrays.set_value_at_pos( - detection_heatmap, - 1.0, - time=time, - frequency=frequency, - ) - except KeyError: - # Skip the sound event if the position is outside the spectrogram + time_index = map_to_pixels(time, width, clip.start_time, clip.end_time) + freq_index = map_to_pixels(frequency, height, min_freq, max_freq) + + if ( + time_index < 0 + or time_index >= width + or freq_index < 0 + or freq_index >= height + ): logger.debug( "Skipping annotation %s: position outside spectrogram. " "Pos: %s", @@ -285,12 +239,11 @@ def generate_heatmaps( ) continue - size_heatmap = arrays.set_value_at_pos( - size_heatmap, - size, - time=time, - frequency=frequency, - ) + distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2 + gaussian_blob = torch.exp(-distance / (2 * target_sigma**2)) + + detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob) + size_heatmap[:, freq_index, time_index] = torch.tensor(size[:]) # Get the class name of the sound event try: @@ -308,44 +261,11 @@ def generate_heatmaps( # If the label is None skip the sound event continue - try: - class_heatmap = arrays.set_value_at_pos( - class_heatmap, - 1.0, - time=time, - frequency=frequency, - category=class_name, - ) - except KeyError: - # Skip the sound event if the position is outside the spectrogram - logger.debug( - "Skipping annotation %s for class heatmap: " - "position outside spectrogram. Pos: %s", - sound_event_annotation.uuid, - (class_name, time, frequency), - ) - continue - - # Apply gaussian filters - detection_heatmap = xr.apply_ufunc( - gaussian_filter, - detection_heatmap, - target_sigma, - ) - - class_heatmap = class_heatmap.groupby("category").map( - gaussian_filter, # type: ignore - args=(target_sigma,), - ) - - # Normalize heatmaps - detection_heatmap = ( - detection_heatmap / detection_heatmap.max(dim=["time", "frequency"]) - ).fillna(0.0) - - class_heatmap = ( - class_heatmap / class_heatmap.max(dim=["time", "frequency"]) - ).fillna(0.0) + class_index = targets.class_names.index(class_name) + class_heatmap[class_index] = torch.maximum( + class_heatmap[class_index], + gaussian_blob, + ) return Heatmaps( detection=detection_heatmap, diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py index 950624a..7d6079c 100644 --- a/src/batdetect2/typing/train.py +++ b/src/batdetect2/typing/train.py @@ -37,12 +37,12 @@ class Heatmaps(NamedTuple): scaled dimensions placed at the event reference points. """ - detection: xr.DataArray - classes: xr.DataArray - size: xr.DataArray + detection: torch.Tensor + classes: torch.Tensor + size: torch.Tensor -ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps] +ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps] """Type alias for the final clip labelling function. This function takes the complete annotations for a clip and the corresponding