Labels to torch

This commit is contained in:
mbsantiago 2025-08-25 12:43:34 +01:00
parent 667b18a54d
commit c36ef3ecb5
3 changed files with 71 additions and 147 deletions

View File

@ -257,10 +257,14 @@ class ResizeSpec(torch.nn.Module):
def forward(self, spec: torch.Tensor) -> torch.Tensor:
current_length = spec.shape[-1]
target_length = int(self.time_factor * current_length)
return torch.nn.functional.interpolate(
spec.unsqueeze(0).unsqueeze(0),
size=(self.height, target_length),
mode="bilinear",
return (
torch.nn.functional.interpolate(
spec.unsqueeze(0).unsqueeze(0),
size=(self.height, target_length),
mode="bilinear",
)
.squeeze(0)
.squeeze(0)
)

View File

@ -23,15 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
defined in `LabelConfig`.
"""
from collections.abc import Iterable
from functools import partial
from typing import Optional
import numpy as np
import xarray as xr
import torch
from loguru import logger
from scipy.ndimage import gaussian_filter
from soundevent import arrays, data
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.typing import (
@ -69,6 +67,8 @@ class LabelConfig(BaseConfig):
def build_clip_labeler(
targets: TargetProtocol,
min_freq: float,
max_freq: float,
config: Optional[LabelConfig] = None,
) -> ClipLabeller:
"""Construct the final clip labelling function.
@ -102,14 +102,18 @@ def build_clip_labeler(
generate_clip_label,
targets=targets,
config=config,
min_freq=min_freq,
max_freq=max_freq,
)
def generate_clip_label(
clip_annotation: data.ClipAnnotation,
spec: xr.DataArray,
spec: torch.Tensor,
targets: TargetProtocol,
config: LabelConfig,
min_freq: float,
max_freq: float,
) -> Heatmaps:
"""Generate training heatmaps for a single annotated clip.
@ -160,102 +164,53 @@ def generate_clip_label(
sound_events.append(targets.transform(sound_event_annotation))
return generate_heatmaps(
sound_events,
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
spec=spec,
targets=targets,
target_sigma=config.sigma,
min_freq=min_freq,
max_freq=max_freq,
)
def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.floor(np.interp(x, [min_val, max_val], [0, size])))
def generate_heatmaps(
sound_events: Iterable[data.SoundEventAnnotation],
spec: xr.DataArray,
clip_annotation: data.ClipAnnotation,
spec: torch.Tensor,
targets: TargetProtocol,
min_freq: float,
max_freq: float,
target_sigma: float = 3.0,
dtype=np.float32,
dtype=torch.float32,
) -> Heatmaps:
"""Generate detection, class, and size heatmaps from sound events.
Creates heatmap representations from an iterable of sound event
annotations. This function relies on the provided `targets` object to get
the reference position, scaled size, and class encoding for each
annotation.
Process:
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
metadata.
2. Iterates through `sound_events`.
3. For each event:
a. Gets geometry. Skips if missing.
b. Gets reference position using `targets.get_position()`. Skips if out
of bounds.
c. Places a peak (1.0) on the detection heatmap at the position.
d. Gets scaled size using `targets.get_size()` and places it on the
size heatmap.
e. Encodes class using `targets.encode()` and places a peak (1.0) on
the corresponding class heatmap layer if a specific class is
returned.
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
heatmaps.
5. Normalizes detection and class heatmaps to range [0, 1].
Parameters
----------
sound_events : Iterable[data.SoundEventAnnotation]
An iterable of sound event annotations to render onto the heatmaps.
spec : xr.DataArray
The spectrogram array corresponding to the time/frequency range of
the annotations. Used for shape and coordinate information. Must have
'time' and 'frequency' dimensions/coordinates.
targets : TargetProtocol
The fully configured target definition object. Used to access class
names, dimension names, and the methods `get_position`, `get_size`,
`encode`.
target_sigma : float, default=3.0
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
smooth the detection and class heatmaps.
dtype : type, default=np.float32
The data type for the generated heatmap arrays (e.g., `np.float32`).
Returns
-------
Heatmaps
A NamedTuple containing the 'detection', 'classes', and 'size'
xarray DataArrays, ready for use in model training.
Raises
------
ValueError
If the input spectrogram `spec` does not have both 'time' and
'frequency' dimensions, or if `targets.class_names` is empty.
"""
shape = dict(zip(spec.dims, spec.shape))
if "time" not in shape or "frequency" not in shape:
if not spec.ndim == 2:
raise ValueError(
"Spectrogram must have time and frequency dimensions."
"Expecting a 2-dimensional tensor of shape (H, W), "
"H is the height of the spectrogram "
"(frequency bins), and W is the width of the spectrogram "
f"(temporal bins). Instead got: {spec.shape}"
)
height, width = spec.shape
num_classes = len(targets.class_names)
num_dims = len(targets.dimension_names)
clip = clip_annotation.clip
# Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray(
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
dims=["category", *spec.dims],
coords={
"category": [*targets.class_names],
**spec.coords,
},
)
size_heatmap = xr.DataArray(
data=np.zeros((2, *spec.shape), dtype=dtype),
dims=[SIZE_DIMENSION, *spec.dims],
coords={
SIZE_DIMENSION: targets.dimension_names,
**spec.coords,
},
detection_heatmap = torch.zeros([height, width], dtype=dtype)
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
freqs, times = torch.meshgrid(
torch.arange(height, dtype=dtype),
torch.arange(width, dtype=dtype),
indexing="ij",
)
for sound_event_annotation in sound_events:
for sound_event_annotation in clip_annotation.sound_events:
geom = sound_event_annotation.sound_event.geometry
if geom is None:
logger.debug(
@ -267,16 +222,15 @@ def generate_heatmaps(
# Get the position of the sound event
(time, frequency), size = targets.encode_roi(sound_event_annotation)
# Set 1.0 at the position of the sound event in the detection heatmap
try:
detection_heatmap = arrays.set_value_at_pos(
detection_heatmap,
1.0,
time=time,
frequency=frequency,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
if (
time_index < 0
or time_index >= width
or freq_index < 0
or freq_index >= height
):
logger.debug(
"Skipping annotation %s: position outside spectrogram. "
"Pos: %s",
@ -285,12 +239,11 @@ def generate_heatmaps(
)
continue
size_heatmap = arrays.set_value_at_pos(
size_heatmap,
size,
time=time,
frequency=frequency,
)
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob)
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
# Get the class name of the sound event
try:
@ -308,44 +261,11 @@ def generate_heatmaps(
# If the label is None skip the sound event
continue
try:
class_heatmap = arrays.set_value_at_pos(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=class_name,
)
except KeyError:
# Skip the sound event if the position is outside the spectrogram
logger.debug(
"Skipping annotation %s for class heatmap: "
"position outside spectrogram. Pos: %s",
sound_event_annotation.uuid,
(class_name, time, frequency),
)
continue
# Apply gaussian filters
detection_heatmap = xr.apply_ufunc(
gaussian_filter,
detection_heatmap,
target_sigma,
)
class_heatmap = class_heatmap.groupby("category").map(
gaussian_filter, # type: ignore
args=(target_sigma,),
)
# Normalize heatmaps
detection_heatmap = (
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
class_heatmap = (
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
class_index = targets.class_names.index(class_name)
class_heatmap[class_index] = torch.maximum(
class_heatmap[class_index],
gaussian_blob,
)
return Heatmaps(
detection=detection_heatmap,

View File

@ -37,12 +37,12 @@ class Heatmaps(NamedTuple):
scaled dimensions placed at the event reference points.
"""
detection: xr.DataArray
classes: xr.DataArray
size: xr.DataArray
detection: torch.Tensor
classes: torch.Tensor
size: torch.Tensor
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
"""Type alias for the final clip labelling function.
This function takes the complete annotations for a clip and the corresponding