Re-org gallery example plots

This commit is contained in:
mbsantiago 2025-09-28 15:45:48 +01:00
parent 87ed44c8f7
commit 10865ee600
18 changed files with 785 additions and 901 deletions

View File

@ -30,6 +30,7 @@ __all__ = [
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]

View File

@ -28,6 +28,7 @@ __all__ = [
@dataclass
class MatchEval:
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]

View File

@ -1,560 +0,0 @@
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import BaseConfig, Registry
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipMatches,
MatchEvaluation,
PlotterProtocol,
PreprocessorProtocol,
)
__all__ = [
"build_plotter",
"ExampleGallery",
"ExampleGalleryConfig",
]
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
class ExampleGalleryConfig(BaseConfig):
name: Literal["example_gallery"] = "example_gallery"
examples_per_class: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleGallery(PlotterProtocol):
def __init__(
self,
examples_per_class: int,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.examples_per_class = examples_per_class
self.preprocessor = preprocessor or build_preprocessor()
self.audio_loader = audio_loader or build_audio_loader()
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
per_class_matches = group_matches(clip_evaluations)
for class_name, matches in per_class_matches.items():
true_positives = get_binned_sample(
matches.true_positives,
n_examples=self.examples_per_class,
)
false_positives = get_binned_sample(
matches.false_positives,
n_examples=self.examples_per_class,
)
false_negatives = random.sample(
matches.false_negatives,
k=min(self.examples_per_class, len(matches.false_negatives)),
)
cross_triggers = get_binned_sample(
matches.cross_triggers,
n_examples=self.examples_per_class,
)
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.examples_per_class,
)
yield f"example_gallery/{class_name}", fig
plt.close(fig)
@classmethod
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
examples_per_class=config.examples_per_class,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
class ClipEvaluationPlotConfig(BaseConfig):
name: Literal["example_clip"] = "example_clip"
num_plots: int = 5
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class PlotClipEvaluation(PlotterProtocol):
def __init__(
self,
num_plots: int = 3,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
):
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.num_plots = num_plots
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
examples = random.sample(
clip_evaluations,
k=min(self.num_plots, len(clip_evaluations)),
)
for index, clip_evaluation in enumerate(examples):
fig, ax = plt.subplots()
plot_matches(
clip_evaluation.matches,
clip=clip_evaluation.clip,
audio_loader=self.audio_loader,
ax=ax,
)
yield f"clip_evaluation/example_{index}", fig
plt.close(fig)
@classmethod
def from_config(
cls,
config: ClipEvaluationPlotConfig,
class_names: List[str],
):
audio_loader = build_audio_loader(config.audio)
preprocessor = build_preprocessor(
config.preprocessing,
input_samplerate=audio_loader.samplerate,
)
return cls(
num_plots=config.num_plots,
preprocessor=preprocessor,
audio_loader=audio_loader,
)
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
class DetectionPRCurveConfig(BaseConfig):
name: Literal["detection_pr_curve"] = "detection_pr_curve"
class DetectionPRCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(recall, precision, label="Detector")
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend()
yield "detection_pr_curve", fig
@classmethod
def from_config(
cls,
config: DetectionPRCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
class ClassificationPRCurvesConfig(BaseConfig):
name: Literal["classification_pr_curves"] = "classification_pr_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationPRCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
precision, recall, _ = metrics.precision_recall_curve(
y_true_class,
y_pred_class,
)
ax.plot(recall, precision, label=class_name)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_pr_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationPRCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
class DetectionROCCurveConfig(BaseConfig):
name: Literal["detection_roc_curve"] = "detection_roc_curve"
class DetectionROCCurve(PlotterProtocol):
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true, y_score = zip(
*[
(match.gt_det, match.pred_score)
for clip_eval in clip_evaluations
for match in clip_eval.matches
]
)
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
fig, ax = plt.subplots()
ax.plot(fpr, tpr, label="Detection")
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
yield "detection_roc_curve", fig
@classmethod
def from_config(
cls,
config: DetectionROCCurveConfig,
class_names: List[str],
):
return cls()
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
class ClassificationROCCurvesConfig(BaseConfig):
name: Literal["classification_roc_curves"] = "classification_roc_curves"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClassificationROCCurves(PlotterProtocol):
def __init__(
self,
class_names: List[str],
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.class_names = class_names
self.selected = class_names
if include is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name in include
]
if exclude is not None:
self.selected = [
class_name
for class_name in self.selected
if class_name not in exclude
]
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else "__NONE__"
)
y_pred.append(
np.array(
[
match.pred_class_scores.get(name, 0)
for name in self.class_names
]
)
)
y_true = label_binarize(y_true, classes=self.class_names)
y_pred = np.stack(y_pred)
fig, ax = plt.subplots(figsize=(10, 10))
for class_index, class_name in enumerate(self.class_names):
if class_name not in self.selected:
continue
y_true_class = y_true[:, class_index]
y_roced_class = y_pred[:, class_index]
fpr, tpr, _ = metrics.roc_curve(
y_true_class,
y_roced_class,
)
ax.plot(fpr, tpr, label=class_name)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
yield "classification_roc_curve", fig
@classmethod
def from_config(
cls,
config: ClassificationROCCurvesConfig,
class_names: List[str],
):
return cls(
class_names=class_names,
include=config.include,
exclude=config.exclude,
)
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
class ConfusionMatrixConfig(BaseConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
background_class: str = "noise"
class ConfusionMatrix(PlotterProtocol):
def __init__(self, background_class: str, class_names: List[str]):
self.background_class = background_class
self.class_names = class_names
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else self.background_class
)
top_class = match.top_class
y_pred.append(
top_class
if top_class is not None
else self.background_class
)
display = metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=[*self.class_names, self.background_class],
)
yield "confusion_matrix", display.figure_
@classmethod
def from_config(
cls,
config: ConfusionMatrixConfig,
class_names: List[str],
):
return cls(
background_class=config.background_class,
class_names=class_names,
)
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
PlotConfig = Annotated[
Union[
ExampleGalleryConfig,
ClipEvaluationPlotConfig,
DetectionPRCurveConfig,
ClassificationPRCurvesConfig,
DetectionROCCurveConfig,
ClassificationROCCurvesConfig,
ConfusionMatrixConfig,
],
Field(discriminator="name"),
]
def build_plotter(
config: PlotConfig, class_names: List[str]
) -> PlotterProtocol:
return plots_registry.build(config, class_names)
@dataclass
class ClassMatches:
false_positives: List[MatchEvaluation] = field(default_factory=list)
false_negatives: List[MatchEvaluation] = field(default_factory=list)
true_positives: List[MatchEvaluation] = field(default_factory=list)
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
def group_matches(
clip_evaluations: Sequence[ClipMatches],
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_class = match.gt_class
pred_class = match.top_class
if pred_class is None:
class_examples[gt_class].false_negatives.append(match)
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[gt_class].cross_triggers.append(match)
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[
(index, match.pred_class_scores[pred_class])
for index, match in enumerate(matches)
if (pred_class := match.top_class) is not None
]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -11,7 +11,7 @@ class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
figsize: tuple[int, int] = (5, 5)
figsize: tuple[int, int] = (10, 10)
dpi: int = 100
@ -20,7 +20,7 @@ class BasePlot:
self,
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (5, 5),
figsize: tuple[int, int] = (10, 10),
title: Optional[str] = None,
dpi: int = 100,
theme: str = "default",
@ -32,7 +32,7 @@ class BasePlot:
self.theme = theme
self.title = title
def get_figure(self) -> Figure:
def create_figure(self) -> Figure:
plt.style.use(self.theme)
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)

