Compute mAP

This commit is contained in:
mbsantiago 2025-09-14 10:08:51 +01:00
parent d80377981e
commit 8628133fd7
2 changed files with 10 additions and 11 deletions

View File

@ -47,7 +47,7 @@ def evaluate(
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train,
config=config.train.val_loader,
num_workers=num_workers,
)
@ -67,7 +67,8 @@ def evaluate(
predictions = get_raw_predictions(
outputs,
start_times=[
clip_annotation.clip for clip_annotation in clip_annotations
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,

View File

@ -1,5 +1,6 @@
from typing import Dict, List
import numpy as np
import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol):
class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str], per_class: bool = True):
def __init__(self, class_names: List[str]):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize(
@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
for match in matches
]
).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = {
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
ret = {}
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
)
ret[f"classification_AP/{class_name}"] = float(class_ap)
ret["classification_mAP"] = np.mean(
[value for value in ret.values() if value != 0]
)
return ret