Adding plotting functions

This commit is contained in:
mbsantiago 2025-08-08 12:25:26 +01:00
parent e1908c35ca
commit 87ce2acd6f
6 changed files with 989 additions and 0 deletions

View File

@ -0,0 +1,21 @@
from batdetect2.plotting.clip_annotations import plot_clip_annotation
from batdetect2.plotting.clip_predictions import plot_clip_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
__all__ = [
"plot_clip",
"plot_clip_annotation",
"plot_clip_prediction",
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
]

View File

@ -0,0 +1,49 @@
from typing import Optional, Tuple
from matplotlib.axes import Axes
from soundevent import data, plot
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_annotation",
]
def plot_clip_annotation(
clip_annotation: data.ClipAnnotation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
cmap: str = "gray",
alpha: float = 1,
linewidth: float = 1,
fill: bool = False,
) -> Axes:
ax = plot_clip(
clip_annotation.clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=cmap,
)
plot.plot_annotations(
clip_annotation.sound_events,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
alpha=alpha,
linewidth=linewidth,
facecolor="none" if not fill else None,
)
return ax

View File

@ -0,0 +1,141 @@
from typing import Iterable, Optional, Tuple
from matplotlib.axes import Axes
from soundevent import data
from soundevent.geometry.operations import Positions, get_geometry_point
from soundevent.plot.common import create_axes
from soundevent.plot.geometries import plot_geometry
from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import PreprocessorProtocol
__all__ = [
"plot_clip_prediction",
]
def plot_clip_prediction(
clip_prediction: data.ClipPrediction,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_legend: bool = False,
spec_cmap: str = "gray",
linewidth: float = 1,
fill: bool = False,
) -> Axes:
ax = plot_clip(
clip_prediction.clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot_predictions(
clip_prediction.sound_events,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=False,
linewidth=linewidth,
facecolor="none" if not fill else None,
legend=add_legend,
)
return ax
def plot_predictions(
predictions: Iterable[data.SoundEventPrediction],
ax: Optional[Axes] = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
legend: bool = True,
max_alpha: float = 0.5,
color: Optional[str] = None,
**kwargs,
):
"""Plot an prediction."""
if ax is None:
ax = create_axes(**kwargs)
if color_mapper is None:
color_mapper = TagColorMapper()
for prediction in predictions:
ax = plot_prediction(
prediction,
ax=ax,
position=position,
color_mapper=color_mapper,
time_offset=time_offset,
freq_offset=freq_offset,
max_alpha=max_alpha,
color=color,
**kwargs,
)
if legend:
ax = add_tags_legend(ax, color_mapper)
return ax
def plot_prediction(
prediction: data.SoundEventPrediction,
ax: Optional[Axes] = None,
position: Positions = "top-right",
color_mapper: Optional[TagColorMapper] = None,
time_offset: float = 0.001,
freq_offset: float = 1000,
max_alpha: float = 0.5,
alpha: Optional[float] = None,
color: Optional[str] = None,
**kwargs,
) -> Axes:
"""Plot an annotation."""
geometry = prediction.sound_event.geometry
if geometry is None:
raise ValueError("Annotation does not have a geometry.")
if ax is None:
ax = create_axes(**kwargs)
if color_mapper is None:
color_mapper = TagColorMapper()
if alpha is None:
alpha = min(prediction.score * max_alpha, 1)
ax = plot_geometry(
geometry,
ax=ax,
color=color,
alpha=alpha,
**kwargs,
)
x, y = get_geometry_point(geometry, position=position)
for index, tag in enumerate(prediction.tags):
color = color_mapper.get_color(tag.tag)
ax = plot_tag(
time=x + time_offset,
frequency=y - index * freq_offset,
color=color,
ax=ax,
alpha=min(tag.score, prediction.score),
**kwargs,
)
return ax

View File

@ -0,0 +1,44 @@
from typing import Optional, Tuple
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
__all__ = [
"plot_clip",
]
def plot_clip(
clip: data.Clip,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
add_colorbar: bool = False,
add_labels: bool = False,
spec_cmap: str = "gray",
) -> Axes:
if ax is None:
_, ax = plt.subplots(figsize=figsize)
if preprocessor is None:
preprocessor = get_default_preprocessor()
spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir)
spec.plot( # type: ignore
ax=ax,
add_colorbar=add_colorbar,
cmap=spec_cmap,
add_labels=add_labels,
vmin=spec.min().item(),
vmax=spec.max().item(),
)
return ax

View File