View File

@ -1,5 +1,15 @@
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
@ -12,14 +22,20 @@ from batdetect2.evaluate.metrics.classification import (
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
plot_threshold_precision_curve,
plot_threshold_precision_curves,
plot_threshold_recall_curve,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
ClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
Registry("classification_plot")
@ -29,8 +45,10 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Classification Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class PRCurve(BasePlot):
@ -39,25 +57,24 @@ class PRCurve(BasePlot):
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
fig = self.get_figure()
ax = fig.subplots()
data = {
class_name: compute_precision_recall(
y_true[class_name],
@ -67,9 +84,23 @@ class PRCurve(BasePlot):
for class_name in self.targets.class_names
}
plot_pr_curves(data, ax=ax)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
return self.label, fig
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(PRCurveConfig)
@staticmethod
@ -79,33 +110,37 @@ class PRCurve(BasePlot):
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdPRCurveConfig(BasePlotConfig):
name: Literal["threshold_pr_curve"] = "threshold_pr_curve"
label: str = "threshold_pr_curve"
figsize: tuple[int, int] = (10, 5)
class ThresholdPrecisionCurveConfig(BasePlotConfig):
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
label: str = "threshold_precision_curve"
title: Optional[str] = "Classification Threshold-Precision Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdPRCurve(BasePlot):
class ThresholdPrecisionCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
@ -121,30 +156,135 @@ class ThresholdPRCurve(BasePlot):
for class_name in self.targets.class_names
}
fig = self.get_figure()
ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True)
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_precision_curves(data, ax=ax1, add_legend=False)
plot_threshold_recall_curves(data, ax=ax2, add_legend=True)
plot_threshold_precision_curves(data, ax=ax)
return self.label, fig
yield self.label, fig
@classification_plots.register(ThresholdPRCurveConfig)
return
for class_name, (precision, _, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_precision_curve(
thresholds,
precision,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdPrecisionCurveConfig)
@staticmethod
def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol):
return ThresholdPRCurve.build(
def from_config(
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
):
return ThresholdPrecisionCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ThresholdRecallCurveConfig(BasePlotConfig):
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
label: str = "threshold_recall_curve"
title: Optional[str] = "Classification Threshold-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ThresholdRecallCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
yield self.label, fig
return
for class_name, (_, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_threshold_recall_curve(
thresholds,
recall,
ax=ax,
)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ThresholdRecallCurveConfig)
@staticmethod
def from_config(
config: ThresholdRecallCurveConfig, targets: TargetProtocol
):
return ThresholdRecallCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Classification ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
separate_figures: bool = False
class ROCCurve(BasePlot):
@ -153,16 +293,18 @@ class ROCCurve(BasePlot):
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
@ -177,12 +319,26 @@ class ROCCurve(BasePlot):
for class_name in self.targets.class_names
}
fig = self.get_figure()
ax = fig.subplots()
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
plot_roc_curves(data, ax=ax)
return self.label, fig
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(ROCCurveConfig)
@staticmethod
@ -192,6 +348,7 @@ class ROCCurve(BasePlot):
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
separate_figures=config.separate_figures,
)
@ -199,7 +356,8 @@ ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPRCurveConfig,
ThresholdPrecisionCurveConfig,
ThresholdRecallCurveConfig,
],
Field(discriminator="name"),
]

