feat(plotting): add size heatmap label plotting

This commit is contained in:
mbsantiago 2026-05-04 16:48:49 +01:00
parent f82ec218f0
commit d7e61ccd43
2 changed files with 94 additions and 2 deletions

View File

@ -6,6 +6,7 @@ from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.heatmaps import ( from batdetect2.plotting.heatmaps import (
plot_classification_heatmap, plot_classification_heatmap,
plot_detection_heatmap, plot_detection_heatmap,
plot_size_heatmap,
) )
from batdetect2.plotting.matches import ( from batdetect2.plotting.matches import (
plot_cross_trigger_match, plot_cross_trigger_match,
@ -25,5 +26,6 @@ __all__ = [
"plot_true_positive_match", "plot_true_positive_match",
"plot_detection_heatmap", "plot_detection_heatmap",
"plot_classification_heatmap", "plot_classification_heatmap",
"plot_size_heatmap",
"plot_match_gallery", "plot_match_gallery",
] ]

View File

@ -1,4 +1,4 @@
"""Plot heatmaps""" """Plot heatmaps."""
import numpy as np import numpy as np
import torch import torch
@ -8,6 +8,12 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
from batdetect2.plotting.common import create_ax from batdetect2.plotting.common import create_ax
__all__ = [
"plot_detection_heatmap",
"plot_classification_heatmap",
"plot_size_heatmap",
]
def plot_detection_heatmap( def plot_detection_heatmap(
heatmap: torch.Tensor | np.ndarray, heatmap: torch.Tensor | np.ndarray,
@ -108,7 +114,91 @@ def plot_classification_heatmap(
return ax return ax
def create_colormap(color: str) -> Colormap: def plot_size_heatmap(
heatmap: torch.Tensor | np.ndarray,
dimension_names: list[str],
ax: axes.Axes | None = None,
figsize: tuple[int, int] = (10, 10),
color: str = "crimson",
size: float = 20,
fontsize: float = 8,
) -> axes.Axes:
"""Plot sparse size labels from a size heatmap.
Parameters
----------
heatmap : torch.Tensor | np.ndarray
Size heatmap with shape ``[num_dims, height, width]``. Entries are
expected to be zero everywhere except at labelled positions.
dimension_names : list[str]
Names corresponding to the first heatmap dimension.
ax : matplotlib.axes.Axes | None, default=None
Axis to plot on. If ``None``, a new axis is created.
figsize : tuple[int, int], default=(10, 10)
Figure size used when creating a new axis.
color : str, default="crimson"
Color used for scatter points and text labels.
size : float, default=20
Marker size for plotted points.
fontsize : float, default=8
Font size used for the text labels.
Returns
-------
matplotlib.axes.Axes
Axis containing the plotted size labels.
"""
ax = create_ax(ax, figsize=figsize)
if isinstance(heatmap, torch.Tensor):
heatmap = heatmap.numpy()
if heatmap.ndim == 4:
heatmap = heatmap[0]
if heatmap.ndim != 3:
raise ValueError("Expecting a 3-dimensional array")
if len(dimension_names) != heatmap.shape[0]:
raise ValueError("Inconsistent number of dimension names")
point_mask = np.any(heatmap != 0, axis=0)
rows, cols = np.nonzero(point_mask)
if len(rows) == 0:
return ax
ax.scatter(cols, rows, c=color, s=size)
for row, col in zip(rows, cols, strict=False):
values = heatmap[:, row, col]
labels = [
f"{name}={value:.2f}"
for name, value in zip(
dimension_names,
values,
strict=False,
)
if value != 0
]
ax.text(
float(col),
float(row),
"\n".join(labels),
fontsize=fontsize,
color=color,
va="bottom",
ha="left",
)
ax.set_xlim(0, heatmap.shape[2])
ax.set_ylim(0, heatmap.shape[1])
return ax
def create_colormap(
color: str | tuple[float, float, float, float],
) -> Colormap:
(r, g, b, a) = to_rgba(color) (r, g, b, a) = to_rgba(color)
return LinearSegmentedColormap.from_list( return LinearSegmentedColormap.from_list(
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)] "cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]