mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Added plot module
This commit is contained in:
parent
0a22f1798e
commit
638b5c578e
308
batdetect2/plot.py
Normal file
308
batdetect2/plot.py
Normal file
@ -0,0 +1,308 @@
|
||||
"""Plot functions to visualize detections and spectrograms."""
|
||||
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from matplotlib import axes, patches
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
|
||||
from batdetect2.types import (
|
||||
Annotation,
|
||||
ProcessingConfiguration,
|
||||
SpectrogramParameters,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"spectrogram_with_detections",
|
||||
"detection",
|
||||
"detections",
|
||||
"spectrogram",
|
||||
]
|
||||
|
||||
|
||||
def spectrogram(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
cmap: str = "plasma",
|
||||
start_time: float = 0,
|
||||
) -> axes.Axes:
|
||||
"""Plot a spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
||||
used to compute the spectrogram. Defaults to None. If None,
|
||||
the default configuration will be used.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
||||
start_time (float, optional): Start time of the spectrogram.
|
||||
Defaults to 0. This is useful if plotting a spectrogram
|
||||
of a segment of a longer audio file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
axes.Axes: Matplotlib axes object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the spectrogram is not of
|
||||
shape (1, T, F), (1, 1, T, F) or (T, F)
|
||||
"""
|
||||
# Convert to numpy array if needed
|
||||
if isinstance(spec, torch.Tensor):
|
||||
spec = spec.numpy()
|
||||
|
||||
# Remove batch and channel dimensions if present
|
||||
spec = spec.squeeze()
|
||||
|
||||
if spec.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Expected a 2D tensor, got {spec.ndim}D tensor instead."
|
||||
)
|
||||
|
||||
# Get config
|
||||
if config is None:
|
||||
config = DEFAULT_PROCESSING_CONFIGURATIONS.copy()
|
||||
|
||||
# Frequency axis is reversed
|
||||
spec = spec[::-1, :]
|
||||
|
||||
if ax is None:
|
||||
# Using cast to fix typing. pyplot subplots is not
|
||||
# correctly typed.
|
||||
ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1])
|
||||
|
||||
# compute extent
|
||||
extent = _compute_spec_extent(spec.shape, config)
|
||||
|
||||
# add start time
|
||||
extent = (extent[0] + start_time, extent[1] + start_time, *extent[2:])
|
||||
|
||||
ax.imshow(spec, aspect="auto", origin="lower", cmap=cmap, extent=extent)
|
||||
|
||||
ax.set_xlabel("Time (s)")
|
||||
ax.set_ylabel("Frequency (Hz)")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
dets: List[Annotation],
|
||||
config: Optional[ProcessingConfiguration] = None,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
cmap: str = "plasma",
|
||||
with_names: bool = True,
|
||||
start_time: float = 0,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
"""Plot a spectrogram with detections.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
|
||||
detections (List[Annotation]): List of detections.
|
||||
config (Optional[ProcessingConfiguration], optional): Configuration
|
||||
used to compute the spectrogram. Defaults to None. If None,
|
||||
the default configuration will be used.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
cmap (str, optional): Colormap to use. Defaults to "plasma".
|
||||
with_names (bool, optional): Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
start_time (float, optional): Start time of the spectrogram.
|
||||
Defaults to 0. This is useful if plotting a spectrogram
|
||||
of a segment of a longer audio file.
|
||||
**kwargs: Additional keyword arguments to pass to the
|
||||
`plot.detections` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
axes.Axes: Matplotlib axes object.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError: If the spectrogram is not of shape (1, F, T),
|
||||
(1, 1, F, T) or (F, T).
|
||||
"""
|
||||
ax = spectrogram(
|
||||
spec,
|
||||
start_time=start_time,
|
||||
config=config,
|
||||
cmap=cmap,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
)
|
||||
|
||||
ax = detections(
|
||||
dets,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
with_names=with_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def detections(
|
||||
dets: List[Annotation],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
with_names: bool = True,
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
"""Plot a list of detections.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dets (List[Annotation]): List of detections.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object.
|
||||
Defaults to None. if provided, the spectrogram will be plotted
|
||||
on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size.
|
||||
Defaults to None. If `ax` is None, this will be used to create
|
||||
a new figure of the given size.
|
||||
with_names (bool, optional): Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
**kwargs: Additional keyword arguments to pass to the
|
||||
`plot.detection` function.
|
||||
|
||||
Returns
|
||||
-------
|
||||
axes.Axes: Matplotlib axes object on which the detections
|
||||
were plotted.
|
||||
"""
|
||||
if ax is None:
|
||||
# Using cast to fix typing. pyplot subplots is not
|
||||
# correctly typed.
|
||||
ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1])
|
||||
|
||||
for det in dets:
|
||||
ax = detection(
|
||||
det,
|
||||
ax=ax,
|
||||
figsize=figsize,
|
||||
with_name=with_names,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def detection(
|
||||
det: Annotation,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
linewidth: float = 1,
|
||||
edgecolor: str = "w",
|
||||
facecolor: str = "none",
|
||||
with_name: bool = True,
|
||||
) -> axes.Axes:
|
||||
"""Plot a single detection.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
det (Annotation): Detection to plot.
|
||||
ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults
|
||||
to None. If provided, the spectrogram will be plotted on this axes.
|
||||
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
|
||||
to None. If `ax` is None, this will be used to create a new figure
|
||||
of the given size.
|
||||
linewidth (float, optional): Line width of the detection.
|
||||
Defaults to 1.
|
||||
edgecolor (str, optional): Edge color of the detection.
|
||||
Defaults to "w", i.e. white.
|
||||
facecolor (str, optional): Face color of the detection.
|
||||
Defaults to "none", i.e. transparent.
|
||||
with_name (bool, optional): Whether to plot the name of the
|
||||
predicted class next to the detection. Defaults to True.
|
||||
|
||||
Returns
|
||||
-------
|
||||
axes.Axes: Matplotlib axes object on which the detection
|
||||
was plotted.
|
||||
"""
|
||||
if ax is None:
|
||||
# Using cast to fix typing. pyplot subplots is not
|
||||
# correctly typed.
|
||||
ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1])
|
||||
|
||||
# Plot detection
|
||||
rect = patches.Rectangle(
|
||||
(det["start_time"], det["low_freq"]),
|
||||
det["end_time"] - det["start_time"],
|
||||
det["high_freq"] - det["low_freq"],
|
||||
linewidth=linewidth,
|
||||
edgecolor=edgecolor,
|
||||
facecolor=facecolor,
|
||||
alpha=det.get("det_prob", 1),
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
|
||||
if with_name:
|
||||
# Add class label
|
||||
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
|
||||
font_info = {
|
||||
"color": "white",
|
||||
"size": 10,
|
||||
"weight": "bold",
|
||||
"alpha": rect.get_alpha(),
|
||||
}
|
||||
y_pos = rect.get_xy()[1] + rect.get_height()
|
||||
ax.text(rect.get_xy()[0], y_pos, txt, fontdict=font_info)
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def _compute_spec_extent(
|
||||
shape: Tuple[int, int],
|
||||
params: SpectrogramParameters,
|
||||
) -> Tuple[float, float, float, float]:
|
||||
"""Compute the extent of a spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape (Tuple[int, int]): Shape of the spectrogram.
|
||||
The first dimension is the frequency axis and the second
|
||||
dimension is the time axis.
|
||||
params (SpectrogramParameters): Spectrogram parameters.
|
||||
Should be the same as the ones used to compute the spectrogram.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[float, float, float, float]: Extent of the spectrogram.
|
||||
The first two values are the minimum and maximum time values,
|
||||
the last two values are the minimum and maximum frequency values.
|
||||
"""
|
||||
fft_win_length = params["fft_win_length"]
|
||||
fft_overlap = params["fft_overlap"]
|
||||
max_freq = params["max_freq"]
|
||||
min_freq = params["min_freq"]
|
||||
|
||||
# compute duration based on spectrogram parameters
|
||||
duration = (shape[1] + 1) * (fft_win_length * (1 - fft_overlap))
|
||||
|
||||
# If the spectrogram is not resized, the duration is correct
|
||||
# but if it is resized, the duration needs to be adjusted
|
||||
resize_factor = params["resize_factor"]
|
||||
spec_height = params["spec_height"]
|
||||
if spec_height * resize_factor == shape[0]:
|
||||
duration = duration / resize_factor
|
||||
|
||||
return 0, duration, min_freq, max_freq
|
Loading…
Reference in New Issue
Block a user