Compare commits

..

2 Commits

Author SHA1 Message Date
mbsantiago
fdbb9c2b43 Go back to serial matching 2025-08-12 19:57:49 +01:00
mbsantiago
75e555ff00 Fix matching error after optimising 2025-08-12 19:49:24 +01:00
2 changed files with 10 additions and 18 deletions

View File

@ -9,7 +9,6 @@ from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig): class BBoxMatchConfig(BaseConfig):
@ -126,8 +125,9 @@ def match_sound_events_and_raw_predictions(
class_scores = ( class_scores = (
{ {
str(class_name): float(score) str(class_name): float(score)
for class_name, score in iterate_over_array( for class_name, score in zip(
prediction.raw.class_scores targets.class_names,
prediction.raw.class_scores,
) )
} }
if prediction is not None if prediction is not None

View File

@ -1,6 +1,3 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
@ -170,19 +167,14 @@ def _match_all_collected_examples(
targets: TargetProtocol, targets: TargetProtocol,
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions") logger.info("Matching all annotations and predictions...")
return [
cpu_count = os.cpu_count() or 1 match
with Pool(processes=min(cpu_count, 4)) as p: for clip_annotation, raw_predictions in pre_matches
matches = p.starmap( for match in match_sound_events_and_raw_predictions(
partial( clip_annotation, raw_predictions, targets=targets, config=config
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
) )
return [match for clip_matches in matches for match in clip_matches] ]
def _is_in_subclip( def _is_in_subclip(