Add average precision computation to pr curves if needed

This commit is contained in:
mbsantiago 2025-11-22 00:34:42 +00:00
parent 16c401b1da
commit 7336638fa9

View File

@ -5,6 +5,7 @@ import seaborn as sns
from cycler import cycler
from matplotlib import axes
from batdetect2.evaluate.metrics.common import _average_precision
from batdetect2.plotting.common import create_ax
@ -80,15 +81,21 @@ def plot_pr_curves(
figsize: Optional[Tuple[int, int]] = None,
add_legend: bool = True,
add_labels: bool = True,
include_ap: bool = False,
) -> axes.Axes:
ax = create_ax(ax=ax, figsize=figsize)
ax = set_default_style(ax)
for name, (precision, recall, thresholds) in data.items():
label = name
if include_ap:
label += f" (AP={_average_precision(recall, precision):.2f})"
ax.plot(
recall,
precision,
label=name,
label=label,
markevery=_get_marker_positions(thresholds),
)