From 87ce2acd6f9b899ffe5ca76e87b289e2884af69b Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 8 Aug 2025 12:25:26 +0100 Subject: [PATCH] Adding plotting functions --- src/batdetect2/plotting/__init__.py | 21 + src/batdetect2/plotting/clip_annotations.py | 49 +++ src/batdetect2/plotting/clip_predictions.py | 141 +++++++ src/batdetect2/plotting/clips.py | 44 +++ src/batdetect2/plotting/legacy/plot.py | 317 +++++++++++++++ src/batdetect2/plotting/matches.py | 417 ++++++++++++++++++++ 6 files changed, 989 insertions(+) create mode 100644 src/batdetect2/plotting/clip_annotations.py create mode 100644 src/batdetect2/plotting/clip_predictions.py create mode 100644 src/batdetect2/plotting/clips.py create mode 100644 src/batdetect2/plotting/legacy/plot.py create mode 100644 src/batdetect2/plotting/matches.py diff --git a/src/batdetect2/plotting/__init__.py b/src/batdetect2/plotting/__init__.py index e69de29..eab0a16 100644 --- a/src/batdetect2/plotting/__init__.py +++ b/src/batdetect2/plotting/__init__.py @@ -0,0 +1,21 @@ +from batdetect2.plotting.clip_annotations import plot_clip_annotation +from batdetect2.plotting.clip_predictions import plot_clip_prediction +from batdetect2.plotting.clips import plot_clip +from batdetect2.plotting.matches import ( + plot_cross_trigger_match, + plot_false_negative_match, + plot_false_positive_match, + plot_matches, + plot_true_positive_match, +) + +__all__ = [ + "plot_clip", + "plot_clip_annotation", + "plot_clip_prediction", + "plot_matches", + "plot_false_positive_match", + "plot_true_positive_match", + "plot_false_negative_match", + "plot_cross_trigger_match", +] diff --git a/src/batdetect2/plotting/clip_annotations.py b/src/batdetect2/plotting/clip_annotations.py new file mode 100644 index 0000000..40016e5 --- /dev/null +++ b/src/batdetect2/plotting/clip_annotations.py @@ -0,0 +1,49 @@ +from typing import Optional, Tuple + +from matplotlib.axes import Axes +from soundevent import data, plot + +from batdetect2.plotting.clips import plot_clip +from batdetect2.preprocess import PreprocessorProtocol + +__all__ = [ + "plot_clip_annotation", +] + + +def plot_clip_annotation( + clip_annotation: data.ClipAnnotation, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + add_colorbar: bool = False, + add_labels: bool = False, + add_points: bool = False, + cmap: str = "gray", + alpha: float = 1, + linewidth: float = 1, + fill: bool = False, +) -> Axes: + ax = plot_clip( + clip_annotation.clip, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + add_colorbar=add_colorbar, + add_labels=add_labels, + spec_cmap=cmap, + ) + + plot.plot_annotations( + clip_annotation.sound_events, + ax=ax, + time_offset=0.004, + freq_offset=2_000, + add_points=add_points, + alpha=alpha, + linewidth=linewidth, + facecolor="none" if not fill else None, + ) + return ax diff --git a/src/batdetect2/plotting/clip_predictions.py b/src/batdetect2/plotting/clip_predictions.py new file mode 100644 index 0000000..b741e61 --- /dev/null +++ b/src/batdetect2/plotting/clip_predictions.py @@ -0,0 +1,141 @@ +from typing import Iterable, Optional, Tuple + +from matplotlib.axes import Axes +from soundevent import data +from soundevent.geometry.operations import Positions, get_geometry_point +from soundevent.plot.common import create_axes +from soundevent.plot.geometries import plot_geometry +from soundevent.plot.tags import TagColorMapper, add_tags_legend, plot_tag + +from batdetect2.plotting.clips import plot_clip +from batdetect2.preprocess import PreprocessorProtocol + +__all__ = [ + "plot_clip_prediction", +] + + +def plot_clip_prediction( + clip_prediction: data.ClipPrediction, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + add_colorbar: bool = False, + add_labels: bool = False, + add_legend: bool = False, + spec_cmap: str = "gray", + linewidth: float = 1, + fill: bool = False, +) -> Axes: + ax = plot_clip( + clip_prediction.clip, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + add_colorbar=add_colorbar, + add_labels=add_labels, + spec_cmap=spec_cmap, + ) + + plot_predictions( + clip_prediction.sound_events, + ax=ax, + time_offset=0.004, + freq_offset=2_000, + add_points=False, + linewidth=linewidth, + facecolor="none" if not fill else None, + legend=add_legend, + ) + return ax + + +def plot_predictions( + predictions: Iterable[data.SoundEventPrediction], + ax: Optional[Axes] = None, + position: Positions = "top-right", + color_mapper: Optional[TagColorMapper] = None, + time_offset: float = 0.001, + freq_offset: float = 1000, + legend: bool = True, + max_alpha: float = 0.5, + color: Optional[str] = None, + **kwargs, +): + """Plot an prediction.""" + if ax is None: + ax = create_axes(**kwargs) + + if color_mapper is None: + color_mapper = TagColorMapper() + + for prediction in predictions: + ax = plot_prediction( + prediction, + ax=ax, + position=position, + color_mapper=color_mapper, + time_offset=time_offset, + freq_offset=freq_offset, + max_alpha=max_alpha, + color=color, + **kwargs, + ) + + if legend: + ax = add_tags_legend(ax, color_mapper) + + return ax + + +def plot_prediction( + prediction: data.SoundEventPrediction, + ax: Optional[Axes] = None, + position: Positions = "top-right", + color_mapper: Optional[TagColorMapper] = None, + time_offset: float = 0.001, + freq_offset: float = 1000, + max_alpha: float = 0.5, + alpha: Optional[float] = None, + color: Optional[str] = None, + **kwargs, +) -> Axes: + """Plot an annotation.""" + geometry = prediction.sound_event.geometry + + if geometry is None: + raise ValueError("Annotation does not have a geometry.") + + if ax is None: + ax = create_axes(**kwargs) + + if color_mapper is None: + color_mapper = TagColorMapper() + + if alpha is None: + alpha = min(prediction.score * max_alpha, 1) + + ax = plot_geometry( + geometry, + ax=ax, + color=color, + alpha=alpha, + **kwargs, + ) + + x, y = get_geometry_point(geometry, position=position) + + for index, tag in enumerate(prediction.tags): + color = color_mapper.get_color(tag.tag) + ax = plot_tag( + time=x + time_offset, + frequency=y - index * freq_offset, + color=color, + ax=ax, + alpha=min(tag.score, prediction.score), + **kwargs, + ) + + return ax diff --git a/src/batdetect2/plotting/clips.py b/src/batdetect2/plotting/clips.py new file mode 100644 index 0000000..df1fb16 --- /dev/null +++ b/src/batdetect2/plotting/clips.py @@ -0,0 +1,44 @@ +from typing import Optional, Tuple + +import matplotlib.pyplot as plt +from matplotlib.axes import Axes +from soundevent import data + +from batdetect2.preprocess import ( + PreprocessorProtocol, + get_default_preprocessor, +) + +__all__ = [ + "plot_clip", +] + + +def plot_clip( + clip: data.Clip, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + add_colorbar: bool = False, + add_labels: bool = False, + spec_cmap: str = "gray", +) -> Axes: + if ax is None: + _, ax = plt.subplots(figsize=figsize) + + if preprocessor is None: + preprocessor = get_default_preprocessor() + + spec = preprocessor.preprocess_clip(clip, audio_dir=audio_dir) + + spec.plot( # type: ignore + ax=ax, + add_colorbar=add_colorbar, + cmap=spec_cmap, + add_labels=add_labels, + vmin=spec.min().item(), + vmax=spec.max().item(), + ) + + return ax diff --git a/src/batdetect2/plotting/legacy/plot.py b/src/batdetect2/plotting/legacy/plot.py new file mode 100644 index 0000000..b9e5d4e --- /dev/null +++ b/src/batdetect2/plotting/legacy/plot.py @@ -0,0 +1,317 @@ +"""Plot functions to visualize detections and spectrograms.""" + +from typing import List, Optional, Tuple, Union, cast + +import matplotlib.ticker as tick +import numpy as np +import torch +from matplotlib import axes, patches +from matplotlib import pyplot as plt + +from batdetect2.detector.parameters import DEFAULT_PROCESSING_CONFIGURATIONS +from batdetect2.types import ( + Annotation, + ProcessingConfiguration, + SpectrogramParameters, +) + +__all__ = [ + "spectrogram_with_detections", + "detection", + "detections", + "spectrogram", +] + + +def spectrogram( + spec: Union[torch.Tensor, np.ndarray], + config: Optional[ProcessingConfiguration] = None, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + cmap: str = "plasma", + start_time: float = 0, +) -> axes.Axes: + """Plot a spectrogram. + + Parameters + ---------- + spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. + config (Optional[ProcessingConfiguration], optional): Configuration + used to compute the spectrogram. Defaults to None. If None, + the default configuration will be used. + ax (Optional[axes.Axes], optional): Matplotlib axes object. + Defaults to None. if provided, the spectrogram will be plotted + on this axes. + figsize (Optional[Tuple[int, int]], optional): Figure size. + Defaults to None. If `ax` is None, this will be used to create + a new figure of the given size. + cmap (str, optional): Colormap to use. Defaults to "plasma". + start_time (float, optional): Start time of the spectrogram. + Defaults to 0. This is useful if plotting a spectrogram + of a segment of a longer audio file. + + Returns + ------- + axes.Axes: Matplotlib axes object. + + Raises + ------ + ValueError: If the spectrogram is not of + shape (1, T, F), (1, 1, T, F) or (T, F) + """ + # Convert to numpy array if needed + if isinstance(spec, torch.Tensor): + spec = spec.detach().cpu().numpy() + + # Remove batch and channel dimensions if present + spec = spec.squeeze() + + if spec.ndim != 2: + raise ValueError( + f"Expected a 2D tensor, got {spec.ndim}D tensor instead." + ) + + # Get config + if config is None: + config = DEFAULT_PROCESSING_CONFIGURATIONS.copy() + + # Frequency axis is reversed + spec = spec[::-1, :] + + if ax is None: + # Using cast to fix typing. pyplot subplots is not + # correctly typed. + ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1]) + + # compute extent + extent = _compute_spec_extent(spec.shape, config) + + # add start time + extent = (extent[0] + start_time, extent[1] + start_time, *extent[2:]) + + ax.imshow(spec, aspect="auto", origin="lower", cmap=cmap, extent=extent) + + ax.set_xlabel("Time (s)") + ax.set_ylabel("Frequency (kHz)") + + def y_fmt(x, _): + return f"{x / 1000:.0f}" + + ax.yaxis.set_major_formatter(tick.FuncFormatter(y_fmt)) + + return ax + + +def spectrogram_with_detections( + spec: Union[torch.Tensor, np.ndarray], + dets: List[Annotation], + config: Optional[ProcessingConfiguration] = None, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + cmap: str = "plasma", + with_names: bool = True, + start_time: float = 0, + **kwargs, +) -> axes.Axes: + """Plot a spectrogram with detections. + + Parameters + ---------- + spec (Union[torch.Tensor, np.ndarray]): Spectrogram to plot. + detections (List[Annotation]): List of detections. + config (Optional[ProcessingConfiguration], optional): Configuration + used to compute the spectrogram. Defaults to None. If None, + the default configuration will be used. + ax (Optional[axes.Axes], optional): Matplotlib axes object. + Defaults to None. if provided, the spectrogram will be plotted + on this axes. + figsize (Optional[Tuple[int, int]], optional): Figure size. + Defaults to None. If `ax` is None, this will be used to create + a new figure of the given size. + cmap (str, optional): Colormap to use. Defaults to "plasma". + with_names (bool, optional): Whether to plot the name of the + predicted class next to the detection. Defaults to True. + start_time (float, optional): Start time of the spectrogram. + Defaults to 0. This is useful if plotting a spectrogram + of a segment of a longer audio file. + **kwargs: Additional keyword arguments to pass to the + `plot.detections` function. + + Returns + ------- + axes.Axes: Matplotlib axes object. + + Raises + ------ + ValueError: If the spectrogram is not of shape (1, F, T), + (1, 1, F, T) or (F, T). + """ + ax = spectrogram( + spec, + start_time=start_time, + config=config, + cmap=cmap, + ax=ax, + figsize=figsize, + ) + + ax = detections( + dets, + ax=ax, + figsize=figsize, + with_names=with_names, + **kwargs, + ) + + return ax + + +def detections( + dets: List[Annotation], + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + with_names: bool = True, + **kwargs, +) -> axes.Axes: + """Plot a list of detections. + + Parameters + ---------- + dets (List[Annotation]): List of detections. + ax (Optional[axes.Axes], optional): Matplotlib axes object. + Defaults to None. if provided, the spectrogram will be plotted + on this axes. + figsize (Optional[Tuple[int, int]], optional): Figure size. + Defaults to None. If `ax` is None, this will be used to create + a new figure of the given size. + with_names (bool, optional): Whether to plot the name of the + predicted class next to the detection. Defaults to True. + **kwargs: Additional keyword arguments to pass to the + `plot.detection` function. + + Returns + ------- + axes.Axes: Matplotlib axes object on which the detections + were plotted. + """ + if ax is None: + # Using cast to fix typing. pyplot subplots is not + # correctly typed. + ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1]) + + for det in dets: + ax = detection( + det, + ax=ax, + figsize=figsize, + with_name=with_names, + **kwargs, + ) + + return ax + + +def detection( + det: Annotation, + ax: Optional[axes.Axes] = None, + figsize: Optional[Tuple[int, int]] = None, + linewidth: float = 1, + edgecolor: str = "w", + facecolor: str = "none", + with_name: bool = True, +) -> axes.Axes: + """Plot a single detection. + + Parameters + ---------- + det (Annotation): Detection to plot. + ax (Optional[axes.Axes], optional): Matplotlib axes object. Defaults + to None. If provided, the spectrogram will be plotted on this axes. + figsize (Optional[Tuple[int, int]], optional): Figure size. Defaults + to None. If `ax` is None, this will be used to create a new figure + of the given size. + linewidth (float, optional): Line width of the detection. + Defaults to 1. + edgecolor (str, optional): Edge color of the detection. + Defaults to "w", i.e. white. + facecolor (str, optional): Face color of the detection. + Defaults to "none", i.e. transparent. + with_name (bool, optional): Whether to plot the name of the + predicted class next to the detection. Defaults to True. + + Returns + ------- + axes.Axes: Matplotlib axes object on which the detection + was plotted. + """ + if ax is None: + # Using cast to fix typing. pyplot subplots is not + # correctly typed. + ax = cast(axes.Axes, plt.subplots(figsize=figsize)[1]) + + # Plot detection + rect = patches.Rectangle( + (det["start_time"], det["low_freq"]), + det["end_time"] - det["start_time"], + det["high_freq"] - det["low_freq"], + linewidth=linewidth, + edgecolor=edgecolor, + facecolor=facecolor, + alpha=det.get("det_prob", 1), + ) + ax.add_patch(rect) + + if with_name: + # Add class label + txt = " ".join([sp[:3] for sp in det["class"].split(" ")]) + font_info = { + "color": edgecolor, + "size": 10, + "weight": "bold", + "alpha": rect.get_alpha(), + } + y_pos = rect.get_xy()[1] + rect.get_height() + ax.text(rect.get_xy()[0], y_pos, txt, fontdict=font_info) + + return ax + + +def _compute_spec_extent( + shape: Tuple[int, int], + params: SpectrogramParameters, +) -> Tuple[float, float, float, float]: + """Compute the extent of a spectrogram. + + Parameters + ---------- + shape (Tuple[int, int]): Shape of the spectrogram. + The first dimension is the frequency axis and the second + dimension is the time axis. + params (SpectrogramParameters): Spectrogram parameters. + Should be the same as the ones used to compute the spectrogram. + + Returns + ------- + Tuple[float, float, float, float]: Extent of the spectrogram. + The first two values are the minimum and maximum time values, + the last two values are the minimum and maximum frequency values. + """ + fft_win_length = params["fft_win_length"] + fft_overlap = params["fft_overlap"] + max_freq = params["max_freq"] + min_freq = params["min_freq"] + + # compute duration based on spectrogram parameters + duration = (shape[1] + 1) * (fft_win_length * (1 - fft_overlap)) + + # If the spectrogram is not resized, the duration is correct + # but if it is resized, the duration needs to be adjusted + # NOTE: For now we can only detect if the spectrogram is resized + # by checking if the height is equal to the specified height, + # but this could fail. + resize_factor = params["resize_factor"] + spec_height = params["spec_height"] + if spec_height * resize_factor == shape[0]: + duration = duration / resize_factor + + return 0, duration, min_freq, max_freq diff --git a/src/batdetect2/plotting/matches.py b/src/batdetect2/plotting/matches.py new file mode 100644 index 0000000..881a24e --- /dev/null +++ b/src/batdetect2/plotting/matches.py @@ -0,0 +1,417 @@ +from typing import List, Optional, Tuple, Union + +import matplotlib.pyplot as plt +from matplotlib.axes import Axes +from soundevent import data, plot +from soundevent.geometry import compute_bounds +from soundevent.plot.tags import TagColorMapper + +from batdetect2.evaluate.types import MatchEvaluation +from batdetect2.plotting.clip_predictions import plot_prediction +from batdetect2.plotting.clips import plot_clip +from batdetect2.preprocess import ( + PreprocessorProtocol, + get_default_preprocessor, +) + +__all__ = [ + "plot_matches", + "plot_false_positive_match", + "plot_true_positive_match", + "plot_false_negative_match", + "plot_cross_trigger_match", +] + +DEFAULT_FALSE_POSITIVE_COLOR = "orange" +DEFAULT_FALSE_NEGATIVE_COLOR = "red" +DEFAULT_TRUE_POSITIVE_COLOR = "green" +DEFAULT_CROSS_TRIGGER_COLOR = "orange" +DEFAULT_ANNOTATION_LINE_STYLE = "-" +DEFAULT_PREDICTION_LINE_STYLE = "--" + + +def plot_matches( + matches: List[data.Match], + clip: data.Clip, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + color_mapper: Optional[TagColorMapper] = None, + add_colorbar: bool = False, + add_labels: bool = False, + add_points: bool = False, + fill: bool = False, + spec_cmap: str = "gray", + false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR, + false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR, + true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR, + annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, + prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, +) -> Axes: + if preprocessor is None: + preprocessor = get_default_preprocessor() + + ax = plot_clip( + clip, + ax=ax, + figsize=figsize, + audio_dir=audio_dir, + add_colorbar=add_colorbar, + add_labels=add_labels, + spec_cmap=spec_cmap, + ) + + if color_mapper is None: + color_mapper = TagColorMapper() + + for match in matches: + if match.source is None and match.target is not None: + plot.plot_annotation( + annotation=match.target, + ax=ax, + time_offset=0.004, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + color=false_negative_color, + color_mapper=color_mapper, + linestyle=annotation_linestyle, + ) + elif match.target is None and match.source is not None: + plot_prediction( + prediction=match.source, + ax=ax, + time_offset=0.004, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + color=false_positive_color, + color_mapper=color_mapper, + linestyle=prediction_linestyle, + ) + elif match.target is not None and match.source is not None: + plot.plot_annotation( + annotation=match.target, + ax=ax, + time_offset=0.004, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + color=true_positive_color, + color_mapper=color_mapper, + linestyle=annotation_linestyle, + ) + plot_prediction( + prediction=match.source, + ax=ax, + time_offset=0.004, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + color=true_positive_color, + color_mapper=color_mapper, + linestyle=prediction_linestyle, + ) + else: + continue + + return ax + + +DEFAULT_DURATION = 0.05 + + +def plot_false_positive_match( + match: MatchEvaluation, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + duration: float = DEFAULT_DURATION, + add_colorbar: bool = False, + add_labels: bool = False, + add_points: bool = False, + fill: bool = False, + spec_cmap: str = "gray", + time_offset: float = 0, + color: str = DEFAULT_FALSE_POSITIVE_COLOR, + fontsize: Union[float, str] = "small", +) -> Axes: + assert match.match.source is not None + assert match.match.target is None + sound_event = match.match.source.sound_event + geometry = sound_event.geometry + assert geometry is not None + + start_time, _, _, high_freq = compute_bounds(geometry) + + clip = data.Clip( + start_time=max(start_time - duration / 2, 0), + end_time=min( + start_time + duration / 2, sound_event.recording.duration + ), + recording=sound_event.recording, + ) + + ax = plot_clip( + clip, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + add_colorbar=add_colorbar, + add_labels=add_labels, + spec_cmap=spec_cmap, + ) + + plot_prediction( + match.match.source, + ax=ax, + time_offset=time_offset, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + alpha=1, + color=color, + ) + + plt.text( + start_time, + high_freq, + f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) + + return ax + + +def plot_false_negative_match( + match: MatchEvaluation, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + duration: float = DEFAULT_DURATION, + add_colorbar: bool = False, + add_labels: bool = False, + add_points: bool = False, + fill: bool = False, + spec_cmap: str = "gray", + color: str = DEFAULT_FALSE_NEGATIVE_COLOR, + fontsize: Union[float, str] = "small", +) -> Axes: + assert match.match.source is None + assert match.match.target is not None + sound_event = match.match.target.sound_event + geometry = sound_event.geometry + assert geometry is not None + + start_time, _, _, high_freq = compute_bounds(geometry) + + clip = data.Clip( + start_time=max(start_time - duration / 2, 0), + end_time=min( + start_time + duration / 2, sound_event.recording.duration + ), + recording=sound_event.recording, + ) + + ax = plot_clip( + clip, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + add_colorbar=add_colorbar, + add_labels=add_labels, + spec_cmap=spec_cmap, + ) + + plot.plot_annotation( + match.match.target, + ax=ax, + time_offset=0.001, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + alpha=1, + color=color, + ) + + plt.text( + start_time, + high_freq, + f"False Negative \nClass: {match.gt_class} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) + + return ax + + +def plot_true_positive_match( + match: MatchEvaluation, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + duration: float = DEFAULT_DURATION, + add_colorbar: bool = False, + add_labels: bool = False, + add_points: bool = False, + fill: bool = False, + spec_cmap: str = "gray", + color: str = DEFAULT_TRUE_POSITIVE_COLOR, + fontsize: Union[float, str] = "small", + annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, + prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, +) -> Axes: + assert match.match.source is not None + assert match.match.target is not None + sound_event = match.match.target.sound_event + geometry = sound_event.geometry + assert geometry is not None + + start_time, _, _, high_freq = compute_bounds(geometry) + + clip = data.Clip( + start_time=max(start_time - duration / 2, 0), + end_time=min( + start_time + duration / 2, sound_event.recording.duration + ), + recording=sound_event.recording, + ) + + ax = plot_clip( + clip, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + add_colorbar=add_colorbar, + add_labels=add_labels, + spec_cmap=spec_cmap, + ) + + plot.plot_annotation( + match.match.target, + ax=ax, + time_offset=0.001, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + alpha=1, + color=color, + linestyle=annotation_linestyle, + ) + + plot_prediction( + match.match.source, + ax=ax, + time_offset=0.001, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + alpha=1, + color=color, + linestyle=prediction_linestyle, + ) + + plt.text( + start_time, + high_freq, + f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) + + return ax + + +def plot_cross_trigger_match( + match: MatchEvaluation, + preprocessor: Optional[PreprocessorProtocol] = None, + figsize: Optional[Tuple[int, int]] = None, + ax: Optional[Axes] = None, + audio_dir: Optional[data.PathLike] = None, + duration: float = DEFAULT_DURATION, + add_colorbar: bool = False, + add_labels: bool = False, + add_points: bool = False, + fill: bool = False, + spec_cmap: str = "gray", + color: str = DEFAULT_CROSS_TRIGGER_COLOR, + fontsize: Union[float, str] = "small", + annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE, + prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE, +) -> Axes: + assert match.match.source is not None + assert match.match.target is not None + sound_event = match.match.source.sound_event + geometry = sound_event.geometry + assert geometry is not None + + start_time, _, _, high_freq = compute_bounds(geometry) + + clip = data.Clip( + start_time=max(start_time - duration / 2, 0), + end_time=min( + start_time + duration / 2, sound_event.recording.duration + ), + recording=sound_event.recording, + ) + + ax = plot_clip( + clip, + preprocessor=preprocessor, + figsize=figsize, + ax=ax, + audio_dir=audio_dir, + add_colorbar=add_colorbar, + add_labels=add_labels, + spec_cmap=spec_cmap, + ) + + plot.plot_annotation( + match.match.target, + ax=ax, + time_offset=0.001, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + alpha=1, + color=color, + linestyle=annotation_linestyle, + ) + + plot_prediction( + match.match.source, + ax=ax, + time_offset=0.001, + freq_offset=2_000, + add_points=add_points, + facecolor="none" if not fill else None, + alpha=1, + color=color, + linestyle=prediction_linestyle, + ) + + plt.text( + start_time, + high_freq, + f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ", + va="top", + ha="right", + color=color, + fontsize=fontsize, + ) + + return ax