mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Labels to torch
This commit is contained in:
parent
667b18a54d
commit
c36ef3ecb5
@ -257,10 +257,14 @@ class ResizeSpec(torch.nn.Module):
|
|||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
current_length = spec.shape[-1]
|
current_length = spec.shape[-1]
|
||||||
target_length = int(self.time_factor * current_length)
|
target_length = int(self.time_factor * current_length)
|
||||||
return torch.nn.functional.interpolate(
|
return (
|
||||||
spec.unsqueeze(0).unsqueeze(0),
|
torch.nn.functional.interpolate(
|
||||||
size=(self.height, target_length),
|
spec.unsqueeze(0).unsqueeze(0),
|
||||||
mode="bilinear",
|
size=(self.height, target_length),
|
||||||
|
mode="bilinear",
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.squeeze(0)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -23,15 +23,13 @@ parameter specific to this module is the Gaussian smoothing sigma (`sigma`)
|
|||||||
defined in `LabelConfig`.
|
defined in `LabelConfig`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Iterable
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from scipy.ndimage import gaussian_filter
|
from soundevent import data
|
||||||
from soundevent import arrays, data
|
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
@ -69,6 +67,8 @@ class LabelConfig(BaseConfig):
|
|||||||
|
|
||||||
def build_clip_labeler(
|
def build_clip_labeler(
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
config: Optional[LabelConfig] = None,
|
config: Optional[LabelConfig] = None,
|
||||||
) -> ClipLabeller:
|
) -> ClipLabeller:
|
||||||
"""Construct the final clip labelling function.
|
"""Construct the final clip labelling function.
|
||||||
@ -102,14 +102,18 @@ def build_clip_labeler(
|
|||||||
generate_clip_label,
|
generate_clip_label,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config,
|
config=config,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_clip_label(
|
def generate_clip_label(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
spec: xr.DataArray,
|
spec: torch.Tensor,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: LabelConfig,
|
config: LabelConfig,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
) -> Heatmaps:
|
) -> Heatmaps:
|
||||||
"""Generate training heatmaps for a single annotated clip.
|
"""Generate training heatmaps for a single annotated clip.
|
||||||
|
|
||||||
@ -160,102 +164,53 @@ def generate_clip_label(
|
|||||||
sound_events.append(targets.transform(sound_event_annotation))
|
sound_events.append(targets.transform(sound_event_annotation))
|
||||||
|
|
||||||
return generate_heatmaps(
|
return generate_heatmaps(
|
||||||
sound_events,
|
clip_annotation.model_copy(update=dict(sound_events=sound_events)),
|
||||||
spec=spec,
|
spec=spec,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
target_sigma=config.sigma,
|
target_sigma=config.sigma,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def map_to_pixels(x, size, min_val, max_val) -> int:
|
||||||
|
return int(np.floor(np.interp(x, [min_val, max_val], [0, size])))
|
||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
sound_events: Iterable[data.SoundEventAnnotation],
|
clip_annotation: data.ClipAnnotation,
|
||||||
spec: xr.DataArray,
|
spec: torch.Tensor,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
min_freq: float,
|
||||||
|
max_freq: float,
|
||||||
target_sigma: float = 3.0,
|
target_sigma: float = 3.0,
|
||||||
dtype=np.float32,
|
dtype=torch.float32,
|
||||||
) -> Heatmaps:
|
) -> Heatmaps:
|
||||||
"""Generate detection, class, and size heatmaps from sound events.
|
if not spec.ndim == 2:
|
||||||
|
|
||||||
Creates heatmap representations from an iterable of sound event
|
|
||||||
annotations. This function relies on the provided `targets` object to get
|
|
||||||
the reference position, scaled size, and class encoding for each
|
|
||||||
annotation.
|
|
||||||
|
|
||||||
Process:
|
|
||||||
1. Initializes empty heatmap arrays based on `spec` shape and `targets`
|
|
||||||
metadata.
|
|
||||||
2. Iterates through `sound_events`.
|
|
||||||
3. For each event:
|
|
||||||
a. Gets geometry. Skips if missing.
|
|
||||||
b. Gets reference position using `targets.get_position()`. Skips if out
|
|
||||||
of bounds.
|
|
||||||
c. Places a peak (1.0) on the detection heatmap at the position.
|
|
||||||
d. Gets scaled size using `targets.get_size()` and places it on the
|
|
||||||
size heatmap.
|
|
||||||
e. Encodes class using `targets.encode()` and places a peak (1.0) on
|
|
||||||
the corresponding class heatmap layer if a specific class is
|
|
||||||
returned.
|
|
||||||
4. Applies Gaussian smoothing (using `target_sigma`) to detection and class
|
|
||||||
heatmaps.
|
|
||||||
5. Normalizes detection and class heatmaps to range [0, 1].
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
sound_events : Iterable[data.SoundEventAnnotation]
|
|
||||||
An iterable of sound event annotations to render onto the heatmaps.
|
|
||||||
spec : xr.DataArray
|
|
||||||
The spectrogram array corresponding to the time/frequency range of
|
|
||||||
the annotations. Used for shape and coordinate information. Must have
|
|
||||||
'time' and 'frequency' dimensions/coordinates.
|
|
||||||
targets : TargetProtocol
|
|
||||||
The fully configured target definition object. Used to access class
|
|
||||||
names, dimension names, and the methods `get_position`, `get_size`,
|
|
||||||
`encode`.
|
|
||||||
target_sigma : float, default=3.0
|
|
||||||
Standard deviation (in pixels/bins) of the Gaussian kernel applied to
|
|
||||||
smooth the detection and class heatmaps.
|
|
||||||
dtype : type, default=np.float32
|
|
||||||
The data type for the generated heatmap arrays (e.g., `np.float32`).
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Heatmaps
|
|
||||||
A NamedTuple containing the 'detection', 'classes', and 'size'
|
|
||||||
xarray DataArrays, ready for use in model training.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
ValueError
|
|
||||||
If the input spectrogram `spec` does not have both 'time' and
|
|
||||||
'frequency' dimensions, or if `targets.class_names` is empty.
|
|
||||||
"""
|
|
||||||
shape = dict(zip(spec.dims, spec.shape))
|
|
||||||
|
|
||||||
if "time" not in shape or "frequency" not in shape:
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Spectrogram must have time and frequency dimensions."
|
"Expecting a 2-dimensional tensor of shape (H, W), "
|
||||||
|
"H is the height of the spectrogram "
|
||||||
|
"(frequency bins), and W is the width of the spectrogram "
|
||||||
|
f"(temporal bins). Instead got: {spec.shape}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
height, width = spec.shape
|
||||||
|
num_classes = len(targets.class_names)
|
||||||
|
num_dims = len(targets.dimension_names)
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
detection_heatmap = torch.zeros([height, width], dtype=dtype)
|
||||||
class_heatmap = xr.DataArray(
|
class_heatmap = torch.zeros([num_classes, height, width], dtype=dtype)
|
||||||
data=np.zeros((len(targets.class_names), *spec.shape), dtype=dtype),
|
size_heatmap = torch.zeros([num_dims, height, width], dtype=dtype)
|
||||||
dims=["category", *spec.dims],
|
|
||||||
coords={
|
freqs, times = torch.meshgrid(
|
||||||
"category": [*targets.class_names],
|
torch.arange(height, dtype=dtype),
|
||||||
**spec.coords,
|
torch.arange(width, dtype=dtype),
|
||||||
},
|
indexing="ij",
|
||||||
)
|
|
||||||
size_heatmap = xr.DataArray(
|
|
||||||
data=np.zeros((2, *spec.shape), dtype=dtype),
|
|
||||||
dims=[SIZE_DIMENSION, *spec.dims],
|
|
||||||
coords={
|
|
||||||
SIZE_DIMENSION: targets.dimension_names,
|
|
||||||
**spec.coords,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for sound_event_annotation in sound_events:
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
geom = sound_event_annotation.sound_event.geometry
|
geom = sound_event_annotation.sound_event.geometry
|
||||||
if geom is None:
|
if geom is None:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@ -267,16 +222,15 @@ def generate_heatmaps(
|
|||||||
# Get the position of the sound event
|
# Get the position of the sound event
|
||||||
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
(time, frequency), size = targets.encode_roi(sound_event_annotation)
|
||||||
|
|
||||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
time_index = map_to_pixels(time, width, clip.start_time, clip.end_time)
|
||||||
try:
|
freq_index = map_to_pixels(frequency, height, min_freq, max_freq)
|
||||||
detection_heatmap = arrays.set_value_at_pos(
|
|
||||||
detection_heatmap,
|
if (
|
||||||
1.0,
|
time_index < 0
|
||||||
time=time,
|
or time_index >= width
|
||||||
frequency=frequency,
|
or freq_index < 0
|
||||||
)
|
or freq_index >= height
|
||||||
except KeyError:
|
):
|
||||||
# Skip the sound event if the position is outside the spectrogram
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Skipping annotation %s: position outside spectrogram. "
|
"Skipping annotation %s: position outside spectrogram. "
|
||||||
"Pos: %s",
|
"Pos: %s",
|
||||||
@ -285,12 +239,11 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
size_heatmap = arrays.set_value_at_pos(
|
distance = (times - time_index) ** 2 + (freqs - freq_index) ** 2
|
||||||
size_heatmap,
|
gaussian_blob = torch.exp(-distance / (2 * target_sigma**2))
|
||||||
size,
|
|
||||||
time=time,
|
detection_heatmap = torch.maximum(detection_heatmap, gaussian_blob)
|
||||||
frequency=frequency,
|
size_heatmap[:, freq_index, time_index] = torch.tensor(size[:])
|
||||||
)
|
|
||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
try:
|
try:
|
||||||
@ -308,44 +261,11 @@ def generate_heatmaps(
|
|||||||
# If the label is None skip the sound event
|
# If the label is None skip the sound event
|
||||||
continue
|
continue
|
||||||
|
|
||||||
try:
|
class_index = targets.class_names.index(class_name)
|
||||||
class_heatmap = arrays.set_value_at_pos(
|
class_heatmap[class_index] = torch.maximum(
|
||||||
class_heatmap,
|
class_heatmap[class_index],
|
||||||
1.0,
|
gaussian_blob,
|
||||||
time=time,
|
)
|
||||||
frequency=frequency,
|
|
||||||
category=class_name,
|
|
||||||
)
|
|
||||||
except KeyError:
|
|
||||||
# Skip the sound event if the position is outside the spectrogram
|
|
||||||
logger.debug(
|
|
||||||
"Skipping annotation %s for class heatmap: "
|
|
||||||
"position outside spectrogram. Pos: %s",
|
|
||||||
sound_event_annotation.uuid,
|
|
||||||
(class_name, time, frequency),
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Apply gaussian filters
|
|
||||||
detection_heatmap = xr.apply_ufunc(
|
|
||||||
gaussian_filter,
|
|
||||||
detection_heatmap,
|
|
||||||
target_sigma,
|
|
||||||
)
|
|
||||||
|
|
||||||
class_heatmap = class_heatmap.groupby("category").map(
|
|
||||||
gaussian_filter, # type: ignore
|
|
||||||
args=(target_sigma,),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Normalize heatmaps
|
|
||||||
detection_heatmap = (
|
|
||||||
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
|
|
||||||
).fillna(0.0)
|
|
||||||
|
|
||||||
class_heatmap = (
|
|
||||||
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
|
||||||
).fillna(0.0)
|
|
||||||
|
|
||||||
return Heatmaps(
|
return Heatmaps(
|
||||||
detection=detection_heatmap,
|
detection=detection_heatmap,
|
||||||
|
|||||||
@ -37,12 +37,12 @@ class Heatmaps(NamedTuple):
|
|||||||
scaled dimensions placed at the event reference points.
|
scaled dimensions placed at the event reference points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
detection: xr.DataArray
|
detection: torch.Tensor
|
||||||
classes: xr.DataArray
|
classes: torch.Tensor
|
||||||
size: xr.DataArray
|
size: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
ClipLabeller = Callable[[data.ClipAnnotation, xr.DataArray], Heatmaps]
|
ClipLabeller = Callable[[data.ClipAnnotation, torch.Tensor], Heatmaps]
|
||||||
"""Type alias for the final clip labelling function.
|
"""Type alias for the final clip labelling function.
|
||||||
|
|
||||||
This function takes the complete annotations for a clip and the corresponding
|
This function takes the complete annotations for a clip and the corresponding
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user