Added matching configs

This commit is contained in:
mbsantiago 2025-08-08 13:05:50 +01:00
parent a485ea4f79
commit 6213238585
4 changed files with 140 additions and 36 deletions

View File

@ -1,13 +1,9 @@
from batdetect2.evaluate.evaluate import (
compute_error_auc,
)
from batdetect2.evaluate.match import ( from batdetect2.evaluate.match import (
match_predictions_and_annotations, match_predictions_and_annotations,
match_sound_events_and_raw_predictions, match_sound_events_and_raw_predictions,
) )
__all__ = [ __all__ = [
"compute_error_auc",
"match_predictions_and_annotations",
"match_sound_events_and_raw_predictions", "match_sound_events_and_raw_predictions",
"match_predictions_and_annotations",
] ]

View File

@ -1,54 +1,133 @@
from typing import List from typing import Annotated, List, Literal, Optional, Union
from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_geometries from soundevent.evaluation import match_geometries
from soundevent.geometry import compute_bounds
from batdetect2.evaluate.types import Match from batdetect2.configs import BaseConfig
from batdetect2.postprocess.types import RawPrediction from batdetect2.evaluate.types import MatchEvaluation
from batdetect2.postprocess.types import BatDetect2Prediction
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
from batdetect2.utils.arrays import iterate_over_array from batdetect2.utils.arrays import iterate_over_array
class BBoxMatchConfig(BaseConfig):
match_method: Literal["BBoxIOU"] = "BBoxIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
frequency_buffer: float = 1_000
class IntervalMatchConfig(BaseConfig):
match_method: Literal["IntervalIOU"] = "IntervalIOU"
affinity_threshold: float = 0.5
time_buffer: float = 0.01
class StartTimeMatchConfig(BaseConfig):
match_method: Literal["StartTime"] = "StartTime"
time_buffer: float = 0.01
MatchConfig = Annotated[
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
Field(discriminator="match_method"),
]
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
def prepare_geometry(
geometry: data.Geometry, config: MatchConfig
) -> data.Geometry:
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
if config.match_method == "BBoxIOU":
return data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if config.match_method == "IntervalIOU":
return data.TimeInterval(coordinates=[start_time, end_time])
if config.match_method == "StartTime":
return data.TimeStamp(coordinates=start_time)
raise NotImplementedError(
f"Invalid matching configuration. Unknown match method: {config.match_method}"
)
def _get_frequency_buffer(config: MatchConfig) -> float:
if config.match_method == "BBoxIOU":
return config.frequency_buffer
return 0
def _get_affinity_threshold(config: MatchConfig) -> float:
if (
config.match_method == "BBoxIOU"
or config.match_method == "IntervalIOU"
):
return config.affinity_threshold
return 0
def match_sound_events_and_raw_predictions( def match_sound_events_and_raw_predictions(
sound_events: List[data.SoundEventAnnotation], clip_annotation: data.ClipAnnotation,
raw_predictions: List[RawPrediction], raw_predictions: List[BatDetect2Prediction],
targets: TargetProtocol, targets: TargetProtocol,
) -> List[Match]: config: Optional[MatchConfig] = None,
) -> List[MatchEvaluation]:
config = config or DEFAULT_MATCH_CONFIG
target_sound_events = [ target_sound_events = [
targets.transform(sound_event_annotation) targets.transform(sound_event_annotation)
for sound_event_annotation in sound_events for sound_event_annotation in clip_annotation.sound_events
if targets.filter(sound_event_annotation) if targets.filter(sound_event_annotation)
and sound_event_annotation.sound_event.geometry is not None and sound_event_annotation.sound_event.geometry is not None
] ]
target_geometries: List[data.Geometry] = [ # type: ignore target_geometries: List[data.Geometry] = [ # type: ignore
sound_event_annotation.sound_event.geometry prepare_geometry(
sound_event_annotation.sound_event.geometry,
config=config,
)
for sound_event_annotation in target_sound_events for sound_event_annotation in target_sound_events
if sound_event_annotation.sound_event.geometry is not None
] ]
predicted_geometries = [ predicted_geometries = [
raw_prediction.geometry for raw_prediction in raw_predictions prepare_geometry(raw_prediction.raw.geometry, config=config)
for raw_prediction in raw_predictions
] ]
matches = [] matches = []
for id1, id2, affinity in match_geometries( for id1, id2, affinity in match_geometries(
target_geometries, target_geometries,
predicted_geometries, predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
): ):
target = target_sound_events[id1] if id1 is not None else None target = target_sound_events[id1] if id1 is not None else None
prediction = raw_predictions[id2] if id2 is not None else None prediction = raw_predictions[id2] if id2 is not None else None
gt_uuid = target.uuid if target is not None else None
gt_det = target is not None gt_det = target is not None
gt_class = targets.encode_class(target) if target is not None else None gt_class = targets.encode_class(target) if target is not None else None
pred_score = float(prediction.detection_score) if prediction else 0 pred_score = float(prediction.raw.detection_score) if prediction else 0
class_scores = ( class_scores = (
{ {
str(class_name): float(score) str(class_name): float(score)
for class_name, score in iterate_over_array( for class_name, score in iterate_over_array(
prediction.class_scores prediction.raw.class_scores
) )
} }
if prediction is not None if prediction is not None
@ -56,13 +135,18 @@ def match_sound_events_and_raw_predictions(
) )
matches.append( matches.append(
Match( MatchEvaluation(
gt_uuid=gt_uuid, match=data.Match(
source=None
if prediction is None
else prediction.sound_event_prediction,
target=target,
affinity=affinity,
),
gt_det=gt_det, gt_det=gt_det,
gt_class=gt_class, gt_class=gt_class,
pred_score=pred_score, pred_score=pred_score,
affinity=affinity, pred_class_scores=class_scores,
class_scores=class_scores,
) )
) )
@ -72,7 +156,10 @@ def match_sound_events_and_raw_predictions(
def match_predictions_and_annotations( def match_predictions_and_annotations(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
clip_prediction: data.ClipPrediction, clip_prediction: data.ClipPrediction,
config: Optional[MatchConfig] = None,
) -> List[data.Match]: ) -> List[data.Match]:
config = config or DEFAULT_MATCH_CONFIG
annotated_sound_events = [ annotated_sound_events = [
sound_event_annotation sound_event_annotation
for sound_event_annotation in clip_annotation.sound_events for sound_event_annotation in clip_annotation.sound_events
@ -86,13 +173,13 @@ def match_predictions_and_annotations(
] ]
annotated_geometries: List[data.Geometry] = [ annotated_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in annotated_sound_events for sound_event in annotated_sound_events
if sound_event.sound_event.geometry is not None if sound_event.sound_event.geometry is not None
] ]
predicted_geometries: List[data.Geometry] = [ predicted_geometries: List[data.Geometry] = [
sound_event.sound_event.geometry prepare_geometry(sound_event.sound_event.geometry, config=config)
for sound_event in predicted_sound_events for sound_event in predicted_sound_events
if sound_event.sound_event.geometry is not None if sound_event.sound_event.geometry is not None
] ]
@ -101,6 +188,9 @@ def match_predictions_and_annotations(
for id1, id2, affinity in match_geometries( for id1, id2, affinity in match_geometries(
annotated_geometries, annotated_geometries,
predicted_geometries, predicted_geometries,
time_buffer=config.time_buffer,
freq_buffer=_get_frequency_buffer(config),
affinity_threshold=_get_affinity_threshold(config),
): ):
target = annotated_sound_events[id1] if id1 is not None else None target = annotated_sound_events[id1] if id1 is not None else None
source = predicted_sound_events[id2] if id2 is not None else None source = predicted_sound_events[id2] if id2 is not None else None

View File

@ -4,13 +4,13 @@ import pandas as pd
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
from batdetect2.evaluate.types import Match, MetricsProtocol from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
__all__ = ["DetectionAveragePrecision"] __all__ = ["DetectionAveragePrecision"]
class DetectionAveragePrecision(MetricsProtocol): class DetectionAveragePrecision(MetricsProtocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]: def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true, y_score = zip( y_true, y_score = zip(
*[(match.gt_det, match.pred_score) for match in matches] *[(match.gt_det, match.pred_score) for match in matches]
) )
@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
self.class_names = class_names self.class_names = class_names
self.per_class = per_class self.per_class = per_class
def __call__(self, matches: List[Match]) -> Dict[str, float]: def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = label_binarize( y_true = label_binarize(
[ [
match.gt_class if match.gt_class is not None else "__NONE__" match.gt_class if match.gt_class is not None else "__NONE__"
@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
y_pred = pd.DataFrame( y_pred = pd.DataFrame(
[ [
{ {
name: match.class_scores.get(name, 0) name: match.pred_class_scores.get(name, 0)
for name in self.class_names for name in self.class_names
} }
for match in matches for match in matches
@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol):
def __init__(self, class_names: List[str]): def __init__(self, class_names: List[str]):
self.class_names = class_names self.class_names = class_names
def __call__(self, matches: List[Match]) -> Dict[str, float]: def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
y_true = [ y_true = [
match.gt_class if match.gt_class is not None else "__NONE__" match.gt_class if match.gt_class is not None else "__NONE__"
for match in matches for match in matches
@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol):
y_pred = pd.DataFrame( y_pred = pd.DataFrame(
[ [
{ {
name: match.class_scores.get(name, 0) name: match.pred_class_scores.get(name, 0)
for name in self.class_names for name in self.class_names
} }
for match in matches for match in matches

View File

@ -1,22 +1,40 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, List, Optional, Protocol from typing import Dict, List, Optional, Protocol
from uuid import UUID
from soundevent import data
__all__ = [ __all__ = [
"MetricsProtocol", "MetricsProtocol",
"Match", "MatchEvaluation",
] ]
@dataclass @dataclass
class Match: class MatchEvaluation:
gt_uuid: Optional[UUID] match: data.Match
gt_det: bool gt_det: bool
gt_class: Optional[str] gt_class: Optional[str]
pred_score: float pred_score: float
affinity: float pred_class_scores: Dict[str, float]
class_scores: Dict[str, float]
@property
def pred_class(self) -> Optional[str]:
if not self.pred_class_scores:
return None
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
@property
def pred_class_score(self) -> float:
pred_class = self.pred_class
if pred_class is None:
return 0
return self.pred_class_scores[pred_class]
class MetricsProtocol(Protocol): class MetricsProtocol(Protocol):
def __call__(self, matches: List[Match]) -> Dict[str, float]: ... def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...