mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat(plotting): add size heatmap label plotting
This commit is contained in:
parent
f82ec218f0
commit
d7e61ccd43
@ -6,6 +6,7 @@ from batdetect2.plotting.gallery import plot_match_gallery
|
|||||||
from batdetect2.plotting.heatmaps import (
|
from batdetect2.plotting.heatmaps import (
|
||||||
plot_classification_heatmap,
|
plot_classification_heatmap,
|
||||||
plot_detection_heatmap,
|
plot_detection_heatmap,
|
||||||
|
plot_size_heatmap,
|
||||||
)
|
)
|
||||||
from batdetect2.plotting.matches import (
|
from batdetect2.plotting.matches import (
|
||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
@ -25,5 +26,6 @@ __all__ = [
|
|||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_detection_heatmap",
|
"plot_detection_heatmap",
|
||||||
"plot_classification_heatmap",
|
"plot_classification_heatmap",
|
||||||
|
"plot_size_heatmap",
|
||||||
"plot_match_gallery",
|
"plot_match_gallery",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
"""Plot heatmaps"""
|
"""Plot heatmaps."""
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -8,6 +8,12 @@ from matplotlib.colors import Colormap, LinearSegmentedColormap, to_rgba
|
|||||||
|
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"plot_detection_heatmap",
|
||||||
|
"plot_classification_heatmap",
|
||||||
|
"plot_size_heatmap",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def plot_detection_heatmap(
|
def plot_detection_heatmap(
|
||||||
heatmap: torch.Tensor | np.ndarray,
|
heatmap: torch.Tensor | np.ndarray,
|
||||||
@ -108,7 +114,91 @@ def plot_classification_heatmap(
|
|||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def create_colormap(color: str) -> Colormap:
|
def plot_size_heatmap(
|
||||||
|
heatmap: torch.Tensor | np.ndarray,
|
||||||
|
dimension_names: list[str],
|
||||||
|
ax: axes.Axes | None = None,
|
||||||
|
figsize: tuple[int, int] = (10, 10),
|
||||||
|
color: str = "crimson",
|
||||||
|
size: float = 20,
|
||||||
|
fontsize: float = 8,
|
||||||
|
) -> axes.Axes:
|
||||||
|
"""Plot sparse size labels from a size heatmap.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
heatmap : torch.Tensor | np.ndarray
|
||||||
|
Size heatmap with shape ``[num_dims, height, width]``. Entries are
|
||||||
|
expected to be zero everywhere except at labelled positions.
|
||||||
|
dimension_names : list[str]
|
||||||
|
Names corresponding to the first heatmap dimension.
|
||||||
|
ax : matplotlib.axes.Axes | None, default=None
|
||||||
|
Axis to plot on. If ``None``, a new axis is created.
|
||||||
|
figsize : tuple[int, int], default=(10, 10)
|
||||||
|
Figure size used when creating a new axis.
|
||||||
|
color : str, default="crimson"
|
||||||
|
Color used for scatter points and text labels.
|
||||||
|
size : float, default=20
|
||||||
|
Marker size for plotted points.
|
||||||
|
fontsize : float, default=8
|
||||||
|
Font size used for the text labels.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
matplotlib.axes.Axes
|
||||||
|
Axis containing the plotted size labels.
|
||||||
|
"""
|
||||||
|
ax = create_ax(ax, figsize=figsize)
|
||||||
|
|
||||||
|
if isinstance(heatmap, torch.Tensor):
|
||||||
|
heatmap = heatmap.numpy()
|
||||||
|
|
||||||
|
if heatmap.ndim == 4:
|
||||||
|
heatmap = heatmap[0]
|
||||||
|
|
||||||
|
if heatmap.ndim != 3:
|
||||||
|
raise ValueError("Expecting a 3-dimensional array")
|
||||||
|
|
||||||
|
if len(dimension_names) != heatmap.shape[0]:
|
||||||
|
raise ValueError("Inconsistent number of dimension names")
|
||||||
|
|
||||||
|
point_mask = np.any(heatmap != 0, axis=0)
|
||||||
|
rows, cols = np.nonzero(point_mask)
|
||||||
|
|
||||||
|
if len(rows) == 0:
|
||||||
|
return ax
|
||||||
|
|
||||||
|
ax.scatter(cols, rows, c=color, s=size)
|
||||||
|
|
||||||
|
for row, col in zip(rows, cols, strict=False):
|
||||||
|
values = heatmap[:, row, col]
|
||||||
|
labels = [
|
||||||
|
f"{name}={value:.2f}"
|
||||||
|
for name, value in zip(
|
||||||
|
dimension_names,
|
||||||
|
values,
|
||||||
|
strict=False,
|
||||||
|
)
|
||||||
|
if value != 0
|
||||||
|
]
|
||||||
|
ax.text(
|
||||||
|
float(col),
|
||||||
|
float(row),
|
||||||
|
"\n".join(labels),
|
||||||
|
fontsize=fontsize,
|
||||||
|
color=color,
|
||||||
|
va="bottom",
|
||||||
|
ha="left",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, heatmap.shape[2])
|
||||||
|
ax.set_ylim(0, heatmap.shape[1])
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def create_colormap(
|
||||||
|
color: str | tuple[float, float, float, float],
|
||||||
|
) -> Colormap:
|
||||||
(r, g, b, a) = to_rgba(color)
|
(r, g, b, a) = to_rgba(color)
|
||||||
return LinearSegmentedColormap.from_list(
|
return LinearSegmentedColormap.from_list(
|
||||||
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
"cmap", colors=[(0, 0, 0, 0), (r, g, b, a)]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user