View File

@ -1,6 +1,7 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
@ -8,6 +9,7 @@ from typing import (
Union,
)
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
@ -17,7 +19,9 @@ from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curve,
plot_pr_curves,
plot_roc_curve,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
@ -28,7 +32,9 @@ __all__ = [
"build_clip_classification_plotter",
]
ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
ClipClassificationPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_classification_plots: Registry[
ClipClassificationPlotter, [TargetProtocol]
@ -38,14 +44,24 @@ clip_classification_plots: Registry[
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Precision-Recall Curve"
title: Optional[str] = "Clip Classification Precision-Recall Curve"
separate_figures: bool = False
class PRCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
@ -61,10 +77,26 @@ class PRCurve(BasePlot):
data[class_name] = (precision, recall, thresholds)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
return self.label, fig
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (precision, recall, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(PRCurveConfig)
@staticmethod
@ -72,20 +104,31 @@ class PRCurve(BasePlot):
return PRCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "ROC Curve"
title: Optional[str] = "Clip Classification ROC Curve"
separate_figures: bool = False
class ROCCurve(BasePlot):
def __init__(
self,
*args,
separate_figures: bool = False,
**kwargs,
):
super().__init__(*args, **kwargs)
self.separate_figures = separate_figures
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
data = {}
for class_name in self.targets.class_names:
@ -101,10 +144,24 @@ class ROCCurve(BasePlot):
data[class_name] = (fpr, tpr, thresholds)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
return self.label, fig
if not self.separate_figures:
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
yield self.label, fig
return
for class_name, (fpr, tpr, thresholds) in data.items():
fig = self.create_figure()
ax = fig.subplots()
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@clip_classification_plots.register(ROCCurveConfig)
@staticmethod
@ -112,6 +169,7 @@ class ROCCurve(BasePlot):
return ROCCurve.build(
config=config,
targets=targets,
separate_figures=config.separate_figures,
)

View File

@ -1,6 +1,7 @@
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
@ -27,7 +28,9 @@ __all__ = [
"build_clip_detection_plotter",
]
ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
ClipDetectionPlotter = Callable[
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
]
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
@ -38,14 +41,14 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Precision-Recall Curve"
title: Optional[str] = "Clip Detection Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
@ -54,10 +57,10 @@ class PRCurve(BasePlot):
y_score,
)
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
yield self.label, fig
@clip_detection_plots.register(PRCurveConfig)
@staticmethod
@ -71,14 +74,14 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "ROC Curve"
title: Optional[str] = "Clip Detection ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
@ -87,10 +90,10 @@ class ROCCurve(BasePlot):
y_score,
)
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
yield self.label, fig
@clip_detection_plots.register(ROCCurveConfig)
@staticmethod
@ -104,18 +107,18 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Score Distribution"
title: Optional[str] = "Clip Detection Score Distribution"
class ScoreDistributionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
df = pd.DataFrame({"is_true": y_true, "score": y_score})
@ -130,7 +133,7 @@ class ScoreDistributionPlot(BasePlot):
common_norm=False,
)
return self.label, fig
yield self.label, fig
@clip_detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod

