diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index 22ae469..5fc9d81 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -4,7 +4,7 @@ from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config -from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig +from batdetect2.evaluate.match import MatchConfig __all__ = [ "EvaluationConfig", @@ -13,9 +13,7 @@ __all__ = [ class EvaluationConfig(BaseConfig): - match: MatchConfig = Field( - default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(), - ) + match: MatchConfig = Field(default_factory=MatchConfig) def load_evaluation_config( diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index ac2db47..06d4b0c 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,8 +1,12 @@ -from typing import Annotated, List, Literal, Optional, Union +from collections.abc import Callable, Iterable, Mapping +from typing import List, Literal, Optional, Tuple -from pydantic import Field +import numpy as np from soundevent import data -from soundevent.evaluation import match_geometries +from soundevent.evaluation import compute_affinity +from soundevent.evaluation import ( + match_geometries as optimal_match, +) from soundevent.geometry import compute_bounds from batdetect2.configs import BaseConfig @@ -10,70 +14,174 @@ from batdetect2.evaluate.types import MatchEvaluation from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol +MatchingStrategy = Literal["greedy", "optimal"] +"""The type of matching algorithm to use: 'greedy' or 'optimal'.""" -class BBoxMatchConfig(BaseConfig): - match_method: Literal["BBoxIOU"] = "BBoxIOU" - affinity_threshold: float = 0.5 - time_buffer: float = 0.01 + +MatchingGeometry = Literal["bbox", "interval", "timestamp"] +"""The geometry representation to use for matching.""" + + +class MatchConfig(BaseConfig): + """Configuration for matching geometries. + + Attributes + ---------- + strategy : MatchingStrategy, default="greedy" + The matching algorithm to use. 'greedy' prioritizes high-confidence + predictions, while 'optimal' finds the globally best set of matches. + geometry : MatchingGeometry, default="timestamp" + The geometric representation to use when computing affinity. + affinity_threshold : float, default=0.0 + The minimum affinity score (e.g., IoU) required for a valid match. + time_buffer : float, default=0.005 + Time tolerance in seconds used in affinity calculations. + frequency_buffer : float, default=1000 + Frequency tolerance in Hertz used in affinity calculations. + """ + + strategy: MatchingStrategy = "greedy" + geometry: MatchingGeometry = "timestamp" + affinity_threshold: float = 0.0 + time_buffer: float = 0.005 frequency_buffer: float = 1_000 -class IntervalMatchConfig(BaseConfig): - match_method: Literal["IntervalIOU"] = "IntervalIOU" - affinity_threshold: float = 0.5 - time_buffer: float = 0.01 - - -class StartTimeMatchConfig(BaseConfig): - match_method: Literal["StartTime"] = "StartTime" - time_buffer: float = 0.01 - - -MatchConfig = Annotated[ - Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig], - Field(discriminator="match_method"), -] - - -DEFAULT_MATCH_CONFIG = BBoxMatchConfig() - - -def prepare_geometry( - geometry: data.Geometry, config: MatchConfig -) -> data.Geometry: +def _to_bbox(geometry: data.Geometry) -> data.BoundingBox: start_time, low_freq, end_time, high_freq = compute_bounds(geometry) - - if config.match_method == "BBoxIOU": - return data.BoundingBox( - coordinates=[start_time, low_freq, end_time, high_freq] - ) - - if config.match_method == "IntervalIOU": - return data.TimeInterval(coordinates=[start_time, end_time]) - - if config.match_method == "StartTime": - return data.TimeStamp(coordinates=start_time) - - raise NotImplementedError( - f"Invalid matching configuration. Unknown match method: {config.match_method}" + return data.BoundingBox( + coordinates=[start_time, low_freq, end_time, high_freq] ) -def _get_frequency_buffer(config: MatchConfig) -> float: - if config.match_method == "BBoxIOU": - return config.frequency_buffer - - return 0 +def _to_interval(geometry: data.Geometry) -> data.TimeInterval: + start_time, _, end_time, _ = compute_bounds(geometry) + return data.TimeInterval(coordinates=[start_time, end_time]) -def _get_affinity_threshold(config: MatchConfig) -> float: - if ( - config.match_method == "BBoxIOU" - or config.match_method == "IntervalIOU" - ): - return config.affinity_threshold +def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp: + start_time = compute_bounds(geometry)[0] + return data.TimeStamp(coordinates=start_time) - return 0 + +_geometry_cast_functions: Mapping[ + MatchingGeometry, Callable[[data.Geometry], data.Geometry] +] = { + "bbox": _to_bbox, + "interval": _to_interval, + "timestamp": _to_timestamp, +} + + +def match_geometries( + source: List[data.Geometry], + target: List[data.Geometry], + config: MatchConfig, + scores: Optional[List[float]] = None, +) -> Iterable[Tuple[Optional[int], Optional[int], float]]: + geometry_cast = _geometry_cast_functions[config.geometry] + + if config.strategy == "optimal": + return optimal_match( + source=[geometry_cast(geom) for geom in source], + target=[geometry_cast(geom) for geom in target], + time_buffer=config.time_buffer, + freq_buffer=config.frequency_buffer, + affinity_threshold=config.affinity_threshold, + ) + + if config.strategy == "greedy": + return greedy_match( + source=[geometry_cast(geom) for geom in source], + target=[geometry_cast(geom) for geom in target], + time_buffer=config.time_buffer, + freq_buffer=config.frequency_buffer, + affinity_threshold=config.affinity_threshold, + scores=scores, + ) + + raise NotImplementedError( + f"Matching strategy not implemented {config.strategy}" + ) + + +def greedy_match( + source: List[data.Geometry], + target: List[data.Geometry], + scores: Optional[List[float]] = None, + affinity_threshold: float = 0.5, + time_buffer: float = 0.001, + freq_buffer: float = 1000, +) -> Iterable[Tuple[Optional[int], Optional[int], float]]: + """Performs a greedy, one-to-one matching of source to target geometries. + + Iterates through source geometries, prioritizing by score if provided. Each + source is matched to the best available target, provided the affinity + exceeds the threshold and the target has not already been assigned. + + Parameters + ---------- + source + A list of source geometries (e.g., predictions). + target + A list of target geometries (e.g., ground truths). + scores + Confidence scores for each source geometry for prioritization. + affinity_threshold + The minimum affinity score required for a valid match. + time_buffer + Time tolerance in seconds for affinity calculation. + freq_buffer + Frequency tolerance in Hertz for affinity calculation. + + Yields + ------ + Tuple[Optional[int], Optional[int], float] + A 3-element tuple describing a match or a miss. There are three + possible formats: + - Successful Match: `(target_idx, source_idx, affinity)` + - Unmatched Source (False Positive): `(None, source_idx, 0)` + - Unmatched Target (False Negative): `(target_idx, None, 0)` + """ + assigned = set() + + if scores is None: + indices = np.arange(len(source)) + else: + indices = np.argsort(scores)[::-1] + + for index in indices: + source_geometry = source[index] + + affinities = np.array( + [ + compute_affinity( + source_geometry, + target_geometry, + time_buffer=time_buffer, + freq_buffer=freq_buffer, + ) + for target_geometry in target + ] + ) + + closest_target = int(np.argmax(affinities)) + affinity = affinities[closest_target] + + if affinities[closest_target] <= affinity_threshold: + yield index, None, 0 + continue + + if closest_target in assigned: + yield index, None, 0 + continue + + assigned.add(closest_target) + yield index, closest_target, affinity + + missed_ground_truth = set(range(len(target))) - assigned + for index in missed_ground_truth: + yield None, index, 0 def match_sound_events_and_raw_predictions( @@ -82,7 +190,7 @@ def match_sound_events_and_raw_predictions( targets: TargetProtocol, config: Optional[MatchConfig] = None, ) -> List[MatchEvaluation]: - config = config or DEFAULT_MATCH_CONFIG + config = config or MatchConfig() target_sound_events = [ targets.transform(sound_event_annotation) @@ -92,30 +200,34 @@ def match_sound_events_and_raw_predictions( ] target_geometries: List[data.Geometry] = [ # type: ignore - prepare_geometry( - sound_event_annotation.sound_event.geometry, - config=config, - ) + sound_event_annotation.sound_event.geometry for sound_event_annotation in target_sound_events if sound_event_annotation.sound_event.geometry is not None ] predicted_geometries = [ - prepare_geometry(raw_prediction.raw.geometry, config=config) + raw_prediction.raw.geometry for raw_prediction in raw_predictions + ] + + scores = [ + raw_prediction.raw.detection_score for raw_prediction in raw_predictions ] matches = [] - for id1, id2, affinity in match_geometries( - target_geometries, - predicted_geometries, - time_buffer=config.time_buffer, - freq_buffer=_get_frequency_buffer(config), - affinity_threshold=_get_affinity_threshold(config), + for source_idx, target_idx, affinity in match_geometries( + source=predicted_geometries, + target=target_geometries, + config=config, + scores=scores, ): - target = target_sound_events[id1] if id1 is not None else None - prediction = raw_predictions[id2] if id2 is not None else None + target = ( + target_sound_events[target_idx] if target_idx is not None else None + ) + prediction = ( + raw_predictions[source_idx] if source_idx is not None else None + ) gt_det = target is not None gt_class = targets.encode_class(target) if target is not None else None @@ -158,7 +270,7 @@ def match_predictions_and_annotations( clip_prediction: data.ClipPrediction, config: Optional[MatchConfig] = None, ) -> List[data.Match]: - config = config or DEFAULT_MATCH_CONFIG + config = config or MatchConfig() annotated_sound_events = [ sound_event_annotation @@ -173,29 +285,46 @@ def match_predictions_and_annotations( ] annotated_geometries: List[data.Geometry] = [ - prepare_geometry(sound_event.sound_event.geometry, config=config) + sound_event.sound_event.geometry for sound_event in annotated_sound_events if sound_event.sound_event.geometry is not None ] predicted_geometries: List[data.Geometry] = [ - prepare_geometry(sound_event.sound_event.geometry, config=config) + sound_event.sound_event.geometry + for sound_event in predicted_sound_events + if sound_event.sound_event.geometry is not None + ] + + scores = [ + sound_event.score for sound_event in predicted_sound_events if sound_event.sound_event.geometry is not None ] matches = [] - for id1, id2, affinity in match_geometries( - annotated_geometries, - predicted_geometries, - time_buffer=config.time_buffer, - freq_buffer=_get_frequency_buffer(config), - affinity_threshold=_get_affinity_threshold(config), + for source_idx, target_idx, affinity in match_geometries( + source=predicted_geometries, + target=annotated_geometries, + config=config, + scores=scores, ): - target = annotated_sound_events[id1] if id1 is not None else None - source = predicted_sound_events[id2] if id2 is not None else None + target = ( + annotated_sound_events[target_idx] + if target_idx is not None + else None + ) + source = ( + predicted_sound_events[source_idx] + if source_idx is not None + else None + ) matches.append( - data.Match(source=source, target=target, affinity=affinity) + data.Match( + source=source, + target=target, + affinity=affinity, + ) ) return matches diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 74d93ab..99d4967 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -177,7 +177,10 @@ def _match_all_collected_examples( match for clip_annotation, raw_predictions in pre_matches for match in match_sound_events_and_raw_predictions( - clip_annotation, raw_predictions, targets=targets, config=config + clip_annotation, + raw_predictions, + targets=targets, + config=config, ) ]