mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Re-org gallery example plots
This commit is contained in:
parent
87ed44c8f7
commit
10865ee600
@ -30,6 +30,7 @@ __all__ = [
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
|
clip: data.Clip
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
pred: Optional[RawPrediction]
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,7 @@ __all__ = [
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MatchEval:
|
class MatchEval:
|
||||||
|
clip: data.Clip
|
||||||
gt: Optional[data.SoundEventAnnotation]
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
pred: Optional[RawPrediction]
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
|||||||
@ -1,560 +0,0 @@
|
|||||||
import random
|
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
from sklearn.preprocessing import label_binarize
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
|
||||||
from batdetect2.plotting.matches import plot_matches
|
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
|
||||||
from batdetect2.typing import (
|
|
||||||
AudioLoader,
|
|
||||||
ClipMatches,
|
|
||||||
MatchEvaluation,
|
|
||||||
PlotterProtocol,
|
|
||||||
PreprocessorProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"build_plotter",
|
|
||||||
"ExampleGallery",
|
|
||||||
"ExampleGalleryConfig",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry: Registry[PlotterProtocol, [List[str]]] = Registry("plot")
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleGalleryConfig(BaseConfig):
|
|
||||||
name: Literal["example_gallery"] = "example_gallery"
|
|
||||||
examples_per_class: int = 5
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ExampleGallery(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
examples_per_class: int,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
):
|
|
||||||
self.examples_per_class = examples_per_class
|
|
||||||
self.preprocessor = preprocessor or build_preprocessor()
|
|
||||||
self.audio_loader = audio_loader or build_audio_loader()
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
|
||||||
per_class_matches = group_matches(clip_evaluations)
|
|
||||||
|
|
||||||
for class_name, matches in per_class_matches.items():
|
|
||||||
true_positives = get_binned_sample(
|
|
||||||
matches.true_positives,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
false_positives = get_binned_sample(
|
|
||||||
matches.false_positives,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
false_negatives = random.sample(
|
|
||||||
matches.false_negatives,
|
|
||||||
k=min(self.examples_per_class, len(matches.false_negatives)),
|
|
||||||
)
|
|
||||||
|
|
||||||
cross_triggers = get_binned_sample(
|
|
||||||
matches.cross_triggers,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
fig = plot_match_gallery(
|
|
||||||
true_positives,
|
|
||||||
false_positives,
|
|
||||||
false_negatives,
|
|
||||||
cross_triggers,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
n_examples=self.examples_per_class,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield f"example_gallery/{class_name}", fig
|
|
||||||
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(cls, config: ExampleGalleryConfig, class_names: List[str]):
|
|
||||||
audio_loader = build_audio_loader(config.audio)
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
config.preprocessing,
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
examples_per_class=config.examples_per_class,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ExampleGalleryConfig, ExampleGallery)
|
|
||||||
|
|
||||||
|
|
||||||
class ClipEvaluationPlotConfig(BaseConfig):
|
|
||||||
name: Literal["example_clip"] = "example_clip"
|
|
||||||
num_plots: int = 5
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PlotClipEvaluation(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
num_plots: int = 3,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
):
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.audio_loader = audio_loader
|
|
||||||
self.num_plots = num_plots
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
|
||||||
examples = random.sample(
|
|
||||||
clip_evaluations,
|
|
||||||
k=min(self.num_plots, len(clip_evaluations)),
|
|
||||||
)
|
|
||||||
|
|
||||||
for index, clip_evaluation in enumerate(examples):
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
plot_matches(
|
|
||||||
clip_evaluation.matches,
|
|
||||||
clip=clip_evaluation.clip,
|
|
||||||
audio_loader=self.audio_loader,
|
|
||||||
ax=ax,
|
|
||||||
)
|
|
||||||
yield f"clip_evaluation/example_{index}", fig
|
|
||||||
plt.close(fig)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClipEvaluationPlotConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
audio_loader = build_audio_loader(config.audio)
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
config.preprocessing,
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
return cls(
|
|
||||||
num_plots=config.num_plots,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ClipEvaluationPlotConfig, PlotClipEvaluation)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPRCurveConfig(BaseConfig):
|
|
||||||
name: Literal["detection_pr_curve"] = "detection_pr_curve"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPRCurve(PlotterProtocol):
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
|
||||||
y_true, y_score = zip(
|
|
||||||
*[
|
|
||||||
(match.gt_det, match.pred_score)
|
|
||||||
for clip_eval in clip_evaluations
|
|
||||||
for match in clip_eval.matches
|
|
||||||
]
|
|
||||||
)
|
|
||||||
precision, recall, _ = metrics.precision_recall_curve(y_true, y_score)
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
|
|
||||||
ax.plot(recall, precision, label="Detector")
|
|
||||||
ax.set_xlabel("Recall")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
yield "detection_pr_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: DetectionPRCurveConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(DetectionPRCurveConfig, DetectionPRCurve)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationPRCurvesConfig(BaseConfig):
|
|
||||||
name: Literal["classification_pr_curves"] = "classification_pr_curves"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationPRCurves(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.class_names = class_names
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else "__NONE__"
|
|
||||||
)
|
|
||||||
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
match.pred_class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = label_binarize(y_true, classes=self.class_names)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 10))
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
if class_name not in self.selected:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_pred_class = y_pred[:, class_index]
|
|
||||||
precision, recall, _ = metrics.precision_recall_curve(
|
|
||||||
y_true_class,
|
|
||||||
y_pred_class,
|
|
||||||
)
|
|
||||||
ax.plot(recall, precision, label=class_name)
|
|
||||||
|
|
||||||
ax.set_xlabel("Recall")
|
|
||||||
ax.set_ylabel("Precision")
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "classification_pr_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClassificationPRCurvesConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
class_names=class_names,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ClassificationPRCurvesConfig, ClassificationPRCurves)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCCurveConfig(BaseConfig):
|
|
||||||
name: Literal["detection_roc_curve"] = "detection_roc_curve"
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCCurve(PlotterProtocol):
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
|
||||||
y_true, y_score = zip(
|
|
||||||
*[
|
|
||||||
(match.gt_det, match.pred_score)
|
|
||||||
for clip_eval in clip_evaluations
|
|
||||||
for match in clip_eval.matches
|
|
||||||
]
|
|
||||||
)
|
|
||||||
fpr, tpr, _ = metrics.roc_curve(y_true, y_score)
|
|
||||||
fig, ax = plt.subplots()
|
|
||||||
|
|
||||||
ax.plot(fpr, tpr, label="Detection")
|
|
||||||
ax.set_xlabel("False Positive Rate")
|
|
||||||
ax.set_ylabel("True Positive Rate")
|
|
||||||
ax.legend()
|
|
||||||
|
|
||||||
yield "detection_roc_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: DetectionROCCurveConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(DetectionROCCurveConfig, DetectionROCCurve)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCCurvesConfig(BaseConfig):
|
|
||||||
name: Literal["classification_roc_curves"] = "classification_roc_curves"
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCCurves(PlotterProtocol):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
class_names: List[str],
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
):
|
|
||||||
self.class_names = class_names
|
|
||||||
self.selected = class_names
|
|
||||||
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else "__NONE__"
|
|
||||||
)
|
|
||||||
|
|
||||||
y_pred.append(
|
|
||||||
np.array(
|
|
||||||
[
|
|
||||||
match.pred_class_scores.get(name, 0)
|
|
||||||
for name in self.class_names
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
y_true = label_binarize(y_true, classes=self.class_names)
|
|
||||||
y_pred = np.stack(y_pred)
|
|
||||||
|
|
||||||
fig, ax = plt.subplots(figsize=(10, 10))
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
|
||||||
if class_name not in self.selected:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true_class = y_true[:, class_index]
|
|
||||||
y_roced_class = y_pred[:, class_index]
|
|
||||||
fpr, tpr, _ = metrics.roc_curve(
|
|
||||||
y_true_class,
|
|
||||||
y_roced_class,
|
|
||||||
)
|
|
||||||
ax.plot(fpr, tpr, label=class_name)
|
|
||||||
|
|
||||||
ax.set_xlabel("False Positive Rate")
|
|
||||||
ax.set_ylabel("True Positive Rate")
|
|
||||||
ax.legend(
|
|
||||||
bbox_to_anchor=(1.05, 1),
|
|
||||||
loc="upper left",
|
|
||||||
borderaxespad=0.0,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "classification_roc_curve", fig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ClassificationROCCurvesConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
class_names=class_names,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrixConfig(BaseConfig):
|
|
||||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
|
||||||
background_class: str = "noise"
|
|
||||||
|
|
||||||
|
|
||||||
class ConfusionMatrix(PlotterProtocol):
|
|
||||||
def __init__(self, background_class: str, class_names: List[str]):
|
|
||||||
self.background_class = background_class
|
|
||||||
self.class_names = class_names
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]):
|
|
||||||
y_true = []
|
|
||||||
y_pred = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for match in clip_eval.matches:
|
|
||||||
# Ignore generic unclassified targets
|
|
||||||
if match.gt_det and match.gt_class is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(
|
|
||||||
match.gt_class
|
|
||||||
if match.gt_class is not None
|
|
||||||
else self.background_class
|
|
||||||
)
|
|
||||||
|
|
||||||
top_class = match.top_class
|
|
||||||
y_pred.append(
|
|
||||||
top_class
|
|
||||||
if top_class is not None
|
|
||||||
else self.background_class
|
|
||||||
)
|
|
||||||
|
|
||||||
display = metrics.ConfusionMatrixDisplay.from_predictions(
|
|
||||||
y_true,
|
|
||||||
y_pred,
|
|
||||||
labels=[*self.class_names, self.background_class],
|
|
||||||
)
|
|
||||||
|
|
||||||
yield "confusion_matrix", display.figure_
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: ConfusionMatrixConfig,
|
|
||||||
class_names: List[str],
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
background_class=config.background_class,
|
|
||||||
class_names=class_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
|
|
||||||
|
|
||||||
|
|
||||||
PlotConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ExampleGalleryConfig,
|
|
||||||
ClipEvaluationPlotConfig,
|
|
||||||
DetectionPRCurveConfig,
|
|
||||||
ClassificationPRCurvesConfig,
|
|
||||||
DetectionROCCurveConfig,
|
|
||||||
ClassificationROCCurvesConfig,
|
|
||||||
ConfusionMatrixConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_plotter(
|
|
||||||
config: PlotConfig, class_names: List[str]
|
|
||||||
) -> PlotterProtocol:
|
|
||||||
return plots_registry.build(config, class_names)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClassMatches:
|
|
||||||
false_positives: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
false_negatives: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
true_positives: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
def group_matches(
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
) -> Dict[str, ClassMatches]:
|
|
||||||
class_examples = defaultdict(ClassMatches)
|
|
||||||
|
|
||||||
for clip_evaluation in clip_evaluations:
|
|
||||||
for match in clip_evaluation.matches:
|
|
||||||
gt_class = match.gt_class
|
|
||||||
pred_class = match.top_class
|
|
||||||
|
|
||||||
if pred_class is None:
|
|
||||||
class_examples[gt_class].false_negatives.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if gt_class is None:
|
|
||||||
class_examples[pred_class].false_positives.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if gt_class != pred_class:
|
|
||||||
class_examples[gt_class].cross_triggers.append(match)
|
|
||||||
class_examples[pred_class].cross_triggers.append(match)
|
|
||||||
continue
|
|
||||||
|
|
||||||
class_examples[gt_class].true_positives.append(match)
|
|
||||||
|
|
||||||
return class_examples
|
|
||||||
|
|
||||||
|
|
||||||
def get_binned_sample(matches: List[MatchEvaluation], n_examples: int = 5):
|
|
||||||
if len(matches) < n_examples:
|
|
||||||
return matches
|
|
||||||
|
|
||||||
indices, pred_scores = zip(
|
|
||||||
*[
|
|
||||||
(index, match.pred_class_scores[pred_class])
|
|
||||||
for index, match in enumerate(matches)
|
|
||||||
if (pred_class := match.top_class) is not None
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
|
||||||
df = pd.DataFrame({"indices": indices, "bins": bins})
|
|
||||||
sample = df.groupby("bins").sample(1)
|
|
||||||
return [matches[ind] for ind in sample["indices"]]
|
|
||||||
@ -11,7 +11,7 @@ class BasePlotConfig(BaseConfig):
|
|||||||
label: str = "plot"
|
label: str = "plot"
|
||||||
theme: str = "default"
|
theme: str = "default"
|
||||||
title: Optional[str] = None
|
title: Optional[str] = None
|
||||||
figsize: tuple[int, int] = (5, 5)
|
figsize: tuple[int, int] = (10, 10)
|
||||||
dpi: int = 100
|
dpi: int = 100
|
||||||
|
|
||||||
|
|
||||||
@ -20,7 +20,7 @@ class BasePlot:
|
|||||||
self,
|
self,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
label: str = "plot",
|
label: str = "plot",
|
||||||
figsize: tuple[int, int] = (5, 5),
|
figsize: tuple[int, int] = (10, 10),
|
||||||
title: Optional[str] = None,
|
title: Optional[str] = None,
|
||||||
dpi: int = 100,
|
dpi: int = 100,
|
||||||
theme: str = "default",
|
theme: str = "default",
|
||||||
@ -32,7 +32,7 @@ class BasePlot:
|
|||||||
self.theme = theme
|
self.theme = theme
|
||||||
self.title = title
|
self.title = title
|
||||||
|
|
||||||
def get_figure(self) -> Figure:
|
def create_figure(self) -> Figure:
|
||||||
plt.style.use(self.theme)
|
plt.style.use(self.theme)
|
||||||
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,15 @@
|
|||||||
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
@ -12,14 +22,20 @@ from batdetect2.evaluate.metrics.classification import (
|
|||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.metrics import (
|
from batdetect2.plotting.metrics import (
|
||||||
|
plot_pr_curve,
|
||||||
plot_pr_curves,
|
plot_pr_curves,
|
||||||
|
plot_roc_curve,
|
||||||
plot_roc_curves,
|
plot_roc_curves,
|
||||||
|
plot_threshold_precision_curve,
|
||||||
plot_threshold_precision_curves,
|
plot_threshold_precision_curves,
|
||||||
|
plot_threshold_recall_curve,
|
||||||
plot_threshold_recall_curves,
|
plot_threshold_recall_curves,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
ClassificationPlotter = Callable[
|
||||||
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
|
]
|
||||||
|
|
||||||
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||||
Registry("classification_plot")
|
Registry("classification_plot")
|
||||||
@ -29,8 +45,10 @@ classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Classification Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
class PRCurve(BasePlot):
|
||||||
@ -39,25 +57,24 @@ class PRCurve(BasePlot):
|
|||||||
*args,
|
*args,
|
||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
self.ignore_generic = ignore_generic
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
ignore_generic=self.ignore_generic,
|
ignore_generic=self.ignore_generic,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = self.get_figure()
|
|
||||||
ax = fig.subplots()
|
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
class_name: compute_precision_recall(
|
class_name: compute_precision_recall(
|
||||||
y_true[class_name],
|
y_true[class_name],
|
||||||
@ -67,9 +84,23 @@ class PRCurve(BasePlot):
|
|||||||
for class_name in self.targets.class_names
|
for class_name in self.targets.class_names
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
plot_pr_curves(data, ax=ax)
|
plot_pr_curves(data, ax=ax)
|
||||||
|
yield self.label, fig
|
||||||
|
return
|
||||||
|
|
||||||
return self.label, fig
|
for class_name, (precision, recall, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
@classification_plots.register(PRCurveConfig)
|
@classification_plots.register(PRCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -79,33 +110,37 @@ class PRCurve(BasePlot):
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
ignore_generic=config.ignore_generic,
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ThresholdPRCurveConfig(BasePlotConfig):
|
class ThresholdPrecisionCurveConfig(BasePlotConfig):
|
||||||
name: Literal["threshold_pr_curve"] = "threshold_pr_curve"
|
name: Literal["threshold_precision_curve"] = "threshold_precision_curve"
|
||||||
label: str = "threshold_pr_curve"
|
label: str = "threshold_precision_curve"
|
||||||
figsize: tuple[int, int] = (10, 5)
|
title: Optional[str] = "Classification Threshold-Precision Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ThresholdPRCurve(BasePlot):
|
class ThresholdPrecisionCurve(BasePlot):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
*args,
|
*args,
|
||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
self.ignore_generic = ignore_generic
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
@ -121,30 +156,135 @@ class ThresholdPRCurve(BasePlot):
|
|||||||
for class_name in self.targets.class_names
|
for class_name in self.targets.class_names
|
||||||
}
|
}
|
||||||
|
|
||||||
fig = self.get_figure()
|
if not self.separate_figures:
|
||||||
ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True)
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
plot_threshold_precision_curves(data, ax=ax1, add_legend=False)
|
plot_threshold_precision_curves(data, ax=ax)
|
||||||
plot_threshold_recall_curves(data, ax=ax2, add_legend=True)
|
|
||||||
|
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
@classification_plots.register(ThresholdPRCurveConfig)
|
return
|
||||||
|
|
||||||
|
for class_name, (precision, _, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_threshold_precision_curve(
|
||||||
|
thresholds,
|
||||||
|
precision,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classification_plots.register(ThresholdPrecisionCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol):
|
def from_config(
|
||||||
return ThresholdPRCurve.build(
|
config: ThresholdPrecisionCurveConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ThresholdPrecisionCurve.build(
|
||||||
config=config,
|
config=config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
ignore_generic=config.ignore_generic,
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdRecallCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["threshold_recall_curve"] = "threshold_recall_curve"
|
||||||
|
label: str = "threshold_recall_curve"
|
||||||
|
title: Optional[str] = "Classification Threshold-Recall Curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdRecallCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: compute_precision_recall(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_threshold_recall_curves(data, ax=ax, add_legend=True)
|
||||||
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (_, recall, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_threshold_recall_curve(
|
||||||
|
thresholds,
|
||||||
|
recall,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@classification_plots.register(ThresholdRecallCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ThresholdRecallCurveConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ThresholdRecallCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Classification ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
class ROCCurve(BasePlot):
|
||||||
@ -153,16 +293,18 @@ class ROCCurve(BasePlot):
|
|||||||
*args,
|
*args,
|
||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
|
separate_figures: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
self.ignore_generic = ignore_generic
|
self.ignore_generic = ignore_generic
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||||
clip_evaluations,
|
clip_evaluations,
|
||||||
ignore_non_predictions=self.ignore_non_predictions,
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
@ -177,12 +319,26 @@ class ROCCurve(BasePlot):
|
|||||||
for class_name in self.targets.class_names
|
for class_name in self.targets.class_names
|
||||||
}
|
}
|
||||||
|
|
||||||
fig = self.get_figure()
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
plot_roc_curves(data, ax=ax)
|
plot_roc_curves(data, ax=ax)
|
||||||
|
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
@classification_plots.register(ROCCurveConfig)
|
@classification_plots.register(ROCCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -192,6 +348,7 @@ class ROCCurve(BasePlot):
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
ignore_generic=config.ignore_generic,
|
ignore_generic=config.ignore_generic,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -199,7 +356,8 @@ ClassificationPlotConfig = Annotated[
|
|||||||
Union[
|
Union[
|
||||||
PRCurveConfig,
|
PRCurveConfig,
|
||||||
ROCCurveConfig,
|
ROCCurveConfig,
|
||||||
ThresholdPRCurveConfig,
|
ThresholdPrecisionCurveConfig,
|
||||||
|
ThresholdRecallCurveConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Callable,
|
Callable,
|
||||||
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -8,6 +9,7 @@ from typing import (
|
|||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
@ -17,7 +19,9 @@ from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
|||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.metrics import (
|
from batdetect2.plotting.metrics import (
|
||||||
|
plot_pr_curve,
|
||||||
plot_pr_curves,
|
plot_pr_curves,
|
||||||
|
plot_roc_curve,
|
||||||
plot_roc_curves,
|
plot_roc_curves,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
@ -28,7 +32,9 @@ __all__ = [
|
|||||||
"build_clip_classification_plotter",
|
"build_clip_classification_plotter",
|
||||||
]
|
]
|
||||||
|
|
||||||
ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
ClipClassificationPlotter = Callable[
|
||||||
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
|
]
|
||||||
|
|
||||||
clip_classification_plots: Registry[
|
clip_classification_plots: Registry[
|
||||||
ClipClassificationPlotter, [TargetProtocol]
|
ClipClassificationPlotter, [TargetProtocol]
|
||||||
@ -38,14 +44,24 @@ clip_classification_plots: Registry[
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: Optional[str] = "Precision-Recall Curve"
|
title: Optional[str] = "Clip Classification Precision-Recall Curve"
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
for class_name in self.targets.class_names:
|
for class_name in self.targets.class_names:
|
||||||
@ -61,10 +77,26 @@ class PRCurve(BasePlot):
|
|||||||
|
|
||||||
data[class_name] = (precision, recall, thresholds)
|
data[class_name] = (precision, recall, thresholds)
|
||||||
|
|
||||||
fig = self.get_figure()
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
plot_pr_curves(data, ax=ax)
|
plot_pr_curves(data, ax=ax)
|
||||||
return self.label, fig
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (precision, recall, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
@clip_classification_plots.register(PRCurveConfig)
|
@clip_classification_plots.register(PRCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -72,20 +104,31 @@ class PRCurve(BasePlot):
|
|||||||
return PRCurve.build(
|
return PRCurve.build(
|
||||||
config=config,
|
config=config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: Optional[str] = "ROC Curve"
|
title: Optional[str] = "Clip Classification ROC Curve"
|
||||||
|
separate_figures: bool = False
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
separate_figures: bool = False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.separate_figures = separate_figures
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
data = {}
|
data = {}
|
||||||
|
|
||||||
for class_name in self.targets.class_names:
|
for class_name in self.targets.class_names:
|
||||||
@ -101,10 +144,24 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
data[class_name] = (fpr, tpr, thresholds)
|
data[class_name] = (fpr, tpr, thresholds)
|
||||||
|
|
||||||
fig = self.get_figure()
|
if not self.separate_figures:
|
||||||
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
plot_roc_curves(data, ax=ax)
|
plot_roc_curves(data, ax=ax)
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
for class_name, (fpr, tpr, thresholds) in data.items():
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
ax = plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
ax.set_title(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
@clip_classification_plots.register(ROCCurveConfig)
|
@clip_classification_plots.register(ROCCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -112,6 +169,7 @@ class ROCCurve(BasePlot):
|
|||||||
return ROCCurve.build(
|
return ROCCurve.build(
|
||||||
config=config,
|
config=config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
separate_figures=config.separate_figures,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
Annotated,
|
Annotated,
|
||||||
Callable,
|
Callable,
|
||||||
|
Iterable,
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
@ -27,7 +28,9 @@ __all__ = [
|
|||||||
"build_clip_detection_plotter",
|
"build_clip_detection_plotter",
|
||||||
]
|
]
|
||||||
|
|
||||||
ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
ClipDetectionPlotter = Callable[
|
||||||
|
[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||||
@ -38,14 +41,14 @@ clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
title: Optional[str] = "Precision-Recall Curve"
|
title: Optional[str] = "Clip Detection Precision-Recall Curve"
|
||||||
|
|
||||||
|
|
||||||
class PRCurve(BasePlot):
|
class PRCurve(BasePlot):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = [c.gt_det for c in clip_evaluations]
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
y_score = [c.score for c in clip_evaluations]
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
@ -54,10 +57,10 @@ class PRCurve(BasePlot):
|
|||||||
y_score,
|
y_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
@clip_detection_plots.register(PRCurveConfig)
|
@clip_detection_plots.register(PRCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -71,14 +74,14 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
title: Optional[str] = "ROC Curve"
|
title: Optional[str] = "Clip Detection ROC Curve"
|
||||||
|
|
||||||
|
|
||||||
class ROCCurve(BasePlot):
|
class ROCCurve(BasePlot):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = [c.gt_det for c in clip_evaluations]
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
y_score = [c.score for c in clip_evaluations]
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
@ -87,10 +90,10 @@ class ROCCurve(BasePlot):
|
|||||||
y_score,
|
y_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
@clip_detection_plots.register(ROCCurveConfig)
|
@clip_detection_plots.register(ROCCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -104,18 +107,18 @@ class ROCCurve(BasePlot):
|
|||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
label: str = "score_distribution"
|
label: str = "score_distribution"
|
||||||
title: Optional[str] = "Score Distribution"
|
title: Optional[str] = "Clip Detection Score Distribution"
|
||||||
|
|
||||||
|
|
||||||
class ScoreDistributionPlot(BasePlot):
|
class ScoreDistributionPlot(BasePlot):
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = [c.gt_det for c in clip_evaluations]
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
y_score = [c.score for c in clip_evaluations]
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||||
@ -130,7 +133,7 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
common_norm=False,
|
common_norm=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,26 +1,33 @@
|
|||||||
import random
|
import random
|
||||||
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Iterable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import seaborn as sns
|
import seaborn as sns
|
||||||
from matplotlib import patches
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from soundevent.plot import plot_geometry
|
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.core import Registry
|
from batdetect2.core import Registry
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
from batdetect2.plotting.clips import plot_clip
|
from batdetect2.plotting.detections import plot_clip_detections
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
|
|
||||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
DetectionPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||||
|
|
||||||
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||||
name="detection_plot"
|
name="detection_plot"
|
||||||
@ -30,6 +37,7 @@ detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Detection Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -49,7 +57,7 @@ class PRCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evals: Sequence[ClipEval],
|
clip_evals: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
num_positives = 0
|
num_positives = 0
|
||||||
@ -71,10 +79,12 @@ class PRCurve(BasePlot):
|
|||||||
num_positives=num_positives,
|
num_positives=num_positives,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
return self.label, fig
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
@detection_plots.register(PRCurveConfig)
|
@detection_plots.register(PRCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -90,6 +100,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Detection ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -109,7 +120,7 @@ class ROCCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
|
|
||||||
@ -127,10 +138,12 @@ class ROCCurve(BasePlot):
|
|||||||
y_score,
|
y_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
return self.label, fig
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
@detection_plots.register(ROCCurveConfig)
|
@detection_plots.register(ROCCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -146,6 +159,7 @@ class ROCCurve(BasePlot):
|
|||||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["score_distribution"] = "score_distribution"
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
label: str = "score_distribution"
|
label: str = "score_distribution"
|
||||||
|
title: Optional[str] = "Detection Score Distribution"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -165,7 +179,7 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
|
|
||||||
@ -180,7 +194,7 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
|
|
||||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
sns.histplot(
|
sns.histplot(
|
||||||
@ -194,7 +208,7 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
common_norm=False,
|
common_norm=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
@detection_plots.register(ScoreDistributionPlotConfig)
|
@detection_plots.register(ScoreDistributionPlotConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -212,7 +226,8 @@ class ScoreDistributionPlot(BasePlot):
|
|||||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||||
name: Literal["example_detection"] = "example_detection"
|
name: Literal["example_detection"] = "example_detection"
|
||||||
label: str = "example_detection"
|
label: str = "example_detection"
|
||||||
figsize: tuple[int, int] = (10, 15)
|
title: Optional[str] = "Example Detection"
|
||||||
|
figsize: tuple[int, int] = (10, 4)
|
||||||
num_examples: int = 5
|
num_examples: int = 5
|
||||||
threshold: float = 0.2
|
threshold: float = 0.2
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
@ -240,82 +255,26 @@ class ExampleDetectionPlot(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
fig = self.get_figure()
|
|
||||||
|
|
||||||
sample = clip_evaluations
|
sample = clip_evaluations
|
||||||
|
|
||||||
if self.num_examples < len(sample):
|
if self.num_examples < len(sample):
|
||||||
sample = random.sample(sample, self.num_examples)
|
sample = random.sample(sample, self.num_examples)
|
||||||
|
|
||||||
axes = fig.subplots(nrows=self.num_examples, ncols=1)
|
for num_example, clip_eval in enumerate(sample):
|
||||||
|
fig = self.create_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
for ax, clip_eval in zip(axes, sample):
|
plot_clip_detections(
|
||||||
plot_clip(
|
clip_eval,
|
||||||
clip_eval.clip,
|
ax=ax,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
ax=ax,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
for m in clip_eval.matches:
|
yield f"{self.label}/example_{num_example}", fig
|
||||||
is_match = (
|
|
||||||
m.pred is not None
|
|
||||||
and m.gt is not None
|
|
||||||
and m.score >= self.threshold
|
|
||||||
)
|
|
||||||
|
|
||||||
if m.pred is not None:
|
plt.close(fig)
|
||||||
plot_geometry(
|
|
||||||
m.pred.geometry,
|
|
||||||
ax=ax,
|
|
||||||
add_points=False,
|
|
||||||
facecolor="none",
|
|
||||||
alpha=m.pred.detection_score,
|
|
||||||
linestyle="-" if not is_match else "--",
|
|
||||||
color="red" if not is_match else "orange",
|
|
||||||
)
|
|
||||||
|
|
||||||
if m.gt is not None:
|
|
||||||
plot_geometry(
|
|
||||||
m.gt.sound_event.geometry, # type: ignore
|
|
||||||
ax=ax,
|
|
||||||
add_points=False,
|
|
||||||
facecolor="none",
|
|
||||||
color="green" if not is_match else "orange",
|
|
||||||
)
|
|
||||||
|
|
||||||
ax.set_title(clip_eval.clip.recording.path.name)
|
|
||||||
|
|
||||||
# ax.legend(
|
|
||||||
# handles=[
|
|
||||||
# patches.Patch(
|
|
||||||
# edgecolor="green",
|
|
||||||
# label="Ground Truth (Unmatched)",
|
|
||||||
# facecolor="none",
|
|
||||||
# ),
|
|
||||||
# patches.Patch(
|
|
||||||
# edgecolor="orange",
|
|
||||||
# label="Ground Truth (Matched)",
|
|
||||||
# facecolor="none",
|
|
||||||
# ),
|
|
||||||
# patches.Patch(
|
|
||||||
# edgecolor="red",
|
|
||||||
# label="Detection (Unmatched)",
|
|
||||||
# facecolor="none",
|
|
||||||
# ),
|
|
||||||
# patches.Patch(
|
|
||||||
# edgecolor="orange",
|
|
||||||
# label="Detection (Matched)",
|
|
||||||
# facecolor="none",
|
|
||||||
# linestyle="--",
|
|
||||||
# ),
|
|
||||||
# ]
|
|
||||||
# )
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
return self.label, fig
|
|
||||||
|
|
||||||
@detection_plots.register(ExampleDetectionPlotConfig)
|
@detection_plots.register(ExampleDetectionPlotConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -1,17 +1,36 @@
|
|||||||
from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union
|
import random
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.core import Registry
|
from batdetect2.core import Registry
|
||||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
from batdetect2.evaluate.metrics.top_class import ClipEval
|
from batdetect2.evaluate.metrics.top_class import ClipEval, MatchEval
|
||||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
|
|
||||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
TopClassPlotter = Callable[[Sequence[ClipEval]], Iterable[Tuple[str, Figure]]]
|
||||||
|
|
||||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||||
name="top_class_plot"
|
name="top_class_plot"
|
||||||
@ -21,6 +40,7 @@ top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
|||||||
class PRCurveConfig(BasePlotConfig):
|
class PRCurveConfig(BasePlotConfig):
|
||||||
name: Literal["pr_curve"] = "pr_curve"
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
label: str = "pr_curve"
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Top Class Precision-Recall Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -40,7 +60,7 @@ class PRCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
num_positives = 0
|
num_positives = 0
|
||||||
@ -66,10 +86,12 @@ class PRCurve(BasePlot):
|
|||||||
num_positives=num_positives,
|
num_positives=num_positives,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
return self.label, fig
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
@top_class_plots.register(PRCurveConfig)
|
@top_class_plots.register(PRCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -85,6 +107,7 @@ class PRCurve(BasePlot):
|
|||||||
class ROCCurveConfig(BasePlotConfig):
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
name: Literal["roc_curve"] = "roc_curve"
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
label: str = "roc_curve"
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "Top Class ROC Curve"
|
||||||
ignore_non_predictions: bool = True
|
ignore_non_predictions: bool = True
|
||||||
ignore_generic: bool = True
|
ignore_generic: bool = True
|
||||||
|
|
||||||
@ -104,7 +127,7 @@ class ROCCurve(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true = []
|
y_true = []
|
||||||
y_score = []
|
y_score = []
|
||||||
|
|
||||||
@ -126,10 +149,12 @@ class ROCCurve(BasePlot):
|
|||||||
y_score,
|
y_score,
|
||||||
)
|
)
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
return self.label, fig
|
|
||||||
|
yield self.label, fig
|
||||||
|
|
||||||
@top_class_plots.register(ROCCurveConfig)
|
@top_class_plots.register(ROCCurveConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -144,6 +169,7 @@ class ROCCurve(BasePlot):
|
|||||||
|
|
||||||
class ConfusionMatrixConfig(BasePlotConfig):
|
class ConfusionMatrixConfig(BasePlotConfig):
|
||||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||||
|
title: Optional[str] = "Top Class Confusion Matrix"
|
||||||
figsize: tuple[int, int] = (10, 10)
|
figsize: tuple[int, int] = (10, 10)
|
||||||
label: str = "confusion_matrix"
|
label: str = "confusion_matrix"
|
||||||
exclude_generic: bool = True
|
exclude_generic: bool = True
|
||||||
@ -180,7 +206,7 @@ class ConfusionMatrix(BasePlot):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
clip_evaluations: Sequence[ClipEval],
|
clip_evaluations: Sequence[ClipEval],
|
||||||
) -> Tuple[str, Figure]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
y_true: List[str] = []
|
y_true: List[str] = []
|
||||||
y_pred: List[str] = []
|
y_pred: List[str] = []
|
||||||
|
|
||||||
@ -213,7 +239,7 @@ class ConfusionMatrix(BasePlot):
|
|||||||
y_true.append(true_class or self.noise_class)
|
y_true.append(true_class or self.noise_class)
|
||||||
y_pred.append(pred_class or self.noise_class)
|
y_pred.append(pred_class or self.noise_class)
|
||||||
|
|
||||||
fig = self.get_figure()
|
fig = self.create_figure()
|
||||||
ax = fig.subplots()
|
ax = fig.subplots()
|
||||||
|
|
||||||
class_names = [*self.targets.class_names]
|
class_names = [*self.targets.class_names]
|
||||||
@ -236,7 +262,7 @@ class ConfusionMatrix(BasePlot):
|
|||||||
values_format=".2f",
|
values_format=".2f",
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.label, fig
|
yield self.label, fig
|
||||||
|
|
||||||
@top_class_plots.register(ConfusionMatrixConfig)
|
@top_class_plots.register(ConfusionMatrixConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -253,11 +279,105 @@ class ConfusionMatrix(BasePlot):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleClassificationPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["example_classification"] = "example_classification"
|
||||||
|
label: str = "example_classification"
|
||||||
|
title: Optional[str] = "Example Classification"
|
||||||
|
num_examples: int = 4
|
||||||
|
threshold: float = 0.2
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleClassificationPlot(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
num_examples: int = 4,
|
||||||
|
threshold: float = 0.2,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_examples = num_examples
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.threshold = threshold
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.num_examples = num_examples
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
grouped = group_matches(clip_evaluations, threshold=self.threshold)
|
||||||
|
|
||||||
|
for class_name, matches in grouped.items():
|
||||||
|
true_positives: List[MatchEval] = get_binned_sample(
|
||||||
|
matches.true_positives,
|
||||||
|
n_examples=self.num_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
false_positives: List[MatchEval] = get_binned_sample(
|
||||||
|
matches.false_positives,
|
||||||
|
n_examples=self.num_examples,
|
||||||
|
)
|
||||||
|
|
||||||
|
false_negatives: List[MatchEval] = random.sample(
|
||||||
|
matches.false_negatives,
|
||||||
|
k=min(self.num_examples, len(matches.false_negatives)),
|
||||||
|
)
|
||||||
|
|
||||||
|
cross_triggers: List[MatchEval] = get_binned_sample(
|
||||||
|
matches.cross_triggers, n_examples=self.num_examples
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.create_figure()
|
||||||
|
|
||||||
|
fig = plot_match_gallery(
|
||||||
|
true_positives,
|
||||||
|
false_positives,
|
||||||
|
false_negatives,
|
||||||
|
cross_triggers,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
n_examples=self.num_examples,
|
||||||
|
fig=fig,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.title is not None:
|
||||||
|
fig.suptitle(f"{self.title}: {class_name}")
|
||||||
|
else:
|
||||||
|
fig.suptitle(class_name)
|
||||||
|
|
||||||
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
|
plt.close(fig)
|
||||||
|
|
||||||
|
@top_class_plots.register(ExampleClassificationPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ExampleClassificationPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
return ExampleClassificationPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
num_examples=config.num_examples,
|
||||||
|
threshold=config.threshold,
|
||||||
|
audio_loader=build_audio_loader(config.audio),
|
||||||
|
preprocessor=build_preprocessor(config.preprocessing),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
TopClassPlotConfig = Annotated[
|
TopClassPlotConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
PRCurveConfig,
|
PRCurveConfig,
|
||||||
ROCCurveConfig,
|
ROCCurveConfig,
|
||||||
ConfusionMatrixConfig,
|
ConfusionMatrixConfig,
|
||||||
|
ExampleClassificationPlotConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
@ -268,3 +388,57 @@ def build_top_class_plotter(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> TopClassPlotter:
|
) -> TopClassPlotter:
|
||||||
return top_class_plots.build(config, targets)
|
return top_class_plots.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClassMatches:
|
||||||
|
false_positives: List[MatchEval] = field(default_factory=list)
|
||||||
|
false_negatives: List[MatchEval] = field(default_factory=list)
|
||||||
|
true_positives: List[MatchEval] = field(default_factory=list)
|
||||||
|
cross_triggers: List[MatchEval] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
def group_matches(
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
threshold: float = 0.2,
|
||||||
|
) -> Dict[str, ClassMatches]:
|
||||||
|
class_examples = defaultdict(ClassMatches)
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
gt_class = match.true_class
|
||||||
|
pred_class = match.pred_class
|
||||||
|
is_pred = match.score >= threshold
|
||||||
|
|
||||||
|
if not is_pred and gt_class is not None:
|
||||||
|
class_examples[gt_class].false_negatives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not is_pred:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class is None:
|
||||||
|
class_examples[pred_class].false_positives.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if gt_class != pred_class:
|
||||||
|
class_examples[pred_class].cross_triggers.append(match)
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_examples[gt_class].true_positives.append(match)
|
||||||
|
|
||||||
|
return class_examples
|
||||||
|
|
||||||
|
|
||||||
|
def get_binned_sample(matches: List[MatchEval], n_examples: int = 5):
|
||||||
|
if len(matches) < n_examples:
|
||||||
|
return matches
|
||||||
|
|
||||||
|
indices, pred_scores = zip(
|
||||||
|
*[(index, match.score) for index, match in enumerate(matches)]
|
||||||
|
)
|
||||||
|
|
||||||
|
bins = pd.qcut(pred_scores, q=n_examples, labels=False, duplicates="drop")
|
||||||
|
df = pd.DataFrame({"indices": indices, "bins": bins})
|
||||||
|
sample = df.groupby("bins").sample(1)
|
||||||
|
return [matches[ind] for ind in sample["indices"]]
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||||
|
|
||||||
plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
plots: List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
|
|
||||||
ignore_start_end: float
|
ignore_start_end: float
|
||||||
|
|
||||||
@ -68,7 +68,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
prefix: str,
|
prefix: str,
|
||||||
ignore_start_end: float = 0.01,
|
ignore_start_end: float = 0.01,
|
||||||
plots: Optional[
|
plots: Optional[
|
||||||
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
] = None,
|
] = None,
|
||||||
):
|
):
|
||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
@ -93,7 +93,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
self, eval_outputs: List[T_Output]
|
self, eval_outputs: List[T_Output]
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
for plot in self.plots:
|
for plot in self.plots:
|
||||||
yield plot(eval_outputs)
|
for name, fig in plot(eval_outputs):
|
||||||
|
yield f"{self.prefix}/{name}", fig
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
@ -147,7 +148,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
plots: Optional[
|
plots: Optional[
|
||||||
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
List[Callable[[Sequence[T_Output]], Iterable[Tuple[str, Figure]]]]
|
||||||
] = None,
|
] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -98,6 +98,7 @@ class ClassificationTask(BaseTask[ClipEval]):
|
|||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEval(
|
MatchEval(
|
||||||
|
clip=clip,
|
||||||
gt=gt,
|
gt=gt,
|
||||||
pred=pred,
|
pred=pred,
|
||||||
is_prediction=pred is not None,
|
is_prediction=pred is not None,
|
||||||
|
|||||||
@ -79,6 +79,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEval(
|
MatchEval(
|
||||||
|
clip=clip,
|
||||||
gt=gt,
|
gt=gt,
|
||||||
pred=pred,
|
pred=pred,
|
||||||
is_ground_truth=gt is not None,
|
is_ground_truth=gt is not None,
|
||||||
|
|||||||
@ -11,7 +11,6 @@ from batdetect2.plotting.matches import (
|
|||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
plot_false_negative_match,
|
plot_false_negative_match,
|
||||||
plot_false_positive_match,
|
plot_false_positive_match,
|
||||||
plot_matches,
|
|
||||||
plot_true_positive_match,
|
plot_true_positive_match,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -22,7 +21,6 @@ __all__ = [
|
|||||||
"plot_cross_trigger_match",
|
"plot_cross_trigger_match",
|
||||||
"plot_false_negative_match",
|
"plot_false_negative_match",
|
||||||
"plot_false_positive_match",
|
"plot_false_positive_match",
|
||||||
"plot_matches",
|
|
||||||
"plot_spectrogram",
|
"plot_spectrogram",
|
||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_detection_heatmap",
|
"plot_detection_heatmap",
|
||||||
|
|||||||
@ -66,6 +66,9 @@ def plot_spectrogram(
|
|||||||
vmax=vmax,
|
vmax=vmax,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(start_time, end_time)
|
||||||
|
ax.set_ylim(min_freq, max_freq)
|
||||||
|
|
||||||
if add_colorbar:
|
if add_colorbar:
|
||||||
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
plt.colorbar(mappable, ax=ax, **(colorbar_kwargs or {}))
|
||||||
|
|
||||||
|
|||||||
113
src/batdetect2/plotting/detections.py
Normal file
113
src/batdetect2/plotting/detections.py
Normal file
@ -0,0 +1,113 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from matplotlib import axes, patches
|
||||||
|
from soundevent.plot import plot_geometry
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||||
|
from batdetect2.plotting.clips import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
plot_clip,
|
||||||
|
)
|
||||||
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"plot_clip_detections",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def plot_clip_detections(
|
||||||
|
clip_eval: ClipEval,
|
||||||
|
figsize: tuple[int, int] = (10, 10),
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
threshold: float = 0.2,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_title: bool = True,
|
||||||
|
fill: bool = False,
|
||||||
|
linewidth: float = 1.0,
|
||||||
|
gt_color: str = "green",
|
||||||
|
gt_linestyle: str = "-",
|
||||||
|
true_pred_color: str = "yellow",
|
||||||
|
true_pred_linestyle: str = "--",
|
||||||
|
false_pred_color: str = "blue",
|
||||||
|
false_pred_linestyle: str = "-",
|
||||||
|
missed_gt_color: str = "red",
|
||||||
|
missed_gt_linestyle: str = "-",
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(figsize=figsize, ax=ax)
|
||||||
|
|
||||||
|
plot_clip(
|
||||||
|
clip_eval.clip,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
is_match = (
|
||||||
|
m.pred is not None and m.gt is not None and m.score >= threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.pred is not None:
|
||||||
|
color = true_pred_color if is_match else false_pred_color
|
||||||
|
plot_geometry(
|
||||||
|
m.pred.geometry,
|
||||||
|
ax=ax,
|
||||||
|
add_points=False,
|
||||||
|
facecolor="none" if not fill else color,
|
||||||
|
alpha=m.pred.detection_score,
|
||||||
|
linewidth=linewidth,
|
||||||
|
linestyle=true_pred_linestyle
|
||||||
|
if is_match
|
||||||
|
else missed_gt_linestyle,
|
||||||
|
color=color,
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.gt is not None:
|
||||||
|
color = gt_color if is_match else missed_gt_color
|
||||||
|
plot_geometry(
|
||||||
|
m.gt.sound_event.geometry, # type: ignore
|
||||||
|
ax=ax,
|
||||||
|
add_points=False,
|
||||||
|
linewidth=linewidth,
|
||||||
|
facecolor="none" if not fill else color,
|
||||||
|
linestyle=gt_linestyle if is_match else false_pred_linestyle,
|
||||||
|
color=color,
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title(clip_eval.clip.recording.path.name)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
handles=[
|
||||||
|
patches.Patch(
|
||||||
|
label="found GT",
|
||||||
|
edgecolor=gt_color,
|
||||||
|
facecolor="none" if not fill else gt_color,
|
||||||
|
linestyle=gt_linestyle,
|
||||||
|
),
|
||||||
|
patches.Patch(
|
||||||
|
label="missed GT",
|
||||||
|
edgecolor=missed_gt_color,
|
||||||
|
facecolor="none" if not fill else missed_gt_color,
|
||||||
|
linestyle=missed_gt_linestyle,
|
||||||
|
),
|
||||||
|
patches.Patch(
|
||||||
|
label="true Det",
|
||||||
|
edgecolor=true_pred_color,
|
||||||
|
facecolor="none" if not fill else true_pred_color,
|
||||||
|
linestyle=true_pred_linestyle,
|
||||||
|
),
|
||||||
|
patches.Patch(
|
||||||
|
label="false Det",
|
||||||
|
edgecolor=false_pred_color,
|
||||||
|
facecolor="none" if not fill else false_pred_color,
|
||||||
|
linestyle=false_pred_linestyle,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return ax
|
||||||
@ -1,81 +1,109 @@
|
|||||||
from typing import List, Optional
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
from batdetect2.plotting.matches import (
|
from batdetect2.plotting.matches import (
|
||||||
|
MatchProtocol,
|
||||||
plot_cross_trigger_match,
|
plot_cross_trigger_match,
|
||||||
plot_false_negative_match,
|
plot_false_negative_match,
|
||||||
plot_false_positive_match,
|
plot_false_positive_match,
|
||||||
plot_true_positive_match,
|
plot_true_positive_match,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.evaluate import MatchEvaluation
|
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = ["plot_match_gallery"]
|
__all__ = ["plot_match_gallery"]
|
||||||
|
|
||||||
|
|
||||||
def plot_match_gallery(
|
def plot_match_gallery(
|
||||||
true_positives: List[MatchEvaluation],
|
true_positives: Sequence[MatchProtocol],
|
||||||
false_positives: List[MatchEvaluation],
|
false_positives: Sequence[MatchProtocol],
|
||||||
false_negatives: List[MatchEvaluation],
|
false_negatives: Sequence[MatchProtocol],
|
||||||
cross_triggers: List[MatchEvaluation],
|
cross_triggers: Sequence[MatchProtocol],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
n_examples: int = 5,
|
n_examples: int = 5,
|
||||||
duration: float = 0.1,
|
duration: float = 0.1,
|
||||||
|
fig: Optional[Figure] = None,
|
||||||
):
|
):
|
||||||
|
if fig is None:
|
||||||
fig = plt.figure(figsize=(20, 20))
|
fig = plt.figure(figsize=(20, 20))
|
||||||
|
|
||||||
for index, match in enumerate(true_positives[:n_examples]):
|
axes = fig.subplots(
|
||||||
ax = plt.subplot(4, n_examples, index + 1)
|
nrows=4,
|
||||||
|
ncols=n_examples,
|
||||||
|
sharex="none",
|
||||||
|
sharey="row",
|
||||||
|
)
|
||||||
|
|
||||||
|
for tp_ax, tp_match in zip(axes[0], true_positives[:n_examples]):
|
||||||
try:
|
try:
|
||||||
plot_true_positive_match(
|
plot_true_positive_match(
|
||||||
match,
|
tp_match,
|
||||||
ax=ax,
|
ax=tp_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_positives[:n_examples]):
|
for fp_ax, fp_match in zip(axes[1], false_positives[:n_examples]):
|
||||||
ax = plt.subplot(4, n_examples, n_examples + index + 1)
|
|
||||||
try:
|
try:
|
||||||
plot_false_positive_match(
|
plot_false_positive_match(
|
||||||
match,
|
fp_match,
|
||||||
ax=ax,
|
ax=fp_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(false_negatives[:n_examples]):
|
for fn_ax, fn_match in zip(axes[2], false_negatives[:n_examples]):
|
||||||
ax = plt.subplot(4, n_examples, 2 * n_examples + index + 1)
|
|
||||||
try:
|
try:
|
||||||
plot_false_negative_match(
|
plot_false_negative_match(
|
||||||
match,
|
fn_match,
|
||||||
ax=ax,
|
ax=fn_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
for ct_ax, ct_match in zip(axes[3], cross_triggers[:n_examples]):
|
||||||
ax = plt.subplot(4, n_examples, 3 * n_examples + index + 1)
|
|
||||||
try:
|
try:
|
||||||
plot_cross_trigger_match(
|
plot_cross_trigger_match(
|
||||||
match,
|
ct_match,
|
||||||
ax=ax,
|
ax=ct_ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
)
|
)
|
||||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
except (
|
||||||
|
ValueError,
|
||||||
|
AssertionError,
|
||||||
|
RuntimeError,
|
||||||
|
FileNotFoundError,
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
fig.tight_layout()
|
||||||
|
|
||||||
return fig
|
return fig
|
||||||
|
|||||||
@ -1,16 +1,17 @@
|
|||||||
from typing import List, Optional, Tuple, Union
|
from typing import Optional, Protocol, Tuple, Union
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
from matplotlib.axes import Axes
|
from matplotlib.axes import Axes
|
||||||
from soundevent import data, plot
|
from soundevent import data, plot
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
from soundevent.plot.tags import TagColorMapper
|
|
||||||
|
|
||||||
from batdetect2.plotting.clips import AudioLoader, plot_clip
|
from batdetect2.plotting.clips import plot_clip
|
||||||
from batdetect2.typing import MatchEvaluation, PreprocessorProtocol
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"plot_matches",
|
|
||||||
"plot_false_positive_match",
|
"plot_false_positive_match",
|
||||||
"plot_true_positive_match",
|
"plot_true_positive_match",
|
||||||
"plot_false_negative_match",
|
"plot_false_negative_match",
|
||||||
@ -18,6 +19,14 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class MatchProtocol(Protocol):
|
||||||
|
clip: data.Clip
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
score: float
|
||||||
|
true_class: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_DURATION = 0.05
|
DEFAULT_DURATION = 0.05
|
||||||
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
DEFAULT_FALSE_POSITIVE_COLOR = "orange"
|
||||||
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
DEFAULT_FALSE_NEGATIVE_COLOR = "red"
|
||||||
@ -27,88 +36,8 @@ DEFAULT_ANNOTATION_LINE_STYLE = "-"
|
|||||||
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
DEFAULT_PREDICTION_LINE_STYLE = "--"
|
||||||
|
|
||||||
|
|
||||||
def plot_matches(
|
|
||||||
matches: List[MatchEvaluation],
|
|
||||||
clip: data.Clip,
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
|
||||||
ax: Optional[Axes] = None,
|
|
||||||
audio_dir: Optional[data.PathLike] = None,
|
|
||||||
color_mapper: Optional[TagColorMapper] = None,
|
|
||||||
add_points: bool = False,
|
|
||||||
fill: bool = False,
|
|
||||||
spec_cmap: str = "gray",
|
|
||||||
false_positive_color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
|
||||||
false_negative_color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
|
||||||
true_positive_color: str = DEFAULT_TRUE_POSITIVE_COLOR,
|
|
||||||
cross_trigger_color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
|
||||||
) -> Axes:
|
|
||||||
ax = plot_clip(
|
|
||||||
clip,
|
|
||||||
ax=ax,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
figsize=figsize,
|
|
||||||
audio_dir=audio_dir,
|
|
||||||
spec_cmap=spec_cmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
if color_mapper is None:
|
|
||||||
color_mapper = TagColorMapper()
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
if match.is_cross_trigger():
|
|
||||||
plot_cross_trigger_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_points=add_points,
|
|
||||||
add_spectrogram=False,
|
|
||||||
use_score=True,
|
|
||||||
color=cross_trigger_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
elif match.is_true_positive():
|
|
||||||
plot_true_positive_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_spectrogram=False,
|
|
||||||
use_score=True,
|
|
||||||
add_points=add_points,
|
|
||||||
color=true_positive_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
elif match.is_false_negative():
|
|
||||||
plot_false_negative_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_spectrogram=False,
|
|
||||||
add_points=add_points,
|
|
||||||
color=false_negative_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
elif match.is_false_positive:
|
|
||||||
plot_false_positive_match(
|
|
||||||
match,
|
|
||||||
ax=ax,
|
|
||||||
fill=fill,
|
|
||||||
add_spectrogram=False,
|
|
||||||
use_score=True,
|
|
||||||
add_points=add_points,
|
|
||||||
color=false_positive_color,
|
|
||||||
add_text=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return ax
|
|
||||||
|
|
||||||
|
|
||||||
def plot_false_positive_match(
|
def plot_false_positive_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -119,21 +48,24 @@ def plot_false_positive_match(
|
|||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
add_text: bool = True,
|
add_text: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
|
add_title: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
color: str = DEFAULT_FALSE_POSITIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
fontsize: Union[float, str] = "small",
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.pred_geometry is not None
|
assert match.pred is not None
|
||||||
assert match.sound_event_annotation is None
|
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(match.pred_geometry)
|
start_time, _, _, high_freq = compute_bounds(match.pred.geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
|
start_time - duration / 2,
|
||||||
|
0,
|
||||||
|
),
|
||||||
end_time=min(
|
end_time=min(
|
||||||
start_time + duration / 2,
|
start_time + duration / 2,
|
||||||
match.clip.end_time,
|
match.clip.recording.duration,
|
||||||
),
|
),
|
||||||
recording=match.clip.recording,
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
@ -150,30 +82,33 @@ def plot_false_positive_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred.geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.pred_score if use_score else 1,
|
alpha=match.score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_text:
|
||||||
plt.text(
|
ax.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"False Positive \nScore: {match.pred_score:.2f} \nTop Class: {match.top_class} \nTop Class Score: {match.top_class_score:.2f} ",
|
f"score={match.score:.2f}",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=fontsize,
|
fontsize=fontsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title("False Positive")
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_false_negative_match(
|
def plot_false_negative_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -182,26 +117,28 @@ def plot_false_negative_match(
|
|||||||
duration: float = DEFAULT_DURATION,
|
duration: float = DEFAULT_DURATION,
|
||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_text: bool = True,
|
add_title: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
color: str = DEFAULT_FALSE_NEGATIVE_COLOR,
|
||||||
fontsize: Union[float, str] = "small",
|
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.pred_geometry is None
|
assert match.gt is not None
|
||||||
assert match.sound_event_annotation is not None
|
|
||||||
sound_event = match.sound_event_annotation.sound_event
|
geometry = match.gt.sound_event.geometry
|
||||||
geometry = sound_event.geometry
|
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time = compute_bounds(geometry)[0]
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
end_time=min(
|
start_time - duration / 2,
|
||||||
start_time + duration / 2, sound_event.recording.duration
|
0,
|
||||||
),
|
),
|
||||||
recording=sound_event.recording,
|
end_time=min(
|
||||||
|
start_time + duration / 2,
|
||||||
|
match.clip.recording.duration,
|
||||||
|
),
|
||||||
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
if add_spectrogram:
|
||||||
@ -215,33 +152,23 @@ def plot_false_negative_match(
|
|||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_annotation(
|
ax = plot.plot_geometry(
|
||||||
match.sound_event_annotation,
|
geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
color=color,
|
color=color,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_title:
|
||||||
plt.text(
|
ax.set_title("False Negative")
|
||||||
start_time,
|
|
||||||
high_freq,
|
|
||||||
f"False Negative \nClass: {match.gt_class} ",
|
|
||||||
va="top",
|
|
||||||
ha="right",
|
|
||||||
color=color,
|
|
||||||
fontsize=fontsize,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_true_positive_match(
|
def plot_true_positive_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -258,39 +185,42 @@ def plot_true_positive_match(
|
|||||||
fontsize: Union[float, str] = "small",
|
fontsize: Union[float, str] = "small",
|
||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
|
add_title: bool = True,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.sound_event_annotation is not None
|
assert match.gt is not None
|
||||||
assert match.pred_geometry is not None
|
assert match.pred is not None
|
||||||
sound_event = match.sound_event_annotation.sound_event
|
|
||||||
geometry = sound_event.geometry
|
geometry = match.gt.sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
end_time=min(
|
start_time - duration / 2,
|
||||||
start_time + duration / 2, sound_event.recording.duration
|
0,
|
||||||
),
|
),
|
||||||
recording=sound_event.recording,
|
end_time=min(
|
||||||
|
start_time + duration / 2,
|
||||||
|
match.clip.recording.duration,
|
||||||
|
),
|
||||||
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
if add_spectrogram:
|
||||||
ax = plot_clip(
|
ax = plot_clip(
|
||||||
clip,
|
clip,
|
||||||
|
ax=ax,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
figsize=figsize,
|
figsize=figsize,
|
||||||
ax=ax,
|
|
||||||
audio_dir=audio_dir,
|
audio_dir=audio_dir,
|
||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_annotation(
|
ax = plot.plot_geometry(
|
||||||
match.sound_event_annotation,
|
geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -299,31 +229,34 @@ def plot_true_positive_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
plot.plot_geometry(
|
plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred.geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.pred_score if use_score else 1,
|
alpha=match.score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_text:
|
||||||
plt.text(
|
ax.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
|
f"score={match.score:.2f}",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=fontsize,
|
fontsize=fontsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title("True Positive")
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
def plot_cross_trigger_match(
|
def plot_cross_trigger_match(
|
||||||
match: MatchEvaluation,
|
match: MatchProtocol,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
@ -334,6 +267,7 @@ def plot_cross_trigger_match(
|
|||||||
add_spectrogram: bool = True,
|
add_spectrogram: bool = True,
|
||||||
add_points: bool = False,
|
add_points: bool = False,
|
||||||
add_text: bool = True,
|
add_text: bool = True,
|
||||||
|
add_title: bool = True,
|
||||||
fill: bool = False,
|
fill: bool = False,
|
||||||
spec_cmap: str = "gray",
|
spec_cmap: str = "gray",
|
||||||
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
color: str = DEFAULT_CROSS_TRIGGER_COLOR,
|
||||||
@ -341,20 +275,24 @@ def plot_cross_trigger_match(
|
|||||||
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
annotation_linestyle: str = DEFAULT_ANNOTATION_LINE_STYLE,
|
||||||
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
prediction_linestyle: str = DEFAULT_PREDICTION_LINE_STYLE,
|
||||||
) -> Axes:
|
) -> Axes:
|
||||||
assert match.sound_event_annotation is not None
|
assert match.gt is not None
|
||||||
assert match.pred_geometry is not None
|
assert match.pred is not None
|
||||||
sound_event = match.sound_event_annotation.sound_event
|
|
||||||
geometry = sound_event.geometry
|
geometry = match.gt.sound_event.geometry
|
||||||
assert geometry is not None
|
assert geometry is not None
|
||||||
|
|
||||||
start_time, _, _, high_freq = compute_bounds(geometry)
|
start_time, _, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
start_time=max(start_time - duration / 2, 0),
|
start_time=max(
|
||||||
end_time=min(
|
start_time - duration / 2,
|
||||||
start_time + duration / 2, sound_event.recording.duration
|
0,
|
||||||
),
|
),
|
||||||
recording=sound_event.recording,
|
end_time=min(
|
||||||
|
start_time + duration / 2,
|
||||||
|
match.clip.recording.duration,
|
||||||
|
),
|
||||||
|
recording=match.clip.recording,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_spectrogram:
|
if add_spectrogram:
|
||||||
@ -368,11 +306,9 @@ def plot_cross_trigger_match(
|
|||||||
spec_cmap=spec_cmap,
|
spec_cmap=spec_cmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_annotation(
|
ax = plot.plot_geometry(
|
||||||
match.sound_event_annotation,
|
geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
time_offset=0.001,
|
|
||||||
freq_offset=2_000,
|
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=1,
|
alpha=1,
|
||||||
@ -381,24 +317,28 @@ def plot_cross_trigger_match(
|
|||||||
)
|
)
|
||||||
|
|
||||||
ax = plot.plot_geometry(
|
ax = plot.plot_geometry(
|
||||||
match.pred_geometry,
|
match.pred.geometry,
|
||||||
ax=ax,
|
ax=ax,
|
||||||
add_points=add_points,
|
add_points=add_points,
|
||||||
facecolor="none" if not fill else None,
|
facecolor="none" if not fill else None,
|
||||||
alpha=match.pred_score if use_score else 1,
|
alpha=match.score if use_score else 1,
|
||||||
color=color,
|
color=color,
|
||||||
linestyle=prediction_linestyle,
|
linestyle=prediction_linestyle,
|
||||||
)
|
)
|
||||||
|
|
||||||
if add_text:
|
if add_text:
|
||||||
plt.text(
|
ax.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.top_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.top_class_score:.2f} ",
|
f"score={match.score:.2f}\nclass={match.true_class}",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
fontsize=fontsize,
|
fontsize=fontsize,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if add_title:
|
||||||
|
ax.set_title("Cross Trigger")
|
||||||
|
|
||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|||||||
@ -35,6 +35,8 @@ def plot_pr_curve(
|
|||||||
ax: Optional[axes.Axes] = None,
|
ax: Optional[axes.Axes] = None,
|
||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
|
add_legend: bool = False,
|
||||||
|
label: str = "PR Curve",
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
@ -43,7 +45,7 @@ def plot_pr_curve(
|
|||||||
ax.plot(
|
ax.plot(
|
||||||
recall,
|
recall,
|
||||||
precision,
|
precision,
|
||||||
label="PR Curve",
|
label=label,
|
||||||
marker="o",
|
marker="o",
|
||||||
markevery=_get_marker_positions(thresholds),
|
markevery=_get_marker_positions(thresholds),
|
||||||
)
|
)
|
||||||
@ -51,6 +53,9 @@ def plot_pr_curve(
|
|||||||
ax.set_xlim(0, 1.05)
|
ax.set_xlim(0, 1.05)
|
||||||
ax.set_ylim(0, 1.05)
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend()
|
||||||
|
|
||||||
if add_labels:
|
if add_labels:
|
||||||
ax.set_xlabel("Recall")
|
ax.set_xlabel("Recall")
|
||||||
ax.set_ylabel("Precision")
|
ax.set_ylabel("Precision")
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user