Compare commits

..

No commits in common. "9d1497b3f4a5e9e80325cadfc0d01a3c0380afea" and "4aea3fb2b0b79edade6f072810a18a27de2ad661" have entirely different histories.

3 changed files with 81 additions and 223 deletions

View File

@ -4,7 +4,7 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.match import MatchConfig from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
@ -13,7 +13,9 @@ __all__ = [
class EvaluationConfig(BaseConfig): class EvaluationConfig(BaseConfig):
match: MatchConfig = Field(default_factory=MatchConfig) match: MatchConfig = Field(
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
)
def load_evaluation_config( def load_evaluation_config(

View File

@ -1,12 +1,8 @@
from collections.abc import Callable, Iterable, Mapping from typing import Annotated, List, Literal, Optional, Union
from typing import List, Literal, Optional, Tuple
import numpy as np from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import compute_affinity from soundevent.evaluation import match_geometries
from soundevent.evaluation import (
match_geometries as optimal_match,
)
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
@ -14,186 +10,70 @@ from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
MatchingStrategy = Literal["greedy", "optimal"]
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
class BBoxMatchConfig(BaseConfig):
MatchingGeometry = Literal["bbox", "interval", "timestamp"] match_method: Literal["BBoxIOU"] = "BBoxIOU"
"""The geometry representation to use for matching.""" affinity_threshold: float = 0.5
time_buffer: float = 0.01
class MatchConfig(BaseConfig):
"""Configuration for matching geometries.
Attributes
----------
strategy : MatchingStrategy, default="greedy"
The matching algorithm to use. 'greedy' prioritizes high-confidence
predictions, while 'optimal' finds the globally best set of matches.
geometry : MatchingGeometry, default="timestamp"
The geometric representation to use when computing affinity.
affinity_threshold : float, default=0.0
The minimum affinity score (e.g., IoU) required for a valid match.
time_buffer : float, default=0.005
Time tolerance in seconds used in affinity calculations.
frequency_buffer : float, default=1000
Frequency tolerance in Hertz used in affinity calculations.
"""
strategy: MatchingStrategy = "greedy"
geometry: MatchingGeometry = "timestamp"
affinity_threshold: float = 0.0
time_buffer: float = 0.005
frequency_buffer: float = 1_000 frequency_buffer: float = 1_000
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox: class IntervalMatchConfig(BaseConfig):
match_method: Literal["IntervalIOU"] = "IntervalIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
class StartTimeMatchConfig(BaseConfig):
match_method: Literal["StartTime"] = "StartTime"
time_buffer: float = 0.01
MatchConfig = Annotated[
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
Field(discriminator="match_method"),
]
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
def prepare_geometry(
geometry: data.Geometry, config: MatchConfig
) -> data.Geometry:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry) start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if config.match_method == "BBoxIOU":
def _to_interval(geometry: data.Geometry) -> data.TimeInterval: return data.BoundingBox(
start_time, _, end_time, _ = compute_bounds(geometry) coordinates=[start_time, low_freq, end_time, high_freq]
return data.TimeInterval(coordinates=[start_time, end_time])
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
start_time = compute_bounds(geometry)[0]
return data.TimeStamp(coordinates=start_time)
_geometry_cast_functions: Mapping[
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
] = {
"bbox": _to_bbox,
"interval": _to_interval,
"timestamp": _to_timestamp,
}
def match_geometries(
source: List[data.Geometry],
target: List[data.Geometry],
config: MatchConfig,
scores: Optional[List[float]] = None,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
geometry_cast = _geometry_cast_functions[config.geometry]
if config.strategy == "optimal":
return optimal_match(
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target],
time_buffer=config.time_buffer,
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
) )
if config.strategy == "greedy": if config.match_method == "IntervalIOU":
return greedy_match( return data.TimeInterval(coordinates=[start_time, end_time])
source=[geometry_cast(geom) for geom in source],
target=[geometry_cast(geom) for geom in target], if config.match_method == "StartTime":
time_buffer=config.time_buffer, return data.TimeStamp(coordinates=start_time)
freq_buffer=config.frequency_buffer,
affinity_threshold=config.affinity_threshold,
scores=scores,
)
raise NotImplementedError( raise NotImplementedError(
f"Matching strategy not implemented {config.strategy}" f"Invalid matching configuration. Unknown match method: {config.match_method}"
) )
def greedy_match( def _get_frequency_buffer(config: MatchConfig) -> float:
source: List[data.Geometry], if config.match_method == "BBoxIOU":
target: List[data.Geometry], return config.frequency_buffer
scores: Optional[List[float]] = None,
affinity_threshold: float = 0.5,
time_buffer: float = 0.001,
freq_buffer: float = 1000,
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
"""Performs a greedy, one-to-one matching of source to target geometries.
Iterates through source geometries, prioritizing by score if provided. Each return 0
source is matched to the best available target, provided the affinity
exceeds the threshold and the target has not already been assigned.
Parameters
----------
source
A list of source geometries (e.g., predictions).
target
A list of target geometries (e.g., ground truths).
scores
Confidence scores for each source geometry for prioritization.
affinity_threshold
The minimum affinity score required for a valid match.
time_buffer
Time tolerance in seconds for affinity calculation.
freq_buffer
Frequency tolerance in Hertz for affinity calculation.
Yields def _get_affinity_threshold(config: MatchConfig) -> float:
------ if (
Tuple[Optional[int], Optional[int], float] config.match_method == "BBoxIOU"
A 3-element tuple describing a match or a miss. There are three or config.match_method == "IntervalIOU"
possible formats: ):
- Successful Match: `(source_idx, target_idx, affinity)` return config.affinity_threshold
- Unmatched Source (False Positive): `(source_idx, None, 0)`
- Unmatched Target (False Negative): `(None, target_idx, 0)`
"""
assigned = set()
if not source: return 0
for target_idx in range(len(target)):
yield None, target_idx, 0
return
if not target:
for source_idx in range(len(source)):
yield source_idx, None, 0
return
if scores is None:
indices = np.arange(len(source))
else:
indices = np.argsort(scores)[::-1]
for source_idx in indices:
source_geometry = source[source_idx]
affinities = np.array(
[
compute_affinity(
source_geometry,
target_geometry,
time_buffer=time_buffer,
freq_buffer=freq_buffer,
)
for target_geometry in target
]
)
closest_target = int(np.argmax(affinities))
affinity = affinities[closest_target]
if affinities[closest_target] <= affinity_threshold:
yield source_idx, None, 0
continue
if closest_target in assigned:
yield source_idx, None, 0
continue
assigned.add(closest_target)
yield source_idx, closest_target, affinity
missed_ground_truth = set(range(len(target))) - assigned
for target_idx in missed_ground_truth:
yield None, target_idx, 0
def match_sound_events_and_raw_predictions( def match_sound_events_and_raw_predictions(
@ -202,7 +82,7 @@ def match_sound_events_and_raw_predictions(
targets: TargetProtocol, targets: TargetProtocol,
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]: ) -> List[MatchEvaluation]:
config = config or MatchConfig() config = config or DEFAULT_MATCH_CONFIG
target_sound_events = [ target_sound_events = [
targets.transform(sound_event_annotation) targets.transform(sound_event_annotation)
@ -212,34 +92,30 @@ def match_sound_events_and_raw_predictions(
] ]
target_geometries: List[data.Geometry] = [ # type: ignore target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry prepare_geometry(
sound_event_annotation.sound_event.geometry,
config=config,
)
for sound_event_annotation in target_sound_events for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None if sound_event_annotation.sound_event.geometry is not None
] ]
predicted_geometries = [ predicted_geometries = [
raw_prediction.raw.geometry for raw_prediction in raw_predictions prepare_geometry(raw_prediction.raw.geometry, config=config)
]
scores = [
raw_prediction.raw.detection_score
for raw_prediction in raw_predictions for raw_prediction in raw_predictions
] ]
matches = [] matches = []
for source_idx, target_idx, affinity in match_geometries( for id1, id2, affinity in match_geometries(
source=predicted_geometries, target_geometries,
target=target_geometries, predicted_geometries,
config=config, time_buffer=config.time_buffer,
scores=scores, freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
): ):
target = ( target = target_sound_events[id1] if id1 is not None else None
target_sound_events[target_idx] if target_idx is not None else None prediction = raw_predictions[id2] if id2 is not None else None
)
prediction = (
raw_predictions[source_idx] if source_idx is not None else None
)
gt_det = target is not None gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
@ -282,7 +158,7 @@ def match_predictions_and_annotations(
clip_prediction: data.ClipPrediction, clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None, config: Optional[MatchConfig] = None,
) -> List[data.Match]: ) -> List[data.Match]:
config = config or MatchConfig() config = config or DEFAULT_MATCH_CONFIG
annotated_sound_events = [ annotated_sound_events = [
sound_event_annotation sound_event_annotation
@ -297,46 +173,29 @@ def match_predictions_and_annotations(
] ]
annotated_geometries: List[data.Geometry] = [ annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in annotated_sound_events for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None if sound_event.sound_event.geometry is not None
] ]
predicted_geometries: List[data.Geometry] = [ predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
scores = [
sound_event.score
for sound_event in predicted_sound_events for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None if sound_event.sound_event.geometry is not None
] ]
matches = [] matches = []
for source_idx, target_idx, affinity in match_geometries( for id1, id2, affinity in match_geometries(
source=predicted_geometries, annotated_geometries,
target=annotated_geometries, predicted_geometries,
config=config, time_buffer=config.time_buffer,
scores=scores, freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
): ):
target = ( target = annotated_sound_events[id1] if id1 is not None else None
annotated_sound_events[target_idx] source = predicted_sound_events[id2] if id2 is not None else None
if target_idx is not None
else None
)
source = (
predicted_sound_events[source_idx]
if source_idx is not None
else None
)
matches.append( matches.append(
data.Match( data.Match(source=source, target=target, affinity=affinity)
source=source,
target=target,
affinity=affinity,
)
) )
return matches return matches

View File

@ -177,10 +177,7 @@ def _match_all_collected_examples(
match match
for clip_annotation, raw_predictions in pre_matches for clip_annotation, raw_predictions in pre_matches
for match in match_sound_events_and_raw_predictions( for match in match_sound_events_and_raw_predictions(
clip_annotation, clip_annotation, raw_predictions, targets=targets, config=config
raw_predictions,
targets=targets,
config=config,
) )
] ]