diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index 74feb4f..d93cb7a 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -3,14 +3,15 @@ from typing import Annotated, List, Literal, Optional, Sequence, Tuple, Union import numpy as np from pydantic import Field +from scipy.optimize import linear_sum_assignment from soundevent import data from soundevent.evaluation import compute_affinity -from soundevent.evaluation import match_geometries as optimal_match -from soundevent.geometry import compute_bounds +from soundevent.geometry import buffer_geometry, compute_bounds, scale_geometry from batdetect2.core import BaseConfig, Registry from batdetect2.evaluate.affinity import ( AffinityConfig, + BBoxIOUConfig, GeometricIOUConfig, build_affinity_function, ) @@ -357,23 +358,32 @@ def greedy_match( yield None, gt_idx, 0 -class OptimalMatchConfig(BaseConfig): - name: Literal["optimal_match"] = "optimal_match" +class GreedyAffinityMatchConfig(BaseConfig): + name: Literal["greedy_affinity_match"] = "greedy_affinity_match" + affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig) affinity_threshold: float = 0.5 - time_buffer: float = 0.005 - frequency_buffer: float = 1_000 + time_buffer: float = 0 + frequency_buffer: float = 0 + time_scale: float = 1.0 + frequency_scale: float = 1.0 -class OptimalMatcher(MatcherProtocol): +class GreedyAffinityMatcher(MatcherProtocol): def __init__( self, affinity_threshold: float, - time_buffer: float, - frequency_buffer: float, + affinity_function: AffinityFunction, + time_buffer: float = 0, + frequency_buffer: float = 0, + time_scale: float = 1.0, + frequency_scale: float = 1.0, ): self.affinity_threshold = affinity_threshold + self.affinity_function = affinity_function self.time_buffer = time_buffer self.frequency_buffer = frequency_buffer + self.time_scale = time_scale + self.frequency_scale = frequency_scale def __call__( self, @@ -381,21 +391,125 @@ class OptimalMatcher(MatcherProtocol): predictions: Sequence[data.Geometry], scores: Sequence[float], ): - return optimal_match( - source=predictions, - target=ground_truth, - time_buffer=self.time_buffer, - freq_buffer=self.frequency_buffer, + if self.time_buffer != 0 or self.frequency_buffer != 0: + ground_truth = [ + buffer_geometry( + geometry, + time_buffer=self.time_buffer, + freq_buffer=self.frequency_buffer, + ) + for geometry in ground_truth + ] + + predictions = [ + buffer_geometry( + geometry, + time_buffer=self.time_buffer, + freq_buffer=self.frequency_buffer, + ) + for geometry in predictions + ] + + affinity_matrix = compute_affinity_matrix( + ground_truth, + predictions, + self.affinity_function, + time_scale=self.time_scale, + frequency_scale=self.frequency_scale, + ) + + return select_greedy_matches( + affinity_matrix, + affinity_threshold=self.affinity_threshold, + ) + + @matching_strategies.register(GreedyAffinityMatchConfig) + @staticmethod + def from_config(config: GreedyAffinityMatchConfig): + affinity_function = build_affinity_function(config.affinity_function) + return GreedyAffinityMatcher( + affinity_threshold=config.affinity_threshold, + affinity_function=affinity_function, + time_scale=config.time_scale, + frequency_scale=config.frequency_scale, + ) + + +class OptimalMatchConfig(BaseConfig): + name: Literal["optimal_affinity_match"] = "optimal_affinity_match" + affinity_function: AffinityConfig = Field(default_factory=BBoxIOUConfig) + affinity_threshold: float = 0.5 + time_buffer: float = 0 + frequency_buffer: float = 0 + time_scale: float = 1.0 + frequency_scale: float = 1.0 + + +class OptimalMatcher(MatcherProtocol): + def __init__( + self, + affinity_threshold: float, + affinity_function: AffinityFunction, + time_buffer: float = 0, + frequency_buffer: float = 0, + time_scale: float = 1.0, + frequency_scale: float = 1.0, + ): + self.affinity_threshold = affinity_threshold + self.affinity_function = affinity_function + self.time_buffer = time_buffer + self.frequency_buffer = frequency_buffer + self.time_scale = time_scale + self.frequency_scale = frequency_scale + + def __call__( + self, + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + scores: Sequence[float], + ): + if self.time_buffer != 0 or self.frequency_buffer != 0: + ground_truth = [ + buffer_geometry( + geometry, + time_buffer=self.time_buffer, + freq_buffer=self.frequency_buffer, + ) + for geometry in ground_truth + ] + + predictions = [ + buffer_geometry( + geometry, + time_buffer=self.time_buffer, + freq_buffer=self.frequency_buffer, + ) + for geometry in predictions + ] + + affinity_matrix = compute_affinity_matrix( + ground_truth, + predictions, + self.affinity_function, + time_scale=self.time_scale, + frequency_scale=self.frequency_scale, + ) + return select_optimal_matches( + affinity_matrix, affinity_threshold=self.affinity_threshold, ) @matching_strategies.register(OptimalMatchConfig) @staticmethod def from_config(config: OptimalMatchConfig): + affinity_function = build_affinity_function(config.affinity_function) return OptimalMatcher( affinity_threshold=config.affinity_threshold, + affinity_function=affinity_function, time_buffer=config.time_buffer, frequency_buffer=config.frequency_buffer, + time_scale=config.time_scale, + frequency_scale=config.frequency_scale, ) @@ -404,11 +518,100 @@ MatchConfig = Annotated[ GreedyMatchConfig, StartTimeMatchConfig, OptimalMatchConfig, + GreedyAffinityMatchConfig, ], Field(discriminator="name"), ] +def compute_affinity_matrix( + ground_truth: Sequence[data.Geometry], + predictions: Sequence[data.Geometry], + affinity_function: AffinityFunction, + time_scale: float = 1, + frequency_scale: float = 1, +) -> np.ndarray: + # Scale geometries if necessary + if time_scale != 1 or frequency_scale != 1: + ground_truth = [ + scale_geometry(geometry, time_scale, frequency_scale) + for geometry in ground_truth + ] + + predictions = [ + scale_geometry(geometry, time_scale, frequency_scale) + for geometry in predictions + ] + + affinity_matrix = np.zeros((len(ground_truth), len(predictions))) + for gt_idx, gt_geometry in enumerate(ground_truth): + for pred_idx, pred_geometry in enumerate(predictions): + affinity = affinity_function( + gt_geometry, + pred_geometry, + ) + affinity_matrix[gt_idx, pred_idx] = affinity + + return affinity_matrix + + +def select_optimal_matches( + affinity_matrix: np.ndarray, + affinity_threshold: float = 0.5, +) -> Iterable[Tuple[Optional[int], Optional[int], float]]: + num_gt, num_pred = affinity_matrix.shape + gts = set(range(num_gt)) + preds = set(range(num_pred)) + + assiged_rows, assigned_columns = linear_sum_assignment( + affinity_matrix, + maximize=True, + ) + + for gt_idx, pred_idx in zip(assiged_rows, assigned_columns): + affinity = float(affinity_matrix[gt_idx, pred_idx]) + + if affinity <= affinity_threshold: + continue + + yield gt_idx, pred_idx, affinity + gts.remove(gt_idx) + preds.remove(pred_idx) + + for gt_idx in gts: + yield gt_idx, None, 0 + + for pred_idx in preds: + yield None, pred_idx, 0 + + +def select_greedy_matches( + affinity_matrix: np.ndarray, + affinity_threshold: float = 0.5, +) -> Iterable[Tuple[Optional[int], Optional[int], float]]: + num_gt, num_pred = affinity_matrix.shape + unmatched_pred = set(range(num_pred)) + + for gt_idx in range(num_gt): + row = affinity_matrix[gt_idx] + + top_pred = int(np.argmax(row)) + top_affinity = float(row[top_pred]) + + if ( + top_affinity <= affinity_threshold + or top_pred not in unmatched_pred + ): + yield None, gt_idx, 0 + continue + + unmatched_pred.remove(top_pred) + yield top_pred, gt_idx, top_affinity + + for pred_idx in unmatched_pred: + yield pred_idx, None, 0 + + def build_matcher(config: Optional[MatchConfig] = None) -> MatcherProtocol: config = config or StartTimeMatchConfig() return matching_strategies.build(config)