Compare commits

...

2 Commits

Author SHA1 Message Date
mbsantiago
db2ad11743 Make matching in parallel for speedup 2025-09-01 11:19:02 +01:00
mbsantiago
e0ecc3c3d1 Clear evaluation callback after epoch ends 2025-09-01 08:56:38 +01:00
2 changed files with 14 additions and 13 deletions

View File

@ -1,5 +1,6 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
from typing import List, Literal, Optional, Protocol, Tuple from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np import numpy as np
@ -8,6 +9,7 @@ from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from torch.multiprocessing import Pool
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.typing import ( from batdetect2.typing import (
@ -428,19 +430,17 @@ def match_all_predictions(
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...") logger.info("Matching all annotations and predictions...")
return [ with Pool() as p:
match all_matches = p.starmap(
for clip_annotation, raw_predictions in zip( partial(
clip_annotations, match_sound_events_and_raw_predictions,
predictions, targets=targets,
config=config,
),
zip(clip_annotations, predictions),
) )
for match in match_sound_events_and_raw_predictions(
clip_annotation, return [match for matches in all_matches for match in matches]
raw_predictions,
targets=targets,
config=config,
)
]
@dataclass @dataclass

View File

@ -108,7 +108,8 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self._matches = [] self._clip_annotations = []
self._predictions = []
return super().on_validation_epoch_start(trainer, pl_module) return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore def on_validation_batch_end( # type: ignore