Compare commits

..

No commits in common. "fdbb9c2b43b41629ea949a5c489da8cfeed4066f" and "c7b110feebff0bb7fcde361484435b9bd809687a" have entirely different histories.

2 changed files with 18 additions and 10 deletions

View File

@ -9,6 +9,7 @@ from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig):
@ -125,9 +126,8 @@ def match_sound_events_and_raw_predictions(
class_scores = (
{
str(class_name): float(score)
for class_name, score in zip(
targets.class_names,
prediction.raw.class_scores,
for class_name, score in iterate_over_array(
prediction.raw.class_scores
)
}
if prediction is not None

View File

@ -1,3 +1,6 @@
import os
from functools import partial
from multiprocessing import Pool
from typing import List, Optional, Tuple
from lightning import LightningModule, Trainer
@ -167,14 +170,19 @@ def _match_all_collected_examples(
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in pre_matches
for match in match_sound_events_and_raw_predictions(
clip_annotation, raw_predictions, targets=targets, config=config
logger.info("Matching all annotations and predictions")
cpu_count = os.cpu_count() or 1
with Pool(processes=min(cpu_count, 4)) as p:
matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
pre_matches,
)
]
return [match for clip_matches in matches for match in clip_matches]
def _is_in_subclip(