batdetect2/batdetect2/evaluate/metrics.py
2025-04-30 22:51:33 +01:00

98 lines
2.8 KiB
Python

from typing import Dict, List
import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import Match, MetricsProtocol
__all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
)
score = float(metrics.average_precision_score(y_true, y_score))
return {"detection_AP": score}
class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str], per_class: bool = True):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
],
classes=self.class_names,
)
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = {
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[class_name]
class_ap = metrics.average_precision_score(
y_true_class,
y_pred_class,
)
ret[f"classification_AP/{class_name}"] = float(class_ap)
return ret
class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[Match]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
]
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
]
).fillna(0)
y_pred = y_pred.apply(
lambda row: row.idxmax()
if row.max() >= (1 - row.sum())
else "__NONE__",
axis=1,
)
accuracy = metrics.balanced_accuracy_score(
y_true,
y_pred,
)
return {
"classification_acc": float(accuracy),
}