@ -0,0 +1,317 @@
"""Plot functions to visualize detections and spectrograms."""
from typing import List, Optional, Tuple, Union, cast
import matplotlib.ticker as tick
import numpy as np
import torch
from matplotlib import axes, patches
from matplotlib import pyplot as plt
from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS
from batdetect2.types import (
Annotation,
ProcessingConfiguration,
SpectrogramParameters,
)
__all__ = [
"spectrogram_with_detections",
"detection",
"detections",
"spectrogram",
]
def spectrogram(
spec: Union[torch.Tensor, np.ndarray],
config: Optional[ProcessingConfiguration] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
cmap: str = "plasma",
start_time: float = 0,
) -> axes.Axes:
"""Plot a spectrogram.
Parameters
----------
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
config (Optional[ProcessingConfiguration], optional): Configuration
used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used.
ax (Optional[axes.Axes], optional): Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted
on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size.
Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size.
cmap (str, optional): Colormap to use. Defaults to "plasma".
start_time (float, optional): Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file.
Returns
-------
axes.Axes: Matplotlib axes object.
Raises
------
ValueError: If the spectrogram is not of
shape (1, T, F), (1, 1, T, F) or (T, F)
"""
# Convert to numpy array if needed
if isinstance(spec, torch.Tensor):
spec = spec.detach().cpu().numpy()
# Remove batch and channel dimensions if present
spec = spec.squeeze()
if spec.ndim != 2:
raise ValueError(
f"Expected a 2D tensor, got {spec.ndim}D tensor instead."
)
# Get config
if config is None:
config = DEFAULT_PROCESSING_CONFIGURATIONS.copy()
# Frequency axis is reversed
spec = spec[::-1, :]
if ax is None:
# Using cast to fix typing. pyplot subplots is not
# correctly typed.
ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1])
# compute extent
extent = _compute_spec_extent(spec.shape, config)
# add start time
extent = (extent[0] + start_time, extent[1] + start_time, *extent[2:])
ax.imshow(spec, aspect="auto", origin="lower", cmap=cmap, extent=extent)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Frequency (kHz)")
def y_fmt(x, _):
return f"{x / 1000:.0f}"
ax.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt))
return ax
def spectrogram_with_detections(
spec: Union[torch.Tensor, np.ndarray],
dets: List[Annotation],
config: Optional[ProcessingConfiguration] = None,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
cmap: str = "plasma",
with_names: bool = True,
start_time: float = 0,
**kwargs,
) -> axes.Axes:
"""Plot a spectrogram with detections.
Parameters
----------
spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot.
detections (List[Annotation]): List of detections.
config (Optional[ProcessingConfiguration], optional): Configuration
used to compute the spectrogram. Defaults to None. If None,
the default configuration will be used.
ax (Optional[axes.Axes], optional): Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted
on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size.
Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size.
cmap (str, optional): Colormap to use. Defaults to "plasma".
with_names (bool, optional): Whether to plot the name of the
predicted class next to the detection. Defaults to True.
start_time (float, optional): Start time of the spectrogram.
Defaults to 0. This is useful if plotting a spectrogram
of a segment of a longer audio file.
**kwargs: Additional keyword arguments to pass to the
`plot.detections` function.
Returns
-------
axes.Axes: Matplotlib axes object.
Raises
------
ValueError: If the spectrogram is not of shape (1, F, T),
(1, 1, F, T) or (F, T).
"""
ax = spectrogram(
spec,
start_time=start_time,
config=config,
cmap=cmap,
ax=ax,
figsize=figsize,
)
ax = detections(
dets,
ax=ax,
figsize=figsize,
with_names=with_names,
**kwargs,
)
return ax
def detections(
dets: List[Annotation],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
with_names: bool = True,
**kwargs,
) -> axes.Axes:
"""Plot a list of detections.
Parameters
----------
dets (List[Annotation]): List of detections.
ax (Optional[axes.Axes], optional): Matplotlib axes object.
Defaults to None. if provided, the spectrogram will be plotted
on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size.
Defaults to None. If `ax` is None, this will be used to create
a new figure of the given size.
with_names (bool, optional): Whether to plot the name of the
predicted class next to the detection. Defaults to True.
**kwargs: Additional keyword arguments to pass to the
`plot.detection` function.
Returns
-------
axes.Axes: Matplotlib axes object on which the detections
were plotted.
"""
if ax is None:
# Using cast to fix typing. pyplot subplots is not
# correctly typed.
ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1])
for det in dets:
ax = detection(
det,
ax=ax,
figsize=figsize,
with_name=with_names,
**kwargs,
)
return ax
def detection(
det: Annotation,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
linewidth: float = 1,
edgecolor: str = "w",
facecolor: str = "none",
with_name: bool = True,
) -> axes.Axes:
"""Plot a single detection.
Parameters
----------
det (Annotation): Detection to plot.
ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults
to None. If provided, the spectrogram will be plotted on this axes.
figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults
to None. If `ax` is None, this will be used to create a new figure
of the given size.
linewidth (float, optional): Line width of the detection.
Defaults to 1.
edgecolor (str, optional): Edge color of the detection.
Defaults to "w", i.e. white.
facecolor (str, optional): Face color of the detection.
Defaults to "none", i.e. transparent.
with_name (bool, optional): Whether to plot the name of the
predicted class next to the detection. Defaults to True.
Returns
-------
axes.Axes: Matplotlib axes object on which the detection
was plotted.
"""
if ax is None:
# Using cast to fix typing. pyplot subplots is not
# correctly typed.
ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1])
# Plot detection
rect = patches.Rectangle(
(det["start_time"], det["low_freq"]),
det["end_time"] - det["start_time"],
det["high_freq"] - det["low_freq"],
linewidth=linewidth,
edgecolor=edgecolor,
facecolor=facecolor,
alpha=det.get("det_prob", 1),
)
ax.add_patch(rect)
if with_name:
# Add class label
txt = " ".join([sp[:3] for sp in det["class"].split(" ")])
font_info = {
"color": edgecolor,
"size": 10,
"weight": "bold",
"alpha": rect.get_alpha(),
}
y_pos = rect.get_xy()[1] + rect.get_height()
ax.text(rect.get_xy()[0], y_pos, txt, fontdict=font_info)
return ax
def _compute_spec_extent(
shape: Tuple[int, int],
params: SpectrogramParameters,
) -> Tuple[float, float, float, float]:
"""Compute the extent of a spectrogram.
Parameters
----------
shape (Tuple[int, int]): Shape of the spectrogram.
The first dimension is the frequency axis and the second
dimension is the time axis.
params (SpectrogramParameters): Spectrogram parameters.
Should be the same as the ones used to compute the spectrogram.
Returns
-------
Tuple[float, float, float, float]: Extent of the spectrogram.
The first two values are the minimum and maximum time values,
the last two values are the minimum and maximum frequency values.
"""
fft_win_length = params["fft_win_length"]
fft_overlap = params["fft_overlap"]
max_freq = params["max_freq"]
min_freq = params["min_freq"]
# compute duration based on spectrogram parameters
duration = (shape[1] + 1) * (fft_win_length * (1 - fft_overlap))
# If the spectrogram is not resized, the duration is correct
# but if it is resized, the duration needs to be adjusted
# NOTE: For now we can only detect if the spectrogram is resized
# by checking if the height is equal to the specified height,
# but this could fail.
resize_factor = params["resize_factor"]
spec_height = params["spec_height"]
if spec_height * resize_factor == shape[0]:
duration = duration / resize_factor
return 0, duration, min_freq, max_freq

