Plotting reorganised

This commit is contained in:
mbsantiago 2025-09-27 23:58:06 +01:00
parent df2abff654
commit 87ed44c8f7
18 changed files with 1626 additions and 77 deletions

View File

@ -19,9 +19,13 @@ from soundevent import data
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.metrics.common import average_precision from batdetect2.evaluate.metrics.common import average_precision
from batdetect2.typing import RawPrediction from batdetect2.typing import RawPrediction, TargetProtocol
__all__ = [] __all__ = [
"ClassificationMetric",
"ClassificationMetricConfig",
"build_classification_metric",
]
@dataclass @dataclass
@ -45,8 +49,8 @@ class ClipEval:
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
classification_metrics: Registry[ClassificationMetric, []] = Registry( classification_metrics: Registry[ClassificationMetric, [TargetProtocol]] = (
"classification_metric" Registry("classification_metric")
) )
@ -58,9 +62,11 @@ class BaseClassificationConfig(BaseConfig):
class BaseClassificationMetric: class BaseClassificationMetric:
def __init__( def __init__(
self, self,
targets: TargetProtocol,
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None, exclude: Optional[List[str]] = None,
): ):
self.targets = targets
self.include = include self.include = include
self.exclude = exclude self.exclude = exclude
@ -84,13 +90,14 @@ class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
class ClassificationAveragePrecision(BaseClassificationMetric): class ClassificationAveragePrecision(BaseClassificationMetric):
def __init__( def __init__(
self, self,
targets: TargetProtocol,
ignore_non_predictions: bool = True, ignore_non_predictions: bool = True,
ignore_generic: bool = True, ignore_generic: bool = True,
label: str = "average_precision", label: str = "average_precision",
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None, exclude: Optional[List[str]] = None,
): ):
super().__init__(include=include, exclude=exclude) super().__init__(include=include, exclude=exclude, targets=targets)
self.ignore_non_predictions = ignore_non_predictions self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic self.ignore_generic = ignore_generic
self.label = label self.label = label
@ -98,33 +105,11 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEval] self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]: ) -> Dict[str, float]:
y_true = defaultdict(list) y_true, y_score, num_positives = _extract_per_class_metric_data(
y_score = defaultdict(list) clip_evaluations,
num_positives = defaultdict(lambda: 0) ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
class_names = set() )
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
class_names.add(class_name)
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and self.ignore_generic:
continue
is_class = m.true_class == class_name
if is_class:
num_positives[class_name] += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true[class_name].append(is_class)
y_score[class_name].append(m.score)
class_scores = { class_scores = {
class_name: average_precision( class_name: average_precision(
@ -132,7 +117,7 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
y_score[class_name], y_score[class_name],
num_positives=num_positives[class_name], num_positives=num_positives[class_name],
) )
for class_name in class_names for class_name in self.targets.class_names
} }
mean_score = float( mean_score = float(
@ -150,8 +135,12 @@ class ClassificationAveragePrecision(BaseClassificationMetric):
@classification_metrics.register(ClassificationAveragePrecisionConfig) @classification_metrics.register(ClassificationAveragePrecisionConfig)
@staticmethod @staticmethod
def from_config(config: ClassificationAveragePrecisionConfig): def from_config(
config: ClassificationAveragePrecisionConfig,
targets: TargetProtocol,
):
return ClassificationAveragePrecision( return ClassificationAveragePrecision(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions, ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic, ignore_generic=config.ignore_generic,
label=config.label, label=config.label,
@ -170,12 +159,14 @@ class ClassificationROCAUCConfig(BaseClassificationConfig):
class ClassificationROCAUC(BaseClassificationMetric): class ClassificationROCAUC(BaseClassificationMetric):
def __init__( def __init__(
self, self,
targets: TargetProtocol,
ignore_non_predictions: bool = True, ignore_non_predictions: bool = True,
ignore_generic: bool = True, ignore_generic: bool = True,
label: str = "roc_auc", label: str = "roc_auc",
include: Optional[List[str]] = None, include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None, exclude: Optional[List[str]] = None,
): ):
self.targets = targets
self.ignore_non_predictions = ignore_non_predictions self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic self.ignore_generic = ignore_generic
self.label = label self.label = label
@ -185,27 +176,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEval] self, clip_evaluations: Sequence[ClipEval]
) -> Dict[str, float]: ) -> Dict[str, float]:
y_true = defaultdict(list) y_true, y_score, _ = _extract_per_class_metric_data(
y_score = defaultdict(list) clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
class_names = set() ignore_generic=self.ignore_generic,
)
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
class_names.add(class_name)
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and self.ignore_generic:
continue
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true[class_name].append(m.true_class == class_name)
y_score[class_name].append(m.score)
class_scores = { class_scores = {
class_name: float( class_name: float(
@ -214,7 +189,7 @@ class ClassificationROCAUC(BaseClassificationMetric):
y_score[class_name], y_score[class_name],
) )
) )
for class_name in class_names for class_name in self.targets.class_names
} }
mean_score = float( mean_score = float(
@ -232,8 +207,11 @@ class ClassificationROCAUC(BaseClassificationMetric):
@classification_metrics.register(ClassificationROCAUCConfig) @classification_metrics.register(ClassificationROCAUCConfig)
@staticmethod @staticmethod
def from_config(config: ClassificationROCAUCConfig): def from_config(
config: ClassificationROCAUCConfig, targets: TargetProtocol
):
return ClassificationROCAUC( return ClassificationROCAUC(
targets=targets,
ignore_non_predictions=config.ignore_non_predictions, ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic, ignore_generic=config.ignore_generic,
label=config.label, label=config.label,
@ -249,5 +227,40 @@ ClassificationMetricConfig = Annotated[
] ]
def build_classification_metrics(config: ClassificationMetricConfig): def build_classification_metric(
return classification_metrics.build(config) config: ClassificationMetricConfig,
targets: TargetProtocol,
) -> ClassificationMetric:
return classification_metrics.build(config, targets)
def _extract_per_class_metric_data(
clip_evaluations: Sequence[ClipEval],
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
):
y_true = defaultdict(list)
y_score = defaultdict(list)
num_positives = defaultdict(lambda: 0)
for clip_eval in clip_evaluations:
for class_name, matches in clip_eval.matches.items():
for m in matches:
# Exclude matches with ground truth sounds where the class
# is unknown
if m.is_generic and ignore_generic:
continue
is_class = m.true_class == class_name
if is_class:
num_positives[class_name] += 1
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and ignore_non_predictions:
continue
y_true[class_name].append(is_class)
y_score[class_name].append(m.score)
return y_true, y_score, num_positives

View File

@ -12,7 +12,7 @@ def compute_precision_recall(
y_true, y_true,
y_score, y_score,
num_positives: Optional[int] = None, num_positives: Optional[int] = None,
) -> Tuple[np.ndarray, np.ndarray]: ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
y_true = np.array(y_true) y_true = np.array(y_true)
y_score = np.array(y_score) y_score = np.array(y_score)
@ -22,6 +22,7 @@ def compute_precision_recall(
# Sort by score # Sort by score
sort_ind = np.argsort(y_score)[::-1] sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind] y_true_sorted = y_true[sort_ind]
y_score_sorted = y_score[sort_ind]
false_pos_c = np.cumsum(1 - y_true_sorted) false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted) true_pos_c = np.cumsum(y_true_sorted)
@ -34,7 +35,7 @@ def compute_precision_recall(
precision[np.isnan(precision)] = 0 precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0 recall[np.isnan(recall)] = 0
return precision, recall return precision, recall, y_score_sorted
def average_precision( def average_precision(
@ -42,7 +43,7 @@ def average_precision(
y_score, y_score,
num_positives: Optional[int] = None, num_positives: Optional[int] = None,
) -> float: ) -> float:
precision, recall = compute_precision_recall( precision, recall, _ = compute_precision_recall(
y_true, y_true,
y_score, y_score,
num_positives=num_positives, num_positives=num_positives,

View File

@ -0,0 +1,54 @@
from typing import Optional
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from batdetect2.core import BaseConfig
from batdetect2.typing import TargetProtocol
class BasePlotConfig(BaseConfig):
label: str = "plot"
theme: str = "default"
title: Optional[str] = None
figsize: tuple[int, int] = (5, 5)
dpi: int = 100
class BasePlot:
def __init__(
self,
targets: TargetProtocol,
label: str = "plot",
figsize: tuple[int, int] = (5, 5),
title: Optional[str] = None,
dpi: int = 100,
theme: str = "default",
):
self.targets = targets
self.label = label
self.figsize = figsize
self.dpi = dpi
self.theme = theme
self.title = title
def get_figure(self) -> Figure:
plt.style.use(self.theme)
fig = plt.figure(figsize=self.figsize, dpi=self.dpi)
if self.title is not None:
fig.suptitle(self.title)
return fig
@classmethod
def build(cls, config: BasePlotConfig, targets: TargetProtocol, **kwargs):
return cls(
targets=targets,
figsize=config.figsize,
dpi=config.dpi,
theme=config.theme,
label=config.label,
title=config.title,
**kwargs,
)

View File

@ -0,0 +1,212 @@
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.classification import (
ClipEval,
_extract_per_class_metric_data,
)
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curves,
plot_roc_curves,
plot_threshold_precision_curves,
plot_threshold_recall_curves,
)
from batdetect2.typing import TargetProtocol
ClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
classification_plots: Registry[ClassificationPlotter, [TargetProtocol]] = (
Registry("classification_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
fig = self.get_figure()
ax = fig.subplots()
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives=num_positives[class_name],
)
for class_name in self.targets.class_names
}
plot_pr_curves(data, ax=ax)
return self.label, fig
@classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ThresholdPRCurveConfig(BasePlotConfig):
name: Literal["threshold_pr_curve"] = "threshold_pr_curve"
label: str = "threshold_pr_curve"
figsize: tuple[int, int] = (10, 5)
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ThresholdPRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true, y_score, num_positives = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: compute_precision_recall(
y_true[class_name],
y_score[class_name],
num_positives[class_name],
)
for class_name in self.targets.class_names
}
fig = self.get_figure()
ax1, ax2 = fig.subplots(nrows=1, ncols=2, sharey=True)
plot_threshold_precision_curves(data, ax=ax1, add_legend=False)
plot_threshold_recall_curves(data, ax=ax2, add_legend=True)
return self.label, fig
@classification_plots.register(ThresholdPRCurveConfig)
@staticmethod
def from_config(config: ThresholdPRCurveConfig, targets: TargetProtocol):
return ThresholdPRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true, y_score, _ = _extract_per_class_metric_data(
clip_evaluations,
ignore_non_predictions=self.ignore_non_predictions,
ignore_generic=self.ignore_generic,
)
data = {
class_name: metrics.roc_curve(
y_true[class_name],
y_score[class_name],
)
for class_name in self.targets.class_names
}
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
return self.label, fig
@classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
ClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ThresholdPRCurveConfig,
],
Field(discriminator="name"),
]
def build_classification_plotter(
config: ClassificationPlotConfig,
targets: TargetProtocol,
) -> ClassificationPlotter:
return classification_plots.build(config, targets)

View File

@ -0,0 +1,131 @@
from typing import (
Annotated,
Callable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_classification import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import (
plot_pr_curves,
plot_roc_curves,
)
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipClassificationPlotConfig",
"ClipClassificationPlotter",
"build_clip_classification_plotter",
]
ClipClassificationPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
clip_classification_plots: Registry[
ClipClassificationPlotter, [TargetProtocol]
] = Registry("clip_classification_plot")
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
data[class_name] = (precision, recall, thresholds)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curves(data, ax=ax)
return self.label, fig
@clip_classification_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
data = {}
for class_name in self.targets.class_names:
y_true = [class_name in c.true_classes for c in clip_evaluations]
y_score = [
c.class_scores.get(class_name, 0) for c in clip_evaluations
]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
data[class_name] = (fpr, tpr, thresholds)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curves(data, ax=ax)
return self.label, fig
@clip_classification_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
)
ClipClassificationPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
],
Field(discriminator="name"),
]
def build_clip_classification_plotter(
config: ClipClassificationPlotConfig,
targets: TargetProtocol,
) -> ClipClassificationPlotter:
return clip_classification_plots.build(config, targets)

