From d7e61ccd43eb7ad9fe31434550f12e35c4a1ceb8 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 4 May 2026 16:48:49 +0100 Subject: [PATCH] feat(plotting): add size heatmap label plotting --- src/batdetect2/plotting/__init__.py | 2 + src/batdetect2/plotting/heatmaps.py | 94 ++++++++++++++++++++++++++++- 2 files changed, 94 insertions(+), 2 deletions(-) diff --git a/src/batdetect2/plotting/__init__.py b/src/batdetect2/plotting/__init__.py index 08f8378..5287b79 100644 --- a/src/batdetect2/plotting/__init__.py +++ b/src/batdetect2/plotting/__init__.py @@ -6,6 +6,7 @@ from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.heatmaps import ( plot_classification_heatmap, plot_detection_heatmap, + plot_size_heatmap, ) from batdetect2.plotting.matches import ( plot_cross_trigger_match, @@ -25,5 +26,6 @@ __all__ = [ "plot_true_positive_match", "plot_detection_heatmap", "plot_classification_heatmap", + "plot_size_heatmap", "plot_match_gallery", ] diff --git a/src/batdetect2/plotting/heatmaps.py b/src/batdetect2/plotting/heatmaps.py index db29c82..9f4021f 100644 --- a/src/batdetect2/plotting/heatmaps.py +++ b/src/batdetect2/plotting/heatmaps.py @@ -1,4 +1,4 @@ -"""Plot heatmaps""" +"""Plot heatmaps.""" import numpy as np import torch @@ -8,6 +8,12 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba from batdetect2.plotting.common import create_ax +__all__ = [ + "plot_detection_heatmap", + "plot_classification_heatmap", + "plot_size_heatmap", +] + def plot_detection_heatmap( heatmap: torch.Tensor | np.ndarray, @@ -108,7 +114,91 @@ def plot_classification_heatmap( return ax -def create_colormap(color: str) -> Colormap: +def plot_size_heatmap( + heatmap: torch.Tensor | np.ndarray, + dimension_names: list[str], + ax: axes.Axes | None = None, + figsize: tuple[int, int] = (10, 10), + color: str = "crimson", + size: float = 20, + fontsize: float = 8, +) -> axes.Axes: + """Plot sparse size labels from a size heatmap. + + Parameters + ---------- + heatmap : torch.Tensor | np.ndarray + Size heatmap with shape ``[num_dims, height, width]``. Entries are + expected to be zero everywhere except at labelled positions. + dimension_names : list[str] + Names corresponding to the first heatmap dimension. + ax : matplotlib.axes.Axes | None, default=None + Axis to plot on. If ``None``, a new axis is created. + figsize : tuple[int, int], default=(10, 10) + Figure size used when creating a new axis. + color : str, default="crimson" + Color used for scatter points and text labels. + size : float, default=20 + Marker size for plotted points. + fontsize : float, default=8 + Font size used for the text labels. + + Returns + ------- + matplotlib.axes.Axes + Axis containing the plotted size labels. + """ + ax = create_ax(ax, figsize=figsize) + + if isinstance(heatmap, torch.Tensor): + heatmap = heatmap.numpy() + + if heatmap.ndim == 4: + heatmap = heatmap[0] + + if heatmap.ndim != 3: + raise ValueError("Expecting a 3-dimensional array") + + if len(dimension_names) != heatmap.shape[0]: + raise ValueError("Inconsistent number of dimension names") + + point_mask = np.any(heatmap != 0, axis=0) + rows, cols = np.nonzero(point_mask) + + if len(rows) == 0: + return ax + + ax.scatter(cols, rows, c=color, s=size) + + for row, col in zip(rows, cols, strict=False): + values = heatmap[:, row, col] + labels = [ + f"{name}={value:.2f}" + for name, value in zip( + dimension_names, + values, + strict=False, + ) + if value != 0 + ] + ax.text( + float(col), + float(row), + "\n".join(labels), + fontsize=fontsize, + color=color, + va="bottom", + ha="left", + ) + + ax.set_xlim(0, heatmap.shape[2]) + ax.set_ylim(0, heatmap.shape[1]) + return ax + + +def create_colormap( + color: str | tuple[float, float, float, float], +) -> Colormap: (r, g, b, a) = to_rgba(color) return LinearSegmentedColormap.from_list( "cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]