View File

@ -0,0 +1,417 @@
from typing import List, Optional, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.plotting.clip_predictions import plot_prediction
from batdetect2.plotting.clips import plot_clip
from batdetect2.preprocess import (
PreprocessorProtocol,
get_default_preprocessor,
)
__all__ = [
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
"plot_cross_trigger_match",
]
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
DEFAULT_TRUE_POSITIVE_COLOR = "green"
DEFAULT_CROSS_TRIGGER_COLOR = "orange"
DEFAULT_ANNOTATION_LINE_STYLE = "-"
DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[data.Match],
clip: data.Clip,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
if preprocessor is None:
preprocessor = get_default_preprocessor()
ax = plot_clip(
clip,
ax=ax,
figsize=figsize,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
if color_mapper is None:
color_mapper = TagColorMapper()
for match in matches:
if match.source is None and match.target is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_negative_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
elif match.target is None and match.source is not None:
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=false_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
elif match.target is not None and match.source is not None:
plot.plot_annotation(
annotation=match.target,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=annotation_linestyle,
)
plot_prediction(
prediction=match.source,
ax=ax,
time_offset=0.004,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
color=true_positive_color,
color_mapper=color_mapper,
linestyle=prediction_linestyle,
)
else:
continue
return ax
DEFAULT_DURATION = 0.05
def plot_false_positive_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
time_offset: float = 0,
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is not None
assert match.match.target is None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=time_offset,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_false_negative_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.match.source is None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_true_positive_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_TRUE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.target.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax
def plot_cross_trigger_match(
match: MatchEvaluation,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
duration: float = DEFAULT_DURATION,
add_colorbar: bool = False,
add_labels: bool = False,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.match.source is not None
assert match.match.target is not None
sound_event = match.match.source.sound_event
geometry = sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
),
recording=sound_event.recording,
)
ax = plot_clip(
clip,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
add_colorbar=add_colorbar,
add_labels=add_labels,
spec_cmap=spec_cmap,
)
plot.plot_annotation(
match.match.target,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=annotation_linestyle,
)
plot_prediction(
match.match.source,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
linestyle=prediction_linestyle,
)
plt.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
return ax