Plotting reorganised

This commit is contained in:
mbsantiago 2025-09-27 23:58:06 +01:00
parent df2abff654
commit 87ed44c8f7
18 changed files with 1626 additions and 77 deletions

View File

@ -19,9 +19,13 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction
from batdetect2.typing import RawPrediction, TargetProtocol
__all__ = []
__all__ = [
"ClassificationMetric",
"ClassificationMetricConfig",
"build_classification_metric",
]
@dataclass
@ -45,8 +49,8 @@ class ClipEval:
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
classification_metrics: Registry[ClassificationMetric, []] = Registry(
"classification_metric"
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
Registry("classification_metric")
)
@ -58,9 +62,11 @@ class BaseClassificationConfig(BaseConfig):
class BaseClassificationMetric:
def __init__(
self,
targets: TargetProtocol,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.include = include
self.exclude = exclude
@ -84,13 +90,14 @@ class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
class ClassificationAveragePrecision(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "average_precision",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
super().__init__(include=include, exclude=exclude)
super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
@ -98,33 +105,11 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
num_positives = defaultdict(lambda: 0)
class_names = set()
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
class_names.add(class_name)
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and self.ignore_generic:
continue
is_class = m.true_class == class_name
if is_class:
num_positives[class_name] += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true[class_name].append(is_class)
y_score[class_name].append(m.score)
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: average_precision(
@ -132,7 +117,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in class_names
for class_name in self.targets.class_names
}
mean_score = float(
@ -150,8 +135,12 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
@classification_metrics.register(ClassificationAveragePrecisionConfig)
@staticmethod
def from_config(config: ClassificationAveragePrecisionConfig):
def from_config(
config: ClassificationAveragePrecisionConfig,
targets: TargetProtocol,
):
return ClassificationAveragePrecision(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
@ -170,12 +159,14 @@ class ClassificationROCAUCConfig(BaseClassificationConfig):
class ClassificationROCAUC(BaseClassificationMetric):
def __init__(
self,
targets: TargetProtocol,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
label: str = "roc_auc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.targets = targets
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
self.label = label
@ -185,27 +176,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
def __call__(
self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]:
y_true = defaultdict(list)
y_score = defaultdict(list)
class_names = set()
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
class_names.add(class_name)
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and self.ignore_generic:
continue
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true[class_name].append(m.true_class == class_name)
y_score[class_name].append(m.score)
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: float(
@ -214,7 +189,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
y_score[class_name],
)
)
for class_name in class_names
for class_name in self.targets.class_names
}
mean_score = float(
@ -232,8 +207,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
@classification_metrics.register(ClassificationROCAUCConfig)
@staticmethod
def from_config(config: ClassificationROCAUCConfig):
def from_config(
config: ClassificationROCAUCConfig, targets: TargetProtocol
):
return ClassificationROCAUC(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
label=config.label,
@ -249,5 +227,40 @@ ClassificationMetricConfig = Annotated[
]
def build_classification_metrics(config: ClassificationMetricConfig):
return classification_metrics.build(config)
def build_classification_metric(
config: ClassificationMetricConfig,
targets: TargetProtocol,
) -> ClassificationMetric:
return classification_metrics.build(config, targets)
def _extract_per_class_metric_data(
clip_evaluations: Sequence[ClipEval],
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
y_true = defaultdict(list)
y_score = defaultdict(list)
num_positives = defaultdict(lambda: 0)
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and ignore_generic:
continue
is_class = m.true_class == class_name
if is_class:
num_positives[class_name] += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and ignore_non_predictions:
continue
y_true[class_name].append(is_class)
y_score[class_name].append(m.score)
return y_true, y_score, num_positives

View File

@ -12,7 +12,7 @@ def compute_precision_recall(
y_true,
y_score,
num_positives: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray]:
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true)
y_score = np.array(y_score)
@ -22,6 +22,7 @@ def compute_precision_recall(
# Sort by score
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
y_score_sorted = y_score[sort_ind]
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
@ -34,7 +35,7 @@ def compute_precision_recall(
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
return precision, recall
return precision, recall, y_score_sorted
def average_precision(
@ -42,7 +43,7 @@ def average_precision(
y_score,
num_positives: Optional[int] = None,
) -> float:
precision, recall = compute_precision_recall(
precision, recall, _ = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,

View File

@ -0,0 +1,54 @@
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol
class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
figsize: tuple[int, int] = (5, 5)
dpi: int = 100
class BasePlot:
def __init__(
self,
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (5, 5),
title: Optional[str] = None,
dpi: int = 100,
theme: str = "default",
):
self.targets = targets
self.label = label
self.figsize = figsize
self.dpi = dpi
self.theme = theme
self.title = title
def get_figure(self) -> Figure:
plt.style.use(self.theme)
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
if self.title is not None:
fig.suptitle(self.title)
return fig
@classmethod
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
return cls(
targets=targets,
figsize=config.figsize,
dpi=config.dpi,
theme=config.theme,
label=config.label,
title=config.title,
**kwargs,
)

View File

@ -0,0 +1,212 @@
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.classification import (
ClipEval,
_extract_per_class_metric_data,
)
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curves,
plot_roc_curves,
plot_threshold_precision_curves,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
Registry("classification_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
fig = self.get_figure()
ax = fig.subplots()
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
plot_pr_curves(data, ax=ax)
return self.label, fig
@classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ThresholdPRCurveConfig(BasePlotConfig):
name: Literal["threshold_pr_curve"] = "threshold_pr_curve"
label: str = "threshold_pr_curve"
figsize: tuple[int, int] = (10, 5)
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ThresholdPRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
fig = self.get_figure()
ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True)
plot_threshold_precision_curves(data, ax=ax1, add_legend=False)
plot_threshold_recall_curves(data, ax=ax2, add_legend=True)
return self.label, fig
@classification_plots.register(ThresholdPRCurveConfig)
@staticmethod
def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol):
return ThresholdPRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: metrics.roc_curve(
y_true[class_name],
y_score[class_name],
)
for class_name in self.targets.class_names
}
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
return self.label, fig
@classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPRCurveConfig,
],
Field(discriminator="name"),
]
def build_classification_plotter(
config: ClassificationPlotConfig,
targets: TargetProtocol,
) -> ClassificationPlotter:
return classification_plots.build(config, targets)

View File

@ -0,0 +1,131 @@
from typing import (
Annotated,
Callable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curves,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipClassificationPlotConfig",
"ClipClassificationPlotter",
"build_clip_classification_plotter",
]
ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
clip_classification_plots: Registry[
ClipClassificationPlotter, [TargetProtocol]
] = Registry("clip_classification_plot")
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
data[class_name] = (precision, recall, thresholds)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
return self.label, fig
@clip_classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
data[class_name] = (fpr, tpr, thresholds)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
return self.label, fig
@clip_classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
)
ClipClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
],
Field(discriminator="name"),
]
def build_clip_classification_plotter(
config: ClipClassificationPlotConfig,
targets: TargetProtocol,
) -> ClipClassificationPlotter:
return clip_classification_plots.build(config, targets)

View File

@ -0,0 +1,160 @@
from typing import (
Annotated,
Callable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipDetectionPlotConfig",
"ClipDetectionPlotter",
"build_clip_detection_plotter",
]
ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
Registry("clip_detection_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
@clip_detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
@clip_detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Score Distribution"
class ScoreDistributionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fig = self.get_figure()
ax = fig.subplots()
df = pd.DataFrame({"is_true": y_true, "score": y_score})
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
return self.label, fig
@clip_detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
)
ClipDetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
Field(discriminator="name"),
]
def build_clip_detection_plotter(
config: ClipDetectionPlotConfig,
targets: TargetProtocol,
) -> ClipDetectionPlotter:
return clip_detection_plots.build(config, targets)

View File

@ -0,0 +1,350 @@
import random
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib import patches
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from soundevent.plot import plot_geometry
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
name="detection_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
@detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
@detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ScoreDistributionPlot(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
df = pd.DataFrame({"is_true": y_true, "score": y_score})
fig = self.get_figure()
ax = fig.subplots()
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
return self.label, fig
@detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
figsize: tuple[int, int] = (10, 15)
num_examples: int = 5
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleDetectionPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 5,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
fig = self.get_figure()
sample = clip_evaluations
if self.num_examples < len(sample):
sample = random.sample(sample, self.num_examples)
axes = fig.subplots(nrows=self.num_examples, ncols=1)
for ax, clip_eval in zip(axes, sample):
plot_clip(
clip_eval.clip,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
ax=ax,
)
for m in clip_eval.matches:
is_match = (
m.pred is not None
and m.gt is not None
and m.score >= self.threshold
)
if m.pred is not None:
plot_geometry(
m.pred.geometry,
ax=ax,
add_points=False,
facecolor="none",
alpha=m.pred.detection_score,
linestyle="-" if not is_match else "--",
color="red" if not is_match else "orange",
)
if m.gt is not None:
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
ax=ax,
add_points=False,
facecolor="none",
color="green" if not is_match else "orange",
)
ax.set_title(clip_eval.clip.recording.path.name)
# ax.legend(
# handles=[
# patches.Patch(
# edgecolor="green",
# label="Ground Truth (Unmatched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="orange",
# label="Ground Truth (Matched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="red",
# label="Detection (Unmatched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="orange",
# label="Detection (Matched)",
# facecolor="none",
# linestyle="--",
# ),
# ]
# )
plt.tight_layout()
return self.label, fig
@detection_plots.register(ExampleDetectionPlotConfig)
@staticmethod
def from_config(
config: ExampleDetectionPlotConfig,
targets: TargetProtocol,
):
return ExampleDetectionPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
DetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
Field(discriminator="name"),
]
def build_detection_plotter(
config: DetectionPlotConfig,
targets: TargetProtocol,
) -> DetectionPlotter:
return detection_plots.build(config, targets)

