mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-12 01:39:35 +01:00
Compare commits
No commits in common. "fdbb9c2b43b41629ea949a5c489da8cfeed4066f" and "c7b110feebff0bb7fcde361484435b9bd809687a" have entirely different histories.
fdbb9c2b43
...
c7b110feeb
@ -9,6 +9,7 @@ from batdetect2.configs import BaseConfig
|
|||||||
from batdetect2.evaluate.types import MatchEvaluation
|
from batdetect2.evaluate.types import MatchEvaluation
|
||||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
from batdetect2.utils.arrays import iterate_over_array
|
||||||
|
|
||||||
|
|
||||||
class BBoxMatchConfig(BaseConfig):
|
class BBoxMatchConfig(BaseConfig):
|
||||||
@ -125,9 +126,8 @@ def match_sound_events_and_raw_predictions(
|
|||||||
class_scores = (
|
class_scores = (
|
||||||
{
|
{
|
||||||
str(class_name): float(score)
|
str(class_name): float(score)
|
||||||
for class_name, score in zip(
|
for class_name, score in iterate_over_array(
|
||||||
targets.class_names,
|
prediction.raw.class_scores
|
||||||
prediction.raw.class_scores,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if prediction is not None
|
if prediction is not None
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from multiprocessing import Pool
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
@ -167,14 +170,19 @@ def _match_all_collected_examples(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
logger.info("Matching all annotations and predictions...")
|
logger.info("Matching all annotations and predictions")
|
||||||
return [
|
|
||||||
match
|
cpu_count = os.cpu_count() or 1
|
||||||
for clip_annotation, raw_predictions in pre_matches
|
with Pool(processes=min(cpu_count, 4)) as p:
|
||||||
for match in match_sound_events_and_raw_predictions(
|
matches = p.starmap(
|
||||||
clip_annotation, raw_predictions, targets=targets, config=config
|
partial(
|
||||||
|
match_sound_events_and_raw_predictions,
|
||||||
|
targets=targets,
|
||||||
|
config=config,
|
||||||
|
),
|
||||||
|
pre_matches,
|
||||||
)
|
)
|
||||||
]
|
return [match for clip_matches in matches for match in clip_matches]
|
||||||
|
|
||||||
|
|
||||||
def _is_in_subclip(
|
def _is_in_subclip(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user