mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-12 01:39:35 +01:00
Compare commits
No commits in common. "9d1497b3f4a5e9e80325cadfc0d01a3c0380afea" and "4aea3fb2b0b79edade6f072810a18a27de2ad661" have entirely different histories.
9d1497b3f4
...
4aea3fb2b0
@ -4,7 +4,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.match import MatchConfig
|
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
@ -13,7 +13,9 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
match: MatchConfig = Field(default_factory=MatchConfig)
|
match: MatchConfig = Field(
|
||||||
|
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
|
|||||||
@ -1,12 +1,8 @@
|
|||||||
from collections.abc import Callable, Iterable, Mapping
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
from typing import List, Literal, Optional, Tuple
|
|
||||||
|
|
||||||
import numpy as np
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import compute_affinity
|
from soundevent.evaluation import match_geometries
|
||||||
from soundevent.evaluation import (
|
|
||||||
match_geometries as optimal_match,
|
|
||||||
)
|
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
@ -14,186 +10,70 @@ from batdetect2.evaluate.types import MatchEvaluation
|
|||||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
MatchingStrategy = Literal["greedy", "optimal"]
|
|
||||||
"""The type of matching algorithm to use: 'greedy' or 'optimal'."""
|
|
||||||
|
|
||||||
|
class BBoxMatchConfig(BaseConfig):
|
||||||
MatchingGeometry = Literal["bbox", "interval", "timestamp"]
|
match_method: Literal["BBoxIOU"] = "BBoxIOU"
|
||||||
"""The geometry representation to use for matching."""
|
affinity_threshold: float = 0.5
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
|
||||||
class MatchConfig(BaseConfig):
|
|
||||||
"""Configuration for matching geometries.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
strategy : MatchingStrategy, default="greedy"
|
|
||||||
The matching algorithm to use. 'greedy' prioritizes high-confidence
|
|
||||||
predictions, while 'optimal' finds the globally best set of matches.
|
|
||||||
geometry : MatchingGeometry, default="timestamp"
|
|
||||||
The geometric representation to use when computing affinity.
|
|
||||||
affinity_threshold : float, default=0.0
|
|
||||||
The minimum affinity score (e.g., IoU) required for a valid match.
|
|
||||||
time_buffer : float, default=0.005
|
|
||||||
Time tolerance in seconds used in affinity calculations.
|
|
||||||
frequency_buffer : float, default=1000
|
|
||||||
Frequency tolerance in Hertz used in affinity calculations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
strategy: MatchingStrategy = "greedy"
|
|
||||||
geometry: MatchingGeometry = "timestamp"
|
|
||||||
affinity_threshold: float = 0.0
|
|
||||||
time_buffer: float = 0.005
|
|
||||||
frequency_buffer: float = 1_000
|
frequency_buffer: float = 1_000
|
||||||
|
|
||||||
|
|
||||||
def _to_bbox(geometry: data.Geometry) -> data.BoundingBox:
|
class IntervalMatchConfig(BaseConfig):
|
||||||
|
match_method: Literal["IntervalIOU"] = "IntervalIOU"
|
||||||
|
affinity_threshold: float = 0.5
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
|
match_method: Literal["StartTime"] = "StartTime"
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
MatchConfig = Annotated[
|
||||||
|
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
|
||||||
|
Field(discriminator="match_method"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_geometry(
|
||||||
|
geometry: data.Geometry, config: MatchConfig
|
||||||
|
) -> data.Geometry:
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
|
if config.match_method == "BBoxIOU":
|
||||||
return data.BoundingBox(
|
return data.BoundingBox(
|
||||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.match_method == "IntervalIOU":
|
||||||
def _to_interval(geometry: data.Geometry) -> data.TimeInterval:
|
|
||||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
|
||||||
return data.TimeInterval(coordinates=[start_time, end_time])
|
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||||
|
|
||||||
|
if config.match_method == "StartTime":
|
||||||
def _to_timestamp(geometry: data.Geometry) -> data.TimeStamp:
|
|
||||||
start_time = compute_bounds(geometry)[0]
|
|
||||||
return data.TimeStamp(coordinates=start_time)
|
return data.TimeStamp(coordinates=start_time)
|
||||||
|
|
||||||
|
|
||||||
_geometry_cast_functions: Mapping[
|
|
||||||
MatchingGeometry, Callable[[data.Geometry], data.Geometry]
|
|
||||||
] = {
|
|
||||||
"bbox": _to_bbox,
|
|
||||||
"interval": _to_interval,
|
|
||||||
"timestamp": _to_timestamp,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def match_geometries(
|
|
||||||
source: List[data.Geometry],
|
|
||||||
target: List[data.Geometry],
|
|
||||||
config: MatchConfig,
|
|
||||||
scores: Optional[List[float]] = None,
|
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
|
||||||
geometry_cast = _geometry_cast_functions[config.geometry]
|
|
||||||
|
|
||||||
if config.strategy == "optimal":
|
|
||||||
return optimal_match(
|
|
||||||
source=[geometry_cast(geom) for geom in source],
|
|
||||||
target=[geometry_cast(geom) for geom in target],
|
|
||||||
time_buffer=config.time_buffer,
|
|
||||||
freq_buffer=config.frequency_buffer,
|
|
||||||
affinity_threshold=config.affinity_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.strategy == "greedy":
|
|
||||||
return greedy_match(
|
|
||||||
source=[geometry_cast(geom) for geom in source],
|
|
||||||
target=[geometry_cast(geom) for geom in target],
|
|
||||||
time_buffer=config.time_buffer,
|
|
||||||
freq_buffer=config.frequency_buffer,
|
|
||||||
affinity_threshold=config.affinity_threshold,
|
|
||||||
scores=scores,
|
|
||||||
)
|
|
||||||
|
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"Matching strategy not implemented {config.strategy}"
|
f"Invalid matching configuration. Unknown match method: {config.match_method}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def greedy_match(
|
def _get_frequency_buffer(config: MatchConfig) -> float:
|
||||||
source: List[data.Geometry],
|
if config.match_method == "BBoxIOU":
|
||||||
target: List[data.Geometry],
|
return config.frequency_buffer
|
||||||
scores: Optional[List[float]] = None,
|
|
||||||
affinity_threshold: float = 0.5,
|
|
||||||
time_buffer: float = 0.001,
|
|
||||||
freq_buffer: float = 1000,
|
|
||||||
) -> Iterable[Tuple[Optional[int], Optional[int], float]]:
|
|
||||||
"""Performs a greedy, one-to-one matching of source to target geometries.
|
|
||||||
|
|
||||||
Iterates through source geometries, prioritizing by score if provided. Each
|
return 0
|
||||||
source is matched to the best available target, provided the affinity
|
|
||||||
exceeds the threshold and the target has not already been assigned.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
source
|
|
||||||
A list of source geometries (e.g., predictions).
|
|
||||||
target
|
|
||||||
A list of target geometries (e.g., ground truths).
|
|
||||||
scores
|
|
||||||
Confidence scores for each source geometry for prioritization.
|
|
||||||
affinity_threshold
|
|
||||||
The minimum affinity score required for a valid match.
|
|
||||||
time_buffer
|
|
||||||
Time tolerance in seconds for affinity calculation.
|
|
||||||
freq_buffer
|
|
||||||
Frequency tolerance in Hertz for affinity calculation.
|
|
||||||
|
|
||||||
Yields
|
def _get_affinity_threshold(config: MatchConfig) -> float:
|
||||||
------
|
if (
|
||||||
Tuple[Optional[int], Optional[int], float]
|
config.match_method == "BBoxIOU"
|
||||||
A 3-element tuple describing a match or a miss. There are three
|
or config.match_method == "IntervalIOU"
|
||||||
possible formats:
|
):
|
||||||
- Successful Match: `(source_idx, target_idx, affinity)`
|
return config.affinity_threshold
|
||||||
- Unmatched Source (False Positive): `(source_idx, None, 0)`
|
|
||||||
- Unmatched Target (False Negative): `(None, target_idx, 0)`
|
|
||||||
"""
|
|
||||||
assigned = set()
|
|
||||||
|
|
||||||
if not source:
|
return 0
|
||||||
for target_idx in range(len(target)):
|
|
||||||
yield None, target_idx, 0
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
if not target:
|
|
||||||
for source_idx in range(len(source)):
|
|
||||||
yield source_idx, None, 0
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
if scores is None:
|
|
||||||
indices = np.arange(len(source))
|
|
||||||
else:
|
|
||||||
indices = np.argsort(scores)[::-1]
|
|
||||||
|
|
||||||
for source_idx in indices:
|
|
||||||
source_geometry = source[source_idx]
|
|
||||||
|
|
||||||
affinities = np.array(
|
|
||||||
[
|
|
||||||
compute_affinity(
|
|
||||||
source_geometry,
|
|
||||||
target_geometry,
|
|
||||||
time_buffer=time_buffer,
|
|
||||||
freq_buffer=freq_buffer,
|
|
||||||
)
|
|
||||||
for target_geometry in target
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
closest_target = int(np.argmax(affinities))
|
|
||||||
affinity = affinities[closest_target]
|
|
||||||
|
|
||||||
if affinities[closest_target] <= affinity_threshold:
|
|
||||||
yield source_idx, None, 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
if closest_target in assigned:
|
|
||||||
yield source_idx, None, 0
|
|
||||||
continue
|
|
||||||
|
|
||||||
assigned.add(closest_target)
|
|
||||||
yield source_idx, closest_target, affinity
|
|
||||||
|
|
||||||
missed_ground_truth = set(range(len(target))) - assigned
|
|
||||||
for target_idx in missed_ground_truth:
|
|
||||||
yield None, target_idx, 0
|
|
||||||
|
|
||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
def match_sound_events_and_raw_predictions(
|
||||||
@ -202,7 +82,7 @@ def match_sound_events_and_raw_predictions(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[MatchEvaluation]:
|
) -> List[MatchEvaluation]:
|
||||||
config = config or MatchConfig()
|
config = config or DEFAULT_MATCH_CONFIG
|
||||||
|
|
||||||
target_sound_events = [
|
target_sound_events = [
|
||||||
targets.transform(sound_event_annotation)
|
targets.transform(sound_event_annotation)
|
||||||
@ -212,34 +92,30 @@ def match_sound_events_and_raw_predictions(
|
|||||||
]
|
]
|
||||||
|
|
||||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||||
sound_event_annotation.sound_event.geometry
|
prepare_geometry(
|
||||||
|
sound_event_annotation.sound_event.geometry,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
for sound_event_annotation in target_sound_events
|
for sound_event_annotation in target_sound_events
|
||||||
if sound_event_annotation.sound_event.geometry is not None
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_geometries = [
|
predicted_geometries = [
|
||||||
raw_prediction.raw.geometry for raw_prediction in raw_predictions
|
prepare_geometry(raw_prediction.raw.geometry, config=config)
|
||||||
]
|
|
||||||
|
|
||||||
scores = [
|
|
||||||
raw_prediction.raw.detection_score
|
|
||||||
for raw_prediction in raw_predictions
|
for raw_prediction in raw_predictions
|
||||||
]
|
]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
for source_idx, target_idx, affinity in match_geometries(
|
for id1, id2, affinity in match_geometries(
|
||||||
source=predicted_geometries,
|
target_geometries,
|
||||||
target=target_geometries,
|
predicted_geometries,
|
||||||
config=config,
|
time_buffer=config.time_buffer,
|
||||||
scores=scores,
|
freq_buffer=_get_frequency_buffer(config),
|
||||||
|
affinity_threshold=_get_affinity_threshold(config),
|
||||||
):
|
):
|
||||||
target = (
|
target = target_sound_events[id1] if id1 is not None else None
|
||||||
target_sound_events[target_idx] if target_idx is not None else None
|
prediction = raw_predictions[id2] if id2 is not None else None
|
||||||
)
|
|
||||||
prediction = (
|
|
||||||
raw_predictions[source_idx] if source_idx is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
gt_det = target is not None
|
gt_det = target is not None
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
@ -282,7 +158,7 @@ def match_predictions_and_annotations(
|
|||||||
clip_prediction: data.ClipPrediction,
|
clip_prediction: data.ClipPrediction,
|
||||||
config: Optional[MatchConfig] = None,
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[data.Match]:
|
) -> List[data.Match]:
|
||||||
config = config or MatchConfig()
|
config = config or DEFAULT_MATCH_CONFIG
|
||||||
|
|
||||||
annotated_sound_events = [
|
annotated_sound_events = [
|
||||||
sound_event_annotation
|
sound_event_annotation
|
||||||
@ -297,46 +173,29 @@ def match_predictions_and_annotations(
|
|||||||
]
|
]
|
||||||
|
|
||||||
annotated_geometries: List[data.Geometry] = [
|
annotated_geometries: List[data.Geometry] = [
|
||||||
sound_event.sound_event.geometry
|
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||||
for sound_event in annotated_sound_events
|
for sound_event in annotated_sound_events
|
||||||
if sound_event.sound_event.geometry is not None
|
if sound_event.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_geometries: List[data.Geometry] = [
|
predicted_geometries: List[data.Geometry] = [
|
||||||
sound_event.sound_event.geometry
|
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||||
for sound_event in predicted_sound_events
|
|
||||||
if sound_event.sound_event.geometry is not None
|
|
||||||
]
|
|
||||||
|
|
||||||
scores = [
|
|
||||||
sound_event.score
|
|
||||||
for sound_event in predicted_sound_events
|
for sound_event in predicted_sound_events
|
||||||
if sound_event.sound_event.geometry is not None
|
if sound_event.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
for source_idx, target_idx, affinity in match_geometries(
|
for id1, id2, affinity in match_geometries(
|
||||||
source=predicted_geometries,
|
annotated_geometries,
|
||||||
target=annotated_geometries,
|
predicted_geometries,
|
||||||
config=config,
|
time_buffer=config.time_buffer,
|
||||||
scores=scores,
|
freq_buffer=_get_frequency_buffer(config),
|
||||||
|
affinity_threshold=_get_affinity_threshold(config),
|
||||||
):
|
):
|
||||||
target = (
|
target = annotated_sound_events[id1] if id1 is not None else None
|
||||||
annotated_sound_events[target_idx]
|
source = predicted_sound_events[id2] if id2 is not None else None
|
||||||
if target_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
source = (
|
|
||||||
predicted_sound_events[source_idx]
|
|
||||||
if source_idx is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
matches.append(
|
matches.append(
|
||||||
data.Match(
|
data.Match(source=source, target=target, affinity=affinity)
|
||||||
source=source,
|
|
||||||
target=target,
|
|
||||||
affinity=affinity,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return matches
|
return matches
|
||||||
|
|||||||
@ -177,10 +177,7 @@ def _match_all_collected_examples(
|
|||||||
match
|
match
|
||||||
for clip_annotation, raw_predictions in pre_matches
|
for clip_annotation, raw_predictions in pre_matches
|
||||||
for match in match_sound_events_and_raw_predictions(
|
for match in match_sound_events_and_raw_predictions(
|
||||||
clip_annotation,
|
clip_annotation, raw_predictions, targets=targets, config=config
|
||||||
raw_predictions,
|
|
||||||
targets=targets,
|
|
||||||
config=config,
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user