Compare commits

...

2 Commits

Author SHA1 Message Date
mbsantiago
fdbb9c2b43 Go back to serial matching 2025-08-12 19:57:49 +01:00
mbsantiago
75e555ff00 Fix matching error after optimising 2025-08-12 19:49:24 +01:00
2 changed files with 10 additions and 18 deletions

View File

@ -9,7 +9,6 @@ from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig):
@ -126,8 +125,9 @@ def match_sound_events_and_raw_predictions(
class_scores = (
{
str(class_name): float(score)
for class_name, score in iterate_over_array(
prediction.raw.class_scores
for class_name, score in zip(
targets.class_names,
prediction.raw.class_scores,
)
}
if prediction is not None

View File

@ -1,6 +1,3 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer
@ -170,19 +167,14 @@ def _match_all_collected_examples(
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions")
cpu_count = os.cpu_count() or 1
with Pool(processes=min(cpu_count, 4)) as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in pre_matches
for match in match_sound_events_and_raw_predictions(
clip_annotation, raw_predictions, targets=targets, config=config
)
return [match for clip_matches in matches for match in clip_matches]
]
def _is_in_subclip(