mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add GreedyAffinityMatching as an alternative to optimal affinity matching
This commit is contained in:
parent
6039b2c3eb
commit
69921f258a
@ -3,14 +3,15 @@ from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from scipy.optimize import linear_sum_assignment
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import match_geometries as optimal_match
|
||||
from soundevent.geometry import compute_bounds
|
||||
from soundevent.geometry import buffer_geometry, compute_bounds, scale_geometry
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
BBoxIOUConfig,
|
||||
GeometricIOUConfig,
|
||||
build_affinity_function,
|
||||
)
|
||||
@ -357,23 +358,32 @@ def greedy_match(
|
||||
yield None, gt_idx, 0
|
||||
|
||||
|
||||
class OptimalMatchConfig(BaseConfig):
|
||||
name: Literal["optimal_match"] = "optimal_match"
|
||||
class GreedyAffinityMatchConfig(BaseConfig):
|
||||
name: Literal["greedy_affinity_match"] = "greedy_affinity_match"
|
||||
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.005
|
||||
frequency_buffer: float = 1_000
|
||||
time_buffer: float = 0
|
||||
frequency_buffer: float = 0
|
||||
time_scale: float = 1.0
|
||||
frequency_scale: float = 1.0
|
||||
|
||||
|
||||
class OptimalMatcher(MatcherProtocol):
|
||||
class GreedyAffinityMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
affinity_threshold: float,
|
||||
time_buffer: float,
|
||||
frequency_buffer: float,
|
||||
affinity_function: AffinityFunction,
|
||||
time_buffer: float = 0,
|
||||
frequency_buffer: float = 0,
|
||||
time_scale: float = 1.0,
|
||||
frequency_scale: float = 1.0,
|
||||
):
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.affinity_function = affinity_function
|
||||
self.time_buffer = time_buffer
|
||||
self.frequency_buffer = frequency_buffer
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -381,21 +391,125 @@ class OptimalMatcher(MatcherProtocol):
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
return optimal_match(
|
||||
source=predictions,
|
||||
target=ground_truth,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
if self.time_buffer != 0 or self.frequency_buffer != 0:
|
||||
ground_truth = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in ground_truth
|
||||
]
|
||||
|
||||
predictions = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in predictions
|
||||
]
|
||||
|
||||
affinity_matrix = compute_affinity_matrix(
|
||||
ground_truth,
|
||||
predictions,
|
||||
self.affinity_function,
|
||||
time_scale=self.time_scale,
|
||||
frequency_scale=self.frequency_scale,
|
||||
)
|
||||
|
||||
return select_greedy_matches(
|
||||
affinity_matrix,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(GreedyAffinityMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: GreedyAffinityMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return GreedyAffinityMatcher(
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
class OptimalMatchConfig(BaseConfig):
|
||||
name: Literal["optimal_affinity_match"] = "optimal_affinity_match"
|
||||
affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig)
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0
|
||||
frequency_buffer: float = 0
|
||||
time_scale: float = 1.0
|
||||
frequency_scale: float = 1.0
|
||||
|
||||
|
||||
class OptimalMatcher(MatcherProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
affinity_threshold: float,
|
||||
affinity_function: AffinityFunction,
|
||||
time_buffer: float = 0,
|
||||
frequency_buffer: float = 0,
|
||||
time_scale: float = 1.0,
|
||||
frequency_scale: float = 1.0,
|
||||
):
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.affinity_function = affinity_function
|
||||
self.time_buffer = time_buffer
|
||||
self.frequency_buffer = frequency_buffer
|
||||
self.time_scale = time_scale
|
||||
self.frequency_scale = frequency_scale
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
scores: Sequence[float],
|
||||
):
|
||||
if self.time_buffer != 0 or self.frequency_buffer != 0:
|
||||
ground_truth = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in ground_truth
|
||||
]
|
||||
|
||||
predictions = [
|
||||
buffer_geometry(
|
||||
geometry,
|
||||
time_buffer=self.time_buffer,
|
||||
freq_buffer=self.frequency_buffer,
|
||||
)
|
||||
for geometry in predictions
|
||||
]
|
||||
|
||||
affinity_matrix = compute_affinity_matrix(
|
||||
ground_truth,
|
||||
predictions,
|
||||
self.affinity_function,
|
||||
time_scale=self.time_scale,
|
||||
frequency_scale=self.frequency_scale,
|
||||
)
|
||||
return select_optimal_matches(
|
||||
affinity_matrix,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
)
|
||||
|
||||
@matching_strategies.register(OptimalMatchConfig)
|
||||
@staticmethod
|
||||
def from_config(config: OptimalMatchConfig):
|
||||
affinity_function = build_affinity_function(config.affinity_function)
|
||||
return OptimalMatcher(
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
affinity_function=affinity_function,
|
||||
time_buffer=config.time_buffer,
|
||||
frequency_buffer=config.frequency_buffer,
|
||||
time_scale=config.time_scale,
|
||||
frequency_scale=config.frequency_scale,
|
||||
)
|
||||
|
||||
|
||||
@ -404,11 +518,100 @@ MatchConfig = Annotated[
|
||||
GreedyMatchConfig,
|
||||
StartTimeMatchConfig,
|
||||
OptimalMatchConfig,
|
||||
GreedyAffinityMatchConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def compute_affinity_matrix(
|
||||
ground_truth: Sequence[data.Geometry],
|
||||
predictions: Sequence[data.Geometry],
|
||||
affinity_function: AffinityFunction,
|
||||
time_scale: float = 1,
|
||||
frequency_scale: float = 1,
|
||||
) -> np.ndarray:
|
||||
# Scale geometries if necessary
|
||||
if time_scale != 1 or frequency_scale != 1:
|
||||
ground_truth = [
|
||||
scale_geometry(geometry, time_scale, frequency_scale)
|
||||
for geometry in ground_truth
|
||||
]
|
||||
|
||||
predictions = [
|
||||
scale_geometry(geometry, time_scale, frequency_scale)
|
||||
for geometry in predictions
|
||||
]
|
||||
|
||||
affinity_matrix = np.zeros((len(ground_truth), len(predictions)))
|
||||
for gt_idx, gt_geometry in enumerate(ground_truth):
|
||||
for pred_idx, pred_geometry in enumerate(predictions):
|
||||
affinity = affinity_function(
|
||||
gt_geometry,
|
||||
pred_geometry,
|
||||
)
|
||||
affinity_matrix[gt_idx, pred_idx] = affinity
|
||||
|
||||
return affinity_matrix
|
||||
|
||||
|
||||
def select_optimal_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
gts = set(range(num_gt))
|
||||
preds = set(range(num_pred))
|
||||
|
||||
assiged_rows, assigned_columns = linear_sum_assignment(
|
||||
affinity_matrix,
|
||||
maximize=True,
|
||||
)
|
||||
|
||||
for gt_idx, pred_idx in zip(assiged_rows, assigned_columns):
|
||||
affinity = float(affinity_matrix[gt_idx, pred_idx])
|
||||
|
||||
if affinity <= affinity_threshold:
|
||||
continue
|
||||
|
||||
yield gt_idx, pred_idx, affinity
|
||||
gts.remove(gt_idx)
|
||||
preds.remove(pred_idx)
|
||||
|
||||
for gt_idx in gts:
|
||||
yield gt_idx, None, 0
|
||||
|
||||
for pred_idx in preds:
|
||||
yield None, pred_idx, 0
|
||||
|
||||
|
||||
def select_greedy_matches(
|
||||
affinity_matrix: np.ndarray,
|
||||
affinity_threshold: float = 0.5,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
num_gt, num_pred = affinity_matrix.shape
|
||||
unmatched_pred = set(range(num_pred))
|
||||
|
||||
for gt_idx in range(num_gt):
|
||||
row = affinity_matrix[gt_idx]
|
||||
|
||||
top_pred = int(np.argmax(row))
|
||||
top_affinity = float(row[top_pred])
|
||||
|
||||
if (
|
||||
top_affinity <= affinity_threshold
|
||||
or top_pred not in unmatched_pred
|
||||
):
|
||||
yield None, gt_idx, 0
|
||||
continue
|
||||
|
||||
unmatched_pred.remove(top_pred)
|
||||
yield top_pred, gt_idx, top_affinity
|
||||
|
||||
for pred_idx in unmatched_pred:
|
||||
yield pred_idx, None, 0
|
||||
|
||||
|
||||
def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol:
|
||||
config = config or StartTimeMatchConfig()
|
||||
return matching_strategies.build(config)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user