View File

@ -0,0 +1,160 @@
from typing import (
Annotated,
Callable,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.clip_detection import ClipEval
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
__all__ = [
"ClipDetectionPlotConfig",
"ClipDetectionPlotter",
"build_clip_detection_plotter",
]
ClipDetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
clip_detection_plots: Registry[ClipDetectionPlotter, [TargetProtocol]] = (
Registry("clip_detection_plot")
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
title: Optional[str] = "Precision-Recall Curve"
class PRCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
@clip_detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
title: Optional[str] = "ROC Curve"
class ROCCurve(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
@clip_detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
title: Optional[str] = "Score Distribution"
class ScoreDistributionPlot(BasePlot):
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = [c.gt_det for c in clip_evaluations]
y_score = [c.score for c in clip_evaluations]
fig = self.get_figure()
ax = fig.subplots()
df = pd.DataFrame({"is_true": y_true, "score": y_score})
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
return self.label, fig
@clip_detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
)
ClipDetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
],
Field(discriminator="name"),
]
def build_clip_detection_plotter(
config: ClipDetectionPlotConfig,
targets: TargetProtocol,
) -> ClipDetectionPlotter:
return clip_detection_plots.build(config, targets)

View File

@ -0,0 +1,350 @@
import random
from typing import Annotated, Callable, Literal, Sequence, Tuple, Union
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from matplotlib import patches
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from soundevent.plot import plot_geometry
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.detection import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.clips import plot_clip
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol, TargetProtocol
DetectionPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
detection_plots: Registry[DetectionPlotter, [TargetProtocol]] = Registry(
name="detection_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evals: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evals:
for m in clip_eval.matches:
num_positives += int(m.is_ground_truth)
# Ignore matches that don't correspond to a prediction
if not m.is_prediction and self.ignore_non_predictions:
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
@detection_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
@detection_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ScoreDistributionPlotConfig(BasePlotConfig):
name: Literal["score_distribution"] = "score_distribution"
label: str = "score_distribution"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ScoreDistributionPlot(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if not m.is_prediction and self.ignore_non_predictions:
# Ignore matches that don't correspond to a prediction
continue
y_true.append(m.is_ground_truth)
y_score.append(m.score)
df = pd.DataFrame({"is_true": y_true, "score": y_score})
fig = self.get_figure()
ax = fig.subplots()
sns.histplot(
data=df,
x="score",
binwidth=0.025,
binrange=(0, 1),
hue="is_true",
ax=ax,
stat="probability",
common_norm=False,
)
return self.label, fig
@detection_plots.register(ScoreDistributionPlotConfig)
@staticmethod
def from_config(
config: ScoreDistributionPlotConfig, targets: TargetProtocol
):
return ScoreDistributionPlot.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ExampleDetectionPlotConfig(BasePlotConfig):
name: Literal["example_detection"] = "example_detection"
label: str = "example_detection"
figsize: tuple[int, int] = (10, 15)
num_examples: int = 5
threshold: float = 0.2
audio: AudioConfig = Field(default_factory=AudioConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
class ExampleDetectionPlot(BasePlot):
def __init__(
self,
*args,
num_examples: int = 5,
threshold: float = 0.2,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
**kwargs,
):
super().__init__(*args, **kwargs)
self.num_examples = num_examples
self.audio_loader = audio_loader
self.threshold = threshold
self.preprocessor = preprocessor
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
fig = self.get_figure()
sample = clip_evaluations
if self.num_examples < len(sample):
sample = random.sample(sample, self.num_examples)
axes = fig.subplots(nrows=self.num_examples, ncols=1)
for ax, clip_eval in zip(axes, sample):
plot_clip(
clip_eval.clip,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
ax=ax,
)
for m in clip_eval.matches:
is_match = (
m.pred is not None
and m.gt is not None
and m.score >= self.threshold
)
if m.pred is not None:
plot_geometry(
m.pred.geometry,
ax=ax,
add_points=False,
facecolor="none",
alpha=m.pred.detection_score,
linestyle="-" if not is_match else "--",
color="red" if not is_match else "orange",
)
if m.gt is not None:
plot_geometry(
m.gt.sound_event.geometry, # type: ignore
ax=ax,
add_points=False,
facecolor="none",
color="green" if not is_match else "orange",
)
ax.set_title(clip_eval.clip.recording.path.name)
# ax.legend(
# handles=[
# patches.Patch(
# edgecolor="green",
# label="Ground Truth (Unmatched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="orange",
# label="Ground Truth (Matched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="red",
# label="Detection (Unmatched)",
# facecolor="none",
# ),
# patches.Patch(
# edgecolor="orange",
# label="Detection (Matched)",
# facecolor="none",
# linestyle="--",
# ),
# ]
# )
plt.tight_layout()
return self.label, fig
@detection_plots.register(ExampleDetectionPlotConfig)
@staticmethod
def from_config(
config: ExampleDetectionPlotConfig,
targets: TargetProtocol,
):
return ExampleDetectionPlot.build(
config=config,
targets=targets,
num_examples=config.num_examples,
audio_loader=build_audio_loader(config.audio),
preprocessor=build_preprocessor(config.preprocessing),
)
DetectionPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ScoreDistributionPlotConfig,
ExampleDetectionPlotConfig,
],
Field(discriminator="name"),
]
def build_detection_plotter(
config: DetectionPlotConfig,
targets: TargetProtocol,
) -> DetectionPlotter:
return detection_plots.build(config, targets)

View File

@ -0,0 +1,270 @@
from typing import Annotated, Callable, List, Literal, Sequence, Tuple, Union
from matplotlib.figure import Figure
from pydantic import Field
from sklearn import metrics
from batdetect2.core import Registry
from batdetect2.evaluate.metrics.common import compute_precision_recall
from batdetect2.evaluate.metrics.top_class import ClipEval
from batdetect2.evaluate.plots.base import BasePlot, BasePlotConfig
from batdetect2.plotting.metrics import plot_pr_curve, plot_roc_curve
from batdetect2.typing import TargetProtocol
TopClassPlotter = Callable[[Sequence[ClipEval]], Tuple[str, Figure]]
top_class_plots: Registry[TopClassPlotter, [TargetProtocol]] = Registry(
name="top_class_plot"
)
class PRCurveConfig(BasePlotConfig):
name: Literal["pr_curve"] = "pr_curve"
label: str = "pr_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class PRCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
num_positives = 0
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
num_positives += int(m.is_ground_truth)
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
precision, recall, thresholds = compute_precision_recall(
y_true,
y_score,
num_positives=num_positives,
)
fig = self.get_figure()
ax = fig.subplots()
plot_pr_curve(precision, recall, thresholds, ax=ax)
return self.label, fig
@top_class_plots.register(PRCurveConfig)
@staticmethod
def from_config(config: PRCurveConfig, targets: TargetProtocol):
return PRCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ROCCurveConfig(BasePlotConfig):
name: Literal["roc_curve"] = "roc_curve"
label: str = "roc_curve"
ignore_non_predictions: bool = True
ignore_generic: bool = True
class ROCCurve(BasePlot):
def __init__(
self,
*args,
ignore_non_predictions: bool = True,
ignore_generic: bool = True,
**kwargs,
):
super().__init__(*args, **kwargs)
self.ignore_non_predictions = ignore_non_predictions
self.ignore_generic = ignore_generic
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
if m.is_generic and self.ignore_generic:
# Ignore gt sounds with unknown class
continue
if not m.is_prediction and self.ignore_non_predictions:
# Ignore non predictions
continue
y_true.append(m.pred_class == m.true_class)
y_score.append(m.score)
fpr, tpr, thresholds = metrics.roc_curve(
y_true,
y_score,
)
fig = self.get_figure()
ax = fig.subplots()
plot_roc_curve(fpr, tpr, thresholds, ax=ax)
return self.label, fig
@top_class_plots.register(ROCCurveConfig)
@staticmethod
def from_config(config: ROCCurveConfig, targets: TargetProtocol):
return ROCCurve.build(
config=config,
targets=targets,
ignore_non_predictions=config.ignore_non_predictions,
ignore_generic=config.ignore_generic,
)
class ConfusionMatrixConfig(BasePlotConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
figsize: tuple[int, int] = (10, 10)
label: str = "confusion_matrix"
exclude_generic: bool = True
exclude_noise: bool = False
noise_class: str = "noise"
normalize: Literal["true", "pred", "all", "none"] = "true"
threshold: float = 0.2
add_colorbar: bool = True
cmap: str = "Blues"
class ConfusionMatrix(BasePlot):
def __init__(
self,
*args,
exclude_generic: bool = True,
exclude_noise: bool = False,
noise_class: str = "noise",
add_colorbar: bool = True,
normalize: Literal["true", "pred", "all", "none"] = "true",
cmap: str = "Blues",
threshold: float = 0.2,
**kwargs,
):
super().__init__(*args, **kwargs)
self.exclude_generic = exclude_generic
self.exclude_noise = exclude_noise
self.noise_class = noise_class
self.normalize = normalize
self.add_colorbar = add_colorbar
self.threshold = threshold
self.cmap = cmap
def __call__(
self,
clip_evaluations: Sequence[ClipEval],
) -> Tuple[str, Figure]:
y_true: List[str] = []
y_pred: List[str] = []
for clip_eval in clip_evaluations:
for m in clip_eval.matches:
true_class = m.true_class
pred_class = m.pred_class
if not m.is_prediction and self.exclude_noise:
# Ignore matches that don't correspond to a prediction
continue
if not m.is_ground_truth and self.exclude_noise:
# Ignore matches that don't correspond to a ground truth
continue
if m.score < self.threshold:
if self.exclude_noise:
continue
pred_class = self.noise_class
if m.is_generic:
if self.exclude_generic:
# Ignore gt sounds with unknown class
continue
true_class = self.targets.detection_class_name
y_true.append(true_class or self.noise_class)
y_pred.append(pred_class or self.noise_class)
fig = self.get_figure()
ax = fig.subplots()
class_names = [*self.targets.class_names]
if not self.exclude_generic:
class_names.append(self.targets.detection_class_name)
if not self.exclude_noise:
class_names.append(self.noise_class)
metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=class_names,
ax=ax,
xticks_rotation="vertical",
cmap=self.cmap,
colorbar=self.add_colorbar,
normalize=self.normalize if self.normalize != "none" else None,
values_format=".2f",
)
return self.label, fig
@top_class_plots.register(ConfusionMatrixConfig)
@staticmethod
def from_config(config: ConfusionMatrixConfig, targets: TargetProtocol):
return ConfusionMatrix.build(
config=config,
targets=targets,
exclude_generic=config.exclude_generic,
exclude_noise=config.exclude_noise,
noise_class=config.noise_class,
add_colorbar=config.add_colorbar,
normalize=config.normalize,
cmap=config.cmap,
)
TopClassPlotConfig = Annotated[
Union[
PRCurveConfig,
ROCCurveConfig,
ConfusionMatrixConfig,
],
Field(discriminator="name"),
]
def build_top_class_plotter(
config: TopClassPlotConfig,
targets: TargetProtocol,
) -> TopClassPlotter:
return top_class_plots.build(config, targets)

View File

@ -1,5 +1,16 @@
from typing import Callable, Dict, Generic, List, Sequence, TypeVar from typing import (
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
TypeVar,
)
from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
@ -43,6 +54,8 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
plots: List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
ignore_start_end: float ignore_start_end: float
prefix: str prefix: str
@ -54,9 +67,13 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
prefix: str, prefix: str,
ignore_start_end: float = 0.01, ignore_start_end: float = 0.01,
plots: Optional[
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
] = None,
): ):
self.matcher = matcher self.matcher = matcher
self.metrics = metrics self.metrics = metrics
self.plots = plots or []
self.targets = targets self.targets = targets
self.prefix = prefix self.prefix = prefix
self.ignore_start_end = ignore_start_end self.ignore_start_end = ignore_start_end
@ -72,6 +89,12 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
for name, score in metric_output.items() for name, score in metric_output.items()
} }
def generate_plots(
self, eval_outputs: List[T_Output]
) -> Iterable[Tuple[str, Figure]]:
for plot in self.plots:
yield plot(eval_outputs)
def evaluate( def evaluate(
self, self,
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
@ -123,6 +146,9 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
config: BaseTaskConfig, config: BaseTaskConfig,
targets: TargetProtocol, targets: TargetProtocol,
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
plots: Optional[
List[Callable[[Sequence[T_Output]], Tuple[str, Figure]]]
] = None,
**kwargs, **kwargs,
): ):
matcher = build_matcher(config.matching_strategy) matcher = build_matcher(config.matching_strategy)
@ -130,6 +156,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
matcher=matcher, matcher=matcher,
targets=targets, targets=targets,
metrics=metrics, metrics=metrics,
plots=plots,
prefix=config.prefix, prefix=config.prefix,
ignore_start_end=config.ignore_start_end, ignore_start_end=config.ignore_start_end,
**kwargs, **kwargs,

