Compare commits

..

2 Commits

Author SHA1 Message Date
mbsantiago
db2ad11743 Make matching in parallel for speedup 2025-09-01 11:19:02 +01:00
mbsantiago
e0ecc3c3d1 Clear evaluation callback after epoch ends 2025-09-01 08:56:38 +01:00
2 changed files with 14 additions and 13 deletions

View File

@ -1,5 +1,6 @@
from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field
from functools import partial
from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np
@ -8,6 +9,7 @@ from soundevent import data
from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds
from torch.multiprocessing import Pool
from batdetect2.configs import BaseConfig
from batdetect2.typing import (
@ -428,19 +430,17 @@ def match_all_predictions(
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...")
return [
match
for clip_annotation, raw_predictions in zip(
clip_annotations,
predictions,
with Pool() as p:
all_matches = p.starmap(
partial(
match_sound_events_and_raw_predictions,
targets=targets,
config=config,
),
zip(clip_annotations, predictions),
)
for match in match_sound_events_and_raw_predictions(
clip_annotation,
raw_predictions,
targets=targets,
config=config,
)
]
return [match for matches in all_matches for match in matches]
@dataclass

View File

@ -108,7 +108,8 @@ class ValidationMetrics(Callback):
trainer: Trainer,
pl_module: LightningModule,
) -> None:
self._matches = []
self._clip_annotations = []
self._predictions = []
return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore