diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 2557e07..d973c37 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -1,6 +1,3 @@ -import os -from functools import partial -from multiprocessing import Pool from typing import List, Optional, Tuple from lightning import LightningModule, Trainer @@ -170,19 +167,14 @@ def _match_all_collected_examples( targets: TargetProtocol, config: Optional[MatchConfig] = None, ) -> List[MatchEvaluation]: - logger.info("Matching all annotations and predictions") - - with Pool() as p: - matches = p.starmap( - partial( - match_sound_events_and_raw_predictions, - targets=targets, - config=config, - ), - pre_matches, + logger.info("Matching all annotations and predictions...") + return [ + match + for clip_annotation, raw_predictions in pre_matches + for match in match_sound_events_and_raw_predictions( + clip_annotation, raw_predictions, targets=targets, config=config ) - - return [match for clip_matches in matches for match in clip_matches] + ] def _is_in_subclip(