diff --git a/src/batdetect2/plotting/metrics.py b/src/batdetect2/plotting/metrics.py index 709c3b9..acb9099 100644 --- a/src/batdetect2/plotting/metrics.py +++ b/src/batdetect2/plotting/metrics.py @@ -5,6 +5,7 @@ import seaborn as sns from cycler import cycler from matplotlib import axes +from batdetect2.evaluate.metrics.common import _average_precision from batdetect2.plotting.common import create_ax @@ -80,15 +81,21 @@ def plot_pr_curves( figsize: Optional[Tuple[int, int]] = None, add_legend: bool = True, add_labels: bool = True, + include_ap: bool = False, ) -> axes.Axes: ax = create_ax(ax=ax, figsize=figsize) ax = set_default_style(ax) for name, (precision, recall, thresholds) in data.items(): + label = name + + if include_ap: + label += f" (AP={_average_precision(recall, precision):.2f})" + ax.plot( recall, precision, - label=name, + label=label, markevery=_get_marker_positions(thresholds), )