View File

@ -12,7 +12,11 @@ from batdetect2.evaluate.metrics.classification import (
ClassificationMetricConfig, ClassificationMetricConfig,
ClipEval, ClipEval,
MatchEval, MatchEval,
build_classification_metrics, build_classification_metric,
)
from batdetect2.evaluate.plots.classification import (
ClassificationPlotConfig,
build_classification_plotter,
) )
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseTask,
@ -28,6 +32,7 @@ class ClassificationTaskConfig(BaseTaskConfig):
metrics: List[ClassificationMetricConfig] = Field( metrics: List[ClassificationMetricConfig] = Field(
default_factory=lambda: [ClassificationAveragePrecisionConfig()] default_factory=lambda: [ClassificationAveragePrecisionConfig()]
) )
plots: List[ClassificationPlotConfig] = Field(default_factory=list)
include_generics: bool = True include_generics: bool = True
@ -128,10 +133,16 @@ class ClassificationTask(BaseTask[ClipEval]):
targets: TargetProtocol, targets: TargetProtocol,
): ):
metrics = [ metrics = [
build_classification_metrics(metric) for metric in config.metrics build_classification_metric(metric, targets)
for metric in config.metrics
]
plots = [
build_classification_plotter(plot, targets)
for plot in config.plots
] ]
return ClassificationTask.build( return ClassificationTask.build(
config=config, config=config,
plots=plots,
targets=targets, targets=targets,
metrics=metrics, metrics=metrics,
) )

