diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 06d4b0c..6c80cf8 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -139,19 +139,31 @@ def greedy_match( Tuple[Optional[int], Optional[int], float] A 3-element tuple describing a match or a miss. There are three possible formats: - - Successful Match: `(target_idx, source_idx, affinity)` - - Unmatched Source (False Positive): `(None, source_idx, 0)` - - Unmatched Target (False Negative): `(target_idx, None, 0)` + - Successful Match: `(source_idx, target_idx, affinity)` + - Unmatched Source (False Positive): `(source_idx, None, 0)` + - Unmatched Target (False Negative): `(None, target_idx, 0)` """ assigned = set() + if not source: + for target_idx in range(len(target)): + yield None, target_idx, 0 + + return + + if not target: + for source_idx in range(len(source)): + yield source_idx, None, 0 + + return + if scores is None: indices = np.arange(len(source)) else: indices = np.argsort(scores)[::-1] - for index in indices: - source_geometry = source[index] + for source_idx in indices: + source_geometry = source[source_idx] affinities = np.array( [ @@ -169,19 +181,19 @@ def greedy_match( affinity = affinities[closest_target] if affinities[closest_target] <= affinity_threshold: - yield index, None, 0 + yield source_idx, None, 0 continue if closest_target in assigned: - yield index, None, 0 + yield source_idx, None, 0 continue assigned.add(closest_target) - yield index, closest_target, affinity + yield source_idx, closest_target, affinity missed_ground_truth = set(range(len(target))) - assigned - for index in missed_ground_truth: - yield None, index, 0 + for target_idx in missed_ground_truth: + yield None, target_idx, 0 def match_sound_events_and_raw_predictions(