Compare commits

..

No commits in common. "fdbb9c2b43b41629ea949a5c489da8cfeed4066f" and "c7b110feebff0bb7fcde361484435b9bd809687a" have entirely different histories.

2 changed files with 18 additions and 10 deletions

View File

@ -9,6 +9,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig): class BBoxMatchConfig(BaseConfig):
@ -125,9 +126,8 @@ def match_sound_events_and_raw_predictions(
class_scores = ( class_scores = (
{ {
str(class_name): float(score) str(class_name): float(score)
for class_name, score in zip( for class_name, score in iterate_over_array(
targets.class_names, prediction.raw.class_scores
prediction.raw.class_scores,
) )
} }
if prediction is not None if prediction is not None

View File

@ -1,3 +1,6 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer from lightning import LightningModule, Trainer
@ -167,14 +170,19 @@ def _match_all_collected_examples(
targets: TargetProtocol, targets: TargetProtocol,
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...") logger.info("Matching all annotations and predictions")
return [
match cpu_count = os.cpu_count() or 1
for clip_annotation, raw_predictions in pre_matches with Pool(processes=min(cpu_count, 4)) as p:
for match in match_sound_events_and_raw_predictions( matches = p.starmap(
clip_annotation, raw_predictions, targets=targets, config=config partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
) )
] return [match for clip_matches in matches for match in clip_matches]
def _is_in_subclip( def _is_in_subclip(