From 8628133fd7f99f9a0f8905548f6db39da5be2417 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sun, 14 Sep 2025 10:08:51 +0100 Subject: [PATCH] Compute mAP --- src/batdetect2/evaluate/evaluate.py | 5 +++-- src/batdetect2/evaluate/metrics.py | 16 +++++++--------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 75c2c93..aed3a08 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -47,7 +47,7 @@ def evaluate( audio_loader=audio_loader, labeller=labeller, preprocessor=preprocessor, - config=config.train, + config=config.train.val_loader, num_workers=num_workers, ) @@ -67,7 +67,8 @@ def evaluate( predictions = get_raw_predictions( outputs, start_times=[ - clip_annotation.clip for clip_annotation in clip_annotations + clip_annotation.clip.start_time + for clip_annotation in clip_annotations ], targets=targets, postprocessor=model.postprocessor, diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index c5df1d0..b42230d 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -1,5 +1,6 @@ from typing import Dict, List +import numpy as np import pandas as pd from sklearn import metrics from sklearn.preprocessing import label_binarize @@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol): class ClassificationMeanAveragePrecision(MetricsProtocol): - def __init__(self, class_names: List[str], per_class: bool = True): + def __init__(self, class_names: List[str]): self.class_names = class_names - self.per_class = per_class def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: y_true = label_binarize( @@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): for match in matches ] ).fillna(0) - mAP = metrics.average_precision_score(y_true, y_pred[self.class_names]) - ret = { - "classification_mAP": float(mAP), - } - - if not self.per_class: - return ret + ret = {} for class_index, class_name in enumerate(self.class_names): y_true_class = y_true[:, class_index] @@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): ) ret[f"classification_AP/{class_name}"] = float(class_ap) + ret["classification_mAP"] = np.mean( + [value for value in ret.values() if value != 0] + ) + return ret