mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Added matching configs
This commit is contained in:
parent
a485ea4f79
commit
6213238585
@ -1,13 +1,9 @@
|
||||
from batdetect2.evaluate.evaluate import (
|
||||
compute_error_auc,
|
||||
)
|
||||
from batdetect2.evaluate.match import (
|
||||
match_predictions_and_annotations,
|
||||
match_sound_events_and_raw_predictions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"compute_error_auc",
|
||||
"match_predictions_and_annotations",
|
||||
"match_sound_events_and_raw_predictions",
|
||||
"match_predictions_and_annotations",
|
||||
]
|
||||
|
||||
@ -1,54 +1,133 @@
|
||||
from typing import List
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_geometries
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.evaluate.types import Match
|
||||
from batdetect2.postprocess.types import RawPrediction
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.types import MatchEvaluation
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.utils.arrays import iterate_over_array
|
||||
|
||||
|
||||
class BBoxMatchConfig(BaseConfig):
|
||||
match_method: Literal["BBoxIOU"] = "BBoxIOU"
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.01
|
||||
frequency_buffer: float = 1_000
|
||||
|
||||
|
||||
class IntervalMatchConfig(BaseConfig):
|
||||
match_method: Literal["IntervalIOU"] = "IntervalIOU"
|
||||
affinity_threshold: float = 0.5
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
class StartTimeMatchConfig(BaseConfig):
|
||||
match_method: Literal["StartTime"] = "StartTime"
|
||||
time_buffer: float = 0.01
|
||||
|
||||
|
||||
MatchConfig = Annotated[
|
||||
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
|
||||
Field(discriminator="match_method"),
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
|
||||
|
||||
|
||||
def prepare_geometry(
|
||||
geometry: data.Geometry, config: MatchConfig
|
||||
) -> data.Geometry:
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
|
||||
if config.match_method == "BBoxIOU":
|
||||
return data.BoundingBox(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
if config.match_method == "IntervalIOU":
|
||||
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||
|
||||
if config.match_method == "StartTime":
|
||||
return data.TimeStamp(coordinates=start_time)
|
||||
|
||||
raise NotImplementedError(
|
||||
f"Invalid matching configuration. Unknown match method: {config.match_method}"
|
||||
)
|
||||
|
||||
|
||||
def _get_frequency_buffer(config: MatchConfig) -> float:
|
||||
if config.match_method == "BBoxIOU":
|
||||
return config.frequency_buffer
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _get_affinity_threshold(config: MatchConfig) -> float:
|
||||
if (
|
||||
config.match_method == "BBoxIOU"
|
||||
or config.match_method == "IntervalIOU"
|
||||
):
|
||||
return config.affinity_threshold
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def match_sound_events_and_raw_predictions(
|
||||
sound_events: List[data.SoundEventAnnotation],
|
||||
raw_predictions: List[RawPrediction],
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
raw_predictions: List[BatDetect2Prediction],
|
||||
targets: TargetProtocol,
|
||||
) -> List[Match]:
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
config = config or DEFAULT_MATCH_CONFIG
|
||||
|
||||
target_sound_events = [
|
||||
targets.transform(sound_event_annotation)
|
||||
for sound_event_annotation in sound_events
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
if targets.filter(sound_event_annotation)
|
||||
and sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||
sound_event_annotation.sound_event.geometry
|
||||
prepare_geometry(
|
||||
sound_event_annotation.sound_event.geometry,
|
||||
config=config,
|
||||
)
|
||||
for sound_event_annotation in target_sound_events
|
||||
if sound_event_annotation.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries = [
|
||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
||||
prepare_geometry(raw_prediction.raw.geometry, config=config)
|
||||
for raw_prediction in raw_predictions
|
||||
]
|
||||
|
||||
matches = []
|
||||
|
||||
for id1, id2, affinity in match_geometries(
|
||||
target_geometries,
|
||||
predicted_geometries,
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=_get_frequency_buffer(config),
|
||||
affinity_threshold=_get_affinity_threshold(config),
|
||||
):
|
||||
target = target_sound_events[id1] if id1 is not None else None
|
||||
prediction = raw_predictions[id2] if id2 is not None else None
|
||||
|
||||
gt_uuid = target.uuid if target is not None else None
|
||||
gt_det = target is not None
|
||||
gt_class = targets.encode_class(target) if target is not None else None
|
||||
|
||||
pred_score = float(prediction.detection_score) if prediction else 0
|
||||
pred_score = float(prediction.raw.detection_score) if prediction else 0
|
||||
|
||||
class_scores = (
|
||||
{
|
||||
str(class_name): float(score)
|
||||
for class_name, score in iterate_over_array(
|
||||
prediction.class_scores
|
||||
prediction.raw.class_scores
|
||||
)
|
||||
}
|
||||
if prediction is not None
|
||||
@ -56,13 +135,18 @@ def match_sound_events_and_raw_predictions(
|
||||
)
|
||||
|
||||
matches.append(
|
||||
Match(
|
||||
gt_uuid=gt_uuid,
|
||||
MatchEvaluation(
|
||||
match=data.Match(
|
||||
source=None
|
||||
if prediction is None
|
||||
else prediction.sound_event_prediction,
|
||||
target=target,
|
||||
affinity=affinity,
|
||||
),
|
||||
gt_det=gt_det,
|
||||
gt_class=gt_class,
|
||||
pred_score=pred_score,
|
||||
affinity=affinity,
|
||||
class_scores=class_scores,
|
||||
pred_class_scores=class_scores,
|
||||
)
|
||||
)
|
||||
|
||||
@ -72,7 +156,10 @@ def match_sound_events_and_raw_predictions(
|
||||
def match_predictions_and_annotations(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
clip_prediction: data.ClipPrediction,
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[data.Match]:
|
||||
config = config or DEFAULT_MATCH_CONFIG
|
||||
|
||||
annotated_sound_events = [
|
||||
sound_event_annotation
|
||||
for sound_event_annotation in clip_annotation.sound_events
|
||||
@ -86,13 +173,13 @@ def match_predictions_and_annotations(
|
||||
]
|
||||
|
||||
annotated_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||
for sound_event in annotated_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
|
||||
predicted_geometries: List[data.Geometry] = [
|
||||
sound_event.sound_event.geometry
|
||||
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||
for sound_event in predicted_sound_events
|
||||
if sound_event.sound_event.geometry is not None
|
||||
]
|
||||
@ -101,6 +188,9 @@ def match_predictions_and_annotations(
|
||||
for id1, id2, affinity in match_geometries(
|
||||
annotated_geometries,
|
||||
predicted_geometries,
|
||||
time_buffer=config.time_buffer,
|
||||
freq_buffer=_get_frequency_buffer(config),
|
||||
affinity_threshold=_get_affinity_threshold(config),
|
||||
):
|
||||
target = annotated_sound_events[id1] if id1 is not None else None
|
||||
source = predicted_sound_events[id2] if id2 is not None else None
|
||||
|
||||
@ -4,13 +4,13 @@ import pandas as pd
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
|
||||
__all__ = ["DetectionAveragePrecision"]
|
||||
|
||||
|
||||
class DetectionAveragePrecision(MetricsProtocol):
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true, y_score = zip(
|
||||
*[(match.gt_det, match.pred_score) for match in matches]
|
||||
)
|
||||
@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
self.class_names = class_names
|
||||
self.per_class = per_class
|
||||
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true = label_binarize(
|
||||
[
|
||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||
@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
||||
y_pred = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
name: match.class_scores.get(name, 0)
|
||||
name: match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
}
|
||||
for match in matches
|
||||
@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol):
|
||||
def __init__(self, class_names: List[str]):
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||
y_true = [
|
||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||
for match in matches
|
||||
@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol):
|
||||
y_pred = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
name: match.class_scores.get(name, 0)
|
||||
name: match.pred_class_scores.get(name, 0)
|
||||
for name in self.class_names
|
||||
}
|
||||
for match in matches
|
||||
|
||||
@ -1,22 +1,40 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Protocol
|
||||
from uuid import UUID
|
||||
|
||||
from soundevent import data
|
||||
|
||||
__all__ = [
|
||||
"MetricsProtocol",
|
||||
"Match",
|
||||
"MatchEvaluation",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
gt_uuid: Optional[UUID]
|
||||
class MatchEvaluation:
|
||||
match: data.Match
|
||||
|
||||
gt_det: bool
|
||||
gt_class: Optional[str]
|
||||
|
||||
pred_score: float
|
||||
affinity: float
|
||||
class_scores: Dict[str, float]
|
||||
pred_class_scores: Dict[str, float]
|
||||
|
||||
@property
|
||||
def pred_class(self) -> Optional[str]:
|
||||
if not self.pred_class_scores:
|
||||
return None
|
||||
|
||||
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||
|
||||
@property
|
||||
def pred_class_score(self) -> float:
|
||||
pred_class = self.pred_class
|
||||
|
||||
if pred_class is None:
|
||||
return 0
|
||||
|
||||
return self.pred_class_scores[pred_class]
|
||||
|
||||
|
||||
class MetricsProtocol(Protocol):
|
||||
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...
|
||||
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user