From 6213238585284e8662313ea304dd80fcc6a6380f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 8 Aug 2025 13:05:50 +0100 Subject: [PATCH] Added matching configs --- src/batdetect2/evaluate/__init__.py | 6 +- src/batdetect2/evaluate/match.py | 126 ++++++++++++++++++++++++---- src/batdetect2/evaluate/metrics.py | 12 +-- src/batdetect2/evaluate/types.py | 32 +++++-- 4 files changed, 140 insertions(+), 36 deletions(-) diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index d9235df..bf7c41d 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,13 +1,9 @@ -from batdetect2.evaluate.evaluate import ( - compute_error_auc, -) from batdetect2.evaluate.match import ( match_predictions_and_annotations, match_sound_events_and_raw_predictions, ) __all__ = [ - "compute_error_auc", - "match_predictions_and_annotations", "match_sound_events_and_raw_predictions", + "match_predictions_and_annotations", ] diff --git a/src/batdetect2/evaluate/match.py b/src/batdetect2/evaluate/match.py index b20b361..5ed4062 100644 --- a/src/batdetect2/evaluate/match.py +++ b/src/batdetect2/evaluate/match.py @@ -1,54 +1,133 @@ -from typing import List +from typing import Annotated, List, Literal, Optional, Union +from pydantic import Field from soundevent import data from soundevent.evaluation import match_geometries +from soundevent.geometry import compute_bounds -from batdetect2.evaluate.types import Match -from batdetect2.postprocess.types import RawPrediction +from batdetect2.configs import BaseConfig +from batdetect2.evaluate.types import MatchEvaluation +from batdetect2.postprocess.types import BatDetect2Prediction from batdetect2.targets.types import TargetProtocol from batdetect2.utils.arrays import iterate_over_array +class BBoxMatchConfig(BaseConfig): + match_method: Literal["BBoxIOU"] = "BBoxIOU" + affinity_threshold: float = 0.5 + time_buffer: float = 0.01 + frequency_buffer: float = 1_000 + + +class IntervalMatchConfig(BaseConfig): + match_method: Literal["IntervalIOU"] = "IntervalIOU" + affinity_threshold: float = 0.5 + time_buffer: float = 0.01 + + +class StartTimeMatchConfig(BaseConfig): + match_method: Literal["StartTime"] = "StartTime" + time_buffer: float = 0.01 + + +MatchConfig = Annotated[ + Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig], + Field(discriminator="match_method"), +] + + +DEFAULT_MATCH_CONFIG = BBoxMatchConfig() + + +def prepare_geometry( + geometry: data.Geometry, config: MatchConfig +) -> data.Geometry: + start_time, low_freq, end_time, high_freq = compute_bounds(geometry) + + if config.match_method == "BBoxIOU": + return data.BoundingBox( + coordinates=[start_time, low_freq, end_time, high_freq] + ) + + if config.match_method == "IntervalIOU": + return data.TimeInterval(coordinates=[start_time, end_time]) + + if config.match_method == "StartTime": + return data.TimeStamp(coordinates=start_time) + + raise NotImplementedError( + f"Invalid matching configuration. Unknown match method: {config.match_method}" + ) + + +def _get_frequency_buffer(config: MatchConfig) -> float: + if config.match_method == "BBoxIOU": + return config.frequency_buffer + + return 0 + + +def _get_affinity_threshold(config: MatchConfig) -> float: + if ( + config.match_method == "BBoxIOU" + or config.match_method == "IntervalIOU" + ): + return config.affinity_threshold + + return 0 + + def match_sound_events_and_raw_predictions( - sound_events: List[data.SoundEventAnnotation], - raw_predictions: List[RawPrediction], + clip_annotation: data.ClipAnnotation, + raw_predictions: List[BatDetect2Prediction], targets: TargetProtocol, -) -> List[Match]: + config: Optional[MatchConfig] = None, +) -> List[MatchEvaluation]: + config = config or DEFAULT_MATCH_CONFIG + target_sound_events = [ targets.transform(sound_event_annotation) - for sound_event_annotation in sound_events + for sound_event_annotation in clip_annotation.sound_events if targets.filter(sound_event_annotation) and sound_event_annotation.sound_event.geometry is not None ] target_geometries: List[data.Geometry] = [ # type: ignore - sound_event_annotation.sound_event.geometry + prepare_geometry( + sound_event_annotation.sound_event.geometry, + config=config, + ) for sound_event_annotation in target_sound_events + if sound_event_annotation.sound_event.geometry is not None ] predicted_geometries = [ - raw_prediction.geometry for raw_prediction in raw_predictions + prepare_geometry(raw_prediction.raw.geometry, config=config) + for raw_prediction in raw_predictions ] matches = [] + for id1, id2, affinity in match_geometries( target_geometries, predicted_geometries, + time_buffer=config.time_buffer, + freq_buffer=_get_frequency_buffer(config), + affinity_threshold=_get_affinity_threshold(config), ): target = target_sound_events[id1] if id1 is not None else None prediction = raw_predictions[id2] if id2 is not None else None - gt_uuid = target.uuid if target is not None else None gt_det = target is not None gt_class = targets.encode_class(target) if target is not None else None - pred_score = float(prediction.detection_score) if prediction else 0 + pred_score = float(prediction.raw.detection_score) if prediction else 0 class_scores = ( { str(class_name): float(score) for class_name, score in iterate_over_array( - prediction.class_scores + prediction.raw.class_scores ) } if prediction is not None @@ -56,13 +135,18 @@ def match_sound_events_and_raw_predictions( ) matches.append( - Match( - gt_uuid=gt_uuid, + MatchEvaluation( + match=data.Match( + source=None + if prediction is None + else prediction.sound_event_prediction, + target=target, + affinity=affinity, + ), gt_det=gt_det, gt_class=gt_class, pred_score=pred_score, - affinity=affinity, - class_scores=class_scores, + pred_class_scores=class_scores, ) ) @@ -72,7 +156,10 @@ def match_sound_events_and_raw_predictions( def match_predictions_and_annotations( clip_annotation: data.ClipAnnotation, clip_prediction: data.ClipPrediction, + config: Optional[MatchConfig] = None, ) -> List[data.Match]: + config = config or DEFAULT_MATCH_CONFIG + annotated_sound_events = [ sound_event_annotation for sound_event_annotation in clip_annotation.sound_events @@ -86,13 +173,13 @@ def match_predictions_and_annotations( ] annotated_geometries: List[data.Geometry] = [ - sound_event.sound_event.geometry + prepare_geometry(sound_event.sound_event.geometry, config=config) for sound_event in annotated_sound_events if sound_event.sound_event.geometry is not None ] predicted_geometries: List[data.Geometry] = [ - sound_event.sound_event.geometry + prepare_geometry(sound_event.sound_event.geometry, config=config) for sound_event in predicted_sound_events if sound_event.sound_event.geometry is not None ] @@ -101,6 +188,9 @@ def match_predictions_and_annotations( for id1, id2, affinity in match_geometries( annotated_geometries, predicted_geometries, + time_buffer=config.time_buffer, + freq_buffer=_get_frequency_buffer(config), + affinity_threshold=_get_affinity_threshold(config), ): target = annotated_sound_events[id1] if id1 is not None else None source = predicted_sound_events[id2] if id2 is not None else None diff --git a/src/batdetect2/evaluate/metrics.py b/src/batdetect2/evaluate/metrics.py index c1bc924..7b1c933 100644 --- a/src/batdetect2/evaluate/metrics.py +++ b/src/batdetect2/evaluate/metrics.py @@ -4,13 +4,13 @@ import pandas as pd from sklearn import metrics from sklearn.preprocessing import label_binarize -from batdetect2.evaluate.types import Match, MetricsProtocol +from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol __all__ = ["DetectionAveragePrecision"] class DetectionAveragePrecision(MetricsProtocol): - def __call__(self, matches: List[Match]) -> Dict[str, float]: + def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: y_true, y_score = zip( *[(match.gt_det, match.pred_score) for match in matches] ) @@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): self.class_names = class_names self.per_class = per_class - def __call__(self, matches: List[Match]) -> Dict[str, float]: + def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: y_true = label_binarize( [ match.gt_class if match.gt_class is not None else "__NONE__" @@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol): y_pred = pd.DataFrame( [ { - name: match.class_scores.get(name, 0) + name: match.pred_class_scores.get(name, 0) for name in self.class_names } for match in matches @@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol): def __init__(self, class_names: List[str]): self.class_names = class_names - def __call__(self, matches: List[Match]) -> Dict[str, float]: + def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: y_true = [ match.gt_class if match.gt_class is not None else "__NONE__" for match in matches @@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol): y_pred = pd.DataFrame( [ { - name: match.class_scores.get(name, 0) + name: match.pred_class_scores.get(name, 0) for name in self.class_names } for match in matches diff --git a/src/batdetect2/evaluate/types.py b/src/batdetect2/evaluate/types.py index 76e39a6..2ef9206 100644 --- a/src/batdetect2/evaluate/types.py +++ b/src/batdetect2/evaluate/types.py @@ -1,22 +1,40 @@ from dataclasses import dataclass from typing import Dict, List, Optional, Protocol -from uuid import UUID + +from soundevent import data __all__ = [ "MetricsProtocol", - "Match", + "MatchEvaluation", ] @dataclass -class Match: - gt_uuid: Optional[UUID] +class MatchEvaluation: + match: data.Match + gt_det: bool gt_class: Optional[str] + pred_score: float - affinity: float - class_scores: Dict[str, float] + pred_class_scores: Dict[str, float] + + @property + def pred_class(self) -> Optional[str]: + if not self.pred_class_scores: + return None + + return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore + + @property + def pred_class_score(self) -> float: + pred_class = self.pred_class + + if pred_class is None: + return 0 + + return self.pred_class_scores[pred_class] class MetricsProtocol(Protocol): - def __call__(self, matches: List[Match]) -> Dict[str, float]: ... + def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...