Compare commits

..

No commits in common. "db2ad1174363f1c228829c7a139e229bf9171c42" and "71c2301c216c52fc516b11bcda9b8f1e86ea1570" have entirely different histories.

2 changed files with 13 additions and 14 deletions

View File

@ -1,6 +1,5 @@
from collections.abc import Callable, Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial
from typing import List, Literal, Optional, Protocol, Tuple from typing import List, Literal, Optional, Protocol, Tuple
import numpy as np import numpy as np
@ -9,7 +8,6 @@ from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import compute_affinity
from soundevent.evaluation import match_geometries as optimal_match from soundevent.evaluation import match_geometries as optimal_match
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from torch.multiprocessing import Pool
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.typing import ( from batdetect2.typing import (
@ -430,17 +428,19 @@ def match_all_predictions(
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
logger.info("Matching all annotations and predictions...") logger.info("Matching all annotations and predictions...")
with Pool() as p: return [
all_matches = p.starmap( match
partial( for clip_annotation, raw_predictions in zip(
match_sound_events_and_raw_predictions, clip_annotations,
predictions,
)
for match in match_sound_events_and_raw_predictions(
clip_annotation,
raw_predictions,
targets=targets, targets=targets,
config=config, config=config,
),
zip(clip_annotations, predictions),
) )
]
return [match for matches in all_matches for match in matches]
@dataclass @dataclass

View File

@ -108,8 +108,7 @@ class ValidationMetrics(Callback):
trainer: Trainer, trainer: Trainer,
pl_module: LightningModule, pl_module: LightningModule,
) -> None: ) -> None:
self._clip_annotations = [] self._matches = []
self._predictions = []
return super().on_validation_epoch_start(trainer, pl_module) return super().on_validation_epoch_start(trainer, pl_module)
def on_validation_batch_end( # type: ignore def on_validation_batch_end( # type: ignore