mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Handle case of empty preds or gt
This commit is contained in:
parent
7af72912da
commit
9d1497b3f4
@ -139,19 +139,31 @@ def greedy_match(
|
||||
Tuple[Optional[int], Optional[int], float]
|
||||
A 3-element tuple describing a match or a miss. There are three
|
||||
possible formats:
|
||||
- Successful Match: `(target_idx, source_idx, affinity)`
|
||||
- Unmatched Source (False Positive): `(None, source_idx, 0)`
|
||||
- Unmatched Target (False Negative): `(target_idx, None, 0)`
|
||||
- Successful Match: `(source_idx, target_idx, affinity)`
|
||||
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||
"""
|
||||
assigned = set()
|
||||
|
||||
if not source:
|
||||
for target_idx in range(len(target)):
|
||||
yield None, target_idx, 0
|
||||
|
||||
return
|
||||
|
||||
if not target:
|
||||
for source_idx in range(len(source)):
|
||||
yield source_idx, None, 0
|
||||
|
||||
return
|
||||
|
||||
if scores is None:
|
||||
indices = np.arange(len(source))
|
||||
else:
|
||||
indices = np.argsort(scores)[::-1]
|
||||
|
||||
for index in indices:
|
||||
source_geometry = source[index]
|
||||
for source_idx in indices:
|
||||
source_geometry = source[source_idx]
|
||||
|
||||
affinities = np.array(
|
||||
[
|
||||
@ -169,19 +181,19 @@ def greedy_match(
|
||||
affinity = affinities[closest_target]
|
||||
|
||||
if affinities[closest_target] <= affinity_threshold:
|
||||
yield index, None, 0
|
||||
yield source_idx, None, 0
|
||||
continue
|
||||
|
||||
if closest_target in assigned:
|
||||
yield index, None, 0
|
||||
yield source_idx, None, 0
|
||||
continue
|
||||
|
||||
assigned.add(closest_target)
|
||||
yield index, closest_target, affinity
|
||||
yield source_idx, closest_target, affinity
|
||||
|
||||
missed_ground_truth = set(range(len(target))) - assigned
|
||||
for index in missed_ground_truth:
|
||||
yield None, index, 0
|
||||
for target_idx in missed_ground_truth:
|
||||
yield None, target_idx, 0
|
||||
|
||||
|
||||
def match_sound_events_and_raw_predictions(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user