diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 5ed4062..ac2db47 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -9,7 +9,6 @@ from batdetect2.configs import BaseConfig from batdetect2.evaluate.types import MatchEvaluation from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol -from batdetect2.utils.arrays import iterate_over_array class BBoxMatchConfig(BaseConfig): @@ -126,8 +125,9 @@ def match_sound_events_and_raw_predictions( class_scores = ( { str(class_name): float(score) - for class_name, score in iterate_over_array( - prediction.raw.class_scores + for class_name, score in zip( + targets.class_names, + prediction.raw.class_scores, ) } if prediction is not None diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index edaf064..2557e07 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -172,8 +172,7 @@ def _match_all_collected_examples( ) -> List[MatchEvaluation]: logger.info("Matching all annotations and predictions") - cpu_count = os.cpu_count() or 1 - with Pool(processes=min(cpu_count, 4)) as p: + with Pool() as p: matches = p.starmap( partial( match_sound_events_and_raw_predictions, @@ -182,6 +181,7 @@ def _match_all_collected_examples( ), pre_matches, ) + return [match for clip_matches in matches for match in clip_matches]