Compute mAP

This commit is contained in:
mbsantiago 2025-09-14 10:08:51 +01:00
parent d80377981e
commit 8628133fd7
2 changed files with 10 additions and 11 deletions

View File

@ -47,7 +47,7 @@ def evaluate(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train, config=config.train.val_loader,
num_workers=num_workers, num_workers=num_workers,
) )
@ -67,7 +67,8 @@ def evaluate(
predictions = get_raw_predictions( predictions = get_raw_predictions(
outputs, outputs,
start_times=[ start_times=[
clip_annotation.clip for clip_annotation in clip_annotations clip_annotation.clip.start_time
for clip_annotation in clip_annotations
], ],
targets=targets, targets=targets,
postprocessor=model.postprocessor, postprocessor=model.postprocessor,

View File

@ -1,5 +1,6 @@
from typing import Dict, List from typing import Dict, List
import numpy as np
import pandas as pd import pandas as pd
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol):
class ClassificationMeanAveragePrecision(MetricsProtocol): class ClassificationMeanAveragePrecision(MetricsProtocol):
def __init__(self, class_names: List[str], per_class: bool = True): def __init__(self, class_names: List[str]):
self.class_names = class_names self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize( y_true = label_binarize(
@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
for match in matches for match in matches
] ]
).fillna(0) ).fillna(0)
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
ret = { ret = {}
"classification_mAP": float(mAP),
}
if not self.per_class:
return ret
for class_index, class_name in enumerate(self.class_names): for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index] y_true_class = y_true[:, class_index]
@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
) )
ret[f"classification_AP/{class_name}"] = float(class_ap) ret[f"classification_AP/{class_name}"] = float(class_ap)
ret["classification_mAP"] = np.mean(
[value for value in ret.values() if value != 0]
)
return ret return ret