View File

@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.clip_classification import (
ClipEval, ClipEval,
build_clip_metric, build_clip_metric,
) )
from batdetect2.evaluate.plots.clip_classification import (
ClipClassificationPlotConfig,
build_clip_classification_plotter,
)
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseTask,
BaseTaskConfig, BaseTaskConfig,
@ -26,6 +30,7 @@ class ClipClassificationTaskConfig(BaseTaskConfig):
ClipClassificationAveragePrecisionConfig(), ClipClassificationAveragePrecisionConfig(),
] ]
) )
plots: List[ClipClassificationPlotConfig] = Field(default_factory=list)
class ClipClassificationTask(BaseTask[ClipEval]): class ClipClassificationTask(BaseTask[ClipEval]):
@ -68,8 +73,13 @@ class ClipClassificationTask(BaseTask[ClipEval]):
targets: TargetProtocol, targets: TargetProtocol,
): ):
metrics = [build_clip_metric(metric) for metric in config.metrics] metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_classification_plotter(plot, targets)
for plot in config.plots
]
return ClipClassificationTask.build( return ClipClassificationTask.build(
config=config, config=config,
plots=plots,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
) )

View File

@ -9,6 +9,10 @@ from batdetect2.evaluate.metrics.clip_detection import (
ClipEval, ClipEval,
build_clip_metric, build_clip_metric,
) )
from batdetect2.evaluate.plots.clip_detection import (
ClipDetectionPlotConfig,
build_clip_detection_plotter,
)
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseTask,
BaseTaskConfig, BaseTaskConfig,
@ -25,6 +29,7 @@ class ClipDetectionTaskConfig(BaseTaskConfig):
ClipDetectionAveragePrecisionConfig(), ClipDetectionAveragePrecisionConfig(),
] ]
) )
plots: List[ClipDetectionPlotConfig] = Field(default_factory=list)
class ClipDetectionTask(BaseTask[ClipEval]): class ClipDetectionTask(BaseTask[ClipEval]):
@ -59,8 +64,13 @@ class ClipDetectionTask(BaseTask[ClipEval]):
targets: TargetProtocol, targets: TargetProtocol,
): ):
metrics = [build_clip_metric(metric) for metric in config.metrics] metrics = [build_clip_metric(metric) for metric in config.metrics]
plots = [
build_clip_detection_plotter(plot, targets)
for plot in config.plots
]
return ClipDetectionTask.build( return ClipDetectionTask.build(
config=config, config=config,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
plots=plots,
) )

View File

@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.detection import (
MatchEval, MatchEval,
build_detection_metric, build_detection_metric,
) )
from batdetect2.evaluate.plots.detection import (
DetectionPlotConfig,
build_detection_plotter,
)
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseTask,
BaseTaskConfig, BaseTaskConfig,
@ -24,6 +28,7 @@ class DetectionTaskConfig(BaseTaskConfig):
metrics: List[DetectionMetricConfig] = Field( metrics: List[DetectionMetricConfig] = Field(
default_factory=lambda: [DetectionAveragePrecisionConfig()] default_factory=lambda: [DetectionAveragePrecisionConfig()]
) )
plots: List[DetectionPlotConfig] = Field(default_factory=list)
class DetectionTask(BaseTask[ClipEval]): class DetectionTask(BaseTask[ClipEval]):
@ -72,8 +77,12 @@ class DetectionTask(BaseTask[ClipEval]):
targets: TargetProtocol, targets: TargetProtocol,
): ):
metrics = [build_detection_metric(metric) for metric in config.metrics] metrics = [build_detection_metric(metric) for metric in config.metrics]
plots = [
build_detection_plotter(plot, targets) for plot in config.plots
]
return DetectionTask.build( return DetectionTask.build(
config=config, config=config,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
plots=plots,
) )

