mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add average precision computation to pr curves if needed
This commit is contained in:
parent
16c401b1da
commit
7336638fa9
@ -5,6 +5,7 @@ import seaborn as sns
|
||||
from cycler import cycler
|
||||
from matplotlib import axes
|
||||
|
||||
from batdetect2.evaluate.metrics.common import _average_precision
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
|
||||
@ -80,15 +81,21 @@ def plot_pr_curves(
|
||||
figsize: Optional[Tuple[int, int]] = None,
|
||||
add_legend: bool = True,
|
||||
add_labels: bool = True,
|
||||
include_ap: bool = False,
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax=ax, figsize=figsize)
|
||||
ax = set_default_style(ax)
|
||||
|
||||
for name, (precision, recall, thresholds) in data.items():
|
||||
label = name
|
||||
|
||||
if include_ap:
|
||||
label += f" (AP={_average_precision(recall, precision):.2f})"
|
||||
|
||||
ax.plot(
|
||||
recall,
|
||||
precision,
|
||||
label=name,
|
||||
label=label,
|
||||
markevery=_get_marker_positions(thresholds),
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user