From db2ad1174363f1c228829c7a139e229bf9171c42 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 1 Sep 2025 11:19:02 +0100 Subject: [PATCH] Make matching in parallel for speedup --- src/batdetect2/evaluate/match.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 15f3a34..c71be73 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,5 +1,6 @@ from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, field +from functools import partial from typing import List, Literal, Optional, Protocol, Tuple import numpy as np @@ -8,6 +9,7 @@ from soundevent import data from soundevent.evaluation import compute_affinity from soundevent.evaluation import match_geometries as optimal_match from soundevent.geometry import compute_bounds +from torch.multiprocessing import Pool from batdetect2.configs import BaseConfig from batdetect2.typing import ( @@ -428,19 +430,17 @@ def match_all_predictions( config: Optional[MatchConfig] = None, ) -> List[MatchEvaluation]: logger.info("Matching all annotations and predictions...") - return [ - match - for clip_annotation, raw_predictions in zip( - clip_annotations, - predictions, + with Pool() as p: + all_matches = p.starmap( + partial( + match_sound_events_and_raw_predictions, + targets=targets, + config=config, + ), + zip(clip_annotations, predictions), ) - for match in match_sound_events_and_raw_predictions( - clip_annotation, - raw_predictions, - targets=targets, - config=config, - ) - ] + + return [match for matches in all_matches for match in matches] @dataclass