View File

@ -10,6 +10,10 @@ from batdetect2.evaluate.metrics.top_class import (
TopClassMetricConfig, TopClassMetricConfig,
build_top_class_metric, build_top_class_metric,
) )
from batdetect2.evaluate.plots.top_class import (
TopClassPlotConfig,
build_top_class_plotter,
)
from batdetect2.evaluate.tasks.base import ( from batdetect2.evaluate.tasks.base import (
BaseTask, BaseTask,
BaseTaskConfig, BaseTaskConfig,
@ -24,6 +28,7 @@ class TopClassDetectionTaskConfig(BaseTaskConfig):
metrics: List[TopClassMetricConfig] = Field( metrics: List[TopClassMetricConfig] = Field(
default_factory=lambda: [TopClassAveragePrecisionConfig()] default_factory=lambda: [TopClassAveragePrecisionConfig()]
) )
plots: List[TopClassPlotConfig] = Field(default_factory=list)
class TopClassDetectionTask(BaseTask[ClipEval]): class TopClassDetectionTask(BaseTask[ClipEval]):
@ -94,8 +99,12 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
targets: TargetProtocol, targets: TargetProtocol,
): ):
metrics = [build_top_class_metric(metric) for metric in config.metrics] metrics = [build_top_class_metric(metric) for metric in config.metrics]
plots = [
build_top_class_plotter(plot, targets) for plot in config.plots
]
return TopClassDetectionTask.build( return TopClassDetectionTask.build(
config=config, config=config,
plots=plots,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
) )

