Handle case of empty preds or gt

This commit is contained in:
mbsantiago 2025-08-17 23:00:40 +01:00
parent 7af72912da
commit 9d1497b3f4

View File

@ -139,19 +139,31 @@ def greedy_match(
Tuple[Optional[int], Optional[int], float] Tuple[Optional[int], Optional[int], float]
A 3-element tuple describing a match or a miss. There are three A 3-element tuple describing a match or a miss. There are three
possible formats: possible formats:
- Successful Match: `(target_idx, source_idx, affinity)` - Successful Match: `(source_idx, target_idx, affinity)`
- Unmatched Source (False Positive): `(None, source_idx, 0)` - Unmatched Source (False Positive): `(source_idx, None, 0)`
- Unmatched Target (False Negative): `(target_idx, None, 0)` - Unmatched Target (False Negative): `(None, target_idx, 0)`
""" """
assigned = set() assigned = set()
if not source:
for target_idx in range(len(target)):
yield None, target_idx, 0
return
if not target:
for source_idx in range(len(source)):
yield source_idx, None, 0
return
if scores is None: if scores is None:
indices = np.arange(len(source)) indices = np.arange(len(source))
else: else:
indices = np.argsort(scores)[::-1] indices = np.argsort(scores)[::-1]
for index in indices: for source_idx in indices:
source_geometry = source[index] source_geometry = source[source_idx]
affinities = np.array( affinities = np.array(
[ [
@ -169,19 +181,19 @@ def greedy_match(
affinity = affinities[closest_target] affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold: if affinities[closest_target] <= affinity_threshold:
yield index, None, 0 yield source_idx, None, 0
continue continue
if closest_target in assigned: if closest_target in assigned:
yield index, None, 0 yield source_idx, None, 0
continue continue
assigned.add(closest_target) assigned.add(closest_target)
yield index, closest_target, affinity yield source_idx, closest_target, affinity
missed_ground_truth = set(range(len(target))) - assigned missed_ground_truth = set(range(len(target))) - assigned
for index in missed_ground_truth: for target_idx in missed_ground_truth:
yield None, index, 0 yield None, target_idx, 0
def match_sound_events_and_raw_predictions( def match_sound_events_and_raw_predictions(