mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
151 lines
4.2 KiB
Python
151 lines
4.2 KiB
Python
from collections.abc import Iterable
|
|
from typing import Callable, List, Optional, Sequence, Tuple
|
|
|
|
import numpy as np
|
|
import xarray as xr
|
|
from pydantic import Field
|
|
from scipy.ndimage import gaussian_filter
|
|
from soundevent import arrays, data, geometry
|
|
from soundevent.geometry.operations import Positions
|
|
|
|
from batdetect2.configs import BaseConfig, load_config
|
|
|
|
__all__ = [
|
|
"HeatmapsConfig",
|
|
"LabelConfig",
|
|
"generate_heatmaps",
|
|
"load_label_config",
|
|
]
|
|
|
|
|
|
class HeatmapsConfig(BaseConfig):
|
|
position: Positions = "bottom-left"
|
|
sigma: float = 3.0
|
|
time_scale: float = 1000.0
|
|
frequency_scale: float = 1 / 859.375
|
|
|
|
|
|
class LabelConfig(BaseConfig):
|
|
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
|
|
|
|
|
def generate_heatmaps(
|
|
sound_events: Sequence[data.SoundEventAnnotation],
|
|
spec: xr.DataArray,
|
|
class_names: List[str],
|
|
encoder: Callable[[Iterable[data.Tag]], Optional[str]],
|
|
target_sigma: float = 3.0,
|
|
position: Positions = "bottom-left",
|
|
time_scale: float = 1000.0,
|
|
frequency_scale: float = 1 / 859.375,
|
|
dtype=np.float32,
|
|
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
|
shape = dict(zip(spec.dims, spec.shape))
|
|
|
|
if "time" not in shape or "frequency" not in shape:
|
|
raise ValueError(
|
|
"Spectrogram must have time and frequency dimensions."
|
|
)
|
|
|
|
# Initialize heatmaps
|
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
|
class_heatmap = xr.DataArray(
|
|
data=np.zeros((len(class_names), *spec.shape), dtype=dtype),
|
|
dims=["category", *spec.dims],
|
|
coords={
|
|
"category": [*class_names],
|
|
**spec.coords,
|
|
},
|
|
)
|
|
size_heatmap = xr.DataArray(
|
|
data=np.zeros((2, *spec.shape), dtype=dtype),
|
|
dims=["dimension", *spec.dims],
|
|
coords={
|
|
"dimension": ["width", "height"],
|
|
**spec.coords,
|
|
},
|
|
)
|
|
|
|
for sound_event_annotation in sound_events:
|
|
geom = sound_event_annotation.sound_event.geometry
|
|
if geom is None:
|
|
continue
|
|
|
|
# Get the position of the sound event
|
|
time, frequency = geometry.get_geometry_point(geom, position=position)
|
|
|
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
|
try:
|
|
detection_heatmap = arrays.set_value_at_pos(
|
|
detection_heatmap,
|
|
1.0,
|
|
time=time,
|
|
frequency=frequency,
|
|
)
|
|
except KeyError:
|
|
# Skip the sound event if the position is outside the spectrogram
|
|
continue
|
|
|
|
# Set the size of the sound event at the position in the size heatmap
|
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
|
geom
|
|
)
|
|
|
|
size = np.array(
|
|
[
|
|
(end_time - start_time) * time_scale,
|
|
(high_freq - low_freq) * frequency_scale,
|
|
]
|
|
)
|
|
|
|
size_heatmap = arrays.set_value_at_pos(
|
|
size_heatmap,
|
|
size,
|
|
time=time,
|
|
frequency=frequency,
|
|
)
|
|
|
|
# Get the class name of the sound event
|
|
class_name = encoder(sound_event_annotation.tags)
|
|
|
|
if class_name is None:
|
|
# If the label is None skip the sound event
|
|
continue
|
|
|
|
class_heatmap = arrays.set_value_at_pos(
|
|
class_heatmap,
|
|
1.0,
|
|
time=time,
|
|
frequency=frequency,
|
|
category=class_name,
|
|
)
|
|
|
|
# Apply gaussian filters
|
|
detection_heatmap = xr.apply_ufunc(
|
|
gaussian_filter,
|
|
detection_heatmap,
|
|
target_sigma,
|
|
)
|
|
|
|
class_heatmap = class_heatmap.groupby("category").map(
|
|
gaussian_filter, # type: ignore
|
|
args=(target_sigma,),
|
|
)
|
|
|
|
# Normalize heatmaps
|
|
detection_heatmap = (
|
|
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
|
|
).fillna(0.0)
|
|
|
|
class_heatmap = (
|
|
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
|
).fillna(0.0)
|
|
|
|
return detection_heatmap, class_heatmap, size_heatmap
|
|
|
|
|
|
def load_label_config(
|
|
path: data.PathLike, field: Optional[str] = None
|
|
) -> LabelConfig:
|
|
return load_config(path, schema=LabelConfig, field=field)
|