View File

@ -1,26 +1,33 @@
import random
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
from typing import (
Annotated,
Callable,
Iterable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib import patches
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from soundevent.plot import plot_geometry
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.detections import plot_clip_detections
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
name="detection_plot"
@ -30,6 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Detection Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -49,7 +57,7 @@ class PRCurve(BasePlot):
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
@ -71,10 +79,12 @@ class PRCurve(BasePlot):
num_positives=num_positives,
)
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
yield self.label, fig
@detection_plots.register(PRCurveConfig)
@staticmethod
@ -90,6 +100,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Detection ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -109,7 +120,7 @@ class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
@ -127,10 +138,12 @@ class ROCCurve(BasePlot):
y_score,
)
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
yield self.label, fig
@detection_plots.register(ROCCurveConfig)
@staticmethod
@ -146,6 +159,7 @@ class ROCCurve(BasePlot):
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Detection Score Distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -165,7 +179,7 @@ class ScoreDistributionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
@ -180,7 +194,7 @@ class ScoreDistributionPlot(BasePlot):
df = pd.DataFrame({"is_true": y_true, "score": y_score})
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
sns.histplot(
@ -194,7 +208,7 @@ class ScoreDistributionPlot(BasePlot):
common_norm=False,
)
return self.label, fig
yield self.label, fig
@detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
@ -212,7 +226,8 @@ class ScoreDistributionPlot(BasePlot):
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
figsize: tuple[int, int] = (10, 15)
title: Optional[str] = "Example Detection"
figsize: tuple[int, int] = (10, 4)
num_examples: int = 5
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
@ -240,82 +255,26 @@ class ExampleDetectionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
fig = self.get_figure()
) -> Iterable[Tuple[str, Figure]]:
sample = clip_evaluations
if self.num_examples < len(sample):
sample = random.sample(sample, self.num_examples)
axes = fig.subplots(nrows=self.num_examples, ncols=1)
for num_example, clip_eval in enumerate(sample):
fig = self.create_figure()
ax = fig.subplots()
for ax, clip_eval in zip(axes, sample):
plot_clip(
clip_eval.clip,
plot_clip_detections(
clip_eval,
ax=ax,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
ax=ax,
)
for m in clip_eval.matches:
is_match = (
m.pred is not None
and m.gt is not None
and m.score >= self.threshold
)
yield f"{self.label}/example_{num_example}", fig
if m.pred is not None:
plot_geometry(
m.pred.geometry,
ax=ax,
add_points=False,
facecolor="none",
alpha=m.pred.detection_score,
linestyle="-" if not is_match else "--",
color="red" if not is_match else "orange",
)
if m.gt is not None:
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
ax=ax,
add_points=False,
facecolor="none",
color="green" if not is_match else "orange",
)
ax.set_title(clip_eval.clip.recording.path.name)
# ax.legend(
# handles=[
# patches.Patch(
# edgecolor="green",
# label="Ground Truth (Unmatched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="orange",
# label="Ground Truth (Matched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="red",
# label="Detection (Unmatched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="orange",
# label="Detection (Matched)",
# facecolor="none",
# linestyle="--",
# ),
# ]
# )
plt.tight_layout()
return self.label, fig
plt.close(fig)
@detection_plots.register(ExampleDetectionPlotConfig)
@staticmethod

View File

