mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Labels to torch
This commit is contained in:
parent
667b18a54d
commit
c36ef3ecb5
@ -257,10 +257,14 @@ class ResizeSpec(torch.nn.Module):
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
current_length = spec.shape[-1]
|
||||
target_length = int(self.time_factor * current_length)
|
||||
return torch.nn.functional.interpolate(
|
||||
spec.unsqueeze(0).unsqueeze(0),
|
||||
size=(self.height, target_length),
|
||||
mode="bilinear",
|
||||
return (
|
||||
torch.nn.functional.interpolate(
|
||||
spec.unsqueeze(0).unsqueeze(0),
|
||||
size=(self.height, target_length),
|
||||
mode="bilinear",
|
||||
)
|
||||
.squeeze(0)
|
||||
.squeeze(0)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -23,15 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
||||
defined in `LabelConfig`.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
import torch
|
||||
from loguru import logger
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import arrays, data
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.typing import (
|
||||
@ -69,6 +67,8 @@ class LabelConfig(BaseConfig):
|
||||
|
||||
def build_clip_labeler(
|
||||
targets: TargetProtocol,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
config: Optional[LabelConfig] = None,
|
||||
) -> ClipLabeller:
|
||||
"""Construct the final clip labelling function.
|
||||
@ -102,14 +102,18 @@ def build_clip_labeler(
|
||||
generate_clip_label,
|
||||
targets=targets,
|
||||
config=config,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
def generate_clip_label(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: xr.DataArray,
|
||||
spec: torch.Tensor,
|
||||
targets: TargetProtocol,
|
||||
config: LabelConfig,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> Heatmaps:
|
||||
"""Generate training heatmaps for a single annotated clip.
|
||||
|
||||
@ -160,102 +164,53 @@ def generate_clip_label(
|
||||
sound_events.append(targets.transform(sound_event_annotation))
|
||||
|
||||
return generate_heatmaps(
|
||||
sound_events,
|
||||
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
|
||||
spec=spec,
|
||||
targets=targets,
|
||||
target_sigma=config.sigma,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
)
|
||||
|
||||
|
||||
def map_to_pixels(x, size, min_val, max_val) -> int:
|
||||
return int(np.floor(np.interp(x, [min_val, max_val], [0, size])))
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
sound_events: Iterable[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: torch.Tensor,
|
||||
targets: TargetProtocol,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
target_sigma: float = 3.0,
|
||||
dtype=np.float32,
|
||||
dtype=torch.float32,
|
||||
) -> Heatmaps:
|
||||
"""Generate detection, class, and size heatmaps from sound events.
|
||||
|
||||
Creates heatmap representations from an iterable of sound event
|
||||
annotations. This function relies on the provided `targets` object to get
|
||||
the reference position, scaled size, and class encoding for each
|
||||
annotation.
|
||||
|
||||
Process:
|
||||
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
|
||||
metadata.
|
||||
2. Iterates through `sound_events`.
|
||||
3. For each event:
|
||||
a. Gets geometry. Skips if missing.
|
||||
b. Gets reference position using `targets.get_position()`. Skips if out
|
||||
of bounds.
|
||||
c. Places a peak (1.0) on the detection heatmap at the position.
|
||||
d. Gets scaled size using `targets.get_size()` and places it on the
|
||||
size heatmap.
|
||||
e. Encodes class using `targets.encode()` and places a peak (1.0) on
|
||||
the corresponding class heatmap layer if a specific class is
|
||||
returned.
|
||||
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
|
||||
heatmaps.
|
||||
5. Normalizes detection and class heatmaps to range [0, 1].
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sound_events : Iterable[data.SoundEventAnnotation]
|
||||
An iterable of sound event annotations to render onto the heatmaps.
|
||||
spec : xr.DataArray
|
||||
The spectrogram array corresponding to the time/frequency range of
|
||||
the annotations. Used for shape and coordinate information. Must have
|
||||
'time' and 'frequency' dimensions/coordinates.
|
||||
targets : TargetProtocol
|
||||
The fully configured target definition object. Used to access class
|
||||
names, dimension names, and the methods `get_position`, `get_size`,
|
||||
`encode`.
|
||||
target_sigma : float, default=3.0
|
||||
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
||||
smooth the detection and class heatmaps.
|
||||
dtype : type, default=np.float32
|
||||
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
||||
|
||||
Returns
|
||||
-------
|
||||
Heatmaps
|
||||
A NamedTuple containing the 'detection', 'classes', and 'size'
|
||||
xarray DataArrays, ready for use in model training.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the input spectrogram `spec` does not have both 'time' and
|
||||
'frequency' dimensions, or if `targets.class_names` is empty.
|
||||
"""
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
|
||||
if "time" not in shape or "frequency" not in shape:
|
||||
if not spec.ndim == 2:
|
||||
raise ValueError(
|
||||
"Spectrogram must have time and frequency dimensions."
|
||||
"Expecting a 2-dimensional tensor of shape (H, W), "
|
||||
"H is the height of the spectrogram "
|
||||
"(frequency bins), and W is the width of the spectrogram "
|
||||
f"(temporal bins). Instead got: {spec.shape}"
|
||||
)
|
||||
|
||||
height, width = spec.shape
|
||||
num_classes = len(targets.class_names)
|
||||
num_dims = len(targets.dimension_names)
|
||||
clip = clip_annotation.clip
|
||||
|
||||
# Initialize heatmaps
|
||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||
class_heatmap = xr.DataArray(
|
||||
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
|
||||
dims=["category", *spec.dims],
|
||||
coords={
|
||||
"category": [*targets.class_names],
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
size_heatmap = xr.DataArray(
|
||||
data=np.zeros((2, *spec.shape), dtype=dtype),
|
||||
dims=[SIZE_DIMENSION, *spec.dims],
|
||||
coords={
|
||||
SIZE_DIMENSION: targets.dimension_names,
|
||||
**spec.coords,
|
||||
},
|
||||
detection_heatmap = torch.zeros([height, width], dtype=dtype)
|
||||
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
|
||||
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
|
||||
|
||||
freqs, times = torch.meshgrid(
|
||||
torch.arange(height, dtype=dtype),
|
||||
torch.arange(width, dtype=dtype),
|
||||
indexing="ij",
|
||||
)
|
||||
|
||||
for sound_event_annotation in sound_events:
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
geom = sound_event_annotation.sound_event.geometry
|
||||
if geom is None:
|
||||
logger.debug(
|
||||
@ -267,16 +222,15 @@ def generate_heatmaps(
|
||||
# Get the position of the sound event
|
||||
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
try:
|
||||
detection_heatmap = arrays.set_value_at_pos(
|
||||
detection_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
|
||||
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
|
||||
|
||||
if (
|
||||
time_index < 0
|
||||
or time_index >= width
|
||||
or freq_index < 0
|
||||
or freq_index >= height
|
||||
):
|
||||
logger.debug(
|
||||
"Skipping annotation %s: position outside spectrogram. "
|
||||
"Pos: %s",
|
||||
@ -285,12 +239,11 @@ def generate_heatmaps(
|
||||
)
|
||||
continue
|
||||
|
||||
size_heatmap = arrays.set_value_at_pos(
|
||||
size_heatmap,
|
||||
size,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
|
||||
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
|
||||
|
||||
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob)
|
||||
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
|
||||
|
||||
# Get the class name of the sound event
|
||||
try:
|
||||
@ -308,44 +261,11 @@ def generate_heatmaps(
|
||||
# If the label is None skip the sound event
|
||||
continue
|
||||
|
||||
try:
|
||||
class_heatmap = arrays.set_value_at_pos(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
category=class_name,
|
||||
)
|
||||
except KeyError:
|
||||
# Skip the sound event if the position is outside the spectrogram
|
||||
logger.debug(
|
||||
"Skipping annotation %s for class heatmap: "
|
||||
"position outside spectrogram. Pos: %s",
|
||||
sound_event_annotation.uuid,
|
||||
(class_name, time, frequency),
|
||||
)
|
||||
continue
|
||||
|
||||
# Apply gaussian filters
|
||||
detection_heatmap = xr.apply_ufunc(
|
||||
gaussian_filter,
|
||||
detection_heatmap,
|
||||
target_sigma,
|
||||
)
|
||||
|
||||
class_heatmap = class_heatmap.groupby("category").map(
|
||||
gaussian_filter, # type: ignore
|
||||
args=(target_sigma,),
|
||||
)
|
||||
|
||||
# Normalize heatmaps
|
||||
detection_heatmap = (
|
||||
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
|
||||
).fillna(0.0)
|
||||
|
||||
class_heatmap = (
|
||||
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
||||
).fillna(0.0)
|
||||
class_index = targets.class_names.index(class_name)
|
||||
class_heatmap[class_index] = torch.maximum(
|
||||
class_heatmap[class_index],
|
||||
gaussian_blob,
|
||||
)
|
||||
|
||||
return Heatmaps(
|
||||
detection=detection_heatmap,
|
||||
|
||||
@ -37,12 +37,12 @@ class Heatmaps(NamedTuple):
|
||||
scaled dimensions placed at the event reference points.
|
||||
"""
|
||||
|
||||
detection: xr.DataArray
|
||||
classes: xr.DataArray
|
||||
size: xr.DataArray
|
||||
detection: torch.Tensor
|
||||
classes: torch.Tensor
|
||||
size: torch.Tensor
|
||||
|
||||
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
||||
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||
"""Type alias for the final clip labelling function.
|
||||
|
||||
This function takes the complete annotations for a clip and the corresponding
|
||||
|
||||
Loading…
Reference in New Issue
Block a user