mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
98 lines
2.8 KiB
Python
98 lines
2.8 KiB
Python
from typing import Dict, List
|
|
|
|
import pandas as pd
|
|
from sklearn import metrics
|
|
from sklearn.preprocessing import label_binarize
|
|
|
|
from batdetect2.evaluate.types import Match, MetricsProtocol
|
|
|
|
__all__ = ["DetectionAveragePrecision"]
|
|
|
|
|
|
class DetectionAveragePrecision(MetricsProtocol):
|
|
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
|
y_true, y_score = zip(
|
|
*[(match.gt_det, match.pred_score) for match in matches]
|
|
)
|
|
score = float(metrics.average_precision_score(y_true, y_score))
|
|
return {"detection_AP": score}
|
|
|
|
|
|
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|
def __init__(self, class_names: List[str], per_class: bool = True):
|
|
self.class_names = class_names
|
|
self.per_class = per_class
|
|
|
|
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
|
y_true = label_binarize(
|
|
[
|
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
|
for match in matches
|
|
],
|
|
classes=self.class_names,
|
|
)
|
|
y_pred = pd.DataFrame(
|
|
[
|
|
{
|
|
name: match.class_scores.get(name, 0)
|
|
for name in self.class_names
|
|
}
|
|
for match in matches
|
|
]
|
|
).fillna(0)
|
|
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
|
|
|
ret = {
|
|
"classification_mAP": float(mAP),
|
|
}
|
|
|
|
if not self.per_class:
|
|
return ret
|
|
|
|
for class_index, class_name in enumerate(self.class_names):
|
|
y_true_class = y_true[:, class_index]
|
|
y_pred_class = y_pred[class_name]
|
|
class_ap = metrics.average_precision_score(
|
|
y_true_class,
|
|
y_pred_class,
|
|
)
|
|
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
|
|
|
return ret
|
|
|
|
|
|
class ClassificationAccuracy(MetricsProtocol):
|
|
def __init__(self, class_names: List[str]):
|
|
self.class_names = class_names
|
|
|
|
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
|
y_true = [
|
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
|
for match in matches
|
|
]
|
|
|
|
y_pred = pd.DataFrame(
|
|
[
|
|
{
|
|
name: match.class_scores.get(name, 0)
|
|
for name in self.class_names
|
|
}
|
|
for match in matches
|
|
]
|
|
).fillna(0)
|
|
y_pred = y_pred.apply(
|
|
lambda row: row.idxmax()
|
|
if row.max() >= (1 - row.sum())
|
|
else "__NONE__",
|
|
axis=1,
|
|
)
|
|
|
|
accuracy = metrics.balanced_accuracy_score(
|
|
y_true,
|
|
y_pred,
|
|
)
|
|
|
|
return {
|
|
"classification_acc": float(accuracy),
|
|
}
|