mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Plotting reorganised
This commit is contained in:
parent
df2abff654
commit
87ed44c8f7
@ -19,9 +19,13 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
from batdetect2.typing import RawPrediction
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
__all__ = []
|
__all__ = [
|
||||||
|
"ClassificationMetric",
|
||||||
|
"ClassificationMetricConfig",
|
||||||
|
"build_classification_metric",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -45,8 +49,8 @@ class ClipEval:
|
|||||||
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
classification_metrics: Registry[ClassificationMetric, []] = Registry(
|
classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
|
||||||
"classification_metric"
|
Registry("classification_metric")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -58,9 +62,11 @@ class BaseClassificationConfig(BaseConfig):
|
|||||||
class BaseClassificationMetric:
|
class BaseClassificationMetric:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
|
self.targets = targets
|
||||||
self.include = include
|
self.include = include
|
||||||
self.exclude = exclude
|
self.exclude = exclude
|
||||||
|
|
||||||
@ -84,13 +90,14 @@ class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
|
|||||||
class ClassificationAveragePrecision(BaseClassificationMetric):
|
class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
label: str = "average_precision",
|
label: str = "average_precision",
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
super().__init__(include=include, exclude=exclude)
|
super().__init__(include=include, exclude=exclude, targets=targets)
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
self.ignore_generic = ignore_generic
|
self.ignore_generic = ignore_generic
|
||||||
self.label = label
|
self.label = label
|
||||||
@ -98,33 +105,11 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEval]
|
self, clip_evaluations: Sequence[ClipEval]
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
y_true = defaultdict(list)
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
y_score = defaultdict(list)
|
clip_evaluations,
|
||||||
num_positives = defaultdict(lambda: 0)
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
class_names = set()
|
)
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for class_name, matches in clip_eval.matches.items():
|
|
||||||
class_names.add(class_name)
|
|
||||||
|
|
||||||
for m in matches:
|
|
||||||
# Exclude matches with ground truth sounds where the class
|
|
||||||
# is unknown
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
continue
|
|
||||||
|
|
||||||
is_class = m.true_class == class_name
|
|
||||||
|
|
||||||
if is_class:
|
|
||||||
num_positives[class_name] += 1
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true[class_name].append(is_class)
|
|
||||||
y_score[class_name].append(m.score)
|
|
||||||
|
|
||||||
class_scores = {
|
class_scores = {
|
||||||
class_name: average_precision(
|
class_name: average_precision(
|
||||||
@ -132,7 +117,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
|||||||
y_score[class_name],
|
y_score[class_name],
|
||||||
num_positives=num_positives[class_name],
|
num_positives=num_positives[class_name],
|
||||||
)
|
)
|
||||||
for class_name in class_names
|
for class_name in self.targets.class_names
|
||||||
}
|
}
|
||||||
|
|
||||||
mean_score = float(
|
mean_score = float(
|
||||||
@ -150,8 +135,12 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
|
|||||||
|
|
||||||
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: ClassificationAveragePrecisionConfig):
|
def from_config(
|
||||||
|
config: ClassificationAveragePrecisionConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
return ClassificationAveragePrecision(
|
return ClassificationAveragePrecision(
|
||||||
|
targets=targets,
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
ignore_generic=config.ignore_generic,
|
ignore_generic=config.ignore_generic,
|
||||||
label=config.label,
|
label=config.label,
|
||||||
@ -170,12 +159,14 @@ class ClassificationROCAUCConfig(BaseClassificationConfig):
|
|||||||
class ClassificationROCAUC(BaseClassificationMetric):
|
class ClassificationROCAUC(BaseClassificationMetric):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
ignore_non_predictions: bool = True,
|
ignore_non_predictions: bool = True,
|
||||||
ignore_generic: bool = True,
|
ignore_generic: bool = True,
|
||||||
label: str = "roc_auc",
|
label: str = "roc_auc",
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
|
self.targets = targets
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
self.ignore_generic = ignore_generic
|
self.ignore_generic = ignore_generic
|
||||||
self.label = label
|
self.label = label
|
||||||
@ -185,27 +176,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEval]
|
self, clip_evaluations: Sequence[ClipEval]
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
y_true = defaultdict(list)
|
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||||
y_score = defaultdict(list)
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
class_names = set()
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for class_name, matches in clip_eval.matches.items():
|
|
||||||
class_names.add(class_name)
|
|
||||||
|
|
||||||
for m in matches:
|
|
||||||
# Exclude matches with ground truth sounds where the class
|
|
||||||
# is unknown
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true[class_name].append(m.true_class == class_name)
|
|
||||||
y_score[class_name].append(m.score)
|
|
||||||
|
|
||||||
class_scores = {
|
class_scores = {
|
||||||
class_name: float(
|
class_name: float(
|
||||||
@ -214,7 +189,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
y_score[class_name],
|
y_score[class_name],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for class_name in class_names
|
for class_name in self.targets.class_names
|
||||||
}
|
}
|
||||||
|
|
||||||
mean_score = float(
|
mean_score = float(
|
||||||
@ -232,8 +207,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
|||||||
|
|
||||||
@classification_metrics.register(ClassificationROCAUCConfig)
|
@classification_metrics.register(ClassificationROCAUCConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: ClassificationROCAUCConfig):
|
def from_config(
|
||||||
|
config: ClassificationROCAUCConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
return ClassificationROCAUC(
|
return ClassificationROCAUC(
|
||||||
|
targets=targets,
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
ignore_generic=config.ignore_generic,
|
ignore_generic=config.ignore_generic,
|
||||||
label=config.label,
|
label=config.label,
|
||||||
@ -249,5 +227,40 @@ ClassificationMetricConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_classification_metrics(config: ClassificationMetricConfig):
|
def build_classification_metric(
|
||||||
return classification_metrics.build(config)
|
config: ClassificationMetricConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClassificationMetric:
|
||||||
|
return classification_metrics.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_per_class_metric_data(
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
):
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
num_positives = defaultdict(lambda: 0)
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, matches in clip_eval.matches.items():
|
||||||
|
for m in matches:
|
||||||
|
# Exclude matches with ground truth sounds where the class
|
||||||
|
# is unknown
|
||||||
|
if m.is_generic and ignore_generic:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_class = m.true_class == class_name
|
||||||
|
|
||||||
|
if is_class:
|
||||||
|
num_positives[class_name] += 1
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true[class_name].append(is_class)
|
||||||
|
y_score[class_name].append(m.score)
|
||||||
|
|
||||||
|
return y_true, y_score, num_positives
|
||||||
|
|||||||
@ -12,7 +12,7 @@ def compute_precision_recall(
|
|||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
num_positives: Optional[int] = None,
|
num_positives: Optional[int] = None,
|
||||||
) -> Tuple[np.ndarray, np.ndarray]:
|
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
y_true = np.array(y_true)
|
y_true = np.array(y_true)
|
||||||
y_score = np.array(y_score)
|
y_score = np.array(y_score)
|
||||||
|
|
||||||
@ -22,6 +22,7 @@ def compute_precision_recall(
|
|||||||
# Sort by score
|
# Sort by score
|
||||||
sort_ind = np.argsort(y_score)[::-1]
|
sort_ind = np.argsort(y_score)[::-1]
|
||||||
y_true_sorted = y_true[sort_ind]
|
y_true_sorted = y_true[sort_ind]
|
||||||
|
y_score_sorted = y_score[sort_ind]
|
||||||
|
|
||||||
false_pos_c = np.cumsum(1 - y_true_sorted)
|
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||||
true_pos_c = np.cumsum(y_true_sorted)
|
true_pos_c = np.cumsum(y_true_sorted)
|
||||||
@ -34,7 +35,7 @@ def compute_precision_recall(
|
|||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
precision[np.isnan(precision)] = 0
|
||||||
recall[np.isnan(recall)] = 0
|
recall[np.isnan(recall)] = 0
|
||||||
return precision, recall
|
return precision, recall, y_score_sorted
|
||||||
|
|
||||||
|
|
||||||
def average_precision(
|
def average_precision(
|
||||||
@ -42,7 +43,7 @@ def average_precision(
|
|||||||
y_score,
|
y_score,
|
||||||
num_positives: Optional[int] = None,
|
num_positives: Optional[int] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
precision, recall = compute_precision_recall(
|
precision, recall, _ = compute_precision_recall(
|
||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
num_positives=num_positives,
|
num_positives=num_positives,
|
||||||
|
|||||||
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
0
src/batdetect2/evaluate/plots/__init__.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
54
src/batdetect2/evaluate/plots/base.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlotConfig(BaseConfig):
|
||||||
|
label: str = "plot"
|
||||||
|
theme: str = "default"
|
||||||
|
title: Optional[str] = None
|
||||||
|
figsize: tuple[int, int] = (5, 5)
|
||||||
|
dpi: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class BasePlot:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
label: str = "plot",
|
||||||
|
figsize: tuple[int, int] = (5, 5),
|
||||||
|
title: Optional[str] = None,
|
||||||
|
dpi: int = 100,
|
||||||
|
theme: str = "default",
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.label = label
|
||||||
|
self.figsize = figsize
|
||||||
|
self.dpi = dpi
|
||||||
|
self.theme = theme
|
||||||
|
self.title = title
|
||||||
|
|
||||||
|
def get_figure(self) -> Figure:
|
||||||
|
plt.style.use(self.theme)
|
||||||
|
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
|
||||||
|
|
||||||
|
if self.title is not None:
|
||||||
|
fig.suptitle(self.title)
|
||||||
|
|
||||||
|
return fig
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
|
||||||
|
return cls(
|
||||||
|
targets=targets,
|
||||||
|
figsize=config.figsize,
|
||||||
|
dpi=config.dpi,
|
||||||
|
theme=config.theme,
|
||||||
|
label=config.label,
|
||||||
|
title=config.title,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
212
src/batdetect2/evaluate/plots/classification.py
Normal file
212
src/batdetect2/evaluate/plots/classification.py
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
|
ClipEval,
|
||||||
|
_extract_per_class_metric_data,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.metrics import (
|
||||||
|
plot_pr_curves,
|
||||||
|
plot_roc_curves,
|
||||||
|
plot_threshold_precision_curves,
|
||||||
|
plot_threshold_recall_curves,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||||
|
|
||||||
|
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
|
||||||
|
Registry("classification_plot")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: compute_precision_recall(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives=num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
plot_pr_curves(data, ax=ax)
|
||||||
|
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@classification_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdPRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["threshold_pr_curve"] = "threshold_pr_curve"
|
||||||
|
label: str = "threshold_pr_curve"
|
||||||
|
figsize: tuple[int, int] = (10, 5)
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ThresholdPRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true, y_score, num_positives = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: compute_precision_recall(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True)
|
||||||
|
|
||||||
|
plot_threshold_precision_curves(data, ax=ax1, add_legend=False)
|
||||||
|
plot_threshold_recall_curves(data, ax=ax2, add_legend=True)
|
||||||
|
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@classification_plots.register(ThresholdPRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol):
|
||||||
|
return ThresholdPRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true, y_score, _ = _extract_per_class_metric_data(
|
||||||
|
clip_evaluations,
|
||||||
|
ignore_non_predictions=self.ignore_non_predictions,
|
||||||
|
ignore_generic=self.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
class_name: metrics.roc_curve(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
)
|
||||||
|
for class_name in self.targets.class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
plot_roc_curves(data, ax=ax)
|
||||||
|
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@classification_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ThresholdPRCurveConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_classification_plotter(
|
||||||
|
config: ClassificationPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClassificationPlotter:
|
||||||
|
return classification_plots.build(config, targets)
|
||||||
131
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
131
src/batdetect2/evaluate/plots/clip_classification.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.clip_classification import ClipEval
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.metrics import (
|
||||||
|
plot_pr_curves,
|
||||||
|
plot_roc_curves,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipClassificationPlotConfig",
|
||||||
|
"ClipClassificationPlotter",
|
||||||
|
"build_clip_classification_plotter",
|
||||||
|
]
|
||||||
|
|
||||||
|
ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||||
|
|
||||||
|
clip_classification_plots: Registry[
|
||||||
|
ClipClassificationPlotter, [TargetProtocol]
|
||||||
|
] = Registry("clip_classification_plot")
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Precision-Recall Curve"
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
for class_name in self.targets.class_names:
|
||||||
|
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||||
|
y_score = [
|
||||||
|
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||||
|
]
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
data[class_name] = (precision, recall, thresholds)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_pr_curves(data, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@clip_classification_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "ROC Curve"
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
for class_name in self.targets.class_names:
|
||||||
|
y_true = [class_name in c.true_classes for c in clip_evaluations]
|
||||||
|
y_score = [
|
||||||
|
c.class_scores.get(class_name, 0) for c in clip_evaluations
|
||||||
|
]
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
data[class_name] = (fpr, tpr, thresholds)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_roc_curves(data, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@clip_classification_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClipClassificationPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_classification_plotter(
|
||||||
|
config: ClipClassificationPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClipClassificationPlotter:
|
||||||
|
return clip_classification_plots.build(config, targets)
|
||||||
160
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
160
src/batdetect2/evaluate/plots/clip_detection.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.clip_detection import ClipEval
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipDetectionPlotConfig",
|
||||||
|
"ClipDetectionPlotter",
|
||||||
|
"build_clip_detection_plotter",
|
||||||
|
]
|
||||||
|
|
||||||
|
ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||||
|
|
||||||
|
|
||||||
|
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
|
||||||
|
Registry("clip_detection_plot")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
title: Optional[str] = "Precision-Recall Curve"
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@clip_detection_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
title: Optional[str] = "ROC Curve"
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@clip_detection_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
|
label: str = "score_distribution"
|
||||||
|
title: Optional[str] = "Score Distribution"
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlot(BasePlot):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = [c.gt_det for c in clip_evaluations]
|
||||||
|
y_score = [c.score for c in clip_evaluations]
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||||
|
sns.histplot(
|
||||||
|
data=df,
|
||||||
|
x="score",
|
||||||
|
binwidth=0.025,
|
||||||
|
binrange=(0, 1),
|
||||||
|
hue="is_true",
|
||||||
|
ax=ax,
|
||||||
|
stat="probability",
|
||||||
|
common_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@clip_detection_plots.register(ScoreDistributionPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ScoreDistributionPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClipDetectionPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ScoreDistributionPlotConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_detection_plotter(
|
||||||
|
config: ClipDetectionPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> ClipDetectionPlotter:
|
||||||
|
return clip_detection_plots.build(config, targets)
|
||||||
350
src/batdetect2/evaluate/plots/detection.py
Normal file
350
src/batdetect2/evaluate/plots/detection.py
Normal file
@ -0,0 +1,350 @@
|
|||||||
|
import random
|
||||||
|
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pandas as pd
|
||||||
|
import seaborn as sns
|
||||||
|
from matplotlib import patches
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
from soundevent.plot import plot_geometry
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.metrics.detection import ClipEval
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.clips import plot_clip
|
||||||
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
|
||||||
|
|
||||||
|
DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||||
|
|
||||||
|
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
|
||||||
|
name="detection_plot"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives=num_positives,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@detection_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@detection_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["score_distribution"] = "score_distribution"
|
||||||
|
label: str = "score_distribution"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ScoreDistributionPlot(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
df = pd.DataFrame({"is_true": y_true, "score": y_score})
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
sns.histplot(
|
||||||
|
data=df,
|
||||||
|
x="score",
|
||||||
|
binwidth=0.025,
|
||||||
|
binrange=(0, 1),
|
||||||
|
hue="is_true",
|
||||||
|
ax=ax,
|
||||||
|
stat="probability",
|
||||||
|
common_norm=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@detection_plots.register(ScoreDistributionPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ScoreDistributionPlotConfig, targets: TargetProtocol
|
||||||
|
):
|
||||||
|
return ScoreDistributionPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleDetectionPlotConfig(BasePlotConfig):
|
||||||
|
name: Literal["example_detection"] = "example_detection"
|
||||||
|
label: str = "example_detection"
|
||||||
|
figsize: tuple[int, int] = (10, 15)
|
||||||
|
num_examples: int = 5
|
||||||
|
threshold: float = 0.2
|
||||||
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
preprocessing: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ExampleDetectionPlot(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
num_examples: int = 5,
|
||||||
|
threshold: float = 0.2,
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.num_examples = num_examples
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.threshold = threshold
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
fig = self.get_figure()
|
||||||
|
|
||||||
|
sample = clip_evaluations
|
||||||
|
|
||||||
|
if self.num_examples < len(sample):
|
||||||
|
sample = random.sample(sample, self.num_examples)
|
||||||
|
|
||||||
|
axes = fig.subplots(nrows=self.num_examples, ncols=1)
|
||||||
|
|
||||||
|
for ax, clip_eval in zip(axes, sample):
|
||||||
|
plot_clip(
|
||||||
|
clip_eval.clip,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
ax=ax,
|
||||||
|
)
|
||||||
|
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
is_match = (
|
||||||
|
m.pred is not None
|
||||||
|
and m.gt is not None
|
||||||
|
and m.score >= self.threshold
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.pred is not None:
|
||||||
|
plot_geometry(
|
||||||
|
m.pred.geometry,
|
||||||
|
ax=ax,
|
||||||
|
add_points=False,
|
||||||
|
facecolor="none",
|
||||||
|
alpha=m.pred.detection_score,
|
||||||
|
linestyle="-" if not is_match else "--",
|
||||||
|
color="red" if not is_match else "orange",
|
||||||
|
)
|
||||||
|
|
||||||
|
if m.gt is not None:
|
||||||
|
plot_geometry(
|
||||||
|
m.gt.sound_event.geometry, # type: ignore
|
||||||
|
ax=ax,
|
||||||
|
add_points=False,
|
||||||
|
facecolor="none",
|
||||||
|
color="green" if not is_match else "orange",
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_title(clip_eval.clip.recording.path.name)
|
||||||
|
|
||||||
|
# ax.legend(
|
||||||
|
# handles=[
|
||||||
|
# patches.Patch(
|
||||||
|
# edgecolor="green",
|
||||||
|
# label="Ground Truth (Unmatched)",
|
||||||
|
# facecolor="none",
|
||||||
|
# ),
|
||||||
|
# patches.Patch(
|
||||||
|
# edgecolor="orange",
|
||||||
|
# label="Ground Truth (Matched)",
|
||||||
|
# facecolor="none",
|
||||||
|
# ),
|
||||||
|
# patches.Patch(
|
||||||
|
# edgecolor="red",
|
||||||
|
# label="Detection (Unmatched)",
|
||||||
|
# facecolor="none",
|
||||||
|
# ),
|
||||||
|
# patches.Patch(
|
||||||
|
# edgecolor="orange",
|
||||||
|
# label="Detection (Matched)",
|
||||||
|
# facecolor="none",
|
||||||
|
# linestyle="--",
|
||||||
|
# ),
|
||||||
|
# ]
|
||||||
|
# )
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@detection_plots.register(ExampleDetectionPlotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ExampleDetectionPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
return ExampleDetectionPlot.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
num_examples=config.num_examples,
|
||||||
|
audio_loader=build_audio_loader(config.audio),
|
||||||
|
preprocessor=build_preprocessor(config.preprocessing),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DetectionPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ScoreDistributionPlotConfig,
|
||||||
|
ExampleDetectionPlotConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_detection_plotter(
|
||||||
|
config: DetectionPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> DetectionPlotter:
|
||||||
|
return detection_plots.build(config, targets)
|
||||||
270
src/batdetect2/evaluate/plots/top_class.py
Normal file
270
src/batdetect2/evaluate/plots/top_class.py
Normal file
@ -0,0 +1,270 @@
|
|||||||
|
from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import compute_precision_recall
|
||||||
|
from batdetect2.evaluate.metrics.top_class import ClipEval
|
||||||
|
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
|
||||||
|
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
|
||||||
|
|
||||||
|
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
|
||||||
|
name="top_class_plot"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["pr_curve"] = "pr_curve"
|
||||||
|
label: str = "pr_curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class PRCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
precision, recall, thresholds = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives=num_positives,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@top_class_plots.register(PRCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: PRCurveConfig, targets: TargetProtocol):
|
||||||
|
return PRCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurveConfig(BasePlotConfig):
|
||||||
|
name: Literal["roc_curve"] = "roc_curve"
|
||||||
|
label: str = "roc_curve"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ROCCurve(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
fpr, tpr, thresholds = metrics.roc_curve(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@top_class_plots.register(ROCCurveConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
|
||||||
|
return ROCCurve.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrixConfig(BasePlotConfig):
|
||||||
|
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||||
|
figsize: tuple[int, int] = (10, 10)
|
||||||
|
label: str = "confusion_matrix"
|
||||||
|
exclude_generic: bool = True
|
||||||
|
exclude_noise: bool = False
|
||||||
|
noise_class: str = "noise"
|
||||||
|
normalize: Literal["true", "pred", "all", "none"] = "true"
|
||||||
|
threshold: float = 0.2
|
||||||
|
add_colorbar: bool = True
|
||||||
|
cmap: str = "Blues"
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix(BasePlot):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
exclude_generic: bool = True,
|
||||||
|
exclude_noise: bool = False,
|
||||||
|
noise_class: str = "noise",
|
||||||
|
add_colorbar: bool = True,
|
||||||
|
normalize: Literal["true", "pred", "all", "none"] = "true",
|
||||||
|
cmap: str = "Blues",
|
||||||
|
threshold: float = 0.2,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.exclude_generic = exclude_generic
|
||||||
|
self.exclude_noise = exclude_noise
|
||||||
|
self.noise_class = noise_class
|
||||||
|
self.normalize = normalize
|
||||||
|
self.add_colorbar = add_colorbar
|
||||||
|
self.threshold = threshold
|
||||||
|
self.cmap = cmap
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Tuple[str, Figure]:
|
||||||
|
y_true: List[str] = []
|
||||||
|
y_pred: List[str] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
true_class = m.true_class
|
||||||
|
pred_class = m.pred_class
|
||||||
|
|
||||||
|
if not m.is_prediction and self.exclude_noise:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_ground_truth and self.exclude_noise:
|
||||||
|
# Ignore matches that don't correspond to a ground truth
|
||||||
|
continue
|
||||||
|
|
||||||
|
if m.score < self.threshold:
|
||||||
|
if self.exclude_noise:
|
||||||
|
continue
|
||||||
|
|
||||||
|
pred_class = self.noise_class
|
||||||
|
|
||||||
|
if m.is_generic:
|
||||||
|
if self.exclude_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
true_class = self.targets.detection_class_name
|
||||||
|
|
||||||
|
y_true.append(true_class or self.noise_class)
|
||||||
|
y_pred.append(pred_class or self.noise_class)
|
||||||
|
|
||||||
|
fig = self.get_figure()
|
||||||
|
ax = fig.subplots()
|
||||||
|
|
||||||
|
class_names = [*self.targets.class_names]
|
||||||
|
|
||||||
|
if not self.exclude_generic:
|
||||||
|
class_names.append(self.targets.detection_class_name)
|
||||||
|
|
||||||
|
if not self.exclude_noise:
|
||||||
|
class_names.append(self.noise_class)
|
||||||
|
|
||||||
|
metrics.ConfusionMatrixDisplay.from_predictions(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
labels=class_names,
|
||||||
|
ax=ax,
|
||||||
|
xticks_rotation="vertical",
|
||||||
|
cmap=self.cmap,
|
||||||
|
colorbar=self.add_colorbar,
|
||||||
|
normalize=self.normalize if self.normalize != "none" else None,
|
||||||
|
values_format=".2f",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.label, fig
|
||||||
|
|
||||||
|
@top_class_plots.register(ConfusionMatrixConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
|
||||||
|
return ConfusionMatrix.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
exclude_generic=config.exclude_generic,
|
||||||
|
exclude_noise=config.exclude_noise,
|
||||||
|
noise_class=config.noise_class,
|
||||||
|
add_colorbar=config.add_colorbar,
|
||||||
|
normalize=config.normalize,
|
||||||
|
cmap=config.cmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TopClassPlotConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
PRCurveConfig,
|
||||||
|
ROCCurveConfig,
|
||||||
|
ConfusionMatrixConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_top_class_plotter(
|
||||||
|
config: TopClassPlotConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> TopClassPlotter:
|
||||||
|
return top_class_plots.build(config, targets)
|
||||||
@ -1,5 +1,16 @@
|
|||||||
from typing import Callable, Dict, Generic, List, Sequence, TypeVar
|
from typing import (
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
@ -43,6 +54,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
|
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||||
|
|
||||||
|
plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||||
|
|
||||||
ignore_start_end: float
|
ignore_start_end: float
|
||||||
|
|
||||||
prefix: str
|
prefix: str
|
||||||
@ -54,9 +67,13 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
ignore_start_end: float = 0.01,
|
ignore_start_end: float = 0.01,
|
||||||
|
plots: Optional[
|
||||||
|
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||||
|
] = None,
|
||||||
):
|
):
|
||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
|
self.plots = plots or []
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.ignore_start_end = ignore_start_end
|
self.ignore_start_end = ignore_start_end
|
||||||
@ -72,6 +89,12 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
for name, score in metric_output.items()
|
for name, score in metric_output.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def generate_plots(
|
||||||
|
self, eval_outputs: List[T_Output]
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
for plot in self.plots:
|
||||||
|
yield plot(eval_outputs)
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
@ -123,6 +146,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
config: BaseTaskConfig,
|
config: BaseTaskConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
|
plots: Optional[
|
||||||
|
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
|
||||||
|
] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
matcher = build_matcher(config.matching_strategy)
|
matcher = build_matcher(config.matching_strategy)
|
||||||
@ -130,6 +156,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
matcher=matcher,
|
matcher=matcher,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
|
plots=plots,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
ignore_start_end=config.ignore_start_end,
|
ignore_start_end=config.ignore_start_end,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
|||||||
@ -12,7 +12,11 @@ from batdetect2.evaluate.metrics.classification import (
|
|||||||
ClassificationMetricConfig,
|
ClassificationMetricConfig,
|
||||||
ClipEval,
|
ClipEval,
|
||||||
MatchEval,
|
MatchEval,
|
||||||
build_classification_metrics,
|
build_classification_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.plots.classification import (
|
||||||
|
ClassificationPlotConfig,
|
||||||
|
build_classification_plotter,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseTask,
|
||||||
@ -28,6 +32,7 @@ class ClassificationTaskConfig(BaseTaskConfig):
|
|||||||
metrics: List[ClassificationMetricConfig] = Field(
|
metrics: List[ClassificationMetricConfig] = Field(
|
||||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
|
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
|
||||||
include_generics: bool = True
|
include_generics: bool = True
|
||||||
|
|
||||||
|
|
||||||
@ -128,10 +133,16 @@ class ClassificationTask(BaseTask[ClipEval]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
metrics = [
|
metrics = [
|
||||||
build_classification_metrics(metric) for metric in config.metrics
|
build_classification_metric(metric, targets)
|
||||||
|
for metric in config.metrics
|
||||||
|
]
|
||||||
|
plots = [
|
||||||
|
build_classification_plotter(plot, targets)
|
||||||
|
for plot in config.plots
|
||||||
]
|
]
|
||||||
return ClassificationTask.build(
|
return ClassificationTask.build(
|
||||||
config=config,
|
config=config,
|
||||||
|
plots=plots,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.clip_classification import (
|
|||||||
ClipEval,
|
ClipEval,
|
||||||
build_clip_metric,
|
build_clip_metric,
|
||||||
)
|
)
|
||||||
|
from batdetect2.evaluate.plots.clip_classification import (
|
||||||
|
ClipClassificationPlotConfig,
|
||||||
|
build_clip_classification_plotter,
|
||||||
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseTask,
|
||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
@ -26,6 +30,7 @@ class ClipClassificationTaskConfig(BaseTaskConfig):
|
|||||||
ClipClassificationAveragePrecisionConfig(),
|
ClipClassificationAveragePrecisionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||||
@ -68,8 +73,13 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_clip_classification_plotter(plot, targets)
|
||||||
|
for plot in config.plots
|
||||||
|
]
|
||||||
return ClipClassificationTask.build(
|
return ClipClassificationTask.build(
|
||||||
config=config,
|
config=config,
|
||||||
|
plots=plots,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -9,6 +9,10 @@ from batdetect2.evaluate.metrics.clip_detection import (
|
|||||||
ClipEval,
|
ClipEval,
|
||||||
build_clip_metric,
|
build_clip_metric,
|
||||||
)
|
)
|
||||||
|
from batdetect2.evaluate.plots.clip_detection import (
|
||||||
|
ClipDetectionPlotConfig,
|
||||||
|
build_clip_detection_plotter,
|
||||||
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseTask,
|
||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
@ -25,6 +29,7 @@ class ClipDetectionTaskConfig(BaseTaskConfig):
|
|||||||
ClipDetectionAveragePrecisionConfig(),
|
ClipDetectionAveragePrecisionConfig(),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||||
@ -59,8 +64,13 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_clip_detection_plotter(plot, targets)
|
||||||
|
for plot in config.plots
|
||||||
|
]
|
||||||
return ClipDetectionTask.build(
|
return ClipDetectionTask.build(
|
||||||
config=config,
|
config=config,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
plots=plots,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.detection import (
|
|||||||
MatchEval,
|
MatchEval,
|
||||||
build_detection_metric,
|
build_detection_metric,
|
||||||
)
|
)
|
||||||
|
from batdetect2.evaluate.plots.detection import (
|
||||||
|
DetectionPlotConfig,
|
||||||
|
build_detection_plotter,
|
||||||
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseTask,
|
||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
@ -24,6 +28,7 @@ class DetectionTaskConfig(BaseTaskConfig):
|
|||||||
metrics: List[DetectionMetricConfig] = Field(
|
metrics: List[DetectionMetricConfig] = Field(
|
||||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
|
plots: List[DetectionPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DetectionTask(BaseTask[ClipEval]):
|
class DetectionTask(BaseTask[ClipEval]):
|
||||||
@ -72,8 +77,12 @@ class DetectionTask(BaseTask[ClipEval]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_detection_plotter(plot, targets) for plot in config.plots
|
||||||
|
]
|
||||||
return DetectionTask.build(
|
return DetectionTask.build(
|
||||||
config=config,
|
config=config,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
plots=plots,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.top_class import (
|
|||||||
TopClassMetricConfig,
|
TopClassMetricConfig,
|
||||||
build_top_class_metric,
|
build_top_class_metric,
|
||||||
)
|
)
|
||||||
|
from batdetect2.evaluate.plots.top_class import (
|
||||||
|
TopClassPlotConfig,
|
||||||
|
build_top_class_plotter,
|
||||||
|
)
|
||||||
from batdetect2.evaluate.tasks.base import (
|
from batdetect2.evaluate.tasks.base import (
|
||||||
BaseTask,
|
BaseTask,
|
||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
@ -24,6 +28,7 @@ class TopClassDetectionTaskConfig(BaseTaskConfig):
|
|||||||
metrics: List[TopClassMetricConfig] = Field(
|
metrics: List[TopClassMetricConfig] = Field(
|
||||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||||
)
|
)
|
||||||
|
plots: List[TopClassPlotConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||||
@ -94,8 +99,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
):
|
):
|
||||||
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
||||||
|
plots = [
|
||||||
|
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||||
|
]
|
||||||
return TopClassDetectionTask.build(
|
return TopClassDetectionTask.build(
|
||||||
config=config,
|
config=config,
|
||||||
|
plots=plots,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -19,7 +19,7 @@ def create_ax(
|
|||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
"""Create a new axis if none is provided"""
|
"""Create a new axis if none is provided"""
|
||||||
if ax is None:
|
if ax is None:
|
||||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
_, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
|
||||||
|
|
||||||
return ax # type: ignore
|
return ax # type: ignore
|
||||||
|
|
||||||
|
|||||||
281
src/batdetect2/plotting/metrics.py
Normal file
281
src/batdetect2/plotting/metrics.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import seaborn as sns
|
||||||
|
from cycler import cycler
|
||||||
|
from matplotlib import axes
|
||||||
|
|
||||||
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_styler(ax: axes.Axes) -> axes.Axes:
|
||||||
|
color_cycler = cycler(color=sns.color_palette("muted"))
|
||||||
|
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
|
||||||
|
marker=["o", "s", "^"]
|
||||||
|
)
|
||||||
|
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
|
||||||
|
color_cycler
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_prop_cycle(custom_cycler)
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_style(ax: axes.Axes) -> axes.Axes:
|
||||||
|
ax = set_default_styler(ax)
|
||||||
|
ax.spines.right.set_visible(False)
|
||||||
|
ax.spines.top.set_visible(False)
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pr_curve(
|
||||||
|
precision: np.ndarray,
|
||||||
|
recall: np.ndarray,
|
||||||
|
thresholds: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(
|
||||||
|
recall,
|
||||||
|
precision,
|
||||||
|
label="PR Curve",
|
||||||
|
marker="o",
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_pr_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (precision, recall, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
recall,
|
||||||
|
precision,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Recall")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_precision_curve(
|
||||||
|
threshold: np.ndarray,
|
||||||
|
precision: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_precision_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (precision, _, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
thresholds,
|
||||||
|
precision,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Precision")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_recall_curve(
|
||||||
|
threshold: np.ndarray,
|
||||||
|
recall: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Recall")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_threshold_recall_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
):
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (_, recall, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
thresholds,
|
||||||
|
recall,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("Threshold")
|
||||||
|
ax.set_ylabel("Recall")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curve(
|
||||||
|
fpr: np.ndarray,
|
||||||
|
tpr: np.ndarray,
|
||||||
|
thresholds: np.ndarray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_labels: bool = True,
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
ax.plot(
|
||||||
|
fpr,
|
||||||
|
tpr,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_roc_curves(
|
||||||
|
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
|
add_legend: bool = True,
|
||||||
|
add_labels: bool = True,
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
|
for name, (fpr, tpr, thresholds) in data.items():
|
||||||
|
ax.plot(
|
||||||
|
fpr,
|
||||||
|
tpr,
|
||||||
|
label=name,
|
||||||
|
markevery=_get_marker_positions(thresholds),
|
||||||
|
)
|
||||||
|
|
||||||
|
if add_legend:
|
||||||
|
ax.legend(
|
||||||
|
bbox_to_anchor=(1.05, 1),
|
||||||
|
loc="upper left",
|
||||||
|
borderaxespad=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
ax.set_xlim(0, 1.05)
|
||||||
|
ax.set_ylim(0, 1.05)
|
||||||
|
|
||||||
|
if add_labels:
|
||||||
|
ax.set_xlabel("False Positive Rate")
|
||||||
|
ax.set_ylabel("True Positive Rate")
|
||||||
|
|
||||||
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
def _get_marker_positions(
|
||||||
|
thresholds: np.ndarray,
|
||||||
|
n_points: int = 11,
|
||||||
|
) -> np.ndarray:
|
||||||
|
size = len(thresholds)
|
||||||
|
cut_points = np.linspace(0, 1, n_points)
|
||||||
|
indices = np.searchsorted(thresholds[::-1], cut_points)
|
||||||
|
return np.clip(size - indices, 0, size - 1) # type: ignore
|
||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List
|
from typing import Any, List
|
||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
@ -35,6 +35,7 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
def generate_plots(
|
def generate_plots(
|
||||||
self,
|
self,
|
||||||
|
eval_outputs: Any,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
):
|
):
|
||||||
plotter = get_image_logger(pl_module.logger) # type: ignore
|
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||||
@ -42,20 +43,15 @@ class ValidationMetrics(Callback):
|
|||||||
if plotter is None:
|
if plotter is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
for figure_name, fig in self.evaluator.generate_plots(
|
for figure_name, fig in self.evaluator.generate_plots(eval_outputs):
|
||||||
self._clip_annotations,
|
|
||||||
self._predictions,
|
|
||||||
):
|
|
||||||
plotter(figure_name, fig, pl_module.global_step)
|
plotter(figure_name, fig, pl_module.global_step)
|
||||||
|
|
||||||
def log_metrics(
|
def log_metrics(
|
||||||
self,
|
self,
|
||||||
|
eval_outputs: Any,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
):
|
):
|
||||||
metrics = self.evaluator.compute_metrics(
|
metrics = self.evaluator.compute_metrics(eval_outputs)
|
||||||
self._clip_annotations,
|
|
||||||
self._predictions,
|
|
||||||
)
|
|
||||||
pl_module.log_dict(metrics)
|
pl_module.log_dict(metrics)
|
||||||
|
|
||||||
def on_validation_epoch_end(
|
def on_validation_epoch_end(
|
||||||
@ -63,8 +59,13 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.log_metrics(pl_module)
|
eval_outputs = self.evaluator.evaluate(
|
||||||
self.generate_plots(pl_module)
|
self._clip_annotations,
|
||||||
|
self._predictions,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_metrics(eval_outputs, pl_module)
|
||||||
|
self.generate_plots(eval_outputs, pl_module)
|
||||||
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user