mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat(plotting): add size heatmap label plotting
This commit is contained in:
parent
f82ec218f0
commit
d7e61ccd43
@ -6,6 +6,7 @@ from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.heatmaps import (
|
||||
plot_classification_heatmap,
|
||||
plot_detection_heatmap,
|
||||
plot_size_heatmap,
|
||||
)
|
||||
from batdetect2.plotting.matches import (
|
||||
plot_cross_trigger_match,
|
||||
@ -25,5 +26,6 @@ __all__ = [
|
||||
"plot_true_positive_match",
|
||||
"plot_detection_heatmap",
|
||||
"plot_classification_heatmap",
|
||||
"plot_size_heatmap",
|
||||
"plot_match_gallery",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
"""Plot heatmaps"""
|
||||
"""Plot heatmaps."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -8,6 +8,12 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
||||
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
__all__ = [
|
||||
"plot_detection_heatmap",
|
||||
"plot_classification_heatmap",
|
||||
"plot_size_heatmap",
|
||||
]
|
||||
|
||||
|
||||
def plot_detection_heatmap(
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
@ -108,7 +114,91 @@ def plot_classification_heatmap(
|
||||
return ax
|
||||
|
||||
|
||||
def create_colormap(color: str) -> Colormap:
|
||||
def plot_size_heatmap(
|
||||
heatmap: torch.Tensor | np.ndarray,
|
||||
dimension_names: list[str],
|
||||
ax: axes.Axes | None = None,
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
color: str = "crimson",
|
||||
size: float = 20,
|
||||
fontsize: float = 8,
|
||||
) -> axes.Axes:
|
||||
"""Plot sparse size labels from a size heatmap.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
heatmap : torch.Tensor | np.ndarray
|
||||
Size heatmap with shape ``[num_dims, height, width]``. Entries are
|
||||
expected to be zero everywhere except at labelled positions.
|
||||
dimension_names : list[str]
|
||||
Names corresponding to the first heatmap dimension.
|
||||
ax : matplotlib.axes.Axes | None, default=None
|
||||
Axis to plot on. If ``None``, a new axis is created.
|
||||
figsize : tuple[int, int], default=(10, 10)
|
||||
Figure size used when creating a new axis.
|
||||
color : str, default="crimson"
|
||||
Color used for scatter points and text labels.
|
||||
size : float, default=20
|
||||
Marker size for plotted points.
|
||||
fontsize : float, default=8
|
||||
Font size used for the text labels.
|
||||
|
||||
Returns
|
||||
-------
|
||||
matplotlib.axes.Axes
|
||||
Axis containing the plotted size labels.
|
||||
"""
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
if isinstance(heatmap, torch.Tensor):
|
||||
heatmap = heatmap.numpy()
|
||||
|
||||
if heatmap.ndim == 4:
|
||||
heatmap = heatmap[0]
|
||||
|
||||
if heatmap.ndim != 3:
|
||||
raise ValueError("Expecting a 3-dimensional array")
|
||||
|
||||
if len(dimension_names) != heatmap.shape[0]:
|
||||
raise ValueError("Inconsistent number of dimension names")
|
||||
|
||||
point_mask = np.any(heatmap != 0, axis=0)
|
||||
rows, cols = np.nonzero(point_mask)
|
||||
|
||||
if len(rows) == 0:
|
||||
return ax
|
||||
|
||||
ax.scatter(cols, rows, c=color, s=size)
|
||||
|
||||
for row, col in zip(rows, cols, strict=False):
|
||||
values = heatmap[:, row, col]
|
||||
labels = [
|
||||
f"{name}={value:.2f}"
|
||||
for name, value in zip(
|
||||
dimension_names,
|
||||
values,
|
||||
strict=False,
|
||||
)
|
||||
if value != 0
|
||||
]
|
||||
ax.text(
|
||||
float(col),
|
||||
float(row),
|
||||
"\n".join(labels),
|
||||
fontsize=fontsize,
|
||||
color=color,
|
||||
va="bottom",
|
||||
ha="left",
|
||||
)
|
||||
|
||||
ax.set_xlim(0, heatmap.shape[2])
|
||||
ax.set_ylim(0, heatmap.shape[1])
|
||||
return ax
|
||||
|
||||
|
||||
def create_colormap(
|
||||
color: str | tuple[float, float, float, float],
|
||||
) -> Colormap:
|
||||
(r, g, b, a) = to_rgba(color)
|
||||
return LinearSegmentedColormap.from_list(
|
||||
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user