Labels to torch

This commit is contained in:
mbsantiago 2025-08-25 12:43:34 +01:00
parent 667b18a54d
commit c36ef3ecb5
3 changed files with 71 additions and 147 deletions

View File

@ -257,10 +257,14 @@ class ResizeSpec(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor: def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1] current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length) target_length = int(self.time_factor * current_length)
return torch.nn.functional.interpolate( return (
spec.unsqueeze(0).unsqueeze(0), torch.nn.functional.interpolate(
size=(self.height, target_length), spec.unsqueeze(0).unsqueeze(0),
mode="bilinear", size=(self.height, target_length),
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
) )

View File

@ -23,15 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`. defined in `LabelConfig`.
""" """
from collections.abc import Iterable
from functools import partial from functools import partial
from typing import Optional from typing import Optional
import numpy as np import numpy as np
import xarray as xr import torch
from loguru import logger from loguru import logger
from scipy.ndimage import gaussian_filter from soundevent import data
from soundevent import arrays, data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.typing import ( from batdetect2.typing import (
@ -69,6 +67,8 @@ class LabelConfig(BaseConfig):
def build_clip_labeler( def build_clip_labeler(
targets: TargetProtocol, targets: TargetProtocol,
min_freq: float,
max_freq: float,
config: Optional[LabelConfig] = None, config: Optional[LabelConfig] = None,
) -> ClipLabeller: ) -> ClipLabeller:
"""Construct the final clip labelling function. """Construct the final clip labelling function.
@ -102,14 +102,18 @@ def build_clip_labeler(
generate_clip_label, generate_clip_label,
targets=targets, targets=targets,
config=config, config=config,
min_freq=min_freq,
max_freq=max_freq,
) )
def generate_clip_label( def generate_clip_label(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
spec: xr.DataArray, spec: torch.Tensor,
targets: TargetProtocol, targets: TargetProtocol,
config: LabelConfig, config: LabelConfig,
min_freq: float,
max_freq: float,
) -> Heatmaps: ) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip. """Generate training heatmaps for a single annotated clip.
@ -160,102 +164,53 @@ def generate_clip_label(
sound_events.append(targets.transform(sound_event_annotation)) sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps( return generate_heatmaps(
sound_events, clip_annotation.model_copy(update=dict(sound_events=sound_events)),
spec=spec, spec=spec,
targets=targets, targets=targets,
target_sigma=config.sigma, target_sigma=config.sigma,
min_freq=min_freq,
max_freq=max_freq,
) )
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.floor(np.interp(x, [min_val, max_val], [0, size])))
def generate_heatmaps( def generate_heatmaps(
sound_events: Iterable[data.SoundEventAnnotation], clip_annotation: data.ClipAnnotation,
spec: xr.DataArray, spec: torch.Tensor,
targets: TargetProtocol, targets: TargetProtocol,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0, target_sigma: float = 3.0,
dtype=np.float32, dtype=torch.float32,
) -> Heatmaps: ) -> Heatmaps:
"""Generate detection, class, and size heatmaps from sound events. if not spec.ndim == 2:
Creates heatmap representations from an iterable of sound event
annotations. This function relies on the provided `targets` object to get
the reference position, scaled size, and class encoding for each
annotation.
Process:
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
metadata.
2. Iterates through `sound_events`.
3. For each event:
a. Gets geometry. Skips if missing.
b. Gets reference position using `targets.get_position()`. Skips if out
of bounds.
c. Places a peak (1.0) on the detection heatmap at the position.
d. Gets scaled size using `targets.get_size()` and places it on the
size heatmap.
e. Encodes class using `targets.encode()` and places a peak (1.0) on
the corresponding class heatmap layer if a specific class is
returned.
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
heatmaps.
5. Normalizes detection and class heatmaps to range [0, 1].
Parameters
----------
sound_events : Iterable[data.SoundEventAnnotation]
An iterable of sound event annotations to render onto the heatmaps.
spec : xr.DataArray
The spectrogram array corresponding to the time/frequency range of
the annotations. Used for shape and coordinate information. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object. Used to access class
names, dimension names, and the methods `get_position`, `get_size`,
`encode`.
target_sigma : float, default=3.0
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
smooth the detection and class heatmaps.
dtype : type, default=np.float32
The data type for the generated heatmap arrays (e.g., `np.float32`).
Returns
-------
Heatmaps
A NamedTuple containing the 'detection', 'classes', and 'size'
xarray DataArrays, ready for use in model training.
Raises
------
ValueError
If the input spectrogram `spec` does not have both 'time' and
'frequency' dimensions, or if `targets.class_names` is empty.
"""
shape = dict(zip(spec.dims, spec.shape))
if "time" not in shape or "frequency" not in shape:
raise ValueError( raise ValueError(
"Spectrogram must have time and frequency dimensions." "Expecting a 2-dimensional tensor of shape (H, W), "
"H is the height of the spectrogram "
"(frequency bins), and W is the width of the spectrogram "
f"(temporal bins). Instead got: {spec.shape}"
) )
height, width = spec.shape
num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names)
clip = clip_annotation.clip
# Initialize heatmaps # Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype) detection_heatmap = torch.zeros([height, width], dtype=dtype)
class_heatmap = xr.DataArray( class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype), size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
dims=["category", *spec.dims],
coords={ freqs, times = torch.meshgrid(
"category": [*targets.class_names], torch.arange(height, dtype=dtype),
**spec.coords, torch.arange(width, dtype=dtype),
}, indexing="ij",
)
size_heatmap = xr.DataArray(
data=np.zeros((2, *spec.shape), dtype=dtype),
dims=[SIZE_DIMENSION, *spec.dims],
coords={
SIZE_DIMENSION: targets.dimension_names,
**spec.coords,
},
) )
for sound_event_annotation in sound_events: for sound_event_annotation in clip_annotation.sound_events:
geom = sound_event_annotation.sound_event.geometry geom = sound_event_annotation.sound_event.geometry
if geom is None: if geom is None:
logger.debug( logger.debug(
@ -267,16 +222,15 @@ def generate_heatmaps(
# Get the position of the sound event # Get the position of the sound event
(time, frequency), size = targets.encode_roi(sound_event_annotation) (time, frequency), size = targets.encode_roi(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
try: freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
detection_heatmap = arrays.set_value_at_pos(
detection_heatmap, if (
1.0, time_index < 0
time=time, or time_index >= width
frequency=frequency, or freq_index < 0
) or freq_index >= height
except KeyError: ):
# Skip the sound event if the position is outside the spectrogram
logger.debug( logger.debug(
"Skipping annotation %s: position outside spectrogram. " "Skipping annotation %s: position outside spectrogram. "
"Pos: %s", "Pos: %s",
@ -285,12 +239,11 @@ def generate_heatmaps(
) )
continue continue
size_heatmap = arrays.set_value_at_pos( distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
size_heatmap, gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
size,
time=time, detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob)
frequency=frequency, size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
)
# Get the class name of the sound event # Get the class name of the sound event
try: try:
@ -308,44 +261,11 @@ def generate_heatmaps(
# If the label is None skip the sound event # If the label is None skip the sound event
continue continue
try: class_index = targets.class_names.index(class_name)
class_heatmap = arrays.set_value_at_pos( class_heatmap[class_index] = torch.maximum(
class_heatmap, class_heatmap[class_index],
1.0, gaussian_blob,
time=time, )
frequency=frequency,
category=class_name,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s for class heatmap: "
"position outside spectrogram. Pos: %s",
sound_event_annotation.uuid,
(class_name, time, frequency),
)
continue
# Apply gaussian filters
detection_heatmap = xr.apply_ufunc(
gaussian_filter,
detection_heatmap,
target_sigma,
)
class_heatmap = class_heatmap.groupby("category").map(
gaussian_filter, # type: ignore
args=(target_sigma,),
)
# Normalize heatmaps
detection_heatmap = (
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
class_heatmap = (
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
return Heatmaps( return Heatmaps(
detection=detection_heatmap, detection=detection_heatmap,

View File

@ -37,12 +37,12 @@ class Heatmaps(NamedTuple):
scaled dimensions placed at the event reference points. scaled dimensions placed at the event reference points.
""" """
detection: xr.DataArray detection: torch.Tensor
classes: xr.DataArray classes: torch.Tensor
size: xr.DataArray size: torch.Tensor
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps] ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
"""Type alias for the final clip labelling function. """Type alias for the final clip labelling function.
This function takes the complete annotations for a clip and the corresponding This function takes the complete annotations for a clip and the corresponding