Plot anchor points

This commit is contained in:
mbsantiago 2025-08-27 18:13:40 +01:00
parent d25efdad10
commit ed76ec24b6
3 changed files with 40 additions and 7 deletions

View File

@ -4,7 +4,9 @@ from matplotlib.axes import Axes
from soundevent import data, plot from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.common import create_ax
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"plot_clip_annotation", "plot_clip_annotation",
@ -43,3 +45,31 @@ def plot_clip_annotation(
facecolor="none" if not fill else None, facecolor="none" if not fill else None,
) )
return ax return ax
def plot_anchor_points(
clip_annotation: data.ClipAnnotation,
targets: TargetProtocol,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
size: int = 1,
color: str = "red",
marker: str = "x",
alpha: float = 1,
) -> Axes:
ax = create_ax(ax=ax, figsize=figsize)
positions = []
for sound_event in clip_annotation.sound_events:
if not targets.filter(sound_event):
continue
sound_event = targets.transform(sound_event)
position, _ = targets.encode_roi(sound_event)
positions.append(position)
X, Y = zip(*positions)
ax.scatter(X, Y, s=size, c=color, marker=marker, alpha=alpha)
return ax

View File

@ -1,6 +1,6 @@
"""General plotting utilities.""" """General plotting utilities."""
from typing import Optional, Tuple from typing import Optional, Tuple, Union
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import numpy as np import numpy as np
@ -25,7 +25,7 @@ def create_ax(
def plot_spectrogram( def plot_spectrogram(
spec: torch.Tensor, spec: Union[torch.Tensor, np.ndarray],
start_time: float, start_time: float,
end_time: float, end_time: float,
min_freq: float, min_freq: float,
@ -34,12 +34,15 @@ def plot_spectrogram(
figsize: Optional[Tuple[int, int]] = None, figsize: Optional[Tuple[int, int]] = None,
cmap="gray", cmap="gray",
) -> axes.Axes: ) -> axes.Axes:
if isinstance(spec, torch.Tensor):
spec = spec.numpy()
ax = create_ax(ax=ax, figsize=figsize) ax = create_ax(ax=ax, figsize=figsize)
ax.pcolormesh( ax.pcolormesh(
np.linspace(start_time, end_time, spec.shape[-1], endpoint=False), np.linspace(start_time, end_time, spec.shape[-1] + 1, endpoint=True),
np.linspace(min_freq, max_freq, spec.shape[-2], endpoint=False), np.linspace(min_freq, max_freq, spec.shape[-2] + 1, endpoint=True),
spec.numpy(), spec,
cmap=cmap, cmap=cmap,
) )
return ax return ax

View File

@ -62,7 +62,7 @@ class LabelConfig(BaseConfig):
diffuse targets. diffuse targets.
""" """
sigma: float = 3.0 sigma: float = 2.0
def build_clip_labeler( def build_clip_labeler(
@ -174,7 +174,7 @@ def generate_clip_label(
def map_to_pixels(x, size, min_val, max_val) -> int: def map_to_pixels(x, size, min_val, max_val) -> int:
return int(np.floor(np.interp(x, [min_val, max_val], [0, size]))) return int(np.interp(x, [min_val, max_val], [0, size]))
def generate_heatmaps( def generate_heatmaps(