mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Plotting reorganised
This commit is contained in:
parent
df2abff654
commit
87ed44c8f7
@ -19,9 +19,13 @@ from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
__all__ = []
|
||||
__all__ = [
|
||||
"ClassificationMetric",
|
||||
"ClassificationMetricConfig",
|
||||
"build_classification_metric",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -45,8 +49,8 @@ class ClipEval:
|
||||
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
|
||||
classification_metrics: Registry[ClassificationMetric, []] = Registry(
|
||||
"classification_metric"
|
||||
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
||||
Registry("classification_metric")
|
||||
)
|
||||
|
||||
|
||||
@ -58,9 +62,11 @@ class BaseClassificationConfig(BaseConfig):
|
||||
class BaseClassificationMetric:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.include = include
|
||||
self.exclude = exclude
|
||||
|
||||
@ -84,13 +90,14 @@ class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
|
||||
class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "average_precision",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(include=include, exclude=exclude)
|
||||
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.label = label
|
||||
@ -98,33 +105,11 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEval]
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
num_positives = defaultdict(lambda: 0)
|
||||
|
||||
class_names = set()
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, matches in clip_eval.matches.items():
|
||||
class_names.add(class_name)
|
||||
|
||||
for m in matches:
|
||||
# Exclude matches with ground truth sounds where the class
|
||||
# is unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
is_class = m.true_class == class_name
|
||||
|
||||
if is_class:
|
||||
num_positives[class_name] += 1
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true[class_name].append(is_class)
|
||||
y_score[class_name].append(m.score)
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
class_scores = {
|
||||
class_name: average_precision(
|
||||
@ -132,7 +117,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
y_score[class_name],
|
||||
num_positives=num_positives[class_name],
|
||||
)
|
||||
for class_name in class_names
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
@ -150,8 +135,12 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
|
||||
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationAveragePrecisionConfig):
|
||||
def from_config(
|
||||
config: ClassificationAveragePrecisionConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
return ClassificationAveragePrecision(
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
@ -170,12 +159,14 @@ class ClassificationROCAUCConfig(BaseClassificationConfig):
|
||||
class ClassificationROCAUC(BaseClassificationMetric):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "roc_auc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.targets = targets
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.label = label
|
||||
@ -185,27 +176,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEval]
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
|
||||
class_names = set()
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, matches in clip_eval.matches.items():
|
||||
class_names.add(class_name)
|
||||
|
||||
for m in matches:
|
||||
# Exclude matches with ground truth sounds where the class
|
||||
# is unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true[class_name].append(m.true_class == class_name)
|
||||
y_score[class_name].append(m.score)
|
||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
@ -214,7 +189,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in class_names
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
@ -232,8 +207,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
|
||||
@classification_metrics.register(ClassificationROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationROCAUCConfig):
|
||||
def from_config(
|
||||
config: ClassificationROCAUCConfig, targets: TargetProtocol
|
||||
):
|
||||
return ClassificationROCAUC(
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
@ -249,5 +227,40 @@ ClassificationMetricConfig = Annotated[
|
||||
]
|
||||
|
||||
|
||||
def build_classification_metrics(config: ClassificationMetricConfig):
|
||||
return classification_metrics.build(config)
|
||||
def build_classification_metric(
|
||||
config: ClassificationMetricConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClassificationMetric:
|
||||
return classification_metrics.build(config, targets)
|
||||
|
||||
|
||||
def _extract_per_class_metric_data(
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
num_positives = defaultdict(lambda: 0)
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, matches in clip_eval.matches.items():
|
||||
for m in matches:
|
||||
# Exclude matches with ground truth sounds where the class
|
||||
# is unknown
|
||||
if m.is_generic and ignore_generic:
|
||||
continue
|
||||
|
||||
is_class = m.true_class == class_name
|
||||
|
||||
if is_class:
|
||||
num_positives[class_name] += 1
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true[class_name].append(is_class)
|
||||
y_score[class_name].append(m.score)
|
||||
|
||||
return y_true, y_score, num_positives
|
||||
|
||||
@ -12,7 +12,7 @@ def compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
@ -22,6 +22,7 @@ def compute_precision_recall(
|
||||
# Sort by score
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
y_score_sorted = y_score[sort_ind]
|
||||
|
||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||
true_pos_c = np.cumsum(y_true_sorted)
|
||||
@ -34,7 +35,7 @@ def compute_precision_recall(
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
return precision, recall
|
||||
return precision, recall, y_score_sorted
|
||||
|
||||
|
||||
def average_precision(
|
||||
@ -42,7 +43,7 @@ def average_precision(
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> float:
|
||||
precision, recall = compute_precision_recall(
|
||||
precision, recall, _ = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives=num_positives,
|
||||
|
||||
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
@ -0,0 +1,54 @@
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.figure import Figure
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
|
||||
class BasePlotConfig(BaseConfig):
|
||||
label: str = "plot"
|
||||
theme: str = "default"
|
||||
title: Optional[str] = None
|
||||
figsize: tuple[int, int] = (5, 5)
|
||||
dpi: int = 100
|
||||
|
||||
|
||||
class BasePlot:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
label: str = "plot",
|
||||
figsize: tuple[int, int] = (5, 5),
|
||||
title: Optional[str] = None,
|
||||
dpi: int = 100,
|
||||
theme: str = "default",
|
||||
):
|
||||
self.targets = targets
|
||||
self.label = label
|
||||
self.figsize = figsize
|
||||
self.dpi = dpi
|
||||
self.theme = theme
|
||||
self.title = title
|
||||
|
||||
def get_figure(self) -> Figure:
|
||||
plt.style.use(self.theme)
|
||||
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
||||
|
||||
if self.title is not None:
|
||||
fig.suptitle(self.title)
|
||||
|
||||
return fig
|
||||
|
||||
@classmethod
|
||||
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
|
||||
return cls(
|
||||
targets=targets,
|
||||
figsize=config.figsize,
|
||||
dpi=config.dpi,
|
||||
theme=config.theme,
|
||||
label=config.label,
|
||||
title=config.title,
|
||||
**kwargs,
|
||||
)
|
||||
212
src/batdetect2/evaluate/plots/classification.py
Normal file
212
src/batdetect2/evaluate/plots/classification.py
Normal file
@ -0,0 +1,212 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClipEval,
|
||||
_extract_per_class_metric_data,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import (
|
||||
plot_pr_curves,
|
||||
plot_roc_curves,
|
||||
plot_threshold_precision_curves,
|
||||
plot_threshold_recall_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
|
||||
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||
Registry("classification_plot")
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives=num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
plot_pr_curves(data, ax=ax)
|
||||
|
||||
return self.label, fig
|
||||
|
||||
@classification_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ThresholdPRCurveConfig(BasePlotConfig):
|
||||
name: Literal["threshold_pr_curve"] = "threshold_pr_curve"
|
||||
label: str = "threshold_pr_curve"
|
||||
figsize: tuple[int, int] = (10, 5)
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ThresholdPRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: compute_precision_recall(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
fig = self.get_figure()
|
||||
ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True)
|
||||
|
||||
plot_threshold_precision_curves(data, ax=ax1, add_legend=False)
|
||||
plot_threshold_recall_curves(data, ax=ax2, add_legend=True)
|
||||
|
||||
return self.label, fig
|
||||
|
||||
@classification_plots.register(ThresholdPRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol):
|
||||
return ThresholdPRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||
clip_evaluations,
|
||||
ignore_non_predictions=self.ignore_non_predictions,
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
data = {
|
||||
class_name: metrics.roc_curve(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
plot_roc_curves(data, ax=ax)
|
||||
|
||||
return self.label, fig
|
||||
|
||||
@classification_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
ClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ThresholdPRCurveConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_classification_plotter(
|
||||
config: ClassificationPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClassificationPlotter:
|
||||
return classification_plots.build(config, targets)
|
||||
131
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
131
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
@ -0,0 +1,131 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import (
|
||||
plot_pr_curves,
|
||||
plot_roc_curves,
|
||||
)
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipClassificationPlotConfig",
|
||||
"ClipClassificationPlotter",
|
||||
"build_clip_classification_plotter",
|
||||
]
|
||||
|
||||
ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
|
||||
clip_classification_plots: Registry[
|
||||
ClipClassificationPlotter, [TargetProtocol]
|
||||
] = Registry("clip_classification_plot")
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Precision-Recall Curve"
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
data = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||
y_score = [
|
||||
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||
]
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
data[class_name] = (precision, recall, thresholds)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curves(data, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@clip_classification_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "ROC Curve"
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
data = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||
y_score = [
|
||||
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||
]
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
data[class_name] = (fpr, tpr, thresholds)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curves(data, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@clip_classification_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
ClipClassificationPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_classification_plotter(
|
||||
config: ClipClassificationPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClipClassificationPlotter:
|
||||
return clip_classification_plots.build(config, targets)
|
||||
160
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
160
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
@ -0,0 +1,160 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionPlotConfig",
|
||||
"ClipDetectionPlotter",
|
||||
"build_clip_detection_plotter",
|
||||
]
|
||||
|
||||
ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
|
||||
|
||||
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||
Registry("clip_detection_plot")
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
title: Optional[str] = "Precision-Recall Curve"
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@clip_detection_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
title: Optional[str] = "ROC Curve"
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@clip_detection_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
title: Optional[str] = "Score Distribution"
|
||||
|
||||
|
||||
class ScoreDistributionPlot(BasePlot):
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = [c.gt_det for c in clip_evaluations]
|
||||
y_score = [c.score for c in clip_evaluations]
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||
sns.histplot(
|
||||
data=df,
|
||||
x="score",
|
||||
binwidth=0.025,
|
||||
binrange=(0, 1),
|
||||
hue="is_true",
|
||||
ax=ax,
|
||||
stat="probability",
|
||||
common_norm=False,
|
||||
)
|
||||
|
||||
return self.label, fig
|
||||
|
||||
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||
):
|
||||
return ScoreDistributionPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
ClipDetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_detection_plotter(
|
||||
config: ClipDetectionPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> ClipDetectionPlotter:
|
||||
return clip_detection_plots.build(config, targets)
|
||||
350
src/batdetect2/evaluate/plots/detection.py
Normal file
350
src/batdetect2/evaluate/plots/detection.py
Normal file
@ -0,0 +1,350 @@
|
||||
import random
|
||||
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
from matplotlib import patches
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent.plot import plot_geometry
|
||||
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.clips import plot_clip
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||
|
||||
DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
|
||||
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||
name="detection_plot"
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evals: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@detection_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@detection_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||
name: Literal["score_distribution"] = "score_distribution"
|
||||
label: str = "score_distribution"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ScoreDistributionPlot(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
sns.histplot(
|
||||
data=df,
|
||||
x="score",
|
||||
binwidth=0.025,
|
||||
binrange=(0, 1),
|
||||
hue="is_true",
|
||||
ax=ax,
|
||||
stat="probability",
|
||||
common_norm=False,
|
||||
)
|
||||
|
||||
return self.label, fig
|
||||
|
||||
@detection_plots.register(ScoreDistributionPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||
):
|
||||
return ScoreDistributionPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||
name: Literal["example_detection"] = "example_detection"
|
||||
label: str = "example_detection"
|
||||
figsize: tuple[int, int] = (10, 15)
|
||||
num_examples: int = 5
|
||||
threshold: float = 0.2
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
|
||||
|
||||
class ExampleDetectionPlot(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
num_examples: int = 5,
|
||||
threshold: float = 0.2,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.num_examples = num_examples
|
||||
self.audio_loader = audio_loader
|
||||
self.threshold = threshold
|
||||
self.preprocessor = preprocessor
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
fig = self.get_figure()
|
||||
|
||||
sample = clip_evaluations
|
||||
|
||||
if self.num_examples < len(sample):
|
||||
sample = random.sample(sample, self.num_examples)
|
||||
|
||||
axes = fig.subplots(nrows=self.num_examples, ncols=1)
|
||||
|
||||
for ax, clip_eval in zip(axes, sample):
|
||||
plot_clip(
|
||||
clip_eval.clip,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
ax=ax,
|
||||
)
|
||||
|
||||
for m in clip_eval.matches:
|
||||
is_match = (
|
||||
m.pred is not None
|
||||
and m.gt is not None
|
||||
and m.score >= self.threshold
|
||||
)
|
||||
|
||||
if m.pred is not None:
|
||||
plot_geometry(
|
||||
m.pred.geometry,
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
facecolor="none",
|
||||
alpha=m.pred.detection_score,
|
||||
linestyle="-" if not is_match else "--",
|
||||
color="red" if not is_match else "orange",
|
||||
)
|
||||
|
||||
if m.gt is not None:
|
||||
plot_geometry(
|
||||
m.gt.sound_event.geometry, # type: ignore
|
||||
ax=ax,
|
||||
add_points=False,
|
||||
facecolor="none",
|
||||
color="green" if not is_match else "orange",
|
||||
)
|
||||
|
||||
ax.set_title(clip_eval.clip.recording.path.name)
|
||||
|
||||
# ax.legend(
|
||||
# handles=[
|
||||
# patches.Patch(
|
||||
# edgecolor="green",
|
||||
# label="Ground Truth (Unmatched)",
|
||||
# facecolor="none",
|
||||
# ),
|
||||
# patches.Patch(
|
||||
# edgecolor="orange",
|
||||
# label="Ground Truth (Matched)",
|
||||
# facecolor="none",
|
||||
# ),
|
||||
# patches.Patch(
|
||||
# edgecolor="red",
|
||||
# label="Detection (Unmatched)",
|
||||
# facecolor="none",
|
||||
# ),
|
||||
# patches.Patch(
|
||||
# edgecolor="orange",
|
||||
# label="Detection (Matched)",
|
||||
# facecolor="none",
|
||||
# linestyle="--",
|
||||
# ),
|
||||
# ]
|
||||
# )
|
||||
|
||||
plt.tight_layout()
|
||||
|
||||
return self.label, fig
|
||||
|
||||
@detection_plots.register(ExampleDetectionPlotConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ExampleDetectionPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
return ExampleDetectionPlot.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
num_examples=config.num_examples,
|
||||
audio_loader=build_audio_loader(config.audio),
|
||||
preprocessor=build_preprocessor(config.preprocessing),
|
||||
)
|
||||
|
||||
|
||||
DetectionPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ScoreDistributionPlotConfig,
|
||||
ExampleDetectionPlotConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_detection_plotter(
|
||||
config: DetectionPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> DetectionPlotter:
|
||||
return detection_plots.build(config, targets)
|
||||
270
src/batdetect2/evaluate/plots/top_class.py
Normal file
270
src/batdetect2/evaluate/plots/top_class.py
Normal file
@ -0,0 +1,270 @@
|
||||
from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||
from batdetect2.evaluate.metrics.top_class import ClipEval
|
||||
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||
|
||||
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||
name="top_class_plot"
|
||||
)
|
||||
|
||||
|
||||
class PRCurveConfig(BasePlotConfig):
|
||||
name: Literal["pr_curve"] = "pr_curve"
|
||||
label: str = "pr_curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class PRCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
precision, recall, thresholds = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@top_class_plots.register(PRCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||
return PRCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ROCCurveConfig(BasePlotConfig):
|
||||
name: Literal["roc_curve"] = "roc_curve"
|
||||
label: str = "roc_curve"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ROCCurve(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
fpr, tpr, thresholds = metrics.roc_curve(
|
||||
y_true,
|
||||
y_score,
|
||||
)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||
return self.label, fig
|
||||
|
||||
@top_class_plots.register(ROCCurveConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||
return ROCCurve.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ConfusionMatrixConfig(BasePlotConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
figsize: tuple[int, int] = (10, 10)
|
||||
label: str = "confusion_matrix"
|
||||
exclude_generic: bool = True
|
||||
exclude_noise: bool = False
|
||||
noise_class: str = "noise"
|
||||
normalize: Literal["true", "pred", "all", "none"] = "true"
|
||||
threshold: float = 0.2
|
||||
add_colorbar: bool = True
|
||||
cmap: str = "Blues"
|
||||
|
||||
|
||||
class ConfusionMatrix(BasePlot):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
exclude_generic: bool = True,
|
||||
exclude_noise: bool = False,
|
||||
noise_class: str = "noise",
|
||||
add_colorbar: bool = True,
|
||||
normalize: Literal["true", "pred", "all", "none"] = "true",
|
||||
cmap: str = "Blues",
|
||||
threshold: float = 0.2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.exclude_generic = exclude_generic
|
||||
self.exclude_noise = exclude_noise
|
||||
self.noise_class = noise_class
|
||||
self.normalize = normalize
|
||||
self.add_colorbar = add_colorbar
|
||||
self.threshold = threshold
|
||||
self.cmap = cmap
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Tuple[str, Figure]:
|
||||
y_true: List[str] = []
|
||||
y_pred: List[str] = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
true_class = m.true_class
|
||||
pred_class = m.pred_class
|
||||
|
||||
if not m.is_prediction and self.exclude_noise:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
if not m.is_ground_truth and self.exclude_noise:
|
||||
# Ignore matches that don't correspond to a ground truth
|
||||
continue
|
||||
|
||||
if m.score < self.threshold:
|
||||
if self.exclude_noise:
|
||||
continue
|
||||
|
||||
pred_class = self.noise_class
|
||||
|
||||
if m.is_generic:
|
||||
if self.exclude_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
true_class = self.targets.detection_class_name
|
||||
|
||||
y_true.append(true_class or self.noise_class)
|
||||
y_pred.append(pred_class or self.noise_class)
|
||||
|
||||
fig = self.get_figure()
|
||||
ax = fig.subplots()
|
||||
|
||||
class_names = [*self.targets.class_names]
|
||||
|
||||
if not self.exclude_generic:
|
||||
class_names.append(self.targets.detection_class_name)
|
||||
|
||||
if not self.exclude_noise:
|
||||
class_names.append(self.noise_class)
|
||||
|
||||
metrics.ConfusionMatrixDisplay.from_predictions(
|
||||
y_true,
|
||||
y_pred,
|
||||
labels=class_names,
|
||||
ax=ax,
|
||||
xticks_rotation="vertical",
|
||||
cmap=self.cmap,
|
||||
colorbar=self.add_colorbar,
|
||||
normalize=self.normalize if self.normalize != "none" else None,
|
||||
values_format=".2f",
|
||||
)
|
||||
|
||||
return self.label, fig
|
||||
|
||||
@top_class_plots.register(ConfusionMatrixConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
|
||||
return ConfusionMatrix.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
exclude_generic=config.exclude_generic,
|
||||
exclude_noise=config.exclude_noise,
|
||||
noise_class=config.noise_class,
|
||||
add_colorbar=config.add_colorbar,
|
||||
normalize=config.normalize,
|
||||
cmap=config.cmap,
|
||||
)
|
||||
|
||||
|
||||
TopClassPlotConfig = Annotated[
|
||||
Union[
|
||||
PRCurveConfig,
|
||||
ROCCurveConfig,
|
||||
ConfusionMatrixConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_top_class_plotter(
|
||||
config: TopClassPlotConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> TopClassPlotter:
|
||||
return top_class_plots.build(config, targets)
|
||||
@ -1,5 +1,16 @@
|
||||
from typing import Callable, Dict, Generic, List, Sequence, TypeVar
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
@ -43,6 +54,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||
|
||||
plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
prefix: str
|
||||
@ -54,9 +67,13 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
prefix: str,
|
||||
ignore_start_end: float = 0.01,
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||
] = None,
|
||||
):
|
||||
self.matcher = matcher
|
||||
self.metrics = metrics
|
||||
self.plots = plots or []
|
||||
self.targets = targets
|
||||
self.prefix = prefix
|
||||
self.ignore_start_end = ignore_start_end
|
||||
@ -72,6 +89,12 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
for name, score in metric_output.items()
|
||||
}
|
||||
|
||||
def generate_plots(
|
||||
self, eval_outputs: List[T_Output]
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for plot in self.plots:
|
||||
yield plot(eval_outputs)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
@ -123,6 +146,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
config: BaseTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
plots: Optional[
|
||||
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||
] = None,
|
||||
**kwargs,
|
||||
):
|
||||
matcher = build_matcher(config.matching_strategy)
|
||||
@ -130,6 +156,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
matcher=matcher,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
plots=plots,
|
||||
prefix=config.prefix,
|
||||
ignore_start_end=config.ignore_start_end,
|
||||
**kwargs,
|
||||
|
||||
@ -12,7 +12,11 @@ from batdetect2.evaluate.metrics.classification import (
|
||||
ClassificationMetricConfig,
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
build_classification_metrics,
|
||||
build_classification_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.classification import (
|
||||
ClassificationPlotConfig,
|
||||
build_classification_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
@ -28,6 +32,7 @@ class ClassificationTaskConfig(BaseTaskConfig):
|
||||
metrics: List[ClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
||||
include_generics: bool = True
|
||||
|
||||
|
||||
@ -128,10 +133,16 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [
|
||||
build_classification_metrics(metric) for metric in config.metrics
|
||||
build_classification_metric(metric, targets)
|
||||
for metric in config.metrics
|
||||
]
|
||||
plots = [
|
||||
build_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClassificationTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
)
|
||||
|
||||
@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.clip_classification import (
|
||||
ClipEval,
|
||||
build_clip_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.clip_classification import (
|
||||
ClipClassificationPlotConfig,
|
||||
build_clip_classification_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
@ -26,6 +30,7 @@ class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
ClipClassificationAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
@ -68,8 +73,13 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_clip_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipClassificationTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
@ -9,6 +9,10 @@ from batdetect2.evaluate.metrics.clip_detection import (
|
||||
ClipEval,
|
||||
build_clip_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.clip_detection import (
|
||||
ClipDetectionPlotConfig,
|
||||
build_clip_detection_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
@ -25,6 +29,7 @@ class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
ClipDetectionAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
@ -59,8 +64,13 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_clip_detection_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
return ClipDetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
)
|
||||
|
||||
@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.detection import (
|
||||
MatchEval,
|
||||
build_detection_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.detection import (
|
||||
DetectionPlotConfig,
|
||||
build_detection_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
@ -24,6 +28,7 @@ class DetectionTaskConfig(BaseTaskConfig):
|
||||
metrics: List[DetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DetectionTask(BaseTask[ClipEval]):
|
||||
@ -72,8 +77,12 @@ class DetectionTask(BaseTask[ClipEval]):
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_detection_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
return DetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
)
|
||||
|
||||
@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.top_class import (
|
||||
TopClassMetricConfig,
|
||||
build_top_class_metric,
|
||||
)
|
||||
from batdetect2.evaluate.plots.top_class import (
|
||||
TopClassPlotConfig,
|
||||
build_top_class_plotter,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
@ -24,6 +28,7 @@ class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||
metrics: List[TopClassMetricConfig] = Field(
|
||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||
)
|
||||
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
@ -94,8 +99,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
||||
plots = [
|
||||
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
return TopClassDetectionTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
@ -19,7 +19,7 @@ def create_ax(
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
|
||||
|
||||
return ax # type: ignore
|
||||
|
||||
|
||||
281
src/batdetect2/plotting/metrics.py
Normal file
281
src/batdetect2/plotting/metrics.py
Normal file
@ -0,0 +1,281 @@
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import seaborn as sns
|
||||
from cycler import cycler
|
||||
from matplotlib import axes
|
||||
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
|
||||
def set_default_styler(ax: axes.Axes) -> axes.Axes:
|
||||
color_cycler = cycler(color=sns.color_palette("muted"))
|
||||
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
|
||||
marker=["o", "s", "^"]
|
||||
)
|
||||
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
|
||||
color_cycler
|
||||
)
|
||||
|
||||
ax.set_prop_cycle(custom_cycler)
|
||||
return ax
|
||||
|
||||
|
||||
def set_default_style(ax: axes.Axes) -> axes.Axes:
|
||||
ax = set_default_styler(ax)
|
||||
ax.spines.right.set_visible(False)
|
||||
ax.spines.top.set_visible(False)
|
||||
return ax
|
||||
|
||||
|
||||
def plot_pr_curve(
|
||||
precision: np.ndarray,
|
||||
recall: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(
|
||||
recall,
|
||||
precision,
|
||||
label="PR Curve",
|
||||
marker="o",
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_pr_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (precision, recall, thresholds) in data.items():
|
||||
ax.plot(
|
||||
recall,
|
||||
precision,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_precision_curve(
|
||||
threshold: np.ndarray,
|
||||
precision: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_precision_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (precision, _, thresholds) in data.items():
|
||||
ax.plot(
|
||||
thresholds,
|
||||
precision,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Precision")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_recall_curve(
|
||||
threshold: np.ndarray,
|
||||
recall: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Recall")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_threshold_recall_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
):
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (_, recall, thresholds) in data.items():
|
||||
ax.plot(
|
||||
thresholds,
|
||||
recall,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("Threshold")
|
||||
ax.set_ylabel("Recall")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_roc_curve(
|
||||
fpr: np.ndarray,
|
||||
tpr: np.ndarray,
|
||||
thresholds: np.ndarray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
|
||||
ax = set_default_style(ax)
|
||||
|
||||
ax.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def plot_roc_curves(
|
||||
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (fpr, tpr, thresholds) in data.items():
|
||||
ax.plot(
|
||||
fpr,
|
||||
tpr,
|
||||
label=name,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
if add_legend:
|
||||
ax.legend(
|
||||
bbox_to_anchor=(1.05, 1),
|
||||
loc="upper left",
|
||||
borderaxespad=0.0,
|
||||
)
|
||||
|
||||
ax.set_xlim(0, 1.05)
|
||||
ax.set_ylim(0, 1.05)
|
||||
|
||||
if add_labels:
|
||||
ax.set_xlabel("False Positive Rate")
|
||||
ax.set_ylabel("True Positive Rate")
|
||||
|
||||
return ax
|
||||
|
||||
|
||||
def _get_marker_positions(
|
||||
thresholds: np.ndarray,
|
||||
n_points: int = 11,
|
||||
) -> np.ndarray:
|
||||
size = len(thresholds)
|
||||
cut_points = np.linspace(0, 1, n_points)
|
||||
indices = np.searchsorted(thresholds[::-1], cut_points)
|
||||
return np.clip(size - indices, 0, size - 1) # type: ignore
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
@ -35,6 +35,7 @@ class ValidationMetrics(Callback):
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: Any,
|
||||
pl_module: LightningModule,
|
||||
):
|
||||
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||
@ -42,20 +43,15 @@ class ValidationMetrics(Callback):
|
||||
if plotter is None:
|
||||
return
|
||||
|
||||
for figure_name, fig in self.evaluator.generate_plots(
|
||||
self._clip_annotations,
|
||||
self._predictions,
|
||||
):
|
||||
for figure_name, fig in self.evaluator.generate_plots(eval_outputs):
|
||||
plotter(figure_name, fig, pl_module.global_step)
|
||||
|
||||
def log_metrics(
|
||||
self,
|
||||
eval_outputs: Any,
|
||||
pl_module: LightningModule,
|
||||
):
|
||||
metrics = self.evaluator.compute_metrics(
|
||||
self._clip_annotations,
|
||||
self._predictions,
|
||||
)
|
||||
metrics = self.evaluator.compute_metrics(eval_outputs)
|
||||
pl_module.log_dict(metrics)
|
||||
|
||||
def on_validation_epoch_end(
|
||||
@ -63,8 +59,13 @@ class ValidationMetrics(Callback):
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
self.log_metrics(pl_module)
|
||||
self.generate_plots(pl_module)
|
||||
eval_outputs = self.evaluator.evaluate(
|
||||
self._clip_annotations,
|
||||
self._predictions,
|
||||
)
|
||||
|
||||
self.log_metrics(eval_outputs, pl_module)
|
||||
self.generate_plots(eval_outputs, pl_module)
|
||||
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user