From 7336638fa99e110fe723fe4b0c400cd9590bfbc6 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 22 Nov 2025 00:34:42 +0000 Subject: [PATCH] Add average precision computation to pr curves if needed --- src/batdetect2/plotting/metrics.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py index 709c3b9..acb9099 100644 --- a/src/batdetect2/plotting/metrics.py +++ b/src/batdetect2/plotting/metrics.py @@ -5,6 +5,7 @@ import seaborn as sns from cycler import cycler from matplotlib import axes +from batdetect2.evaluate.metrics.common import _average_precision from batdetect2.plotting.common import create_ax @@ -80,15 +81,21 @@ def plot_pr_curves( figsize: Optional[Tuple[int, int]] = None, add_legend: bool = True, add_labels: bool = True, + include_ap: bool = False, ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) ax = set_default_style(ax) for name, (precision, recall, thresholds) in data.items(): + label = name + + if include_ap: + label += f" (AP={_average_precision(recall, precision):.2f})" + ax.plot( recall, precision, - label=name, + label=label, markevery=_get_marker_positions(thresholds), )