mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Handle case of empty preds or gt
This commit is contained in:
parent
7af72912da
commit
9d1497b3f4
@ -139,19 +139,31 @@ def greedy_match(
|
|||||||
Tuple[Optional[int], Optional[int], float]
|
Tuple[Optional[int], Optional[int], float]
|
||||||
A 3-element tuple describing a match or a miss. There are three
|
A 3-element tuple describing a match or a miss. There are three
|
||||||
possible formats:
|
possible formats:
|
||||||
- Successful Match: `(target_idx, source_idx, affinity)`
|
- Successful Match: `(source_idx, target_idx, affinity)`
|
||||||
- Unmatched Source (False Positive): `(None, source_idx, 0)`
|
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||||
- Unmatched Target (False Negative): `(target_idx, None, 0)`
|
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||||
"""
|
"""
|
||||||
assigned = set()
|
assigned = set()
|
||||||
|
|
||||||
|
if not source:
|
||||||
|
for target_idx in range(len(target)):
|
||||||
|
yield None, target_idx, 0
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
if not target:
|
||||||
|
for source_idx in range(len(source)):
|
||||||
|
yield source_idx, None, 0
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
if scores is None:
|
if scores is None:
|
||||||
indices = np.arange(len(source))
|
indices = np.arange(len(source))
|
||||||
else:
|
else:
|
||||||
indices = np.argsort(scores)[::-1]
|
indices = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
for index in indices:
|
for source_idx in indices:
|
||||||
source_geometry = source[index]
|
source_geometry = source[source_idx]
|
||||||
|
|
||||||
affinities = np.array(
|
affinities = np.array(
|
||||||
[
|
[
|
||||||
@ -169,19 +181,19 @@ def greedy_match(
|
|||||||
affinity = affinities[closest_target]
|
affinity = affinities[closest_target]
|
||||||
|
|
||||||
if affinities[closest_target] <= affinity_threshold:
|
if affinities[closest_target] <= affinity_threshold:
|
||||||
yield index, None, 0
|
yield source_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if closest_target in assigned:
|
if closest_target in assigned:
|
||||||
yield index, None, 0
|
yield source_idx, None, 0
|
||||||
continue
|
continue
|
||||||
|
|
||||||
assigned.add(closest_target)
|
assigned.add(closest_target)
|
||||||
yield index, closest_target, affinity
|
yield source_idx, closest_target, affinity
|
||||||
|
|
||||||
missed_ground_truth = set(range(len(target))) - assigned
|
missed_ground_truth = set(range(len(target))) - assigned
|
||||||
for index in missed_ground_truth:
|
for target_idx in missed_ground_truth:
|
||||||
yield None, index, 0
|
yield None, target_idx, 0
|
||||||
|
|
||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
def match_sound_events_and_raw_predictions(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user