View File

@ -19,7 +19,7 @@ def create_ax(
) -> axes.Axes: ) -> axes.Axes:
"""Create a new axis if none is provided""" """Create a new axis if none is provided"""
if ax is None: if ax is None:
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore _, ax = plt.subplots(figsize=figsize, nrows=1, ncols=1, **kwargs) # type: ignore
return ax # type: ignore return ax # type: ignore

View File

@ -0,0 +1,281 @@
from typing import Dict, Optional, Tuple
import numpy as np
import seaborn as sns
from cycler import cycler
from matplotlib import axes
from batdetect2.plotting.common import create_ax
def set_default_styler(ax: axes.Axes) -> axes.Axes:
color_cycler = cycler(color=sns.color_palette("muted"))
style_cycler = cycler(linestyle=["-", "--", ":"]) * cycler(
marker=["o", "s", "^"]
)
custom_cycler = color_cycler * len(style_cycler) + style_cycler * len(
color_cycler
)
ax.set_prop_cycle(custom_cycler)
return ax
def set_default_style(ax: axes.Axes) -> axes.Axes:
ax = set_default_styler(ax)
ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)
return ax
def plot_pr_curve(
precision: np.ndarray,
recall: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
recall,
precision,
label="PR Curve",
marker="o",
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
return ax
def plot_pr_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, recall, thresholds) in data.items():
ax.plot(
recall,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
return ax
def plot_threshold_precision_curve(
threshold: np.ndarray,
precision: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, precision, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_precision_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, _, thresholds) in data.items():
ax.plot(
thresholds,
precision,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Precision")
return ax
def plot_threshold_recall_curve(
threshold: np.ndarray,
recall: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(threshold, recall, markevery=_get_marker_positions(threshold))
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_threshold_recall_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
):
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (_, recall, thresholds) in data.items():
ax.plot(
thresholds,
recall,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("Threshold")
ax.set_ylabel("Recall")
return ax
def plot_roc_curve(
fpr: np.ndarray,
tpr: np.ndarray,
thresholds: np.ndarray,
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
ax.plot(
fpr,
tpr,
markevery=_get_marker_positions(thresholds),
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def plot_roc_curves(
data: Dict[str, Tuple[np.ndarray, np.ndarray, np.ndarray]],
ax: Optional[axes.Axes] = None,
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (fpr, tpr, thresholds) in data.items():
ax.plot(
fpr,
tpr,
label=name,
markevery=_get_marker_positions(thresholds),
)
if add_legend:
ax.legend(
bbox_to_anchor=(1.05, 1),
loc="upper left",
borderaxespad=0.0,
)
ax.set_xlim(0, 1.05)
ax.set_ylim(0, 1.05)
if add_labels:
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
return ax
def _get_marker_positions(
thresholds: np.ndarray,
n_points: int = 11,
) -> np.ndarray:
size = len(thresholds)
cut_points = np.linspace(0, 1, n_points)
indices = np.searchsorted(thresholds[::-1], cut_points)
return np.clip(size - indices, 0, size - 1) # type: ignore

View File

@ -1,4 +1,4 @@
from typing import List from typing import Any, List
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
@ -35,6 +35,7 @@ class ValidationMetrics(Callback):
def generate_plots( def generate_plots(
self, self,
eval_outputs: Any,
pl_module: LightningModule, pl_module: LightningModule,
): ):
plotter = get_image_logger(pl_module.logger) # type: ignore plotter = get_image_logger(pl_module.logger) # type: ignore
@ -42,20 +43,15 @@ class ValidationMetrics(Callback):
if plotter is None: if plotter is None:
return return
for figure_name, fig in self.evaluator.generate_plots( for figure_name, fig in self.evaluator.generate_plots(eval_outputs):
self._clip_annotations,
self._predictions,
):
plotter(figure_name, fig, pl_module.global_step) plotter(figure_name, fig, pl_module.global_step)
def log_metrics( def log_metrics(
self, self,
eval_outputs: Any,
pl_module: LightningModule, pl_module: LightningModule,
): ):
metrics = self.evaluator.compute_metrics( metrics = self.evaluator.compute_metrics(eval_outputs)
self._clip_annotations,
self._predictions,
)
pl_module.log_dict(metrics) pl_module.log_dict(metrics)
def on_validation_epoch_end( def on_validation_epoch_end(
@ -63,8 +59,13 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self.log_metrics(pl_module) eval_outputs = self.evaluator.evaluate(
self.generate_plots(pl_module) self._clip_annotations,
self._predictions,
)
self.log_metrics(eval_outputs, pl_module)
self.generate_plots(eval_outputs, pl_module)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)