diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index c71be73..15f3a34 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,6 +1,5 @@ from collections.abc import Callable, Iterable, Mapping from dataclasses import dataclass, field -from functools import partial from typing import List, Literal, Optional, Protocol, Tuple import numpy as np @@ -9,7 +8,6 @@ from soundevent import data from soundevent.evaluation import compute_affinity from soundevent.evaluation import match_geometries as optimal_match from soundevent.geometry import compute_bounds -from torch.multiprocessing import Pool from batdetect2.configs import BaseConfig from batdetect2.typing import ( @@ -430,17 +428,19 @@ def match_all_predictions( config: Optional[MatchConfig] = None, ) -> List[MatchEvaluation]: logger.info("Matching all annotations and predictions...") - with Pool() as p: - all_matches = p.starmap( - partial( - match_sound_events_and_raw_predictions, - targets=targets, - config=config, - ), - zip(clip_annotations, predictions), + return [ + match + for clip_annotation, raw_predictions in zip( + clip_annotations, + predictions, ) - - return [match for matches in all_matches for match in matches] + for match in match_sound_events_and_raw_predictions( + clip_annotation, + raw_predictions, + targets=targets, + config=config, + ) + ] @dataclass