mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Re-org gallery example plots
This commit is contained in:
parent
87ed44c8f7
commit
10865ee600
@ -30,6 +30,7 @@ __all__ = [
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
|
||||
@ -28,6 +28,7 @@ __all__ = [
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
|
||||
@ -1,560 +0,0 @@
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.matches import plot_matches
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipMatches,
|
||||
MatchEvaluation,
|
||||
PlotterProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"build_plotter",
|
||||
"ExampleGallery",
|
||||
"ExampleGalleryConfig",
|
||||
]
|
||||
|
||||
|
||||
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
|
||||
|
||||
|
||||
class ExampleGalleryConfig(BaseConfig):
|
||||
name: Literal["example_gallery"] = "example_gallery"
|
||||
examples_per_class: int = 5
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class ExampleGallery(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
examples_per_class: int,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
):
|
||||
self.examples_per_class = examples_per_class
|
||||
self.preprocessor = preprocessor or build_preprocessor()
|
||||
self.audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
per_class_matches = group_matches(clip_evaluations)
|
||||
|
||||
for class_name, matches in per_class_matches.items():
|
||||
true_positives = get_binned_sample(
|
||||
matches.true_positives,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
false_positives = get_binned_sample(
|
||||
matches.false_positives,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
false_negatives = random.sample(
|
||||
matches.false_negatives,
|
||||
k=min(self.examples_per_class, len(matches.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers = get_binned_sample(
|
||||
matches.cross_triggers,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
fig = plot_match_gallery(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
n_examples=self.examples_per_class,
|
||||
)
|
||||
|
||||
yield f"example_gallery/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
|
||||
audio_loader = build_audio_loader(config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return cls(
|
||||
examples_per_class=config.examples_per_class,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
||||
|
||||
|
||||
class ClipEvaluationPlotConfig(BaseConfig):
|
||||
name: Literal["example_clip"] = "example_clip"
|
||||
num_plots: int = 5
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class PlotClipEvaluation(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
num_plots: int = 3,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
):
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.num_plots = num_plots
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
examples = random.sample(
|
||||
clip_evaluations,
|
||||
k=min(self.num_plots, len(clip_evaluations)),
|
||||
)
|
||||
|
||||
for index, clip_evaluation in enumerate(examples):
|
||||
fig, ax = plt.subplots()
|
||||
plot_matches(
|
||||
clip_evaluation.matches,
|
||||
clip=clip_evaluation.clip,
|
||||
audio_loader=self.audio_loader,
|
||||
ax=ax,
|
||||
)
|
||||
yield f"clip_evaluation/example_{index}", fig
|
||||
plt.close(fig)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipEvaluationPlotConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
audio_loader = build_audio_loader(config.audio)
|
||||
preprocessor = build_preprocessor(
|
||||
config.preprocessing,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
return cls(
|
||||
num_plots=config.num_plots,
|
||||
preprocessor=preprocessor,
|
||||
audio_loader=audio_loader,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
|
||||
|
||||
|
||||
class DetectionPRCurveConfig(BaseConfig):
|
||||
name: Literal["detection_pr_curve"] = "detection_pr_curve"
|
||||
|
||||
|
||||
class DetectionPRCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(recall, precision, label="Detector")
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.legend()
|
||||
|
||||
yield "detection_pr_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: DetectionPRCurveConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
|
||||
|
||||
|
||||
class ClassificationPRCurvesConfig(BaseConfig):
|
||||
name: Literal["classification_pr_curves"] = "classification_pr_curves"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationPRCurves(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_pred_class = y_pred[:, class_index]
|
||||
precision, recall, _ = metrics.precision_recall_curve(
|
||||
y_true_class,
|
||||
y_pred_class,
|
||||
)
|
||||
ax.plot(recall, precision, label=class_name)
|
||||
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
yield "classification_pr_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationPRCurvesConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names=class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
|
||||
|
||||
|
||||
class DetectionROCCurveConfig(BaseConfig):
|
||||
name: Literal["detection_roc_curve"] = "detection_roc_curve"
|
||||
|
||||
|
||||
class DetectionROCCurve(PlotterProtocol):
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true, y_score = zip(
|
||||
*[
|
||||
(match.gt_det, match.pred_score)
|
||||
for clip_eval in clip_evaluations
|
||||
for match in clip_eval.matches
|
||||
]
|
||||
)
|
||||
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
|
||||
fig, ax = plt.subplots()
|
||||
|
||||
ax.plot(fpr, tpr, label="Detection")
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
ax.legend()
|
||||
|
||||
yield "detection_roc_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: DetectionROCCurveConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
|
||||
|
||||
|
||||
class ClassificationROCCurvesConfig(BaseConfig):
|
||||
name: Literal["classification_roc_curves"] = "classification_roc_curves"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClassificationROCCurves(PlotterProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.class_names = class_names
|
||||
self.selected = class_names
|
||||
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else "__NONE__"
|
||||
)
|
||||
|
||||
y_pred.append(
|
||||
np.array(
|
||||
[
|
||||
match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
y_true = label_binarize(y_true, classes=self.class_names)
|
||||
y_pred = np.stack(y_pred)
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 10))
|
||||
for class_index, class_name in enumerate(self.class_names):
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
y_true_class = y_true[:, class_index]
|
||||
y_roced_class = y_pred[:, class_index]
|
||||
fpr, tpr, _ = metrics.roc_curve(
|
||||
y_true_class,
|
||||
y_roced_class,
|
||||
)
|
||||
ax.plot(fpr, tpr, label=class_name)
|
||||
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
yield "classification_roc_curve", fig
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClassificationROCCurvesConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
class_names=class_names,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
||||
|
||||
|
||||
class ConfusionMatrixConfig(BaseConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
background_class: str = "noise"
|
||||
|
||||
|
||||
class ConfusionMatrix(PlotterProtocol):
|
||||
def __init__(self, background_class: str, class_names: List[str]):
|
||||
self.background_class = background_class
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else self.background_class
|
||||
)
|
||||
|
||||
top_class = match.top_class
|
||||
y_pred.append(
|
||||
top_class
|
||||
if top_class is not None
|
||||
else self.background_class
|
||||
)
|
||||
|
||||
display = metrics.ConfusionMatrixDisplay.from_predictions(
|
||||
y_true,
|
||||
y_pred,
|
||||
labels=[*self.class_names, self.background_class],
|
||||
)
|
||||
|
||||
yield "confusion_matrix", display.figure_
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ConfusionMatrixConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
background_class=config.background_class,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
|
||||
|
||||
|
||||
PlotConfig = Annotated[
|
||||
Union[
|
||||
ExampleGalleryConfig,
|
||||
ClipEvaluationPlotConfig,
|
||||
DetectionPRCurveConfig,
|
||||
ClassificationPRCurvesConfig,
|
||||
DetectionROCCurveConfig,
|
||||
ClassificationROCCurvesConfig,
|
||||
ConfusionMatrixConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_plotter(
|
||||
config: PlotConfig, class_names: List[str]
|
||||
) -> PlotterProtocol:
|
||||
return plots_registry.build(config, class_names)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassMatches:
|
||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> Dict[str, ClassMatches]:
|
||||
class_examples = defaultdict(ClassMatches)
|
||||
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_class = match.gt_class
|
||||
pred_class = match.top_class
|
||||
|
||||
if pred_class is None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[gt_class].cross_triggers.append(match)
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
return class_examples
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[
|
||||
(index, match.pred_class_scores[pred_class])
|
||||
for index, match in enumerate(matches)
|
||||
if (pred_class := match.top_class) is not None
|
||||
]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").sample(1)
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
@ -11,7 +11,7 @@ class BasePlotConfig(BaseConfig):
|
||||
label: str = "plot"
|
||||
theme: str = "default"
|
||||
title: Optional[str] = None
|
||||
figsize: tuple[int, int] = (5, 5)
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
dpi: int = 100
|
||||
|
||||
|
||||
@ -20,7 +20,7 @@ class BasePlot:
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
label: str = "plot",
|
||||
figsize: tuple[int, int] = (5, 5),
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
title: Optional[str] = None,
|
||||
dpi: int = 100,
|
||||
theme: str = "default",
|
||||
@ -32,7 +32,7 @@ class BasePlot:
|
||||
self.theme = theme
|
||||
self.title = title
|
||||
|
||||
def get_figure(self) -> Figure:
|
||||
def create_figure(self) -> Figure:
|
||||
plt.style.use(self.theme)
|
||||
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
||||
|
||||
|
||||
@ -1,5 +1,15 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
@ -12,14 +22,20 @@ from batdetect2.evaluate.metrics.classification import (
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import (
|
||||
plot_pr_curve,
|
||||
plot_pr_curves,
|
||||
plot_roc_curve,
|
||||
plot_roc_curves,
|
||||
plot_threshold_precision_curve,
|
||||
plot_threshold_precision_curves,
|
||||
plot_threshold_recall_curve,
|
||||
plot_threshold_recall_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
ClassificationPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
]
|
||||
|
||||
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||
Registry("classification_plot")
|
||||
@ -29,8 +45,10 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Classification Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
@ -39,25 +57,24 @@ class PRCurve(BasePlot):
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
@ -67,9 +84,23 @@ class PRCurve(BasePlot):
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
plot_pr_curves(data, ax=ax)
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curves(data, ax=ax)
|
||||
yield self.label, fig
|
||||
return
|
||||
|
||||
return self.label, fig
|
||||
for class_name, (precision, recall, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
@ -79,33 +110,37 @@ class PRCurve(BasePlot):
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ThresholdPRCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_pr_curve"] = "threshold_pr_curve"
|
||||
label: str = "threshold_pr_curve"
|
||||
figsize: tuple[int, int] = (10, 5)
|
||||
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||
label: str = "threshold_precision_curve"
|
||||
title: Optional[str] = "Classification Threshold-Precision Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ThresholdPRCurve(BasePlot):
|
||||
class ThresholdPrecisionCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
@ -121,30 +156,135 @@ class ThresholdPRCurve(BasePlot):
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
fig = self.get_figure()
|
||||
ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True)
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_threshold_precision_curves(data, ax=ax1, add_legend=False)
|
||||
plot_threshold_recall_curves(data, ax=ax2, add_legend=True)
|
||||
plot_threshold_precision_curves(data, ax=ax)
|
||||
|
||||
return self.label, fig
|
||||
yield self.label, fig
|
||||
|
||||
@classification_plots.register(ThresholdPRCurveConfig)
|
||||
return
|
||||
|
||||
for class_name, (precision, _, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_threshold_precision_curve(
|
||||
thresholds,
|
||||
precision,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(ThresholdPrecisionCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol):
|
||||
return ThresholdPRCurve.build(
|
||||
def from_config(
|
||||
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
|
||||
):
|
||||
return ThresholdPrecisionCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||
label: str = "threshold_recall_curve"
|
||||
title: Optional[str] = "Classification Threshold-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ThresholdRecallCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (_, recall, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_threshold_recall_curve(
|
||||
thresholds,
|
||||
recall,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(ThresholdRecallCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ThresholdRecallCurveConfig, targets: TargetProtocol
|
||||
):
|
||||
return ThresholdRecallCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Classification ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
@ -153,16 +293,18 @@ class ROCCurve(BasePlot):
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
@ -177,12 +319,26 @@ class ROCCurve(BasePlot):
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_roc_curves(data, ax=ax)
|
||||
plot_roc_curves(data, ax=ax)
|
||||
|
||||
return self.label, fig
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
@ -192,6 +348,7 @@ class ROCCurve(BasePlot):
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
@ -199,7 +356,8 @@ ClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ThresholdPRCurveConfig,
|
||||
ThresholdPrecisionCurveConfig,
|
||||
ThresholdRecallCurveConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -8,6 +9,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
@ -17,7 +19,9 @@ from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import (
|
||||
plot_pr_curve,
|
||||
plot_pr_curves,
|
||||
plot_roc_curve,
|
||||
plot_roc_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
@ -28,7 +32,9 @@ __all__ = [
|
||||
"build_clip_classification_plotter",
|
||||
]
|
||||
|
||||
ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
ClipClassificationPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
]
|
||||
|
||||
clip_classification_plots: Registry[
|
||||
ClipClassificationPlotter, [TargetProtocol]
|
||||
@ -38,14 +44,24 @@ clip_classification_plots: Registry[
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Precision-Recall Curve"
|
||||
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
data = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
@ -61,10 +77,26 @@ class PRCurve(BasePlot):
|
||||
|
||||
data[class_name] = (precision, recall, thresholds)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curves(data, ax=ax)
|
||||
return self.label, fig
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_pr_curves(data, ax=ax)
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (precision, recall, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@clip_classification_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
@ -72,20 +104,31 @@ class PRCurve(BasePlot):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "ROC Curve"
|
||||
title: Optional[str] = "Clip Classification ROC Curve"
|
||||
separate_figures: bool = False
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
separate_figures: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.separate_figures = separate_figures
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
data = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
@ -101,10 +144,24 @@ class ROCCurve(BasePlot):
|
||||
|
||||
data[class_name] = (fpr, tpr, thresholds)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curves(data, ax=ax)
|
||||
return self.label, fig
|
||||
if not self.separate_figures:
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curves(data, ax=ax)
|
||||
yield self.label, fig
|
||||
|
||||
return
|
||||
|
||||
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@clip_classification_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
@ -112,6 +169,7 @@ class ROCCurve(BasePlot):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
separate_figures=config.separate_figures,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
@ -27,7 +28,9 @@ __all__ = [
|
||||
"build_clip_detection_plotter",
|
||||
]
|
||||
|
||||
ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
ClipDetectionPlotter = Callable[
|
||||
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||
]
|
||||
|
||||
|
||||
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||
@ -38,14 +41,14 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Precision-Recall Curve"
|
||||
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
@ -54,10 +57,10 @@ class PRCurve(BasePlot):
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
yield self.label, fig
|
||||
|
||||
@clip_detection_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
@ -71,14 +74,14 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "ROC Curve"
|
||||
title: Optional[str] = "Clip Detection ROC Curve"
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
@ -87,10 +90,10 @@ class ROCCurve(BasePlot):
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
yield self.label, fig
|
||||
|
||||
@clip_detection_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
@ -104,18 +107,18 @@ class ROCCurve(BasePlot):
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Score Distribution"
|
||||
title: Optional[str] = "Clip Detection Score Distribution"
|
||||
|
||||
|
||||
class ScoreDistributionPlot(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||
@ -130,7 +133,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
common_norm=False,
|
||||
)
|
||||
|
||||
return self.label, fig
|
||||
yield self.label, fig
|
||||
|
||||
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
||||
@staticmethod
|
||||
|
||||
@ -1,26 +1,33 @@
|
||||
import random
|
||||
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Iterable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib import patches
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent.plot import plot_geometry
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.detections import plot_clip_detections
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
|
||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
|
||||
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||
name="detection_plot"
|
||||
@ -30,6 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Detection Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -49,7 +57,7 @@ class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evals: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
@ -71,10 +79,12 @@ class PRCurve(BasePlot):
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@detection_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
@ -90,6 +100,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Detection ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -109,7 +120,7 @@ class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
@ -127,10 +138,12 @@ class ROCCurve(BasePlot):
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@detection_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
@ -146,6 +159,7 @@ class ROCCurve(BasePlot):
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Detection Score Distribution"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -165,7 +179,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
@ -180,7 +194,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
|
||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
sns.histplot(
|
||||
@ -194,7 +208,7 @@ class ScoreDistributionPlot(BasePlot):
|
||||
common_norm=False,
|
||||
)
|
||||
|
||||
return self.label, fig
|
||||
yield self.label, fig
|
||||
|
||||
@detection_plots.register(ScoreDistributionPlotConfig)
|
||||
@staticmethod
|
||||
@ -212,7 +226,8 @@ class ScoreDistributionPlot(BasePlot):
|
||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_detection"] = "example_detection"
|
||||
label: str = "example_detection"
|
||||
figsize: tuple[int, int] = (10, 15)
|
||||
title: Optional[str] = "Example Detection"
|
||||
figsize: tuple[int, int] = (10, 4)
|
||||
num_examples: int = 5
|
||||
threshold: float = 0.2
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
@ -240,82 +255,26 @@ class ExampleDetectionPlot(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
fig = self.get_figure()
|
||||
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
sample = clip_evaluations
|
||||
|
||||
if self.num_examples < len(sample):
|
||||
sample = random.sample(sample, self.num_examples)
|
||||
|
||||
axes = fig.subplots(nrows=self.num_examples, ncols=1)
|
||||
for num_example, clip_eval in enumerate(sample):
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
for ax, clip_eval in zip(axes, sample):
|
||||
plot_clip(
|
||||
clip_eval.clip,
|
||||
plot_clip_detections(
|
||||
clip_eval,
|
||||
ax=ax,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
for m in clip_eval.matches:
|
||||
is_match = (
|
||||
m.pred is not None
|
||||
and m.gt is not None
|
||||
and m.score >= self.threshold
|
||||
)
|
||||
yield f"{self.label}/example_{num_example}", fig
|
||||
|
||||
if m.pred is not None:
|
||||
plot_geometry(
|
||||
m.pred.geometry,
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
facecolor="none",
|
||||
alpha=m.pred.detection_score,
|
||||
linestyle="-" if not is_match else "--",
|
||||
color="red" if not is_match else "orange",
|
||||
)
|
||||
|
||||
if m.gt is not None:
|
||||
plot_geometry(
|
||||
m.gt.sound_event.geometry, # type: ignore
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
facecolor="none",
|
||||
color="green" if not is_match else "orange",
|
||||
)
|
||||
|
||||
ax.set_title(clip_eval.clip.recording.path.name)
|
||||
|
||||
# ax.legend(
|
||||
# handles=[
|
||||
# patches.Patch(
|
||||
# edgecolor="green",
|
||||
# label="Ground Truth (Unmatched)",
|
||||
# facecolor="none",
|
||||
# ),
|
||||
# patches.Patch(
|
||||
# edgecolor="orange",
|
||||
# label="Ground Truth (Matched)",
|
||||
# facecolor="none",
|
||||
# ),
|
||||
# patches.Patch(
|
||||
# edgecolor="red",
|
||||
# label="Detection (Unmatched)",
|
||||
# facecolor="none",
|
||||
# ),
|
||||
# patches.Patch(
|
||||
# edgecolor="orange",
|
||||
# label="Detection (Matched)",
|
||||
# facecolor="none",
|
||||
# linestyle="--",
|
||||
# ),
|
||||
# ]
|
||||
# )
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
return self.label, fig
|
||||
plt.close(fig)
|
||||
|
||||
@detection_plots.register(ExampleDetectionPlotConfig)
|
||||
@staticmethod
|
||||
|
||||
@ -1,17 +1,36 @@
|
||||
from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.top_class import ClipEval
|
||||
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||
|
||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
name="top_class_plot"
|
||||
@ -21,6 +40,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Top Class Precision-Recall Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -40,7 +60,7 @@ class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
@ -66,10 +86,12 @@ class PRCurve(BasePlot):
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@top_class_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
@ -85,6 +107,7 @@ class PRCurve(BasePlot):
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "Top Class ROC Curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
@ -104,7 +127,7 @@ class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
@ -126,10 +149,12 @@ class ROCCurve(BasePlot):
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
yield self.label, fig
|
||||
|
||||
@top_class_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
@ -144,6 +169,7 @@ class ROCCurve(BasePlot):
|
||||
|
||||
class ConfusionMatrixConfig(BasePlotConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
title: Optional[str] = "Top Class Confusion Matrix"
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
label: str = "confusion_matrix"
|
||||
exclude_generic: bool = True
|
||||
@ -180,7 +206,7 @@ class ConfusionMatrix(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
y_true: List[str] = []
|
||||
y_pred: List[str] = []
|
||||
|
||||
@ -213,7 +239,7 @@ class ConfusionMatrix(BasePlot):
|
||||
y_true.append(true_class or self.noise_class)
|
||||
y_pred.append(pred_class or self.noise_class)
|
||||
|
||||
fig = self.get_figure()
|
||||
fig = self.create_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
class_names = [*self.targets.class_names]
|
||||
@ -236,7 +262,7 @@ class ConfusionMatrix(BasePlot):
|
||||
values_format=".2f",
|
||||
)
|
||||
|
||||
return self.label, fig
|
||||
yield self.label, fig
|
||||
|
||||
@top_class_plots.register(ConfusionMatrixConfig)
|
||||
@staticmethod
|
||||
@ -253,11 +279,105 @@ class ConfusionMatrix(BasePlot):
|
||||
)
|
||||
|
||||
|
||||
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_classification"] = "example_classification"
|
||||
label: str = "example_classification"
|
||||
title: Optional[str] = "Example Classification"
|
||||
num_examples: int = 4
|
||||
threshold: float = 0.2
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class ExampleClassificationPlot(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
num_examples: int = 4,
|
||||
threshold: float = 0.2,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_examples = num_examples
|
||||
self.audio_loader = audio_loader
|
||||
self.threshold = threshold
|
||||
self.preprocessor = preprocessor
|
||||
self.num_examples = num_examples
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||
|
||||
for class_name, matches in grouped.items():
|
||||
true_positives: List[MatchEval] = get_binned_sample(
|
||||
matches.true_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_positives: List[MatchEval] = get_binned_sample(
|
||||
matches.false_positives,
|
||||
n_examples=self.num_examples,
|
||||
)
|
||||
|
||||
false_negatives: List[MatchEval] = random.sample(
|
||||
matches.false_negatives,
|
||||
k=min(self.num_examples, len(matches.false_negatives)),
|
||||
)
|
||||
|
||||
cross_triggers: List[MatchEval] = get_binned_sample(
|
||||
matches.cross_triggers, n_examples=self.num_examples
|
||||
)
|
||||
|
||||
fig = self.create_figure()
|
||||
|
||||
fig = plot_match_gallery(
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
cross_triggers,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
n_examples=self.num_examples,
|
||||
fig=fig,
|
||||
)
|
||||
|
||||
if self.title is not None:
|
||||
fig.suptitle(f"{self.title}: {class_name}")
|
||||
else:
|
||||
fig.suptitle(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@top_class_plots.register(ExampleClassificationPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ExampleClassificationPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
return ExampleClassificationPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
num_examples=config.num_examples,
|
||||
threshold=config.threshold,
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
preprocessor=build_preprocessor(config.preprocessing),
|
||||
)
|
||||
|
||||
|
||||
TopClassPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ConfusionMatrixConfig,
|
||||
ExampleClassificationPlotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
@ -268,3 +388,57 @@ def build_top_class_plotter(
|
||||
targets: TargetProtocol,
|
||||
) -> TopClassPlotter:
|
||||
return top_class_plots.build(config, targets)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassMatches:
|
||||
false_positives: List[MatchEval] = field(default_factory=list)
|
||||
false_negatives: List[MatchEval] = field(default_factory=list)
|
||||
true_positives: List[MatchEval] = field(default_factory=list)
|
||||
cross_triggers: List[MatchEval] = field(default_factory=list)
|
||||
|
||||
|
||||
def group_matches(
|
||||
clip_evals: Sequence[ClipEval],
|
||||
threshold: float = 0.2,
|
||||
) -> Dict[str, ClassMatches]:
|
||||
class_examples = defaultdict(ClassMatches)
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for match in clip_eval.matches:
|
||||
gt_class = match.true_class
|
||||
pred_class = match.pred_class
|
||||
is_pred = match.score >= threshold
|
||||
|
||||
if not is_pred and gt_class is not None:
|
||||
class_examples[gt_class].false_negatives.append(match)
|
||||
continue
|
||||
|
||||
if not is_pred:
|
||||
continue
|
||||
|
||||
if gt_class is None:
|
||||
class_examples[pred_class].false_positives.append(match)
|
||||
continue
|
||||
|
||||
if gt_class != pred_class:
|
||||
class_examples[pred_class].cross_triggers.append(match)
|
||||
continue
|
||||
|
||||
class_examples[gt_class].true_positives.append(match)
|
||||
|
||||
return class_examples
|
||||
|
||||
|
||||
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||
if len(matches) < n_examples:
|
||||
return matches
|
||||
|
||||
indices, pred_scores = zip(
|
||||
*[(index, match.score) for index, match in enumerate(matches)]
|
||||
)
|
||||
|
||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||
sample = df.groupby("bins").sample(1)
|
||||
return [matches[ind] for ind in sample["indices"]]
|
||||
|
||||
@ -54,7 +54,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||
|
||||
plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
@ -68,7 +68,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
prefix: str,
|
||||
ignore_start_end: float = 0.01,
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
):
|
||||
self.matcher = matcher
|
||||
@ -93,7 +93,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
self, eval_outputs: List[T_Output]
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for plot in self.plots:
|
||||
yield plot(eval_outputs)
|
||||
for name, fig in plot(eval_outputs):
|
||||
yield f"{self.prefix}/{name}", fig
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
@ -147,7 +148,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||
] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@ -98,6 +98,7 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
clip=clip,
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
|
||||
@ -79,6 +79,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
clip=clip,
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_ground_truth=gt is not None,
|
||||
|
||||
@ -11,7 +11,6 @@ from batdetect2.plotting.matches import (
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
plot_false_positive_match,
|
||||
plot_matches,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
|
||||
@ -22,7 +21,6 @@ __all__ = [
|
||||
"plot_cross_trigger_match",
|
||||
"plot_false_negative_match",
|
||||
"plot_false_positive_match",
|
||||
"plot_matches",
|
||||
"plot_spectrogram",
|
||||
"plot_true_positive_match",
|
||||
"plot_detection_heatmap",
|
||||
|
||||
@ -66,6 +66,9 @@ def plot_spectrogram(
|
||||
vmax=vmax,
|
||||
)
|
||||
|
||||
ax.set_xlim(start_time, end_time)
|
||||
ax.set_ylim(min_freq, max_freq)
|
||||
|
||||
if add_colorbar:
|
||||
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
||||
|
||||
|
||||
113
src/batdetect2/plotting/detections.py
Normal file
113
src/batdetect2/plotting/detections.py
Normal file
@ -0,0 +1,113 @@
|
||||
from typing import Optional
|
||||
|
||||
from matplotlib import axes, patches
|
||||
from soundevent.plot import plot_geometry
|
||||
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
from batdetect2.plotting.clips import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
plot_clip,
|
||||
)
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
__all__ = [
|
||||
"plot_clip_detections",
|
||||
]
|
||||
|
||||
|
||||
def plot_clip_detections(
|
||||
clip_eval: ClipEval,
|
||||
figsize: tuple[int, int] = (10, 10),
|
||||
ax: Optional[axes.Axes] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
threshold: float = 0.2,
|
||||
add_legend: bool = True,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
linewidth: float = 1.0,
|
||||
gt_color: str = "green",
|
||||
gt_linestyle: str = "-",
|
||||
true_pred_color: str = "yellow",
|
||||
true_pred_linestyle: str = "--",
|
||||
false_pred_color: str = "blue",
|
||||
false_pred_linestyle: str = "-",
|
||||
missed_gt_color: str = "red",
|
||||
missed_gt_linestyle: str = "-",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(figsize=figsize, ax=ax)
|
||||
|
||||
plot_clip(
|
||||
clip_eval.clip,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
for m in clip_eval.matches:
|
||||
is_match = (
|
||||
m.pred is not None and m.gt is not None and m.score >= threshold
|
||||
)
|
||||
|
||||
if m.pred is not None:
|
||||
color = true_pred_color if is_match else false_pred_color
|
||||
plot_geometry(
|
||||
m.pred.geometry,
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
facecolor="none" if not fill else color,
|
||||
alpha=m.pred.detection_score,
|
||||
linewidth=linewidth,
|
||||
linestyle=true_pred_linestyle
|
||||
if is_match
|
||||
else missed_gt_linestyle,
|
||||
color=color,
|
||||
)
|
||||
|
||||
if m.gt is not None:
|
||||
color = gt_color if is_match else missed_gt_color
|
||||
plot_geometry(
|
||||
m.gt.sound_event.geometry, # type: ignore
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
linewidth=linewidth,
|
||||
facecolor="none" if not fill else color,
|
||||
linestyle=gt_linestyle if is_match else false_pred_linestyle,
|
||||
color=color,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title(clip_eval.clip.recording.path.name)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
handles=[
|
||||
patches.Patch(
|
||||
label="found GT",
|
||||
edgecolor=gt_color,
|
||||
facecolor="none" if not fill else gt_color,
|
||||
linestyle=gt_linestyle,
|
||||
),
|
||||
patches.Patch(
|
||||
label="missed GT",
|
||||
edgecolor=missed_gt_color,
|
||||
facecolor="none" if not fill else missed_gt_color,
|
||||
linestyle=missed_gt_linestyle,
|
||||
),
|
||||
patches.Patch(
|
||||
label="true Det",
|
||||
edgecolor=true_pred_color,
|
||||
facecolor="none" if not fill else true_pred_color,
|
||||
linestyle=true_pred_linestyle,
|
||||
),
|
||||
patches.Patch(
|
||||
label="false Det",
|
||||
edgecolor=false_pred_color,
|
||||
facecolor="none" if not fill else false_pred_color,
|
||||
linestyle=false_pred_linestyle,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
return ax
|
||||
@ -1,81 +1,109 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional, Sequence
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from batdetect2.plotting.matches import (
|
||||
MatchProtocol,
|
||||
plot_cross_trigger_match,
|
||||
plot_false_negative_match,
|
||||
plot_false_positive_match,
|
||||
plot_true_positive_match,
|
||||
)
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = ["plot_match_gallery"]
|
||||
|
||||
|
||||
def plot_match_gallery(
|
||||
true_positives: List[MatchEvaluation],
|
||||
false_positives: List[MatchEvaluation],
|
||||
false_negatives: List[MatchEvaluation],
|
||||
cross_triggers: List[MatchEvaluation],
|
||||
true_positives: Sequence[MatchProtocol],
|
||||
false_positives: Sequence[MatchProtocol],
|
||||
false_negatives: Sequence[MatchProtocol],
|
||||
cross_triggers: Sequence[MatchProtocol],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
n_examples: int = 5,
|
||||
duration: float = 0.1,
|
||||
fig: Optional[Figure] = None,
|
||||
):
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
if fig is None:
|
||||
fig = plt.figure(figsize=(20, 20))
|
||||
|
||||
for index, match in enumerate(true_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, index + 1)
|
||||
axes = fig.subplots(
|
||||
nrows=4,
|
||||
ncols=n_examples,
|
||||
sharex="none",
|
||||
sharey="row",
|
||||
)
|
||||
|
||||
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
|
||||
try:
|
||||
plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
tp_match,
|
||||
ax=tp_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
||||
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
|
||||
try:
|
||||
plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fp_match,
|
||||
ax=fp_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
||||
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
|
||||
try:
|
||||
plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fn_match,
|
||||
ax=fn_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
|
||||
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
|
||||
try:
|
||||
plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
ct_match,
|
||||
ax=ct_ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
except (
|
||||
ValueError,
|
||||
AssertionError,
|
||||
RuntimeError,
|
||||
FileNotFoundError,
|
||||
):
|
||||
continue
|
||||
|
||||
fig.tight_layout()
|
||||
|
||||
return fig
|
||||
|
||||
@ -1,16 +1,17 @@
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import Optional, Protocol, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.axes import Axes
|
||||
from soundevent import data, plot
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.plot.tags import TagColorMapper
|
||||
|
||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
||||
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"plot_matches",
|
||||
"plot_false_positive_match",
|
||||
"plot_true_positive_match",
|
||||
"plot_false_negative_match",
|
||||
@ -18,6 +19,14 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class MatchProtocol(Protocol):
|
||||
clip: data.Clip
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
score: float
|
||||
true_class: Optional[str]
|
||||
|
||||
|
||||
DEFAULT_DURATION = 0.05
|
||||
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
||||
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
||||
@ -27,88 +36,8 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
|
||||
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||
|
||||
|
||||
def plot_matches(
|
||||
matches: List[MatchEvaluation],
|
||||
clip: data.Clip,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
ax: Optional[Axes] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
color_mapper: Optional[TagColorMapper] = None,
|
||||
add_points: bool = False,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
||||
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
) -> Axes:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
if color_mapper is None:
|
||||
color_mapper = TagColorMapper()
|
||||
|
||||
for match in matches:
|
||||
if match.is_cross_trigger():
|
||||
plot_cross_trigger_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_points=add_points,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
color=cross_trigger_color,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.is_true_positive():
|
||||
plot_true_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
add_points=add_points,
|
||||
color=true_positive_color,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.is_false_negative():
|
||||
plot_false_negative_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
add_points=add_points,
|
||||
color=false_negative_color,
|
||||
add_text=False,
|
||||
)
|
||||
elif match.is_false_positive:
|
||||
plot_false_positive_match(
|
||||
match,
|
||||
ax=ax,
|
||||
fill=fill,
|
||||
add_spectrogram=False,
|
||||
use_score=True,
|
||||
add_points=add_points,
|
||||
color=false_positive_color,
|
||||
add_text=False,
|
||||
)
|
||||
else:
|
||||
continue
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_false_positive_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
@ -119,21 +48,24 @@ def plot_false_positive_match(
|
||||
add_spectrogram: bool = True,
|
||||
add_text: bool = True,
|
||||
add_points: bool = False,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
) -> Axes:
|
||||
assert match.pred_geometry is not None
|
||||
assert match.sound_event_annotation is None
|
||||
assert match.pred is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
|
||||
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
match.clip.end_time,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
@ -150,30 +82,33 @@ def plot_false_positive_match(
|
||||
)
|
||||
|
||||
ax = plot.plot_geometry(
|
||||
match.pred_geometry,
|
||||
match.pred.geometry,
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
alpha=match.score if use_score else 1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
if add_text:
|
||||
plt.text(
|
||||
ax.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.top_class} \nTop Class Score: {match.top_class_score:.2f} ",
|
||||
f"score={match.score:.2f}",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title("False Positive")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_false_negative_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
@ -182,26 +117,28 @@ def plot_false_negative_match(
|
||||
duration: float = DEFAULT_DURATION,
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||
fontsize: Union[float, str] = "small",
|
||||
) -> Axes:
|
||||
assert match.pred_geometry is None
|
||||
assert match.sound_event_annotation is not None
|
||||
sound_event = match.sound_event_annotation.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert match.gt is not None
|
||||
|
||||
geometry = match.gt.sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
if add_spectrogram:
|
||||
@ -215,33 +152,23 @@ def plot_false_negative_match(
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax = plot.plot_geometry(
|
||||
geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
color=color,
|
||||
)
|
||||
|
||||
if add_text:
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Negative \nClass: {match.gt_class} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
if add_title:
|
||||
ax.set_title("False Negative")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_true_positive_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
@ -258,39 +185,42 @@ def plot_true_positive_match(
|
||||
fontsize: Union[float, str] = "small",
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
add_title: bool = True,
|
||||
) -> Axes:
|
||||
assert match.sound_event_annotation is not None
|
||||
assert match.pred_geometry is not None
|
||||
sound_event = match.sound_event_annotation.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert match.gt is not None
|
||||
assert match.pred is not None
|
||||
|
||||
geometry = match.gt.sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
if add_spectrogram:
|
||||
ax = plot_clip(
|
||||
clip,
|
||||
ax=ax,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
figsize=figsize,
|
||||
ax=ax,
|
||||
audio_dir=audio_dir,
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax = plot.plot_geometry(
|
||||
geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
@ -299,31 +229,34 @@ def plot_true_positive_match(
|
||||
)
|
||||
|
||||
plot.plot_geometry(
|
||||
match.pred_geometry,
|
||||
match.pred.geometry,
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
alpha=match.score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
if add_text:
|
||||
plt.text(
|
||||
ax.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
|
||||
f"score={match.score:.2f}",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title("True Positive")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_cross_trigger_match(
|
||||
match: MatchEvaluation,
|
||||
match: MatchProtocol,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
@ -334,6 +267,7 @@ def plot_cross_trigger_match(
|
||||
add_spectrogram: bool = True,
|
||||
add_points: bool = False,
|
||||
add_text: bool = True,
|
||||
add_title: bool = True,
|
||||
fill: bool = False,
|
||||
spec_cmap: str = "gray",
|
||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||
@ -341,20 +275,24 @@ def plot_cross_trigger_match(
|
||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||
) -> Axes:
|
||||
assert match.sound_event_annotation is not None
|
||||
assert match.pred_geometry is not None
|
||||
sound_event = match.sound_event_annotation.sound_event
|
||||
geometry = sound_event.geometry
|
||||
assert match.gt is not None
|
||||
assert match.pred is not None
|
||||
|
||||
geometry = match.gt.sound_event.geometry
|
||||
assert geometry is not None
|
||||
|
||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
clip = data.Clip(
|
||||
start_time=max(start_time - duration / 2, 0),
|
||||
end_time=min(
|
||||
start_time + duration / 2, sound_event.recording.duration
|
||||
start_time=max(
|
||||
start_time - duration / 2,
|
||||
0,
|
||||
),
|
||||
recording=sound_event.recording,
|
||||
end_time=min(
|
||||
start_time + duration / 2,
|
||||
match.clip.recording.duration,
|
||||
),
|
||||
recording=match.clip.recording,
|
||||
)
|
||||
|
||||
if add_spectrogram:
|
||||
@ -368,11 +306,9 @@ def plot_cross_trigger_match(
|
||||
spec_cmap=spec_cmap,
|
||||
)
|
||||
|
||||
ax = plot.plot_annotation(
|
||||
match.sound_event_annotation,
|
||||
ax = plot.plot_geometry(
|
||||
geometry,
|
||||
ax=ax,
|
||||
time_offset=0.001,
|
||||
freq_offset=2_000,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=1,
|
||||
@ -381,24 +317,28 @@ def plot_cross_trigger_match(
|
||||
)
|
||||
|
||||
ax = plot.plot_geometry(
|
||||
match.pred_geometry,
|
||||
match.pred.geometry,
|
||||
ax=ax,
|
||||
add_points=add_points,
|
||||
facecolor="none" if not fill else None,
|
||||
alpha=match.pred_score if use_score else 1,
|
||||
alpha=match.score if use_score else 1,
|
||||
color=color,
|
||||
linestyle=prediction_linestyle,
|
||||
)
|
||||
|
||||
if add_text:
|
||||
plt.text(
|
||||
ax.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.top_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
|
||||
f"score={match.score:.2f}\nclass={match.true_class}",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
fontsize=fontsize,
|
||||
)
|
||||
|
||||
if add_title:
|
||||
ax.set_title("Cross Trigger")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
@ -35,6 +35,8 @@ def plot_pr_curve(
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
add_legend: bool = False,
|
||||
label: str = "PR Curve",
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
@ -43,7 +45,7 @@ def plot_pr_curve(
|
||||
ax.plot(
|
||||
recall,
|
||||
precision,
|
||||
label="PR Curve",
|
||||
label=label,
|
||||
marker="o",
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
@ -51,6 +53,9 @@ def plot_pr_curve(
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_legend:
|
||||
ax.legend()
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user