Add functional versions of metric and plotting utils

This commit is contained in:
mbsantiago 2025-11-16 21:37:47 +00:00
parent 960b9a92e4
commit a4498cfd83
5 changed files with 123 additions and 80 deletions

View File

@ -9,6 +9,7 @@ from typing import (
Mapping,
Optional,
Sequence,
Tuple,
Union,
)
@ -18,7 +19,10 @@ from sklearn import metrics
from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.evaluate.metrics.common import (
average_precision,
compute_precision_recall,
)
from batdetect2.typing import RawPrediction, TargetProtocol
__all__ = [
@ -265,3 +269,24 @@ def _extract_per_class_metric_data(
y_score[class_name].append(m.score)
return y_true, y_score, num_positives
def compute_precision_recall_curves(
clip_evaluations: Sequence[ClipEval],
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=ignore_non_predictions,
ignore_generic=ignore_generic,
)
return {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in y_true
}

View File

@ -18,6 +18,7 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"TopClassMetricConfig",
@ -312,3 +313,61 @@ TopClassMetricConfig = Annotated[
def build_top_class_metric(config: TopClassMetricConfig):
return top_class_metrics.build(config)
def compute_confusion_matrix(
clip_evaluations: Sequence[ClipEval],
targets: TargetProtocol,
threshold: float = 0.2,
normalize: Literal["true", "pred", "all", "none"] = "true",
exclude_generic: bool = True,
exclude_false_positives: bool = True,
exclude_false_negatives: bool = True,
noise_class: str = "noise",
):
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and exclude_false_negatives:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and exclude_false_positives:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < threshold:
if exclude_false_negatives:
continue
pred_class = noise_class
if m.is_generic:
if exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = targets.detection_class_name
y_true.append(true_class or noise_class)
y_pred.append(pred_class or noise_class)
labels = sorted(targets.class_names)
if not exclude_generic:
labels.append(targets.detection_class_name)
if not exclude_false_positives or not exclude_false_negatives:
labels.append(noise_class)
return metrics.confusion_matrix(
y_true,
y_pred,
labels=labels,
normalize=normalize,
), labels

View File

@ -18,8 +18,8 @@ from batdetect2.core import Registry
from batdetect2.evaluate.metrics.classification import (
ClipEval,
_extract_per_class_metric_data,
compute_precision_recall_curves,
)
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
@ -69,21 +69,12 @@ class PRCurve(BasePlot):
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
data = compute_precision_recall_curves(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
@ -141,21 +132,12 @@ class ThresholdPrecisionCurve(BasePlot):
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
data = compute_precision_recall_curves(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
@ -223,21 +205,12 @@ class ThresholdRecallCurve(BasePlot):
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
data = compute_precision_recall_curves(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()

View File

@ -23,7 +23,11 @@ from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
MatchEval,
compute_confusion_matrix,
)
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
@ -186,6 +190,8 @@ class ConfusionMatrix(BasePlot):
self,
*args,
exclude_generic: bool = True,
exclude_false_positives: bool = True,
exclude_false_negatives: bool = True,
exclude_noise: bool = False,
noise_class: str = "noise",
add_colorbar: bool = True,
@ -196,9 +202,11 @@ class ConfusionMatrix(BasePlot):
):
super().__init__(*args, **kwargs)
self.exclude_generic = exclude_generic
self.exclude_false_positives = exclude_false_positives
self.exclude_false_negatives = exclude_false_negatives
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.normalize = normalize
self.normalize: Literal["true", "pred", "all", "none"] = normalize
self.add_colorbar = add_colorbar
self.threshold = threshold
self.cmap = cmap
@ -207,58 +215,25 @@ class ConfusionMatrix(BasePlot):
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and self.exclude_noise:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < self.threshold:
if self.exclude_noise:
continue
pred_class = self.noise_class
if m.is_generic:
if self.exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = self.targets.detection_class_name
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
cm, labels = compute_confusion_matrix(
clip_evaluations,
self.targets,
threshold=self.threshold,
normalize=self.normalize,
exclude_generic=self.exclude_generic,
exclude_false_positives=self.exclude_false_positives,
exclude_false_negatives=self.exclude_false_negatives,
noise_class=self.noise_class,
)
fig = self.create_figure()
ax = fig.subplots()
class_names = [*self.targets.class_names]
if not self.exclude_generic:
class_names.append(self.targets.detection_class_name)
if not self.exclude_noise:
class_names.append(self.noise_class)
metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=class_names,
metrics.ConfusionMatrixDisplay(cm, display_labels=labels).plot(
ax=ax,
xticks_rotation="vertical",
cmap=self.cmap,
colorbar=self.add_colorbar,
normalize=self.normalize if self.normalize != "none" else None,
values_format=".2f",
)

View File

@ -1,4 +1,4 @@
from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, Union
import numpy as np
import seaborn as sns
@ -34,8 +34,14 @@ def plot_pr_curve(
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
color: Union[str, Tuple[float, float, float], None] = None,
add_labels: bool = True,
add_legend: bool = False,
marker: Union[str, Tuple[int, int, float], None] = "o",
markeredgecolor: Union[str, Tuple[float, float, float], None] = None,
markersize: Optional[float] = None,
linestyle: Union[str, Tuple[int, ...], None] = None,
linewidth: Optional[float] = None,
label: str = "PR Curve",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
@ -45,9 +51,14 @@ def plot_pr_curve(
ax.plot(
recall,
precision,
color=color,
label=label,
marker="o",
marker=marker,
markeredgecolor=markeredgecolor,
markevery=_get_marker_positions(thresholds),
markersize=markersize,
linestyle=linestyle,
linewidth=linewidth,
)
ax.set_xlim(0, 1.05)