mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
2 Commits
4aea3fb2b0
...
9d1497b3f4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d1497b3f4 | ||
|
|
7af72912da |
@ -4,7 +4,7 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
|
||||
from batdetect2.evaluate.match import MatchConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
@ -13,9 +13,7 @@ __all__ = [
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
match: MatchConfig = Field(
|
||||
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
|
||||
)
|
||||
match: MatchConfig = Field(default_factory=MatchConfig)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
|
||||
@ -1,8 +1,12 @@
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
from collections.abc import Callable, Iterable, Mapping
|
||||
from typing import List, Literal, Optional, Tuple
|
||||
|
||||
from pydantic import Field
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
from soundevent.evaluation import compute_affinity
|
||||
from soundevent.evaluation import (
|
||||
match_geometries as optimal_match,
|
||||
)
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
@ -10,70 +14,186 @@ from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
MatchingStrategy = Literal["greedy", "optimal"]
|
||||
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
||||
|
||||
class BBoxMatchConfig(BaseConfig):
|
||||
match_method: Literal["BBoxIOU"] = "BBoxIOU"
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.01
|
||||
|
||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
||||
"""The geometry representation to use for matching."""
|
||||
|
||||
|
||||
class MatchConfig(BaseConfig):
|
||||
"""Configuration for matching geometries.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
strategy : MatchingStrategy, default="greedy"
|
||||
The matching algorithm to use. 'greedy' prioritizes high-confidence
|
||||
predictions, while 'optimal' finds the globally best set of matches.
|
||||
geometry : MatchingGeometry, default="timestamp"
|
||||
The geometric representation to use when computing affinity.
|
||||
affinity_threshold : float, default=0.0
|
||||
The minimum affinity score (e.g., IoU) required for a valid match.
|
||||
time_buffer : float, default=0.005
|
||||
Time tolerance in seconds used in affinity calculations.
|
||||
frequency_buffer : float, default=1000
|
||||
Frequency tolerance in Hertz used in affinity calculations.
|
||||
"""
|
||||
|
||||
strategy: MatchingStrategy = "greedy"
|
||||
geometry: MatchingGeometry = "timestamp"
|
||||
affinity_threshold: float = 0.0
|
||||
time_buffer: float = 0.005
|
||||
frequency_buffer: float = 1_000
|
||||
|
||||
|
||||
class IntervalMatchConfig(BaseConfig):
|
||||
match_method: Literal["IntervalIOU"] = "IntervalIOU"
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
match_method: Literal["StartTime"] = "StartTime"
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
|
||||
Field(discriminator="match_method"),
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
|
||||
|
||||
|
||||
def prepare_geometry(
|
||||
geometry: data.Geometry, config: MatchConfig
|
||||
) -> data.Geometry:
|
||||
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
|
||||
if config.match_method == "BBoxIOU":
|
||||
return data.BoundingBox(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
if config.match_method == "IntervalIOU":
|
||||
|
||||
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||
|
||||
if config.match_method == "StartTime":
|
||||
|
||||
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
|
||||
start_time = compute_bounds(geometry)[0]
|
||||
return data.TimeStamp(coordinates=start_time)
|
||||
|
||||
|
||||
_geometry_cast_functions: Mapping[
|
||||
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
|
||||
] = {
|
||||
"bbox": _to_bbox,
|
||||
"interval": _to_interval,
|
||||
"timestamp": _to_timestamp,
|
||||
}
|
||||
|
||||
|
||||
def match_geometries(
|
||||
source: List[data.Geometry],
|
||||
target: List[data.Geometry],
|
||||
config: MatchConfig,
|
||||
scores: Optional[List[float]] = None,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
geometry_cast = _geometry_cast_functions[config.geometry]
|
||||
|
||||
if config.strategy == "optimal":
|
||||
return optimal_match(
|
||||
source=[geometry_cast(geom) for geom in source],
|
||||
target=[geometry_cast(geom) for geom in target],
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=config.frequency_buffer,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
)
|
||||
|
||||
if config.strategy == "greedy":
|
||||
return greedy_match(
|
||||
source=[geometry_cast(geom) for geom in source],
|
||||
target=[geometry_cast(geom) for geom in target],
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=config.frequency_buffer,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
scores=scores,
|
||||
)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Invalid matching configuration. Unknown match method: {config.match_method}"
|
||||
f"Matching strategy not implemented {config.strategy}"
|
||||
)
|
||||
|
||||
|
||||
def _get_frequency_buffer(config: MatchConfig) -> float:
|
||||
if config.match_method == "BBoxIOU":
|
||||
return config.frequency_buffer
|
||||
def greedy_match(
|
||||
source: List[data.Geometry],
|
||||
target: List[data.Geometry],
|
||||
scores: Optional[List[float]] = None,
|
||||
affinity_threshold: float = 0.5,
|
||||
time_buffer: float = 0.001,
|
||||
freq_buffer: float = 1000,
|
||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
||||
|
||||
return 0
|
||||
Iterates through source geometries, prioritizing by score if provided. Each
|
||||
source is matched to the best available target, provided the affinity
|
||||
exceeds the threshold and the target has not already been assigned.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source
|
||||
A list of source geometries (e.g., predictions).
|
||||
target
|
||||
A list of target geometries (e.g., ground truths).
|
||||
scores
|
||||
Confidence scores for each source geometry for prioritization.
|
||||
affinity_threshold
|
||||
The minimum affinity score required for a valid match.
|
||||
time_buffer
|
||||
Time tolerance in seconds for affinity calculation.
|
||||
freq_buffer
|
||||
Frequency tolerance in Hertz for affinity calculation.
|
||||
|
||||
def _get_affinity_threshold(config: MatchConfig) -> float:
|
||||
if (
|
||||
config.match_method == "BBoxIOU"
|
||||
or config.match_method == "IntervalIOU"
|
||||
):
|
||||
return config.affinity_threshold
|
||||
Yields
|
||||
------
|
||||
Tuple[Optional[int], Optional[int], float]
|
||||
A 3-element tuple describing a match or a miss. There are three
|
||||
possible formats:
|
||||
- Successful Match: `(source_idx, target_idx, affinity)`
|
||||
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
||||
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
||||
"""
|
||||
assigned = set()
|
||||
|
||||
return 0
|
||||
if not source:
|
||||
for target_idx in range(len(target)):
|
||||
yield None, target_idx, 0
|
||||
|
||||
return
|
||||
|
||||
if not target:
|
||||
for source_idx in range(len(source)):
|
||||
yield source_idx, None, 0
|
||||
|
||||
return
|
||||
|
||||
if scores is None:
|
||||
indices = np.arange(len(source))
|
||||
else:
|
||||
indices = np.argsort(scores)[::-1]
|
||||
|
||||
for source_idx in indices:
|
||||
source_geometry = source[source_idx]
|
||||
|
||||
affinities = np.array(
|
||||
[
|
||||
compute_affinity(
|
||||
source_geometry,
|
||||
target_geometry,
|
||||
time_buffer=time_buffer,
|
||||
freq_buffer=freq_buffer,
|
||||
)
|
||||
for target_geometry in target
|
||||
]
|
||||
)
|
||||
|
||||
closest_target = int(np.argmax(affinities))
|
||||
affinity = affinities[closest_target]
|
||||
|
||||
if affinities[closest_target] <= affinity_threshold:
|
||||
yield source_idx, None, 0
|
||||
continue
|
||||
|
||||
if closest_target in assigned:
|
||||
yield source_idx, None, 0
|
||||
continue
|
||||
|
||||
assigned.add(closest_target)
|
||||
yield source_idx, closest_target, affinity
|
||||
|
||||
missed_ground_truth = set(range(len(target))) - assigned
|
||||
for target_idx in missed_ground_truth:
|
||||
yield None, target_idx, 0
|
||||
|
||||
|
||||
def match_sound_events_and_raw_predictions(
|
||||
@ -82,7 +202,7 @@ def match_sound_events_and_raw_predictions(
|
||||
targets: TargetProtocol,
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
config = config or DEFAULT_MATCH_CONFIG
|
||||
config = config or MatchConfig()
|
||||
|
||||
target_sound_events = [
|
||||
targets.transform(sound_event_annotation)
|
||||
@ -92,30 +212,34 @@ def match_sound_events_and_raw_predictions(
|
||||
]
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
prepare_geometry(
|
||||
sound_event_annotation.sound_event.geometry,
|
||||
config=config,
|
||||
)
|
||||
sound_event_annotation.sound_event.geometry
|
||||
for sound_event_annotation in target_sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
prepare_geometry(raw_prediction.raw.geometry, config=config)
|
||||
raw_prediction.raw.geometry for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
scores = [
|
||||
raw_prediction.raw.detection_score
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for id1, id2, affinity in match_geometries(
|
||||
target_geometries,
|
||||
predicted_geometries,
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=_get_frequency_buffer(config),
|
||||
affinity_threshold=_get_affinity_threshold(config),
|
||||
for source_idx, target_idx, affinity in match_geometries(
|
||||
source=predicted_geometries,
|
||||
target=target_geometries,
|
||||
config=config,
|
||||
scores=scores,
|
||||
):
|
||||
target = target_sound_events[id1] if id1 is not None else None
|
||||
prediction = raw_predictions[id2] if id2 is not None else None
|
||||
target = (
|
||||
target_sound_events[target_idx] if target_idx is not None else None
|
||||
)
|
||||
prediction = (
|
||||
raw_predictions[source_idx] if source_idx is not None else None
|
||||
)
|
||||
|
||||
gt_det = target is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
@ -158,7 +282,7 @@ def match_predictions_and_annotations(
|
||||
clip_prediction: data.ClipPrediction,
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[data.Match]:
|
||||
config = config or DEFAULT_MATCH_CONFIG
|
||||
config = config or MatchConfig()
|
||||
|
||||
annotated_sound_events = [
|
||||
sound_event_annotation
|
||||
@ -173,29 +297,46 @@ def match_predictions_and_annotations(
|
||||
]
|
||||
|
||||
annotated_geometries: List[data.Geometry] = [
|
||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in annotated_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries: List[data.Geometry] = [
|
||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||
sound_event.sound_event.geometry
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
scores = [
|
||||
sound_event.score
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
matches = []
|
||||
for id1, id2, affinity in match_geometries(
|
||||
annotated_geometries,
|
||||
predicted_geometries,
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=_get_frequency_buffer(config),
|
||||
affinity_threshold=_get_affinity_threshold(config),
|
||||
for source_idx, target_idx, affinity in match_geometries(
|
||||
source=predicted_geometries,
|
||||
target=annotated_geometries,
|
||||
config=config,
|
||||
scores=scores,
|
||||
):
|
||||
target = annotated_sound_events[id1] if id1 is not None else None
|
||||
source = predicted_sound_events[id2] if id2 is not None else None
|
||||
target = (
|
||||
annotated_sound_events[target_idx]
|
||||
if target_idx is not None
|
||||
else None
|
||||
)
|
||||
source = (
|
||||
predicted_sound_events[source_idx]
|
||||
if source_idx is not None
|
||||
else None
|
||||
)
|
||||
matches.append(
|
||||
data.Match(source=source, target=target, affinity=affinity)
|
||||
data.Match(
|
||||
source=source,
|
||||
target=target,
|
||||
affinity=affinity,
|
||||
)
|
||||
)
|
||||
|
||||
return matches
|
||||
|
||||
@ -177,7 +177,10 @@ def _match_all_collected_examples(
|
||||
match
|
||||
for clip_annotation, raw_predictions in pre_matches
|
||||
for match in match_sound_events_and_raw_predictions(
|
||||
clip_annotation, raw_predictions, targets=targets, config=config
|
||||
clip_annotation,
|
||||
raw_predictions,
|
||||
targets=targets,
|
||||
config=config,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user