@ -1,17 +1,36 @@
from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union
import random
from collections import defaultdict
from dataclasses import dataclass, field
from typing import (
Annotated,
Callable,
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
@ -21,6 +40,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Top Class Precision-Recall Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -40,7 +60,7 @@ class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
num_positives = 0
@ -66,10 +86,12 @@ class PRCurve(BasePlot):
num_positives=num_positives,
)
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
yield self.label, fig
@top_class_plots.register(PRCurveConfig)
@staticmethod
@ -85,6 +107,7 @@ class PRCurve(BasePlot):
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "Top Class ROC Curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
@ -104,7 +127,7 @@ class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true = []
y_score = []
@ -126,10 +149,12 @@ class ROCCurve(BasePlot):
y_score,
)
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
yield self.label, fig
@top_class_plots.register(ROCCurveConfig)
@staticmethod
@ -144,6 +169,7 @@ class ROCCurve(BasePlot):
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
title: Optional[str] = "Top Class Confusion Matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
@ -180,7 +206,7 @@ class ConfusionMatrix(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
) -> Iterable[Tuple[str, Figure]]:
y_true: List[str] = []
y_pred: List[str] = []
@ -213,7 +239,7 @@ class ConfusionMatrix(BasePlot):
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
fig = self.get_figure()
fig = self.create_figure()
ax = fig.subplots()
class_names = [*self.targets.class_names]
@ -236,7 +262,7 @@ class ConfusionMatrix(BasePlot):
values_format=".2f",
)
return self.label, fig
yield self.label, fig
@top_class_plots.register(ConfusionMatrixConfig)
@staticmethod
@ -253,11 +279,105 @@ class ConfusionMatrix(BasePlot):
)
class ExampleClassificationPlotConfig(BasePlotConfig):
name: Literal["example_classification"] = "example_classification"
label: str = "example_classification"
title: Optional[str] = "Example Classification"
num_examples: int = 4
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleClassificationPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 4,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
self.num_examples = num_examples
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Iterable[Tuple[str, Figure]]:
grouped = group_matches(clip_evaluations, threshold=self.threshold)
for class_name, matches in grouped.items():
true_positives: List[MatchEval] = get_binned_sample(
matches.true_positives,
n_examples=self.num_examples,
)
false_positives: List[MatchEval] = get_binned_sample(
matches.false_positives,
n_examples=self.num_examples,
)
false_negatives: List[MatchEval] = random.sample(
matches.false_negatives,
k=min(self.num_examples, len(matches.false_negatives)),
)
cross_triggers: List[MatchEval] = get_binned_sample(
matches.cross_triggers, n_examples=self.num_examples
)
fig = self.create_figure()
fig = plot_match_gallery(
true_positives,
false_positives,
false_negatives,
cross_triggers,
preprocessor=self.preprocessor,
audio_loader=self.audio_loader,
n_examples=self.num_examples,
fig=fig,
)
if self.title is not None:
fig.suptitle(f"{self.title}: {class_name}")
else:
fig.suptitle(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@top_class_plots.register(ExampleClassificationPlotConfig)
@staticmethod
def from_config(
config: ExampleClassificationPlotConfig,
targets: TargetProtocol,
):
return ExampleClassificationPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
threshold=config.threshold,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
ExampleClassificationPlotConfig,
],
Field(discriminator="name"),
]
@ -268,3 +388,57 @@ def build_top_class_plotter(
targets: TargetProtocol,
) -> TopClassPlotter:
return top_class_plots.build(config, targets)
@dataclass
class ClassMatches:
false_positives: List[MatchEval] = field(default_factory=list)
false_negatives: List[MatchEval] = field(default_factory=list)
true_positives: List[MatchEval] = field(default_factory=list)
cross_triggers: List[MatchEval] = field(default_factory=list)
def group_matches(
clip_evals: Sequence[ClipEval],
threshold: float = 0.2,
) -> Dict[str, ClassMatches]:
class_examples = defaultdict(ClassMatches)
for clip_eval in clip_evals:
for match in clip_eval.matches:
gt_class = match.true_class
pred_class = match.pred_class
is_pred = match.score >= threshold
if not is_pred and gt_class is not None:
class_examples[gt_class].false_negatives.append(match)
continue
if not is_pred:
continue
if gt_class is None:
class_examples[pred_class].false_positives.append(match)
continue
if gt_class != pred_class:
class_examples[pred_class].cross_triggers.append(match)
continue
class_examples[gt_class].true_positives.append(match)
return class_examples
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
if len(matches) < n_examples:
return matches
indices, pred_scores = zip(
*[(index, match.score) for index, match in enumerate(matches)]
)
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
df = pd.DataFrame({"indices": indices, "bins": bins})
sample = df.groupby("bins").sample(1)
return [matches[ind] for ind in sample["indices"]]

View File

@ -54,7 +54,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
ignore_start_end: float
@ -68,7 +68,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
prefix: str,
ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
):
self.matcher = matcher
@ -93,7 +93,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots:
yield plot(eval_outputs)
for name, fig in plot(eval_outputs):
yield f"{self.prefix}/{name}", fig
def evaluate(
self,
@ -147,7 +148,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
] = None,
**kwargs,
):

View File

