mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Compute mAP
This commit is contained in:
parent
d80377981e
commit
8628133fd7
@ -47,7 +47,7 @@ def evaluate(
|
|||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.train,
|
config=config.train.val_loader,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -67,7 +67,8 @@ def evaluate(
|
|||||||
predictions = get_raw_predictions(
|
predictions = get_raw_predictions(
|
||||||
outputs,
|
outputs,
|
||||||
start_times=[
|
start_times=[
|
||||||
clip_annotation.clip for clip_annotation in clip_annotations
|
clip_annotation.clip.start_time
|
||||||
|
for clip_annotation in clip_annotations
|
||||||
],
|
],
|
||||||
targets=targets,
|
targets=targets,
|
||||||
postprocessor=model.postprocessor,
|
postprocessor=model.postprocessor,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from sklearn.preprocessing import label_binarize
|
from sklearn.preprocessing import label_binarize
|
||||||
@ -19,9 +20,8 @@ class DetectionAveragePrecision(MetricsProtocol):
|
|||||||
|
|
||||||
|
|
||||||
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||||
def __init__(self, class_names: List[str], per_class: bool = True):
|
def __init__(self, class_names: List[str]):
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
self.per_class = per_class
|
|
||||||
|
|
||||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
y_true = label_binarize(
|
y_true = label_binarize(
|
||||||
@ -40,14 +40,8 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
for match in matches
|
for match in matches
|
||||||
]
|
]
|
||||||
).fillna(0)
|
).fillna(0)
|
||||||
mAP = metrics.average_precision_score(y_true, y_pred[self.class_names])
|
|
||||||
|
|
||||||
ret = {
|
ret = {}
|
||||||
"classification_mAP": float(mAP),
|
|
||||||
}
|
|
||||||
|
|
||||||
if not self.per_class:
|
|
||||||
return ret
|
|
||||||
|
|
||||||
for class_index, class_name in enumerate(self.class_names):
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
y_true_class = y_true[:, class_index]
|
y_true_class = y_true[:, class_index]
|
||||||
@ -58,6 +52,10 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
)
|
)
|
||||||
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
ret[f"classification_AP/{class_name}"] = float(class_ap)
|
||||||
|
|
||||||
|
ret["classification_mAP"] = np.mean(
|
||||||
|
[value for value in ret.values() if value != 0]
|
||||||
|
)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user