mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Implement previous matching strategy (greedy)
This commit is contained in:
parent
4aea3fb2b0
commit
7af72912da
@ -4,7 +4,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
|
from batdetect2.evaluate.match import MatchConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
@ -13,9 +13,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
match: MatchConfig = Field(
|
match: MatchConfig = Field(default_factory=MatchConfig)
|
||||||
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
|
|||||||
@ -1,8 +1,12 @@
|
|||||||
from typing import Annotated, List, Literal, Optional, Union
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
|
from typing import List, Literal, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import Field
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_geometries
|
from soundevent.evaluation import compute_affinity
|
||||||
|
from soundevent.evaluation import (
|
||||||
|
match_geometries as optimal_match,
|
||||||
|
)
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
@ -10,70 +14,174 @@ from batdetect2.evaluate.types import MatchEvaluation
|
|||||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
MatchingStrategy = Literal["greedy", "optimal"]
|
||||||
|
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
||||||
|
|
||||||
class BBoxMatchConfig(BaseConfig):
|
|
||||||
match_method: Literal["BBoxIOU"] = "BBoxIOU"
|
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||||
affinity_threshold: float = 0.5
|
"""The geometry representation to use for matching."""
|
||||||
time_buffer: float = 0.01
|
|
||||||
|
|
||||||
|
class MatchConfig(BaseConfig):
|
||||||
|
"""Configuration for matching geometries.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
strategy : MatchingStrategy, default="greedy"
|
||||||
|
The matching algorithm to use. 'greedy' prioritizes high-confidence
|
||||||
|
predictions, while 'optimal' finds the globally best set of matches.
|
||||||
|
geometry : MatchingGeometry, default="timestamp"
|
||||||
|
The geometric representation to use when computing affinity.
|
||||||
|
affinity_threshold : float, default=0.0
|
||||||
|
The minimum affinity score (e.g., IoU) required for a valid match.
|
||||||
|
time_buffer : float, default=0.005
|
||||||
|
Time tolerance in seconds used in affinity calculations.
|
||||||
|
frequency_buffer : float, default=1000
|
||||||
|
Frequency tolerance in Hertz used in affinity calculations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
strategy: MatchingStrategy = "greedy"
|
||||||
|
geometry: MatchingGeometry = "timestamp"
|
||||||
|
affinity_threshold: float = 0.0
|
||||||
|
time_buffer: float = 0.005
|
||||||
frequency_buffer: float = 1_000
|
frequency_buffer: float = 1_000
|
||||||
|
|
||||||
|
|
||||||
class IntervalMatchConfig(BaseConfig):
|
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||||
match_method: Literal["IntervalIOU"] = "IntervalIOU"
|
|
||||||
affinity_threshold: float = 0.5
|
|
||||||
time_buffer: float = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
class StartTimeMatchConfig(BaseConfig):
|
|
||||||
match_method: Literal["StartTime"] = "StartTime"
|
|
||||||
time_buffer: float = 0.01
|
|
||||||
|
|
||||||
|
|
||||||
MatchConfig = Annotated[
|
|
||||||
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
|
|
||||||
Field(discriminator="match_method"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_geometry(
|
|
||||||
geometry: data.Geometry, config: MatchConfig
|
|
||||||
) -> data.Geometry:
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
return data.BoundingBox(
|
||||||
if config.match_method == "BBoxIOU":
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
return data.BoundingBox(
|
|
||||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.match_method == "IntervalIOU":
|
|
||||||
return data.TimeInterval(coordinates=[start_time, end_time])
|
|
||||||
|
|
||||||
if config.match_method == "StartTime":
|
|
||||||
return data.TimeStamp(coordinates=start_time)
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"Invalid matching configuration. Unknown match method: {config.match_method}"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _get_frequency_buffer(config: MatchConfig) -> float:
|
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
|
||||||
if config.match_method == "BBoxIOU":
|
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||||
return config.frequency_buffer
|
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def _get_affinity_threshold(config: MatchConfig) -> float:
|
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
|
||||||
if (
|
start_time = compute_bounds(geometry)[0]
|
||||||
config.match_method == "BBoxIOU"
|
return data.TimeStamp(coordinates=start_time)
|
||||||
or config.match_method == "IntervalIOU"
|
|
||||||
):
|
|
||||||
return config.affinity_threshold
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
_geometry_cast_functions: Mapping[
|
||||||
|
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
|
||||||
|
] = {
|
||||||
|
"bbox": _to_bbox,
|
||||||
|
"interval": _to_interval,
|
||||||
|
"timestamp": _to_timestamp,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def match_geometries(
|
||||||
|
source: List[data.Geometry],
|
||||||
|
target: List[data.Geometry],
|
||||||
|
config: MatchConfig,
|
||||||
|
scores: Optional[List[float]] = None,
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
geometry_cast = _geometry_cast_functions[config.geometry]
|
||||||
|
|
||||||
|
if config.strategy == "optimal":
|
||||||
|
return optimal_match(
|
||||||
|
source=[geometry_cast(geom) for geom in source],
|
||||||
|
target=[geometry_cast(geom) for geom in target],
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=config.frequency_buffer,
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.strategy == "greedy":
|
||||||
|
return greedy_match(
|
||||||
|
source=[geometry_cast(geom) for geom in source],
|
||||||
|
target=[geometry_cast(geom) for geom in target],
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=config.frequency_buffer,
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
scores=scores,
|
||||||
|
)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Matching strategy not implemented {config.strategy}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def greedy_match(
|
||||||
|
source: List[data.Geometry],
|
||||||
|
target: List[data.Geometry],
|
||||||
|
scores: Optional[List[float]] = None,
|
||||||
|
affinity_threshold: float = 0.5,
|
||||||
|
time_buffer: float = 0.001,
|
||||||
|
freq_buffer: float = 1000,
|
||||||
|
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||||
|
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||||
|
|
||||||
|
Iterates through source geometries, prioritizing by score if provided. Each
|
||||||
|
source is matched to the best available target, provided the affinity
|
||||||
|
exceeds the threshold and the target has not already been assigned.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
source
|
||||||
|
A list of source geometries (e.g., predictions).
|
||||||
|
target
|
||||||
|
A list of target geometries (e.g., ground truths).
|
||||||
|
scores
|
||||||
|
Confidence scores for each source geometry for prioritization.
|
||||||
|
affinity_threshold
|
||||||
|
The minimum affinity score required for a valid match.
|
||||||
|
time_buffer
|
||||||
|
Time tolerance in seconds for affinity calculation.
|
||||||
|
freq_buffer
|
||||||
|
Frequency tolerance in Hertz for affinity calculation.
|
||||||
|
|
||||||
|
Yields
|
||||||
|
------
|
||||||
|
Tuple[Optional[int], Optional[int], float]
|
||||||
|
A 3-element tuple describing a match or a miss. There are three
|
||||||
|
possible formats:
|
||||||
|
- Successful Match: `(target_idx, source_idx, affinity)`
|
||||||
|
- Unmatched Source (False Positive): `(None, source_idx, 0)`
|
||||||
|
- Unmatched Target (False Negative): `(target_idx, None, 0)`
|
||||||
|
"""
|
||||||
|
assigned = set()
|
||||||
|
|
||||||
|
if scores is None:
|
||||||
|
indices = np.arange(len(source))
|
||||||
|
else:
|
||||||
|
indices = np.argsort(scores)[::-1]
|
||||||
|
|
||||||
|
for index in indices:
|
||||||
|
source_geometry = source[index]
|
||||||
|
|
||||||
|
affinities = np.array(
|
||||||
|
[
|
||||||
|
compute_affinity(
|
||||||
|
source_geometry,
|
||||||
|
target_geometry,
|
||||||
|
time_buffer=time_buffer,
|
||||||
|
freq_buffer=freq_buffer,
|
||||||
|
)
|
||||||
|
for target_geometry in target
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
closest_target = int(np.argmax(affinities))
|
||||||
|
affinity = affinities[closest_target]
|
||||||
|
|
||||||
|
if affinities[closest_target] <= affinity_threshold:
|
||||||
|
yield index, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
if closest_target in assigned:
|
||||||
|
yield index, None, 0
|
||||||
|
continue
|
||||||
|
|
||||||
|
assigned.add(closest_target)
|
||||||
|
yield index, closest_target, affinity
|
||||||
|
|
||||||
|
missed_ground_truth = set(range(len(target))) - assigned
|
||||||
|
for index in missed_ground_truth:
|
||||||
|
yield None, index, 0
|
||||||
|
|
||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
def match_sound_events_and_raw_predictions(
|
||||||
@ -82,7 +190,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
config = config or DEFAULT_MATCH_CONFIG
|
config = config or MatchConfig()
|
||||||
|
|
||||||
target_sound_events = [
|
target_sound_events = [
|
||||||
targets.transform(sound_event_annotation)
|
targets.transform(sound_event_annotation)
|
||||||
@ -92,30 +200,34 @@ def match_sound_events_and_raw_predictions(
|
|||||||
]
|
]
|
||||||
|
|
||||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||||
prepare_geometry(
|
sound_event_annotation.sound_event.geometry
|
||||||
sound_event_annotation.sound_event.geometry,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
for sound_event_annotation in target_sound_events
|
for sound_event_annotation in target_sound_events
|
||||||
if sound_event_annotation.sound_event.geometry is not None
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_geometries = [
|
predicted_geometries = [
|
||||||
prepare_geometry(raw_prediction.raw.geometry, config=config)
|
raw_prediction.raw.geometry for raw_prediction in raw_predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = [
|
||||||
|
raw_prediction.raw.detection_score
|
||||||
for raw_prediction in raw_predictions
|
for raw_prediction in raw_predictions
|
||||||
]
|
]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
for id1, id2, affinity in match_geometries(
|
for source_idx, target_idx, affinity in match_geometries(
|
||||||
target_geometries,
|
source=predicted_geometries,
|
||||||
predicted_geometries,
|
target=target_geometries,
|
||||||
time_buffer=config.time_buffer,
|
config=config,
|
||||||
freq_buffer=_get_frequency_buffer(config),
|
scores=scores,
|
||||||
affinity_threshold=_get_affinity_threshold(config),
|
|
||||||
):
|
):
|
||||||
target = target_sound_events[id1] if id1 is not None else None
|
target = (
|
||||||
prediction = raw_predictions[id2] if id2 is not None else None
|
target_sound_events[target_idx] if target_idx is not None else None
|
||||||
|
)
|
||||||
|
prediction = (
|
||||||
|
raw_predictions[source_idx] if source_idx is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
gt_det = target is not None
|
gt_det = target is not None
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
@ -158,7 +270,7 @@ def match_predictions_and_annotations(
|
|||||||
clip_prediction: data.ClipPrediction,
|
clip_prediction: data.ClipPrediction,
|
||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[data.Match]:
|
) -> List[data.Match]:
|
||||||
config = config or DEFAULT_MATCH_CONFIG
|
config = config or MatchConfig()
|
||||||
|
|
||||||
annotated_sound_events = [
|
annotated_sound_events = [
|
||||||
sound_event_annotation
|
sound_event_annotation
|
||||||
@ -173,29 +285,46 @@ def match_predictions_and_annotations(
|
|||||||
]
|
]
|
||||||
|
|
||||||
annotated_geometries: List[data.Geometry] = [
|
annotated_geometries: List[data.Geometry] = [
|
||||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
sound_event.sound_event.geometry
|
||||||
for sound_event in annotated_sound_events
|
for sound_event in annotated_sound_events
|
||||||
if sound_event.sound_event.geometry is not None
|
if sound_event.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_geometries: List[data.Geometry] = [
|
predicted_geometries: List[data.Geometry] = [
|
||||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
sound_event.sound_event.geometry
|
||||||
|
for sound_event in predicted_sound_events
|
||||||
|
if sound_event.sound_event.geometry is not None
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = [
|
||||||
|
sound_event.score
|
||||||
for sound_event in predicted_sound_events
|
for sound_event in predicted_sound_events
|
||||||
if sound_event.sound_event.geometry is not None
|
if sound_event.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for id1, id2, affinity in match_geometries(
|
for source_idx, target_idx, affinity in match_geometries(
|
||||||
annotated_geometries,
|
source=predicted_geometries,
|
||||||
predicted_geometries,
|
target=annotated_geometries,
|
||||||
time_buffer=config.time_buffer,
|
config=config,
|
||||||
freq_buffer=_get_frequency_buffer(config),
|
scores=scores,
|
||||||
affinity_threshold=_get_affinity_threshold(config),
|
|
||||||
):
|
):
|
||||||
target = annotated_sound_events[id1] if id1 is not None else None
|
target = (
|
||||||
source = predicted_sound_events[id2] if id2 is not None else None
|
annotated_sound_events[target_idx]
|
||||||
|
if target_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
source = (
|
||||||
|
predicted_sound_events[source_idx]
|
||||||
|
if source_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
matches.append(
|
matches.append(
|
||||||
data.Match(source=source, target=target, affinity=affinity)
|
data.Match(
|
||||||
|
source=source,
|
||||||
|
target=target,
|
||||||
|
affinity=affinity,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|||||||
@ -177,7 +177,10 @@ def _match_all_collected_examples(
|
|||||||
match
|
match
|
||||||
for clip_annotation, raw_predictions in pre_matches
|
for clip_annotation, raw_predictions in pre_matches
|
||||||
for match in match_sound_events_and_raw_predictions(
|
for match in match_sound_events_and_raw_predictions(
|
||||||
clip_annotation, raw_predictions, targets=targets, config=config
|
clip_annotation,
|
||||||
|
raw_predictions,
|
||||||
|
targets=targets,
|
||||||
|
config=config,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user