mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Add average precision computation to pr curves if needed
This commit is contained in:
parent
16c401b1da
commit
7336638fa9
@ -5,6 +5,7 @@ import seaborn as sns
|
|||||||
from cycler import cycler
|
from cycler import cycler
|
||||||
from matplotlib import axes
|
from matplotlib import axes
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.common import _average_precision
|
||||||
from batdetect2.plotting.common import create_ax
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
|
||||||
@ -80,15 +81,21 @@ def plot_pr_curves(
|
|||||||
figsize: Optional[Tuple[int, int]] = None,
|
figsize: Optional[Tuple[int, int]] = None,
|
||||||
add_legend: bool = True,
|
add_legend: bool = True,
|
||||||
add_labels: bool = True,
|
add_labels: bool = True,
|
||||||
|
include_ap: bool = False,
|
||||||
) -> axes.Axes:
|
) -> axes.Axes:
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
ax = set_default_style(ax)
|
ax = set_default_style(ax)
|
||||||
|
|
||||||
for name, (precision, recall, thresholds) in data.items():
|
for name, (precision, recall, thresholds) in data.items():
|
||||||
|
label = name
|
||||||
|
|
||||||
|
if include_ap:
|
||||||
|
label += f" (AP={_average_precision(recall, precision):.2f})"
|
||||||
|
|
||||||
ax.plot(
|
ax.plot(
|
||||||
recall,
|
recall,
|
||||||
precision,
|
precision,
|
||||||
label=name,
|
label=label,
|
||||||
markevery=_get_marker_positions(thresholds),
|
markevery=_get_marker_positions(thresholds),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user