mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Added matching configs
This commit is contained in:
parent
a485ea4f79
commit
6213238585
@ -1,13 +1,9 @@
|
|||||||
from batdetect2.evaluate.evaluate import (
|
|
||||||
compute_error_auc,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.match import (
|
from batdetect2.evaluate.match import (
|
||||||
match_predictions_and_annotations,
|
match_predictions_and_annotations,
|
||||||
match_sound_events_and_raw_predictions,
|
match_sound_events_and_raw_predictions,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"compute_error_auc",
|
|
||||||
"match_predictions_and_annotations",
|
|
||||||
"match_sound_events_and_raw_predictions",
|
"match_sound_events_and_raw_predictions",
|
||||||
|
"match_predictions_and_annotations",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,54 +1,133 @@
|
|||||||
from typing import List
|
from typing import Annotated, List, Literal, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_geometries
|
from soundevent.evaluation import match_geometries
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.evaluate.types import Match
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.postprocess.types import RawPrediction
|
from batdetect2.evaluate.types import MatchEvaluation
|
||||||
|
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.utils.arrays import iterate_over_array
|
from batdetect2.utils.arrays import iterate_over_array
|
||||||
|
|
||||||
|
|
||||||
|
class BBoxMatchConfig(BaseConfig):
|
||||||
|
match_method: Literal["BBoxIOU"] = "BBoxIOU"
|
||||||
|
affinity_threshold: float = 0.5
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
frequency_buffer: float = 1_000
|
||||||
|
|
||||||
|
|
||||||
|
class IntervalMatchConfig(BaseConfig):
|
||||||
|
match_method: Literal["IntervalIOU"] = "IntervalIOU"
|
||||||
|
affinity_threshold: float = 0.5
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
class StartTimeMatchConfig(BaseConfig):
|
||||||
|
match_method: Literal["StartTime"] = "StartTime"
|
||||||
|
time_buffer: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
|
MatchConfig = Annotated[
|
||||||
|
Union[BBoxMatchConfig, IntervalMatchConfig, StartTimeMatchConfig],
|
||||||
|
Field(discriminator="match_method"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_MATCH_CONFIG = BBoxMatchConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_geometry(
|
||||||
|
geometry: data.Geometry, config: MatchConfig
|
||||||
|
) -> data.Geometry:
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
|
if config.match_method == "BBoxIOU":
|
||||||
|
return data.BoundingBox(
|
||||||
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.match_method == "IntervalIOU":
|
||||||
|
return data.TimeInterval(coordinates=[start_time, end_time])
|
||||||
|
|
||||||
|
if config.match_method == "StartTime":
|
||||||
|
return data.TimeStamp(coordinates=start_time)
|
||||||
|
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Invalid matching configuration. Unknown match method: {config.match_method}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_frequency_buffer(config: MatchConfig) -> float:
|
||||||
|
if config.match_method == "BBoxIOU":
|
||||||
|
return config.frequency_buffer
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def _get_affinity_threshold(config: MatchConfig) -> float:
|
||||||
|
if (
|
||||||
|
config.match_method == "BBoxIOU"
|
||||||
|
or config.match_method == "IntervalIOU"
|
||||||
|
):
|
||||||
|
return config.affinity_threshold
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
def match_sound_events_and_raw_predictions(
|
def match_sound_events_and_raw_predictions(
|
||||||
sound_events: List[data.SoundEventAnnotation],
|
clip_annotation: data.ClipAnnotation,
|
||||||
raw_predictions: List[RawPrediction],
|
raw_predictions: List[BatDetect2Prediction],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> List[Match]:
|
config: Optional[MatchConfig] = None,
|
||||||
|
) -> List[MatchEvaluation]:
|
||||||
|
config = config or DEFAULT_MATCH_CONFIG
|
||||||
|
|
||||||
target_sound_events = [
|
target_sound_events = [
|
||||||
targets.transform(sound_event_annotation)
|
targets.transform(sound_event_annotation)
|
||||||
for sound_event_annotation in sound_events
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
if targets.filter(sound_event_annotation)
|
if targets.filter(sound_event_annotation)
|
||||||
and sound_event_annotation.sound_event.geometry is not None
|
and sound_event_annotation.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
target_geometries: List[data.Geometry] = [ # type: ignore
|
target_geometries: List[data.Geometry] = [ # type: ignore
|
||||||
sound_event_annotation.sound_event.geometry
|
prepare_geometry(
|
||||||
|
sound_event_annotation.sound_event.geometry,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
for sound_event_annotation in target_sound_events
|
for sound_event_annotation in target_sound_events
|
||||||
|
if sound_event_annotation.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_geometries = [
|
predicted_geometries = [
|
||||||
raw_prediction.geometry for raw_prediction in raw_predictions
|
prepare_geometry(raw_prediction.raw.geometry, config=config)
|
||||||
|
for raw_prediction in raw_predictions
|
||||||
]
|
]
|
||||||
|
|
||||||
matches = []
|
matches = []
|
||||||
|
|
||||||
for id1, id2, affinity in match_geometries(
|
for id1, id2, affinity in match_geometries(
|
||||||
target_geometries,
|
target_geometries,
|
||||||
predicted_geometries,
|
predicted_geometries,
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=_get_frequency_buffer(config),
|
||||||
|
affinity_threshold=_get_affinity_threshold(config),
|
||||||
):
|
):
|
||||||
target = target_sound_events[id1] if id1 is not None else None
|
target = target_sound_events[id1] if id1 is not None else None
|
||||||
prediction = raw_predictions[id2] if id2 is not None else None
|
prediction = raw_predictions[id2] if id2 is not None else None
|
||||||
|
|
||||||
gt_uuid = target.uuid if target is not None else None
|
|
||||||
gt_det = target is not None
|
gt_det = target is not None
|
||||||
gt_class = targets.encode_class(target) if target is not None else None
|
gt_class = targets.encode_class(target) if target is not None else None
|
||||||
|
|
||||||
pred_score = float(prediction.detection_score) if prediction else 0
|
pred_score = float(prediction.raw.detection_score) if prediction else 0
|
||||||
|
|
||||||
class_scores = (
|
class_scores = (
|
||||||
{
|
{
|
||||||
str(class_name): float(score)
|
str(class_name): float(score)
|
||||||
for class_name, score in iterate_over_array(
|
for class_name, score in iterate_over_array(
|
||||||
prediction.class_scores
|
prediction.raw.class_scores
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
if prediction is not None
|
if prediction is not None
|
||||||
@ -56,13 +135,18 @@ def match_sound_events_and_raw_predictions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
matches.append(
|
matches.append(
|
||||||
Match(
|
MatchEvaluation(
|
||||||
gt_uuid=gt_uuid,
|
match=data.Match(
|
||||||
|
source=None
|
||||||
|
if prediction is None
|
||||||
|
else prediction.sound_event_prediction,
|
||||||
|
target=target,
|
||||||
|
affinity=affinity,
|
||||||
|
),
|
||||||
gt_det=gt_det,
|
gt_det=gt_det,
|
||||||
gt_class=gt_class,
|
gt_class=gt_class,
|
||||||
pred_score=pred_score,
|
pred_score=pred_score,
|
||||||
affinity=affinity,
|
pred_class_scores=class_scores,
|
||||||
class_scores=class_scores,
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -72,7 +156,10 @@ def match_sound_events_and_raw_predictions(
|
|||||||
def match_predictions_and_annotations(
|
def match_predictions_and_annotations(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
clip_prediction: data.ClipPrediction,
|
clip_prediction: data.ClipPrediction,
|
||||||
|
config: Optional[MatchConfig] = None,
|
||||||
) -> List[data.Match]:
|
) -> List[data.Match]:
|
||||||
|
config = config or DEFAULT_MATCH_CONFIG
|
||||||
|
|
||||||
annotated_sound_events = [
|
annotated_sound_events = [
|
||||||
sound_event_annotation
|
sound_event_annotation
|
||||||
for sound_event_annotation in clip_annotation.sound_events
|
for sound_event_annotation in clip_annotation.sound_events
|
||||||
@ -86,13 +173,13 @@ def match_predictions_and_annotations(
|
|||||||
]
|
]
|
||||||
|
|
||||||
annotated_geometries: List[data.Geometry] = [
|
annotated_geometries: List[data.Geometry] = [
|
||||||
sound_event.sound_event.geometry
|
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||||
for sound_event in annotated_sound_events
|
for sound_event in annotated_sound_events
|
||||||
if sound_event.sound_event.geometry is not None
|
if sound_event.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
predicted_geometries: List[data.Geometry] = [
|
predicted_geometries: List[data.Geometry] = [
|
||||||
sound_event.sound_event.geometry
|
prepare_geometry(sound_event.sound_event.geometry, config=config)
|
||||||
for sound_event in predicted_sound_events
|
for sound_event in predicted_sound_events
|
||||||
if sound_event.sound_event.geometry is not None
|
if sound_event.sound_event.geometry is not None
|
||||||
]
|
]
|
||||||
@ -101,6 +188,9 @@ def match_predictions_and_annotations(
|
|||||||
for id1, id2, affinity in match_geometries(
|
for id1, id2, affinity in match_geometries(
|
||||||
annotated_geometries,
|
annotated_geometries,
|
||||||
predicted_geometries,
|
predicted_geometries,
|
||||||
|
time_buffer=config.time_buffer,
|
||||||
|
freq_buffer=_get_frequency_buffer(config),
|
||||||
|
affinity_threshold=_get_affinity_threshold(config),
|
||||||
):
|
):
|
||||||
target = annotated_sound_events[id1] if id1 is not None else None
|
target = annotated_sound_events[id1] if id1 is not None else None
|
||||||
source = predicted_sound_events[id2] if id2 is not None else None
|
source = predicted_sound_events[id2] if id2 is not None else None
|
||||||
|
|||||||
@ -4,13 +4,13 @@ import pandas as pd
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from sklearn.preprocessing import label_binarize
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
from batdetect2.evaluate.types import Match, MetricsProtocol
|
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||||
|
|
||||||
__all__ = ["DetectionAveragePrecision"]
|
__all__ = ["DetectionAveragePrecision"]
|
||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecision(MetricsProtocol):
|
class DetectionAveragePrecision(MetricsProtocol):
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
y_true, y_score = zip(
|
y_true, y_score = zip(
|
||||||
*[(match.gt_det, match.pred_score) for match in matches]
|
*[(match.gt_det, match.pred_score) for match in matches]
|
||||||
)
|
)
|
||||||
@ -23,7 +23,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
self.per_class = per_class
|
self.per_class = per_class
|
||||||
|
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
y_true = label_binarize(
|
y_true = label_binarize(
|
||||||
[
|
[
|
||||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
@ -34,7 +34,7 @@ class ClassificationMeanAveragePrecision(MetricsProtocol):
|
|||||||
y_pred = pd.DataFrame(
|
y_pred = pd.DataFrame(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
name: match.class_scores.get(name, 0)
|
name: match.pred_class_scores.get(name, 0)
|
||||||
for name in self.class_names
|
for name in self.class_names
|
||||||
}
|
}
|
||||||
for match in matches
|
for match in matches
|
||||||
@ -65,7 +65,7 @@ class ClassificationAccuracy(MetricsProtocol):
|
|||||||
def __init__(self, class_names: List[str]):
|
def __init__(self, class_names: List[str]):
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
|
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]:
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]:
|
||||||
y_true = [
|
y_true = [
|
||||||
match.gt_class if match.gt_class is not None else "__NONE__"
|
match.gt_class if match.gt_class is not None else "__NONE__"
|
||||||
for match in matches
|
for match in matches
|
||||||
@ -74,7 +74,7 @@ class ClassificationAccuracy(MetricsProtocol):
|
|||||||
y_pred = pd.DataFrame(
|
y_pred = pd.DataFrame(
|
||||||
[
|
[
|
||||||
{
|
{
|
||||||
name: match.class_scores.get(name, 0)
|
name: match.pred_class_scores.get(name, 0)
|
||||||
for name in self.class_names
|
for name in self.class_names
|
||||||
}
|
}
|
||||||
for match in matches
|
for match in matches
|
||||||
|
|||||||
@ -1,22 +1,40 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Protocol
|
from typing import Dict, List, Optional, Protocol
|
||||||
from uuid import UUID
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MetricsProtocol",
|
"MetricsProtocol",
|
||||||
"Match",
|
"MatchEvaluation",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Match:
|
class MatchEvaluation:
|
||||||
gt_uuid: Optional[UUID]
|
match: data.Match
|
||||||
|
|
||||||
gt_det: bool
|
gt_det: bool
|
||||||
gt_class: Optional[str]
|
gt_class: Optional[str]
|
||||||
|
|
||||||
pred_score: float
|
pred_score: float
|
||||||
affinity: float
|
pred_class_scores: Dict[str, float]
|
||||||
class_scores: Dict[str, float]
|
|
||||||
|
@property
|
||||||
|
def pred_class(self) -> Optional[str]:
|
||||||
|
if not self.pred_class_scores:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return max(self.pred_class_scores, key=self.pred_class_scores.get) # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pred_class_score(self) -> float:
|
||||||
|
pred_class = self.pred_class
|
||||||
|
|
||||||
|
if pred_class is None:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
return self.pred_class_scores[pred_class]
|
||||||
|
|
||||||
|
|
||||||
class MetricsProtocol(Protocol):
|
class MetricsProtocol(Protocol):
|
||||||
def __call__(self, matches: List[Match]) -> Dict[str, float]: ...
|
def __call__(self, matches: List[MatchEvaluation]) -> Dict[str, float]: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user