Added matching configs

This commit is contained in:
mbsantiago 2025-08-08 13:05:50 +01:00
parent a485ea4f79
commit 6213238585
4 changed files with 140 additions and 36 deletions

View File

@ -1,13 +1,9 @@
from batdetect2.evaluate.evaluate import (
compute_error_auc,
)
from batdetect2.evaluate.match import (
match_predictions_and_annotations,
match_sound_events_and_raw_predictions,
)
__all__ = [
"compute_error_auc",
"match_predictions_and_annotations",
"match_sound_events_and_raw_predictions",
"match_predictions_and_annotations",
]

View File

@ -1,54 +1,133 @@
from typing import List
from typing import Annotated, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_geometries
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.types import Match
from batdetect2.postprocess.types import RawPrediction
from batdetect2.configs import BaseConfig
from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig):
match_method: Literal["BBoxIOU"] = "BBoxIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
frequency_buffer: float = 1_000
class IntervalMatchConfig(BaseConfig):
match_method: Literal["IntervalIOU"] = "IntervalIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
class StartTimeMatchConfig(BaseConfig):
match_method: Literal["StartTime"] = "StartTime"
time_buffer: float = 0.01
MatchConfig = Annotated[
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
Field(discriminator="match_method"),
]
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
def prepare_geometry(
geometry: data.Geometry, config: MatchConfig
) -> data.Geometry:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
if config.match_method == "BBoxIOU":
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if config.match_method == "IntervalIOU":
return data.TimeInterval(coordinates=[start_time, end_time])
if config.match_method == "StartTime":
return data.TimeStamp(coordinates=start_time)
raise NotImplementedError(
f"Invalid matching configuration. Unknown match method: {config.match_method}"
)
def _get_frequency_buffer(config: MatchConfig) -> float:
if config.match_method == "BBoxIOU":
return config.frequency_buffer
return 0
def _get_affinity_threshold(config: MatchConfig) -> float:
if (
config.match_method == "BBoxIOU"
or config.match_method == "IntervalIOU"
):
return config.affinity_threshold
return 0
def match_sound_events_and_raw_predictions(
sound_events: List[data.SoundEventAnnotation],
raw_predictions: List[RawPrediction],
clip_annotation: data.ClipAnnotation,
raw_predictions: List[BatDetect2Prediction],
targets: TargetProtocol,
) -> List[Match]:
config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
config = config or DEFAULT_MATCH_CONFIG
target_sound_events = [
targets.transform(sound_event_annotation)
for sound_event_annotation in sound_events
for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None
]
target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry
prepare_geometry(
sound_event_annotation.sound_event.geometry,
config=config,
)
for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
]
predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions
prepare_geometry(raw_prediction.raw.geometry, config=config)
for raw_prediction in raw_predictions
]
matches = []
for id1, id2, affinity in match_geometries(
target_geometries,
predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
):
target = target_sound_events[id1] if id1 is not None else None
prediction = raw_predictions[id2] if id2 is not None else None
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0
pred_score = float(prediction.raw.detection_score) if prediction else 0
class_scores = (
{
str(class_name): float(score)
for class_name, score in iterate_over_array(
prediction.class_scores
prediction.raw.class_scores
)
}
if prediction is not None
@ -56,13 +135,18 @@ def match_sound_events_and_raw_predictions(
)
matches.append(
Match(
gt_uuid=gt_uuid,
MatchEvaluation(
match=data.Match(
source=None
if prediction is None
else prediction.sound_event_prediction,
target=target,
affinity=affinity,
),
gt_det=gt_det,
gt_class=gt_class,
pred_score=pred_score,
affinity=affinity,
class_scores=class_scores,
pred_class_scores=class_scores,
)
)
@ -72,7 +156,10 @@ def match_sound_events_and_raw_predictions(
def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]:
config = config or DEFAULT_MATCH_CONFIG
annotated_sound_events = [
sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events
@ -86,13 +173,13 @@ def match_predictions_and_annotations(
]
annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None
]
predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry
prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None
]
@ -101,6 +188,9 @@ def match_predictions_and_annotations(
for id1, id2, affinity in match_geometries(
annotated_geometries,
predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
):
target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None

View File

@ -4,13 +4,13 @@ import pandas as pd
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import Match, MetricsProtocol
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
__all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]:
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches]
)
@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
self.class_names = class_names
self.per_class = per_class
def __call__(self, matches: List[Match]) -> Dict[str, float]:
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize(
[
match.gt_class if match.gt_class is not None else "__NONE__"
@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches
@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]):
self.class_names = class_names
def __call__(self, matches: List[Match]) -> Dict[str, float]:
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = [
match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches
@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol):
y_pred = pd.DataFrame(
[
{
name: match.class_scores.get(name, 0)
name: match.pred_class_scores.get(name, 0)
for name in self.class_names
}
for match in matches

View File

@ -1,22 +1,40 @@
from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol
from uuid import UUID
from soundevent import data
__all__ = [
"MetricsProtocol",
"Match",
"MatchEvaluation",
]
@dataclass
class Match:
gt_uuid: Optional[UUID]
class MatchEvaluation:
match: data.Match
gt_det: bool
gt_class: Optional[str]
pred_score: float
affinity: float
class_scores: Dict[str, float]
pred_class_scores: Dict[str, float]
@property
def pred_class(self) -> Optional[str]:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def pred_class_score(self) -> float:
pred_class = self.pred_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
class MetricsProtocol(Protocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...