From a4498cfd83b6fef85c6ee72e8bbc9da00f4722af Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 16 Nov 2025 21:37:47 +0000 Subject: [PATCH] Add functional versions of metric and plotting utils --- .../evaluate/metrics/classification.py | 27 +++++++- src/batdetect2/evaluate/metrics/top_class.py | 59 ++++++++++++++++ .../evaluate/plots/classification.py | 35 ++-------- src/batdetect2/evaluate/plots/top_class.py | 67 ++++++------------- src/batdetect2/plotting/metrics.py | 15 ++++- 5 files changed, 123 insertions(+), 80 deletions(-) diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index 404c57f..d55d27b 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -9,6 +9,7 @@ from typing import ( Mapping, Optional, Sequence, + Tuple, Union, ) @@ -18,7 +19,10 @@ from sklearn import metrics from soundevent import data from batdetect2.core import BaseConfig, Registry -from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.evaluate.metrics.common import ( + average_precision, + compute_precision_recall, +) from batdetect2.typing import RawPrediction, TargetProtocol __all__ = [ @@ -265,3 +269,24 @@ def _extract_per_class_metric_data( y_score[class_name].append(m.score) return y_true, y_score, num_positives + + +def compute_precision_recall_curves( + clip_evaluations: Sequence[ClipEval], + ignore_non_predictions: bool = True, + ignore_generic: bool = True, +) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]: + y_true, y_score, num_positives = _extract_per_class_metric_data( + clip_evaluations, + ignore_non_predictions=ignore_non_predictions, + ignore_generic=ignore_generic, + ) + + return { + class_name: compute_precision_recall( + y_true[class_name], + y_score[class_name], + num_positives=num_positives[class_name], + ) + for class_name in y_true + } diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py index 0f76e2a..16b0fd7 100644 --- a/src/batdetect2/evaluate/metrics/top_class.py +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -18,6 +18,7 @@ from soundevent import data from batdetect2.core import BaseConfig, Registry from batdetect2.evaluate.metrics.common import average_precision from batdetect2.typing import RawPrediction +from batdetect2.typing.targets import TargetProtocol __all__ = [ "TopClassMetricConfig", @@ -312,3 +313,61 @@ TopClassMetricConfig = Annotated[ def build_top_class_metric(config: TopClassMetricConfig): return top_class_metrics.build(config) + + +def compute_confusion_matrix( + clip_evaluations: Sequence[ClipEval], + targets: TargetProtocol, + threshold: float = 0.2, + normalize: Literal["true", "pred", "all", "none"] = "true", + exclude_generic: bool = True, + exclude_false_positives: bool = True, + exclude_false_negatives: bool = True, + noise_class: str = "noise", +): + y_true: List[str] = [] + y_pred: List[str] = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + true_class = m.true_class + pred_class = m.pred_class + + if not m.is_prediction and exclude_false_negatives: + # Ignore matches that don't correspond to a prediction + continue + + if not m.is_ground_truth and exclude_false_positives: + # Ignore matches that don't correspond to a ground truth + continue + + if m.score < threshold: + if exclude_false_negatives: + continue + + pred_class = noise_class + + if m.is_generic: + if exclude_generic: + # Ignore gt sounds with unknown class + continue + + true_class = targets.detection_class_name + + y_true.append(true_class or noise_class) + y_pred.append(pred_class or noise_class) + + labels = sorted(targets.class_names) + + if not exclude_generic: + labels.append(targets.detection_class_name) + + if not exclude_false_positives or not exclude_false_negatives: + labels.append(noise_class) + + return metrics.confusion_matrix( + y_true, + y_pred, + labels=labels, + normalize=normalize, + ), labels diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py index bc3faac..dc9a1e4 100644 --- a/src/batdetect2/evaluate/plots/classification.py +++ b/src/batdetect2/evaluate/plots/classification.py @@ -18,8 +18,8 @@ from batdetect2.core import Registry from batdetect2.evaluate.metrics.classification import ( ClipEval, _extract_per_class_metric_data, + compute_precision_recall_curves, ) -from batdetect2.evaluate.metrics.common import compute_precision_recall from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.plotting.metrics import ( plot_pr_curve, @@ -69,21 +69,12 @@ class PRCurve(BasePlot): self, clip_evaluations: Sequence[ClipEval], ) -> Iterable[Tuple[str, Figure]]: - y_true, y_score, num_positives = _extract_per_class_metric_data( + data = compute_precision_recall_curves( clip_evaluations, ignore_non_predictions=self.ignore_non_predictions, ignore_generic=self.ignore_generic, ) - data = { - class_name: compute_precision_recall( - y_true[class_name], - y_score[class_name], - num_positives=num_positives[class_name], - ) - for class_name in self.targets.class_names - } - if not self.separate_figures: fig = self.create_figure() ax = fig.subplots() @@ -141,21 +132,12 @@ class ThresholdPrecisionCurve(BasePlot): self, clip_evaluations: Sequence[ClipEval], ) -> Iterable[Tuple[str, Figure]]: - y_true, y_score, num_positives = _extract_per_class_metric_data( + data = compute_precision_recall_curves( clip_evaluations, ignore_non_predictions=self.ignore_non_predictions, ignore_generic=self.ignore_generic, ) - data = { - class_name: compute_precision_recall( - y_true[class_name], - y_score[class_name], - num_positives[class_name], - ) - for class_name in self.targets.class_names - } - if not self.separate_figures: fig = self.create_figure() ax = fig.subplots() @@ -223,21 +205,12 @@ class ThresholdRecallCurve(BasePlot): self, clip_evaluations: Sequence[ClipEval], ) -> Iterable[Tuple[str, Figure]]: - y_true, y_score, num_positives = _extract_per_class_metric_data( + data = compute_precision_recall_curves( clip_evaluations, ignore_non_predictions=self.ignore_non_predictions, ignore_generic=self.ignore_generic, ) - data = { - class_name: compute_precision_recall( - y_true[class_name], - y_score[class_name], - num_positives[class_name], - ) - for class_name in self.targets.class_names - } - if not self.separate_figures: fig = self.create_figure() ax = fig.subplots() diff --git a/src/batdetect2/evaluate/plots/top_class.py b/src/batdetect2/evaluate/plots/top_class.py index 32f354c..fe43263 100644 --- a/src/batdetect2/evaluate/plots/top_class.py +++ b/src/batdetect2/evaluate/plots/top_class.py @@ -23,7 +23,11 @@ from sklearn import metrics from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.core import Registry from batdetect2.evaluate.metrics.common import compute_precision_recall -from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval +from batdetect2.evaluate.metrics.top_class import ( + ClipEval, + MatchEval, + compute_confusion_matrix, +) from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve @@ -186,6 +190,8 @@ class ConfusionMatrix(BasePlot): self, *args, exclude_generic: bool = True, + exclude_false_positives: bool = True, + exclude_false_negatives: bool = True, exclude_noise: bool = False, noise_class: str = "noise", add_colorbar: bool = True, @@ -196,9 +202,11 @@ class ConfusionMatrix(BasePlot): ): super().__init__(*args, **kwargs) self.exclude_generic = exclude_generic + self.exclude_false_positives = exclude_false_positives + self.exclude_false_negatives = exclude_false_negatives self.exclude_noise = exclude_noise self.noise_class = noise_class - self.normalize = normalize + self.normalize: Literal["true", "pred", "all", "none"] = normalize self.add_colorbar = add_colorbar self.threshold = threshold self.cmap = cmap @@ -207,58 +215,25 @@ class ConfusionMatrix(BasePlot): self, clip_evaluations: Sequence[ClipEval], ) -> Iterable[Tuple[str, Figure]]: - y_true: List[str] = [] - y_pred: List[str] = [] - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - true_class = m.true_class - pred_class = m.pred_class - - if not m.is_prediction and self.exclude_noise: - # Ignore matches that don't correspond to a prediction - continue - - if not m.is_ground_truth and self.exclude_noise: - # Ignore matches that don't correspond to a ground truth - continue - - if m.score < self.threshold: - if self.exclude_noise: - continue - - pred_class = self.noise_class - - if m.is_generic: - if self.exclude_generic: - # Ignore gt sounds with unknown class - continue - - true_class = self.targets.detection_class_name - - y_true.append(true_class or self.noise_class) - y_pred.append(pred_class or self.noise_class) + cm, labels = compute_confusion_matrix( + clip_evaluations, + self.targets, + threshold=self.threshold, + normalize=self.normalize, + exclude_generic=self.exclude_generic, + exclude_false_positives=self.exclude_false_positives, + exclude_false_negatives=self.exclude_false_negatives, + noise_class=self.noise_class, + ) fig = self.create_figure() ax = fig.subplots() - class_names = [*self.targets.class_names] - - if not self.exclude_generic: - class_names.append(self.targets.detection_class_name) - - if not self.exclude_noise: - class_names.append(self.noise_class) - - metrics.ConfusionMatrixDisplay.from_predictions( - y_true, - y_pred, - labels=class_names, + metrics.ConfusionMatrixDisplay(cm, display_labels=labels).plot( ax=ax, xticks_rotation="vertical", cmap=self.cmap, colorbar=self.add_colorbar, - normalize=self.normalize if self.normalize != "none" else None, values_format=".2f", ) diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py index 52bf6fe..709c3b9 100644 --- a/src/batdetect2/plotting/metrics.py +++ b/src/batdetect2/plotting/metrics.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional, Tuple +from typing import Dict, Optional, Tuple, Union import numpy as np import seaborn as sns @@ -34,8 +34,14 @@ def plot_pr_curve( thresholds: np.ndarray, ax: Optional[axes.Axes] = None, figsize: Optional[Tuple[int, int]] = None, + color: Union[str, Tuple[float, float, float], None] = None, add_labels: bool = True, add_legend: bool = False, + marker: Union[str, Tuple[int, int, float], None] = "o", + markeredgecolor: Union[str, Tuple[float, float, float], None] = None, + markersize: Optional[float] = None, + linestyle: Union[str, Tuple[int, ...], None] = None, + linewidth: Optional[float] = None, label: str = "PR Curve", ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) @@ -45,9 +51,14 @@ def plot_pr_curve( ax.plot( recall, precision, + color=color, label=label, - marker="o", + marker=marker, + markeredgecolor=markeredgecolor, markevery=_get_marker_positions(thresholds), + markersize=markersize, + linestyle=linestyle, + linewidth=linewidth, ) ax.set_xlim(0, 1.05)