mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add functional versions of metric and plotting utils
This commit is contained in:
parent
960b9a92e4
commit
a4498cfd83
@ -9,6 +9,7 @@ from typing import (
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
@ -18,7 +19,10 @@ from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.evaluate.metrics.common import (
|
||||
average_precision,
|
||||
compute_precision_recall,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -265,3 +269,24 @@ def _extract_per_class_metric_data(
|
||||
y_score[class_name].append(m.score)
|
||||
|
||||
return y_true, y_score, num_positives
|
||||
|
||||
|
||||
def compute_precision_recall_curves(
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
) -> Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=ignore_non_predictions,
|
||||
ignore_generic=ignore_generic,
|
||||
)
|
||||
|
||||
return {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives=num_positives[class_name],
|
||||
)
|
||||
for class_name in y_true
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@ from soundevent import data
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TopClassMetricConfig",
|
||||
@ -312,3 +313,61 @@ TopClassMetricConfig = Annotated[
|
||||
|
||||
def build_top_class_metric(config: TopClassMetricConfig):
|
||||
return top_class_metrics.build(config)
|
||||
|
||||
|
||||
def compute_confusion_matrix(
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
targets: TargetProtocol,
|
||||
threshold: float = 0.2,
|
||||
normalize: Literal["true", "pred", "all", "none"] = "true",
|
||||
exclude_generic: bool = True,
|
||||
exclude_false_positives: bool = True,
|
||||
exclude_false_negatives: bool = True,
|
||||
noise_class: str = "noise",
|
||||
):
|
||||
y_true: List[str] = []
|
||||
y_pred: List[str] = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
true_class = m.true_class
|
||||
pred_class = m.pred_class
|
||||
|
||||
if not m.is_prediction and exclude_false_negatives:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
if not m.is_ground_truth and exclude_false_positives:
|
||||
# Ignore matches that don't correspond to a ground truth
|
||||
continue
|
||||
|
||||
if m.score < threshold:
|
||||
if exclude_false_negatives:
|
||||
continue
|
||||
|
||||
pred_class = noise_class
|
||||
|
||||
if m.is_generic:
|
||||
if exclude_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
true_class = targets.detection_class_name
|
||||
|
||||
y_true.append(true_class or noise_class)
|
||||
y_pred.append(pred_class or noise_class)
|
||||
|
||||
labels = sorted(targets.class_names)
|
||||
|
||||
if not exclude_generic:
|
||||
labels.append(targets.detection_class_name)
|
||||
|
||||
if not exclude_false_positives or not exclude_false_negatives:
|
||||
labels.append(noise_class)
|
||||
|
||||
return metrics.confusion_matrix(
|
||||
y_true,
|
||||
y_pred,
|
||||
labels=labels,
|
||||
normalize=normalize,
|
||||
), labels
|
||||
|
||||
@ -18,8 +18,8 @@ from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClipEval,
|
||||
_extract_per_class_metric_data,
|
||||
compute_precision_recall_curves,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import (
|
||||
plot_pr_curve,
|
||||
@ -69,21 +69,12 @@ class PRCurve(BasePlot):
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
data = compute_precision_recall_curves(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives=num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
@ -141,21 +132,12 @@ class ThresholdPrecisionCurve(BasePlot):
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
data = compute_precision_recall_curves(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
@ -223,21 +205,12 @@ class ThresholdRecallCurve(BasePlot):
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
data = compute_precision_recall_curves(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
@ -23,7 +23,11 @@ from sklearn import metrics
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
compute_confusion_matrix,
|
||||
)
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
@ -186,6 +190,8 @@ class ConfusionMatrix(BasePlot):
|
||||
self,
|
||||
*args,
|
||||
exclude_generic: bool = True,
|
||||
exclude_false_positives: bool = True,
|
||||
exclude_false_negatives: bool = True,
|
||||
exclude_noise: bool = False,
|
||||
noise_class: str = "noise",
|
||||
add_colorbar: bool = True,
|
||||
@ -196,9 +202,11 @@ class ConfusionMatrix(BasePlot):
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.exclude_generic = exclude_generic
|
||||
self.exclude_false_positives = exclude_false_positives
|
||||
self.exclude_false_negatives = exclude_false_negatives
|
||||
self.exclude_noise = exclude_noise
|
||||
self.noise_class = noise_class
|
||||
self.normalize = normalize
|
||||
self.normalize: Literal["true", "pred", "all", "none"] = normalize
|
||||
self.add_colorbar = add_colorbar
|
||||
self.threshold = threshold
|
||||
self.cmap = cmap
|
||||
@ -207,58 +215,25 @@ class ConfusionMatrix(BasePlot):
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true: List[str] = []
|
||||
y_pred: List[str] = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
true_class = m.true_class
|
||||
pred_class = m.pred_class
|
||||
|
||||
if not m.is_prediction and self.exclude_noise:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
if not m.is_ground_truth and self.exclude_noise:
|
||||
# Ignore matches that don't correspond to a ground truth
|
||||
continue
|
||||
|
||||
if m.score < self.threshold:
|
||||
if self.exclude_noise:
|
||||
continue
|
||||
|
||||
pred_class = self.noise_class
|
||||
|
||||
if m.is_generic:
|
||||
if self.exclude_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
true_class = self.targets.detection_class_name
|
||||
|
||||
y_true.append(true_class or self.noise_class)
|
||||
y_pred.append(pred_class or self.noise_class)
|
||||
cm, labels = compute_confusion_matrix(
|
||||
clip_evaluations,
|
||||
self.targets,
|
||||
threshold=self.threshold,
|
||||
normalize=self.normalize,
|
||||
exclude_generic=self.exclude_generic,
|
||||
exclude_false_positives=self.exclude_false_positives,
|
||||
exclude_false_negatives=self.exclude_false_negatives,
|
||||
noise_class=self.noise_class,
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
class_names = [*self.targets.class_names]
|
||||
|
||||
if not self.exclude_generic:
|
||||
class_names.append(self.targets.detection_class_name)
|
||||
|
||||
if not self.exclude_noise:
|
||||
class_names.append(self.noise_class)
|
||||
|
||||
metrics.ConfusionMatrixDisplay.from_predictions(
|
||||
y_true,
|
||||
y_pred,
|
||||
labels=class_names,
|
||||
metrics.ConfusionMatrixDisplay(cm, display_labels=labels).plot(
|
||||
ax=ax,
|
||||
xticks_rotation="vertical",
|
||||
cmap=self.cmap,
|
||||
colorbar=self.add_colorbar,
|
||||
normalize=self.normalize if self.normalize != "none" else None,
|
||||
values_format=".2f",
|
||||
)
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
@ -34,8 +34,14 @@ def plot_pr_curve(
|
||||
thresholds: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
color: Union[str, Tuple[float, float, float], None] = None,
|
||||
add_labels: bool = True,
|
||||
add_legend: bool = False,
|
||||
marker: Union[str, Tuple[int, int, float], None] = "o",
|
||||
markeredgecolor: Union[str, Tuple[float, float, float], None] = None,
|
||||
markersize: Optional[float] = None,
|
||||
linestyle: Union[str, Tuple[int, ...], None] = None,
|
||||
linewidth: Optional[float] = None,
|
||||
label: str = "PR Curve",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
@ -45,9 +51,14 @@ def plot_pr_curve(
|
||||
ax.plot(
|
||||
recall,
|
||||
precision,
|
||||
color=color,
|
||||
label=label,
|
||||
marker="o",
|
||||
marker=marker,
|
||||
markeredgecolor=markeredgecolor,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
markersize=markersize,
|
||||
linestyle=linestyle,
|
||||
linewidth=linewidth,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user