View File

@ -0,0 +1,270 @@
from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
@top_class_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
@top_class_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
exclude_noise: bool = False
noise_class: str = "noise"
normalize: Literal["true", "pred", "all", "none"] = "true"
threshold: float = 0.2
add_colorbar: bool = True
cmap: str = "Blues"
class ConfusionMatrix(BasePlot):
def __init__(
self,
*args,
exclude_generic: bool = True,
exclude_noise: bool = False,
noise_class: str = "noise",
add_colorbar: bool = True,
normalize: Literal["true", "pred", "all", "none"] = "true",
cmap: str = "Blues",
threshold: float = 0.2,
**kwargs,
):
super().__init__(*args, **kwargs)
self.exclude_generic = exclude_generic
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.normalize = normalize
self.add_colorbar = add_colorbar
self.threshold = threshold
self.cmap = cmap
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and self.exclude_noise:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < self.threshold:
if self.exclude_noise:
continue
pred_class = self.noise_class
if m.is_generic:
if self.exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = self.targets.detection_class_name
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
fig = self.get_figure()
ax = fig.subplots()
class_names = [*self.targets.class_names]
if not self.exclude_generic:
class_names.append(self.targets.detection_class_name)
if not self.exclude_noise:
class_names.append(self.noise_class)
metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=class_names,
ax=ax,
xticks_rotation="vertical",
cmap=self.cmap,
colorbar=self.add_colorbar,
normalize=self.normalize if self.normalize != "none" else None,
values_format=".2f",
)
return self.label, fig
@top_class_plots.register(ConfusionMatrixConfig)
@staticmethod
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
return ConfusionMatrix.build(
config=config,
targets=targets,
exclude_generic=config.exclude_generic,
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
add_colorbar=config.add_colorbar,
normalize=config.normalize,
cmap=config.cmap,
)
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
],
Field(discriminator="name"),
]
def build_top_class_plotter(
config: TopClassPlotConfig,
targets: TargetProtocol,
) -> TopClassPlotter:
return top_class_plots.build(config, targets)

