torch.multiprocessing didn't work, returning to serial processing

This commit is contained in:
mbsantiago 2025-09-01 11:27:23 +01:00
parent db2ad11743
commit 709b6355c2

View File

@ -1,6 +1,5 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
from typing import List, Literal, Optional, Protocol, Tuple from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np import numpy as np
@ -9,7 +8,6 @@ from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from torch.multiprocessing import Pool
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.typing import ( from batdetect2.typing import (
@ -430,17 +428,19 @@ def match_all_predictions(
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...") logger.info("Matching all annotations and predictions...")
with Pool() as p: return [
all_matches = p.starmap( match
partial( for clip_annotation, raw_predictions in zip(
match_sound_events_and_raw_predictions, clip_annotations,
targets=targets, predictions,
config=config,
),
zip(clip_annotations, predictions),
) )
for match in match_sound_events_and_raw_predictions(
return [match for matches in all_matches for match in matches] clip_annotation,
raw_predictions,
targets=targets,
config=config,
)
]
@dataclass @dataclass