@ -98,6 +98,7 @@ class ClassificationTask(BaseTask[ClipEval]):
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_prediction=pred is not None,

View File

@ -79,6 +79,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
matches.append(
MatchEval(
clip=clip,
gt=gt,
pred=pred,
is_ground_truth=gt is not None,

View File

@ -11,7 +11,6 @@ from batdetect2.plotting.matches import (
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_matches,
plot_true_positive_match,
)
@ -22,7 +21,6 @@ __all__ = [
"plot_cross_trigger_match",
"plot_false_negative_match",
"plot_false_positive_match",
"plot_matches",
"plot_spectrogram",
"plot_true_positive_match",
"plot_detection_heatmap",

View File

@ -66,6 +66,9 @@ def plot_spectrogram(
vmax=vmax,
)
ax.set_xlim(start_time, end_time)
ax.set_ylim(min_freq, max_freq)
if add_colorbar:
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))

View File

@ -0,0 +1,113 @@
from typing import Optional
from matplotlib import axes, patches
from soundevent.plot import plot_geometry
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.plotting.clips import (
AudioLoader,
PreprocessorProtocol,
plot_clip,
)
from batdetect2.plotting.common import create_ax
__all__ = [
"plot_clip_detections",
]
def plot_clip_detections(
clip_eval: ClipEval,
figsize: tuple[int, int] = (10, 10),
ax: Optional[axes.Axes] = None,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
threshold: float = 0.2,
add_legend: bool = True,
add_title: bool = True,
fill: bool = False,
linewidth: float = 1.0,
gt_color: str = "green",
gt_linestyle: str = "-",
true_pred_color: str = "yellow",
true_pred_linestyle: str = "--",
false_pred_color: str = "blue",
false_pred_linestyle: str = "-",
missed_gt_color: str = "red",
missed_gt_linestyle: str = "-",
) -> axes.Axes:
ax = create_ax(figsize=figsize, ax=ax)
plot_clip(
clip_eval.clip,
audio_loader=audio_loader,
preprocessor=preprocessor,
ax=ax,
)
for m in clip_eval.matches:
is_match = (
m.pred is not None and m.gt is not None and m.score >= threshold
)
if m.pred is not None:
color = true_pred_color if is_match else false_pred_color
plot_geometry(
m.pred.geometry,
ax=ax,
add_points=False,
facecolor="none" if not fill else color,
alpha=m.pred.detection_score,
linewidth=linewidth,
linestyle=true_pred_linestyle
if is_match
else missed_gt_linestyle,
color=color,
)
if m.gt is not None:
color = gt_color if is_match else missed_gt_color
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
ax=ax,
add_points=False,
linewidth=linewidth,
facecolor="none" if not fill else color,
linestyle=gt_linestyle if is_match else false_pred_linestyle,
color=color,
)
if add_title:
ax.set_title(clip_eval.clip.recording.path.name)
if add_legend:
ax.legend(
handles=[
patches.Patch(
label="found GT",
edgecolor=gt_color,
facecolor="none" if not fill else gt_color,
linestyle=gt_linestyle,
),
patches.Patch(
label="missed GT",
edgecolor=missed_gt_color,
facecolor="none" if not fill else missed_gt_color,
linestyle=missed_gt_linestyle,
),
patches.Patch(
label="true Det",
edgecolor=true_pred_color,
facecolor="none" if not fill else true_pred_color,
linestyle=true_pred_linestyle,
),
patches.Patch(
label="false Det",
edgecolor=false_pred_color,
facecolor="none" if not fill else false_pred_color,
linestyle=false_pred_linestyle,
),
]
)
return ax

View File