View File

@ -1,5 +1,16 @@
from typing import Callable, Dict, Generic, List, Sequence, TypeVar
from typing import (
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
@ -43,6 +54,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
ignore_start_end: float
prefix: str
@ -54,9 +67,13 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str,
ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
] = None,
):
self.matcher = matcher
self.metrics = metrics
self.plots = plots or []
self.targets = targets
self.prefix = prefix
self.ignore_start_end = ignore_start_end
@ -72,6 +89,12 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
for name, score in metric_output.items()
}
def generate_plots(
self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots:
yield plot(eval_outputs)
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
@ -123,6 +146,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
config: BaseTaskConfig,
targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
] = None,
**kwargs,
):
matcher = build_matcher(config.matching_strategy)
@ -130,6 +156,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
matcher=matcher,
targets=targets,
metrics=metrics,
plots=plots,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
**kwargs,

View File

@ -12,7 +12,11 @@ from batdetect2.evaluate.metrics.classification import (
ClassificationMetricConfig,
ClipEval,
MatchEval,
build_classification_metrics,
build_classification_metric,
)
from batdetect2.evaluate.plots.classification import (
ClassificationPlotConfig,
build_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
@ -28,6 +32,7 @@ class ClassificationTaskConfig(BaseTaskConfig):
metrics: List[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
)
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True
@ -128,10 +133,16 @@ class ClassificationTask(BaseTask[ClipEval]):
targets: TargetProtocol,
):
metrics = [
build_classification_metrics(metric) for metric in config.metrics
build_classification_metric(metric, targets)
for metric in config.metrics
]
plots = [
build_classification_plotter(plot, targets)
for plot in config.plots
]
return ClassificationTask.build(
config=config,
plots=plots,
targets=targets,
metrics=metrics,
)

