diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index 0a9be88..9c48f37 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -1,4 +1,14 @@ -from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union +from collections.abc import Callable, Mapping +from typing import ( + Annotated, + Any, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) import numpy as np from pydantic import Field @@ -16,11 +26,61 @@ __all__ = ["DetectionAP", "ClassificationAP"] metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric") +AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"] + + class DetectionAPConfig(BaseConfig): name: Literal["detection_ap"] = "detection_ap" + implementation: AveragePrecisionImplementation = "pascal_voc" + + +def pascal_voc_average_precision(y_true, y_score) -> float: + y_true = np.array(y_true) + y_score = np.array(y_score) + + sort_ind = np.argsort(y_score)[::-1] + y_true_sorted = y_true[sort_ind] + + num_positives = y_true.sum() + false_pos_c = np.cumsum(1 - y_true_sorted) + true_pos_c = np.cumsum(y_true_sorted) + + recall = true_pos_c / num_positives + precision = true_pos_c / np.maximum( + true_pos_c + false_pos_c, + np.finfo(np.float64).eps, + ) + + precision[np.isnan(precision)] = 0 + recall[np.isnan(recall)] = 0 + + # pascal 12 way + mprec = np.hstack((0, precision, 0)) + mrec = np.hstack((0, recall, 1)) + for ii in range(mprec.shape[0] - 2, -1, -1): + mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1]) + inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1 + ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum() + + return ave_prec + + +_ap_impl_mapping: Mapping[ + AveragePrecisionImplementation, Callable[[Any, Any], float] +] = { + "sklearn": metrics.average_precision_score, + "pascal_voc": pascal_voc_average_precision, +} class DetectionAP(MetricsProtocol): + def __init__( + self, + implementation: AveragePrecisionImplementation = "pascal_voc", + ): + self.implementation = implementation + self.metric = _ap_impl_mapping[self.implementation] + def __call__( self, clip_evaluations: Sequence[ClipEvaluation] ) -> Dict[str, float]: @@ -31,12 +91,12 @@ class DetectionAP(MetricsProtocol): for match in clip_eval.matches ] ) - score = float(metrics.average_precision_score(y_true, y_score)) + score = float(self.metric(y_true, y_score)) return {"detection_AP": score} @classmethod def from_config(cls, config: DetectionAPConfig, class_names: List[str]): - return cls() + return cls(implementation=config.implementation) metrics_registry.register(DetectionAPConfig, DetectionAP) @@ -52,9 +112,12 @@ class ClassificationAP(MetricsProtocol): def __init__( self, class_names: List[str], + implementation: AveragePrecisionImplementation = "pascal_voc", include: Optional[List[str]] = None, exclude: Optional[List[str]] = None, ): + self.implementation = implementation + self.metric = _ap_impl_mapping[self.implementation] self.class_names = class_names self.selected = class_names @@ -107,10 +170,7 @@ class ClassificationAP(MetricsProtocol): for class_index, class_name in enumerate(self.class_names): y_true_class = y_true[:, class_index] y_pred_class = y_pred[:, class_index] - class_ap = metrics.average_precision_score( - y_true_class, - y_pred_class, - ) + class_ap = self.metric(y_true_class, y_pred_class) class_scores[class_name] = float(class_ap) mean_ap = np.mean(