@ -1,81 +1,109 @@
from typing import List, Optional
from typing import Optional, Sequence
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.plotting.matches import (
MatchProtocol,
plot_cross_trigger_match,
plot_false_negative_match,
plot_false_positive_match,
plot_true_positive_match,
)
from batdetect2.typing.evaluate import MatchEvaluation
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
__all__ = ["plot_match_gallery"]
def plot_match_gallery(
true_positives: List[MatchEvaluation],
false_positives: List[MatchEvaluation],
false_negatives: List[MatchEvaluation],
cross_triggers: List[MatchEvaluation],
true_positives: Sequence[MatchProtocol],
false_positives: Sequence[MatchProtocol],
false_negatives: Sequence[MatchProtocol],
cross_triggers: Sequence[MatchProtocol],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
n_examples: int = 5,
duration: float = 0.1,
fig: Optional[Figure] = None,
):
fig = plt.figure(figsize=(20, 20))
if fig is None:
fig = plt.figure(figsize=(20, 20))
for index, match in enumerate(true_positives[:n_examples]):
ax = plt.subplot(4, n_examples, index + 1)
axes = fig.subplots(
nrows=4,
ncols=n_examples,
sharex="none",
sharey="row",
)
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
try:
plot_true_positive_match(
match,
ax=ax,
tp_match,
ax=tp_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
for index, match in enumerate(false_positives[:n_examples]):
ax = plt.subplot(4, n_examples, n_examples + index + 1)
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
try:
plot_false_positive_match(
match,
ax=ax,
fp_match,
ax=fp_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
for index, match in enumerate(false_negatives[:n_examples]):
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
try:
plot_false_negative_match(
match,
ax=ax,
fn_match,
ax=fn_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
for index, match in enumerate(cross_triggers[:n_examples]):
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
try:
plot_cross_trigger_match(
match,
ax=ax,
ct_match,
ax=ct_ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
duration=duration,
)
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
except (
ValueError,
AssertionError,
RuntimeError,
FileNotFoundError,
):
continue
fig.tight_layout()
return fig

View File

@ -1,16 +1,17 @@
from typing import List, Optional, Tuple, Union
from typing import Optional, Protocol, Tuple, Union
import matplotlib.pyplot as plt
from matplotlib.axes import Axes
from soundevent import data, plot
from soundevent.geometry import compute_bounds
from soundevent.plot.tags import TagColorMapper
from batdetect2.plotting.clips import AudioLoader, plot_clip
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
from batdetect2.plotting.clips import plot_clip
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
RawPrediction,
)
__all__ = [
"plot_matches",
"plot_false_positive_match",
"plot_true_positive_match",
"plot_false_negative_match",
@ -18,6 +19,14 @@ __all__ = [
]
class MatchProtocol(Protocol):
clip: data.Clip
gt: Optional[data.SoundEventAnnotation]
pred: Optional[RawPrediction]
score: float
true_class: Optional[str]
DEFAULT_DURATION = 0.05
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
@ -27,88 +36,8 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
DEFAULT_PREDICTION_LINE_STYLE = "--"
def plot_matches(
matches: List[MatchEvaluation],
clip: data.Clip,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
ax: Optional[Axes] = None,
audio_dir: Optional[data.PathLike] = None,
color_mapper: Optional[TagColorMapper] = None,
add_points: bool = False,
fill: bool = False,
spec_cmap: str = "gray",
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
) -> Axes:
ax = plot_clip(
clip,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
if color_mapper is None:
color_mapper = TagColorMapper()
for match in matches:
if match.is_cross_trigger():
plot_cross_trigger_match(
match,
ax=ax,
fill=fill,
add_points=add_points,
add_spectrogram=False,
use_score=True,
color=cross_trigger_color,
add_text=False,
)
elif match.is_true_positive():
plot_true_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=true_positive_color,
add_text=False,
)
elif match.is_false_negative():
plot_false_negative_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
add_points=add_points,
color=false_negative_color,
add_text=False,
)
elif match.is_false_positive:
plot_false_positive_match(
match,
ax=ax,
fill=fill,
add_spectrogram=False,
use_score=True,
add_points=add_points,
color=false_positive_color,
add_text=False,
)
else:
continue
return ax
def plot_false_positive_match(
match: MatchEvaluation,
match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -119,21 +48,24 @@ def plot_false_positive_match(
add_spectrogram: bool = True,
add_text: bool = True,
add_points: bool = False,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.pred_geometry is not None
assert match.sound_event_annotation is None
assert match.pred is not None
start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
start_time=max(
start_time - duration / 2,
0,
),
end_time=min(
start_time + duration / 2,
match.clip.end_time,
match.clip.recording.duration,
),
recording=match.clip.recording,
)
@ -150,30 +82,33 @@ def plot_false_positive_match(
)
ax = plot.plot_geometry(
match.pred_geometry,
match.pred.geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=match.pred_score if use_score else 1,
alpha=match.score if use_score else 1,
color=color,
)
if add_text:
plt.text(
ax.text(
start_time,
high_freq,
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.top_class} \nTop Class Score: {match.top_class_score:.2f} ",
f"score={match.score:.2f}",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("False Positive")
return ax
def plot_false_negative_match(
match: MatchEvaluation,
match: MatchProtocol,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -182,26 +117,28 @@ def plot_false_negative_match(
duration: float = DEFAULT_DURATION,
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
fontsize: Union[float, str] = "small",
) -> Axes:
assert match.pred_geometry is None
assert match.sound_event_annotation is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert match.gt is not None
geometry = match.gt.sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
start_time = compute_bounds(geometry)[0]
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
start_time=max(
start_time - duration / 2,
0,
),
recording=sound_event.recording,
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
)
if add_spectrogram:
@ -215,33 +152,23 @@ def plot_false_negative_match(
spec_cmap=spec_cmap,
)
ax = plot.plot_annotation(
match.sound_event_annotation,
ax = plot.plot_geometry(
geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
color=color,
)
if add_text:
plt.text(
start_time,
high_freq,
f"False Negative \nClass: {match.gt_class} ",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("False Negative")
return ax
def plot_true_positive_match(
match: MatchEvaluation,
match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -258,39 +185,42 @@ def plot_true_positive_match(
fontsize: Union[float, str] = "small",
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
add_title: bool = True,
) -> Axes:
assert match.sound_event_annotation is not None
assert match.pred_geometry is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert match.gt is not None
assert match.pred is not None
geometry = match.gt.sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
start_time=max(
start_time - duration / 2,
0,
),
recording=sound_event.recording,
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
)
if add_spectrogram:
ax = plot_clip(
clip,
ax=ax,
audio_loader=audio_loader,
preprocessor=preprocessor,
figsize=figsize,
ax=ax,
audio_dir=audio_dir,
spec_cmap=spec_cmap,
)
ax = plot.plot_annotation(
match.sound_event_annotation,
ax = plot.plot_geometry(
geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -299,31 +229,34 @@ def plot_true_positive_match(
)
plot.plot_geometry(
match.pred_geometry,
match.pred.geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=match.pred_score if use_score else 1,
alpha=match.score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
if add_text:
plt.text(
ax.text(
start_time,
high_freq,
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
f"score={match.score:.2f}",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("True Positive")
return ax
def plot_cross_trigger_match(
match: MatchEvaluation,
match: MatchProtocol,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
figsize: Optional[Tuple[int, int]] = None,
@ -334,6 +267,7 @@ def plot_cross_trigger_match(
add_spectrogram: bool = True,
add_points: bool = False,
add_text: bool = True,
add_title: bool = True,
fill: bool = False,
spec_cmap: str = "gray",
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
@ -341,20 +275,24 @@ def plot_cross_trigger_match(
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
) -> Axes:
assert match.sound_event_annotation is not None
assert match.pred_geometry is not None
sound_event = match.sound_event_annotation.sound_event
geometry = sound_event.geometry
assert match.gt is not None
assert match.pred is not None
geometry = match.gt.sound_event.geometry
assert geometry is not None
start_time, _, _, high_freq = compute_bounds(geometry)
clip = data.Clip(
start_time=max(start_time - duration / 2, 0),
end_time=min(
start_time + duration / 2, sound_event.recording.duration
start_time=max(
start_time - duration / 2,
0,
),
recording=sound_event.recording,
end_time=min(
start_time + duration / 2,
match.clip.recording.duration,
),
recording=match.clip.recording,
)
if add_spectrogram:
@ -368,11 +306,9 @@ def plot_cross_trigger_match(
spec_cmap=spec_cmap,
)
ax = plot.plot_annotation(
match.sound_event_annotation,
ax = plot.plot_geometry(
geometry,
ax=ax,
time_offset=0.001,
freq_offset=2_000,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=1,
@ -381,24 +317,28 @@ def plot_cross_trigger_match(
)
ax = plot.plot_geometry(
match.pred_geometry,
match.pred.geometry,
ax=ax,
add_points=add_points,
facecolor="none" if not fill else None,
alpha=match.pred_score if use_score else 1,
alpha=match.score if use_score else 1,
color=color,
linestyle=prediction_linestyle,
)
if add_text:
plt.text(
ax.text(
start_time,
high_freq,
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.top_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
f"score={match.score:.2f}\nclass={match.true_class}",
va="top",
ha="right",
color=color,
fontsize=fontsize,
)
if add_title:
ax.set_title("Cross Trigger")
return ax

View File

@ -35,6 +35,8 @@ def plot_pr_curve(
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
add_legend: bool = False,
label: str = "PR Curve",
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
@ -43,7 +45,7 @@ def plot_pr_curve(
ax.plot(
recall,
precision,
label="PR Curve",
label=label,
marker="o",
markevery=_get_marker_positions(thresholds),
)
@ -51,6 +53,9 @@ def plot_pr_curve(
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_legend:
ax.legend()
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")