View File

@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.clip_classification import (
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_classification import (
ClipClassificationPlotConfig,
build_clip_classification_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
@ -26,6 +30,7 @@ class ClipClassificationTaskConfig(BaseTaskConfig):
ClipClassificationAveragePrecisionConfig(),
]
)
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]):
@ -68,8 +73,13 @@ class ClipClassificationTask(BaseTask[ClipEval]):
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_classification_plotter(plot, targets)
for plot in config.plots
]
return ClipClassificationTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -9,6 +9,10 @@ from batdetect2.evaluate.metrics.clip_detection import (
ClipEval,
build_clip_metric,
)
from batdetect2.evaluate.plots.clip_detection import (
ClipDetectionPlotConfig,
build_clip_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
@ -25,6 +29,7 @@ class ClipDetectionTaskConfig(BaseTaskConfig):
ClipDetectionAveragePrecisionConfig(),
]
)
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]):
@ -59,8 +64,13 @@ class ClipDetectionTask(BaseTask[ClipEval]):
targets: TargetProtocol,
):
metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_detection_plotter(plot, targets)
for plot in config.plots
]
return ClipDetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.detection import (
MatchEval,
build_detection_metric,
)
from batdetect2.evaluate.plots.detection import (
DetectionPlotConfig,
build_detection_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
@ -24,6 +28,7 @@ class DetectionTaskConfig(BaseTaskConfig):
metrics: List[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()]
)
plots: List[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]):
@ -72,8 +77,12 @@ class DetectionTask(BaseTask[ClipEval]):
targets: TargetProtocol,
):
metrics = [build_detection_metric(metric) for metric in config.metrics]
plots = [
build_detection_plotter(plot, targets) for plot in config.plots
]
return DetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
)

