Implement previous matching strategy (greedy)

This commit is contained in:
mbsantiago 2025-08-17 22:52:00 +01:00
parent 4aea3fb2b0
commit 7af72912da
3 changed files with 217 additions and 87 deletions

View File

@ -4,7 +4,7 @@ from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
from batdetect2.evaluate.match import MatchConfig
__all__ = [
"EvaluationConfig",
@ -13,9 +13,7 @@ __all__ = [
class EvaluationConfig(BaseConfig):
match: MatchConfig = Field(
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
)
match: MatchConfig = Field(default_factory=MatchConfig)
def load_evaluation_config(

View File

@ -1,8 +1,12 @@
from typing import Annotated, List, Literal, Optional, Union
from collections.abc import Callable, Iterable, Mapping
from typing import List, Literal, Optional, Tuple
from pydantic import Field
import numpy as np
from soundevent import data
from soundevent.evaluation import match_geometries
from soundevent.evaluation import compute_affinity
from soundevent.evaluation import (
match_geometries as optimal_match,
)
from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig
@ -10,70 +14,174 @@ from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
class BBoxMatchConfig(BaseConfig):
match_method: Literal["BBoxIOU"] = "BBoxIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
"""The geometry representation to use for matching."""
class MatchConfig(BaseConfig):
"""Configuration for matching geometries.
Attributes
----------
strategy : MatchingStrategy, default="greedy"
The matching algorithm to use. 'greedy' prioritizes high-confidence
predictions, while 'optimal' finds the globally best set of matches.
geometry : MatchingGeometry, default="timestamp"
The geometric representation to use when computing affinity.
affinity_threshold : float, default=0.0
The minimum affinity score (e.g., IoU) required for a valid match.
time_buffer : float, default=0.005
Time tolerance in seconds used in affinity calculations.
frequency_buffer : float, default=1000
Frequency tolerance in Hertz used in affinity calculations.
"""
strategy: MatchingStrategy = "greedy"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000
class IntervalMatchConfig(BaseConfig):
match_method: Literal["IntervalIOU"] = "IntervalIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
class StartTimeMatchConfig(BaseConfig):
match_method: Literal["StartTime"] = "StartTime"
time_buffer: float = 0.01
MatchConfig = Annotated[
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
Field(discriminator="match_method"),
]
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
def prepare_geometry(
geometry: data.Geometry, config: MatchConfig
) -> data.Geometry:
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
if config.match_method == "BBoxIOU":
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if config.match_method == "IntervalIOU":
return data.TimeInterval(coordinates=[start_time, end_time])
if config.match_method == "StartTime":
return data.TimeStamp(coordinates=start_time)
raise NotImplementedError(
f"Invalid matching configuration. Unknown match method: {config.match_method}"
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
def _get_frequency_buffer(config: MatchConfig) -> float:
if config.match_method == "BBoxIOU":
return config.frequency_buffer
return 0
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
start_time, _, end_time, _ = compute_bounds(geometry)
return data.TimeInterval(coordinates=[start_time, end_time])
def _get_affinity_threshold(config: MatchConfig) -> float:
if (
config.match_method == "BBoxIOU"
or config.match_method == "IntervalIOU"
):
return config.affinity_threshold
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
start_time = compute_bounds(geometry)[0]
return data.TimeStamp(coordinates=start_time)
return 0
_geometry_cast_functions: Mapping[
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
] = {
"bbox": _to_bbox,
"interval": _to_interval,
"timestamp": _to_timestamp,
}
def match_geometries(
source: List[data.Geometry],
target: List[data.Geometry],
config: MatchConfig,
scores: Optional[List[float]] = None,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
geometry_cast = _geometry_cast_functions[config.geometry]
if config.strategy == "optimal":
return optimal_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
)
if config.strategy == "greedy":
return greedy_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
scores=scores,
)
raise NotImplementedError(
f"Matching strategy not implemented {config.strategy}"
)
def greedy_match(
source: List[data.Geometry],
target: List[data.Geometry],
scores: Optional[List[float]] = None,
affinity_threshold: float = 0.5,
time_buffer: float = 0.001,
freq_buffer: float = 1000,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
"""Performs a greedy, one-to-one matching of source to target geometries.
Iterates through source geometries, prioritizing by score if provided. Each
source is matched to the best available target, provided the affinity
exceeds the threshold and the target has not already been assigned.
Parameters
----------
source
A list of source geometries (e.g., predictions).
target
A list of target geometries (e.g., ground truths).
scores
Confidence scores for each source geometry for prioritization.
affinity_threshold
The minimum affinity score required for a valid match.
time_buffer
Time tolerance in seconds for affinity calculation.
freq_buffer
Frequency tolerance in Hertz for affinity calculation.
Yields
------
Tuple[Optional[int], Optional[int], float]
A 3-element tuple describing a match or a miss. There are three
possible formats:
- Successful Match: `(target_idx, source_idx, affinity)`
- Unmatched Source (False Positive): `(None, source_idx, 0)`
- Unmatched Target (False Negative): `(target_idx, None, 0)`
"""
assigned = set()
if scores is None:
indices = np.arange(len(source))
else:
indices = np.argsort(scores)[::-1]
for index in indices:
source_geometry = source[index]
affinities = np.array(
[
compute_affinity(
source_geometry,
target_geometry,
time_buffer=time_buffer,
freq_buffer=freq_buffer,
)
for target_geometry in target
]
)
closest_target = int(np.argmax(affinities))
affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold:
yield index, None, 0
continue
if closest_target in assigned:
yield index, None, 0
continue
assigned.add(closest_target)
yield index, closest_target, affinity
missed_ground_truth = set(range(len(target))) - assigned
for index in missed_ground_truth:
yield None, index, 0
def match_sound_events_and_raw_predictions(
@ -82,7 +190,7 @@ def match_sound_events_and_raw_predictions(
targets: TargetProtocol,
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
config = config or DEFAULT_MATCH_CONFIG
config = config or MatchConfig()
target_sound_events = [
targets.transform(sound_event_annotation)
@ -92,30 +200,34 @@ def match_sound_events_and_raw_predictions(
]
target_geometries: List[data.Geometry] = [ # type: ignore
prepare_geometry(
sound_event_annotation.sound_event.geometry,
config=config,
)
sound_event_annotation.sound_event.geometry
for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_geometries = [
prepare_geometry(raw_prediction.raw.geometry, config=config)
raw_prediction.raw.geometry for raw_prediction in raw_predictions
]
scores = [
raw_prediction.raw.detection_score
for raw_prediction in raw_predictions
]
matches = []
for id1, id2, affinity in match_geometries(
target_geometries,
predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
for source_idx, target_idx, affinity in match_geometries(
source=predicted_geometries,
target=target_geometries,
config=config,
scores=scores,
):
target = target_sound_events[id1] if id1 is not None else None
prediction = raw_predictions[id2] if id2 is not None else None
target = (
target_sound_events[target_idx] if target_idx is not None else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None
@ -158,7 +270,7 @@ def match_predictions_and_annotations(
clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]:
config = config or DEFAULT_MATCH_CONFIG
config = config or MatchConfig()
annotated_sound_events = [
sound_event_annotation
@ -173,29 +285,46 @@ def match_predictions_and_annotations(
]
annotated_geometries: List[data.Geometry] = [
prepare_geometry(sound_event.sound_event.geometry, config=config)
sound_event.sound_event.geometry
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
prepare_geometry(sound_event.sound_event.geometry, config=config)
sound_event.sound_event.geometry
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
scores = [
sound_event.score
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
matches = []
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
for source_idx, target_idx, affinity in match_geometries(
source=predicted_geometries,
target=annotated_geometries,
config=config,
scores=scores,
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None
target = (
annotated_sound_events[target_idx]
if target_idx is not None
else None
)
source = (
predicted_sound_events[source_idx]
if source_idx is not None
else None
)
matches.append(
data.Match(source=source, target=target, affinity=affinity)
data.Match(
source=source,
target=target,
affinity=affinity,
)
)
return matches

View File

@ -177,7 +177,10 @@ def _match_all_collected_examples(
match
for clip_annotation, raw_predictions in pre_matches
for match in match_sound_events_and_raw_predictions(
clip_annotation, raw_predictions, targets=targets, config=config
clip_annotation,
raw_predictions,
targets=targets,
config=config,
)
]