View File

@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.top_class import (
TopClassMetricConfig,
build_top_class_metric,
)
from batdetect2.evaluate.plots.top_class import (
TopClassPlotConfig,
build_top_class_plotter,
)
from batdetect2.evaluate.tasks.base import (
BaseTask,
BaseTaskConfig,
@ -24,6 +28,7 @@ class TopClassDetectionTaskConfig(BaseTaskConfig):
metrics: List[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()]
)
plots: List[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]):
@ -94,8 +99,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
targets: TargetProtocol,
):
metrics = [build_top_class_metric(metric) for metric in config.metrics]
plots = [
build_top_class_plotter(plot, targets) for plot in config.plots
]
return TopClassDetectionTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
)

View File

@ -19,7 +19,7 @@ def create_ax(
) -> axes.Axes:
"""Create a new axis if none is provided"""
if ax is None:
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
return ax # type: ignore

View File

@ -0,0 +1,281 @@
from typing import Dict, Optional, Tuple
import numpy as np
import seaborn as sns
from cycler import cycler
from matplotlib import axes
from batdetect2.plotting.common import create_ax
def set_default_styler(ax: axes.Axes) -> axes.Axes:
color_cycler = cycler(color=sns.color_palette("muted"))
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
marker=["o", "s", "^"]
)
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
color_cycler
)
ax.set_prop_cycle(custom_cycler)
return ax
def set_default_style(ax: axes.Axes) -> axes.Axes:
ax = set_default_styler(ax)
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
return ax
def plot_pr_curve(
precision: np.ndarray,
recall: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
recall,
precision,
label="PR Curve",
marker="o",
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
return ax
def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, recall, thresholds) in data.items():
ax.plot(
recall,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
return ax
def plot_threshold_precision_curve(
threshold: np.ndarray,
precision: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, _, thresholds) in data.items():
ax.plot(
thresholds,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_recall_curve(
threshold: np.ndarray,
recall: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (_, recall, thresholds) in data.items():
ax.plot(
thresholds,
recall,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_roc_curve(
fpr: np.ndarray,
tpr: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
fpr,
tpr,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (fpr, tpr, thresholds) in data.items():
ax.plot(
fpr,
tpr,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def _get_marker_positions(
thresholds: np.ndarray,
n_points: int = 11,
) -> np.ndarray:
size = len(thresholds)
cut_points = np.linspace(0, 1, n_points)
indices = np.searchsorted(thresholds[::-1], cut_points)
return np.clip(size - indices, 0, size - 1) # type: ignore

View File

@ -1,4 +1,4 @@
from typing import List
from typing import Any, List
from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
@ -35,6 +35,7 @@ class ValidationMetrics(Callback):
def generate_plots(
self,
eval_outputs: Any,
pl_module: LightningModule,
):
plotter = get_image_logger(pl_module.logger) # type: ignore
@ -42,20 +43,15 @@ class ValidationMetrics(Callback):
if plotter is None:
return
for figure_name, fig in self.evaluator.generate_plots(
self._clip_annotations,
self._predictions,
):
for figure_name, fig in self.evaluator.generate_plots(eval_outputs):
plotter(figure_name, fig, pl_module.global_step)
def log_metrics(
self,
eval_outputs: Any,
pl_module: LightningModule,
):
metrics = self.evaluator.compute_metrics(
self._clip_annotations,
self._predictions,
)
metrics = self.evaluator.compute_metrics(eval_outputs)
pl_module.log_dict(metrics)
def on_validation_epoch_end(
@ -63,8 +59,13 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self.log_metrics(pl_module)
self.generate_plots(pl_module)
eval_outputs = self.evaluator.evaluate(
self._clip_annotations,
self._predictions,
)
self.log_metrics(eval_outputs, pl_module)
self.generate_plots(eval_outputs, pl_module)
return super().on_validation_epoch_end(trainer, pl_module)