mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Task/Metrics restructure
This commit is contained in:
parent
d6ddc4514c
commit
df2abff654
@ -123,10 +123,7 @@ class BatDetect2API:
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
config=config.evaluation.evaluator,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
# NOTE: Better to have a separate instance of
|
# NOTE: Better to have a separate instance of
|
||||||
# preprocessor and postprocessor as these may be moved
|
# preprocessor and postprocessor as these may be moved
|
||||||
@ -178,10 +175,7 @@ class BatDetect2API:
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
config=config.evaluation.evaluator,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
@ -1,11 +1,14 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
from batdetect2.evaluate.evaluate import evaluate
|
from batdetect2.evaluate.evaluate import evaluate
|
||||||
from batdetect2.evaluate.evaluator import MultipleEvaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
|
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"load_evaluation_config",
|
"Evaluator",
|
||||||
"evaluate",
|
"TaskConfig",
|
||||||
"MultipleEvaluator",
|
|
||||||
"build_evaluator",
|
"build_evaluator",
|
||||||
|
"build_task",
|
||||||
|
"evaluate",
|
||||||
|
"load_evaluation_config",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,13 +1,14 @@
|
|||||||
from typing import Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.evaluator import (
|
from batdetect2.evaluate.tasks import (
|
||||||
EvaluatorConfig,
|
TaskConfig,
|
||||||
MultipleEvaluatorConfig,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -17,7 +18,12 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class EvaluationConfig(BaseConfig):
|
class EvaluationConfig(BaseConfig):
|
||||||
evaluator: EvaluatorConfig = Field(default_factory=MultipleEvaluatorConfig)
|
tasks: List[TaskConfig] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
DetectionTaskConfig(),
|
||||||
|
ClassificationTaskConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -55,10 +55,7 @@ def evaluate(
|
|||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
config=config.evaluation.evaluator,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = build_logger(
|
logger = build_logger(
|
||||||
config.evaluation.logger,
|
config.evaluation.logger,
|
||||||
|
|||||||
67
src/batdetect2/evaluate/evaluator.py
Normal file
67
src/batdetect2/evaluate/evaluator.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
from matplotlib.figure import Figure
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
|
from batdetect2.evaluate.tasks import build_task
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Evaluator",
|
||||||
|
"build_evaluator",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Evaluator:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
tasks: Sequence[EvaluatorProtocol],
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.tasks = tasks
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
|
) -> List[Any]:
|
||||||
|
return [
|
||||||
|
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||||
|
]
|
||||||
|
|
||||||
|
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for task, outputs in zip(self.tasks, eval_outputs):
|
||||||
|
results.update(task.compute_metrics(outputs))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def generate_plots(
|
||||||
|
self,
|
||||||
|
eval_outputs: List[Any],
|
||||||
|
) -> Iterable[Tuple[str, Figure]]:
|
||||||
|
for task, outputs in zip(self.tasks, eval_outputs):
|
||||||
|
for name, fig in task.generate_plots(outputs):
|
||||||
|
yield name, fig
|
||||||
|
|
||||||
|
|
||||||
|
def build_evaluator(
|
||||||
|
config: Optional[Union[EvaluationConfig, dict]] = None,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
) -> EvaluatorProtocol:
|
||||||
|
targets = targets or build_targets()
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = EvaluationConfig()
|
||||||
|
|
||||||
|
if not isinstance(config, EvaluationConfig):
|
||||||
|
config = EvaluationConfig.model_validate(config)
|
||||||
|
|
||||||
|
return Evaluator(
|
||||||
|
targets=targets,
|
||||||
|
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||||
|
)
|
||||||
@ -1,114 +0,0 @@
|
|||||||
from typing import (
|
|
||||||
Annotated,
|
|
||||||
Any,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.evaluate.evaluator.base import evaluators
|
|
||||||
from batdetect2.evaluate.evaluator.clip import ClipMetricsConfig
|
|
||||||
from batdetect2.evaluate.evaluator.per_class import ClassificationMetricsConfig
|
|
||||||
from batdetect2.evaluate.evaluator.single import GlobalEvaluatorConfig
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.typing import (
|
|
||||||
EvaluatorProtocol,
|
|
||||||
RawPrediction,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"EvaluatorConfig",
|
|
||||||
"build_evaluator",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
EvaluatorConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ClassificationMetricsConfig,
|
|
||||||
GlobalEvaluatorConfig,
|
|
||||||
ClipMetricsConfig,
|
|
||||||
"MultipleEvaluatorConfig",
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class MultipleEvaluatorConfig(BaseConfig):
|
|
||||||
name: Literal["multiple_evaluations"] = "multiple_evaluations"
|
|
||||||
evaluations: List[EvaluatorConfig] = Field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
ClassificationMetricsConfig(),
|
|
||||||
GlobalEvaluatorConfig(),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MultipleEvaluator:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
evaluators: Sequence[EvaluatorProtocol],
|
|
||||||
):
|
|
||||||
self.targets = targets
|
|
||||||
self.evaluators = evaluators
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
|
||||||
) -> List[Any]:
|
|
||||||
return [
|
|
||||||
evaluator.evaluate(
|
|
||||||
clip_annotations,
|
|
||||||
predictions,
|
|
||||||
)
|
|
||||||
for evaluator in self.evaluators
|
|
||||||
]
|
|
||||||
|
|
||||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for evaluator, outputs in zip(self.evaluators, eval_outputs):
|
|
||||||
results.update(evaluator.compute_metrics(outputs))
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def generate_plots(
|
|
||||||
self,
|
|
||||||
eval_outputs: List[Any],
|
|
||||||
) -> Iterable[Tuple[str, Figure]]:
|
|
||||||
for evaluator, outputs in zip(self.evaluators, eval_outputs):
|
|
||||||
for name, fig in evaluator.generate_plots(outputs):
|
|
||||||
yield name, fig
|
|
||||||
|
|
||||||
@evaluators.register(MultipleEvaluatorConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: MultipleEvaluatorConfig, targets: TargetProtocol):
|
|
||||||
return MultipleEvaluator(
|
|
||||||
evaluators=[
|
|
||||||
build_evaluator(conf, targets=targets)
|
|
||||||
for conf in config.evaluations
|
|
||||||
],
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_evaluator(
|
|
||||||
config: Optional[EvaluatorConfig] = None,
|
|
||||||
targets: Optional[TargetProtocol] = None,
|
|
||||||
) -> EvaluatorProtocol:
|
|
||||||
targets = targets or build_targets()
|
|
||||||
|
|
||||||
config = config or MultipleEvaluatorConfig()
|
|
||||||
return evaluators.build(config, targets)
|
|
||||||
@ -1,163 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Callable, Dict, List, Literal, Sequence, Set
|
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
|
||||||
from sklearn import metrics
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.evaluator.base import (
|
|
||||||
BaseEvaluator,
|
|
||||||
BaseEvaluatorConfig,
|
|
||||||
evaluators,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ClipInfo:
|
|
||||||
gt_det: bool
|
|
||||||
gt_classes: Set[str]
|
|
||||||
pred_score: float
|
|
||||||
pred_class_scores: Dict[str, float]
|
|
||||||
|
|
||||||
|
|
||||||
ClipMetric = Callable[[Sequence[ClipInfo]], float]
|
|
||||||
|
|
||||||
|
|
||||||
def clip_detection_average_precision(
|
|
||||||
clip_evaluations: Sequence[ClipInfo],
|
|
||||||
) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
y_true.append(clip_eval.gt_det)
|
|
||||||
y_score.append(clip_eval.pred_score)
|
|
||||||
|
|
||||||
return average_precision(y_true=y_true, y_score=y_score)
|
|
||||||
|
|
||||||
|
|
||||||
def clip_detection_roc_auc(
|
|
||||||
clip_evaluations: Sequence[ClipInfo],
|
|
||||||
) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
y_true.append(clip_eval.gt_det)
|
|
||||||
y_score.append(clip_eval.pred_score)
|
|
||||||
|
|
||||||
return float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
|
||||||
|
|
||||||
|
|
||||||
clip_metrics = {
|
|
||||||
"average_precision": clip_detection_average_precision,
|
|
||||||
"roc_auc": clip_detection_roc_auc,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ClipMetricsConfig(BaseEvaluatorConfig):
|
|
||||||
name: Literal["clip"] = "clip"
|
|
||||||
prefix: str = "clip"
|
|
||||||
metrics: List[str] = Field(
|
|
||||||
default_factory=lambda: [
|
|
||||||
"average_precision",
|
|
||||||
"roc_auc",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@field_validator("metrics", mode="after")
|
|
||||||
@classmethod
|
|
||||||
def validate_metrics(cls, v: List[str]) -> List[str]:
|
|
||||||
for metric_name in v:
|
|
||||||
if metric_name not in clip_metrics:
|
|
||||||
raise ValueError(f"Unknown metric {metric_name}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class ClipEvaluator(BaseEvaluator):
|
|
||||||
def __init__(self, *args, metrics: Dict[str, ClipMetric], **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.metrics = metrics
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
|
||||||
) -> List[ClipInfo]:
|
|
||||||
return [
|
|
||||||
self.match_clip(clip_annotation, preds)
|
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
|
||||||
]
|
|
||||||
|
|
||||||
def compute_metrics(
|
|
||||||
self,
|
|
||||||
eval_outputs: List[ClipInfo],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
scores = {
|
|
||||||
name: metric(eval_outputs) for name, metric in self.metrics.items()
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def match_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipInfo:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
gt_det = False
|
|
||||||
gt_classes = set()
|
|
||||||
for sound_event in clip_annotation.sound_events:
|
|
||||||
if self.filter_sound_event_annotations(sound_event, clip):
|
|
||||||
continue
|
|
||||||
|
|
||||||
gt_det = True
|
|
||||||
class_name = self.targets.encode_class(sound_event)
|
|
||||||
|
|
||||||
if class_name is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
gt_classes.add(class_name)
|
|
||||||
|
|
||||||
pred_score = 0
|
|
||||||
pred_class_scores: defaultdict[str, float] = defaultdict(lambda: 0)
|
|
||||||
for pred in predictions:
|
|
||||||
if self.filter_predictions(pred, clip):
|
|
||||||
continue
|
|
||||||
|
|
||||||
pred_score = max(pred_score, pred.detection_score)
|
|
||||||
|
|
||||||
for class_name, class_score in zip(
|
|
||||||
self.targets.class_names,
|
|
||||||
pred.class_scores,
|
|
||||||
):
|
|
||||||
pred_class_scores[class_name] = max(
|
|
||||||
pred_class_scores[class_name],
|
|
||||||
class_score,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ClipInfo(
|
|
||||||
gt_det=gt_det,
|
|
||||||
gt_classes=gt_classes,
|
|
||||||
pred_score=pred_score,
|
|
||||||
pred_class_scores=pred_class_scores,
|
|
||||||
)
|
|
||||||
|
|
||||||
@evaluators.register(ClipMetricsConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ClipMetricsConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = {name: clip_metrics.get(name) for name in config.metrics}
|
|
||||||
return ClipEvaluator.build(
|
|
||||||
config=config,
|
|
||||||
metrics=metrics,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
@ -1,219 +0,0 @@
|
|||||||
from collections import defaultdict
|
|
||||||
from typing import (
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Mapping,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
)
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.evaluator.base import (
|
|
||||||
BaseEvaluator,
|
|
||||||
BaseEvaluatorConfig,
|
|
||||||
evaluators,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.match import match
|
|
||||||
from batdetect2.evaluate.metrics.per_class_matches import (
|
|
||||||
ClassificationAveragePrecisionConfig,
|
|
||||||
PerClassMatchMetric,
|
|
||||||
PerClassMatchMetricConfig,
|
|
||||||
build_per_class_matches_metric,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import (
|
|
||||||
ClipMatches,
|
|
||||||
RawPrediction,
|
|
||||||
TargetProtocol,
|
|
||||||
)
|
|
||||||
|
|
||||||
ScoreFn = Callable[[RawPrediction, int], float]
|
|
||||||
|
|
||||||
|
|
||||||
def score_by_class_score(pred: RawPrediction, class_index: int) -> float:
|
|
||||||
return float(pred.class_scores[class_index])
|
|
||||||
|
|
||||||
|
|
||||||
def score_by_adjusted_class_score(
|
|
||||||
pred: RawPrediction,
|
|
||||||
class_index: int,
|
|
||||||
) -> float:
|
|
||||||
return float(pred.class_scores[class_index]) * pred.detection_score
|
|
||||||
|
|
||||||
|
|
||||||
ScoreFunctionOption = Literal["class_score", "adjusted_class_score"]
|
|
||||||
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
|
|
||||||
"class_score": score_by_class_score,
|
|
||||||
"adjusted_class_score": score_by_adjusted_class_score,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
|
|
||||||
return score_functions[name]
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationMetricsConfig(BaseEvaluatorConfig):
|
|
||||||
name: Literal["classification"] = "classification"
|
|
||||||
prefix: str = "classification"
|
|
||||||
include_generics: bool = True
|
|
||||||
score_by: ScoreFunctionOption = "class_score"
|
|
||||||
metrics: List[PerClassMatchMetricConfig] = Field(
|
|
||||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
|
||||||
)
|
|
||||||
include: Optional[List[str]] = None
|
|
||||||
exclude: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class PerClassEvaluator(BaseEvaluator):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
metrics: Dict[str, PerClassMatchMetric],
|
|
||||||
score_fn: ScoreFn,
|
|
||||||
include_generics: bool = True,
|
|
||||||
include: Optional[List[str]] = None,
|
|
||||||
exclude: Optional[List[str]] = None,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
self.score_fn = score_fn
|
|
||||||
self.metrics = metrics
|
|
||||||
|
|
||||||
self.include_generics = include_generics
|
|
||||||
|
|
||||||
self.include = include
|
|
||||||
self.exclude = exclude
|
|
||||||
|
|
||||||
self.selected = self.targets.class_names
|
|
||||||
if include is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name in include
|
|
||||||
]
|
|
||||||
|
|
||||||
if exclude is not None:
|
|
||||||
self.selected = [
|
|
||||||
class_name
|
|
||||||
for class_name in self.selected
|
|
||||||
if class_name not in exclude
|
|
||||||
]
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
|
||||||
) -> Dict[str, List[ClipMatches]]:
|
|
||||||
ret = defaultdict(list)
|
|
||||||
|
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions):
|
|
||||||
matches = self.match_clip(clip_annotation, preds)
|
|
||||||
for class_name, clip_eval in matches.items():
|
|
||||||
ret[class_name].append(clip_eval)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def compute_metrics(
|
|
||||||
self,
|
|
||||||
eval_outputs: Dict[str, List[ClipMatches]],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
results = {}
|
|
||||||
|
|
||||||
for metric_name, metric in self.metrics.items():
|
|
||||||
class_scores = {
|
|
||||||
class_name: metric(eval_outputs[class_name], class_name)
|
|
||||||
for class_name in self.targets.class_names
|
|
||||||
}
|
|
||||||
mean = float(
|
|
||||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
|
||||||
)
|
|
||||||
|
|
||||||
results[f"{self.prefix}/mean_{metric_name}"] = mean
|
|
||||||
|
|
||||||
for class_name, value in class_scores.items():
|
|
||||||
if class_name not in self.selected:
|
|
||||||
continue
|
|
||||||
|
|
||||||
results[f"{self.prefix}/{metric_name}/{class_name}"] = value
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
def match_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> Dict[str, ClipMatches]:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
preds = [
|
|
||||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
|
||||||
]
|
|
||||||
|
|
||||||
all_gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if self.filter_sound_event_annotations(sound_event, clip)
|
|
||||||
]
|
|
||||||
|
|
||||||
ret = {}
|
|
||||||
|
|
||||||
for class_name in self.targets.class_names:
|
|
||||||
class_idx = self.targets.class_names.index(class_name)
|
|
||||||
|
|
||||||
# Only match to targets of the given class
|
|
||||||
gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in all_gts
|
|
||||||
if self.is_class(sound_event, class_name)
|
|
||||||
]
|
|
||||||
scores = [self.score_fn(pred, class_idx) for pred in preds]
|
|
||||||
|
|
||||||
ret[class_name] = match(
|
|
||||||
gts,
|
|
||||||
preds,
|
|
||||||
clip=clip,
|
|
||||||
scores=scores,
|
|
||||||
targets=self.targets,
|
|
||||||
matcher=self.matcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def is_class(
|
|
||||||
self,
|
|
||||||
sound_event: data.SoundEventAnnotation,
|
|
||||||
class_name: str,
|
|
||||||
) -> bool:
|
|
||||||
sound_event_class = self.targets.encode_class(sound_event)
|
|
||||||
|
|
||||||
if sound_event_class is None and self.include_generics:
|
|
||||||
# Sound events that are generic could be of the given
|
|
||||||
# class
|
|
||||||
return True
|
|
||||||
|
|
||||||
return sound_event_class == class_name
|
|
||||||
|
|
||||||
@evaluators.register(ClassificationMetricsConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: ClassificationMetricsConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = {
|
|
||||||
metric.name: build_per_class_matches_metric(metric)
|
|
||||||
for metric in config.metrics
|
|
||||||
}
|
|
||||||
return PerClassEvaluator.build(
|
|
||||||
config=config,
|
|
||||||
targets=targets,
|
|
||||||
metrics=metrics,
|
|
||||||
score_fn=get_score_fn(config.score_by),
|
|
||||||
include_generics=config.include_generics,
|
|
||||||
include=config.include,
|
|
||||||
exclude=config.exclude,
|
|
||||||
)
|
|
||||||
@ -1,126 +0,0 @@
|
|||||||
from typing import Callable, Dict, List, Literal, Mapping, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.evaluator.base import (
|
|
||||||
BaseEvaluator,
|
|
||||||
BaseEvaluatorConfig,
|
|
||||||
evaluators,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.match import match
|
|
||||||
from batdetect2.evaluate.metrics.matches import (
|
|
||||||
DetectionAveragePrecisionConfig,
|
|
||||||
MatchesMetric,
|
|
||||||
MatchMetricConfig,
|
|
||||||
build_match_metric,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
ScoreFn = Callable[[RawPrediction], float]
|
|
||||||
|
|
||||||
|
|
||||||
def score_by_detection_score(pred: RawPrediction) -> float:
|
|
||||||
return pred.detection_score
|
|
||||||
|
|
||||||
|
|
||||||
def score_by_top_class_score(pred: RawPrediction) -> float:
|
|
||||||
return pred.class_scores.max()
|
|
||||||
|
|
||||||
|
|
||||||
ScoreFunctionOption = Literal["detection_score", "top_class_score"]
|
|
||||||
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
|
|
||||||
"detection_score": score_by_detection_score,
|
|
||||||
"top_class_score": score_by_top_class_score,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
|
|
||||||
return score_functions[name]
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalEvaluatorConfig(BaseEvaluatorConfig):
|
|
||||||
name: Literal["detection"] = "detection"
|
|
||||||
prefix: str = "detection"
|
|
||||||
score_by: ScoreFunctionOption = "detection_score"
|
|
||||||
metrics: List[MatchMetricConfig] = Field(
|
|
||||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GlobalEvaluator(BaseEvaluator):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*args,
|
|
||||||
score_fn: ScoreFn,
|
|
||||||
metrics: Dict[str, MatchesMetric],
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.metrics = metrics
|
|
||||||
self.score_fn = score_fn
|
|
||||||
|
|
||||||
def compute_metrics(
|
|
||||||
self,
|
|
||||||
eval_outputs: List[ClipMatches],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
scores = {
|
|
||||||
name: metric(eval_outputs) for name, metric in self.metrics.items()
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def evaluate(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
|
||||||
) -> List[ClipMatches]:
|
|
||||||
return [
|
|
||||||
self.match_clip(clip_annotation, preds)
|
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
|
||||||
]
|
|
||||||
|
|
||||||
def match_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipMatches:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if self.filter_sound_event_annotations(sound_event, clip)
|
|
||||||
]
|
|
||||||
preds = [
|
|
||||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
|
||||||
]
|
|
||||||
scores = [self.score_fn(pred) for pred in preds]
|
|
||||||
|
|
||||||
return match(
|
|
||||||
gts,
|
|
||||||
preds,
|
|
||||||
scores=scores,
|
|
||||||
clip=clip,
|
|
||||||
targets=self.targets,
|
|
||||||
matcher=self.matcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
@evaluators.register(GlobalEvaluatorConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: GlobalEvaluatorConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = {
|
|
||||||
metric.name: build_match_metric(metric)
|
|
||||||
for metric in config.metrics
|
|
||||||
}
|
|
||||||
score_fn = get_score_fn(config.score_by)
|
|
||||||
return GlobalEvaluator.build(
|
|
||||||
config=config,
|
|
||||||
score_fn=score_fn,
|
|
||||||
metrics=metrics,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
@ -1,133 +0,0 @@
|
|||||||
from typing import Dict, List, Literal, Sequence
|
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.evaluate.match import match
|
|
||||||
from batdetect2.evaluate.metrics.base import (
|
|
||||||
BaseMetric,
|
|
||||||
BaseMetricConfig,
|
|
||||||
metrics_registry,
|
|
||||||
)
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
from batdetect2.evaluate.metrics.detection import DetectionMetric
|
|
||||||
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"TopClassEvaluator",
|
|
||||||
"TopClassEvaluatorConfig",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def top_class_average_precision(
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
is_generic = m.gt_det and (m.gt_class is None)
|
|
||||||
|
|
||||||
# Ignore ground truth sounds with unknown class
|
|
||||||
if is_generic:
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_positives += int(m.gt_det)
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if m.pred_geometry is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.gt_det & (m.top_class == m.gt_class))
|
|
||||||
y_score.append(m.top_class_score)
|
|
||||||
|
|
||||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
|
||||||
|
|
||||||
|
|
||||||
top_class_metrics = {
|
|
||||||
"average_precision": top_class_average_precision,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassEvaluatorConfig(BaseMetricConfig):
|
|
||||||
name: Literal["top_class"] = "top_class"
|
|
||||||
prefix: str = "top_class"
|
|
||||||
metrics: List[str] = Field(default_factory=lambda: ["average_precision"])
|
|
||||||
|
|
||||||
@field_validator("metrics", mode="after")
|
|
||||||
@classmethod
|
|
||||||
def validate_metrics(cls, v: List[str]) -> List[str]:
|
|
||||||
for metric_name in v:
|
|
||||||
if metric_name not in top_class_metrics:
|
|
||||||
raise ValueError(f"Unknown metric {metric_name}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassEvaluator(BaseMetric):
|
|
||||||
def __init__(self, *args, metrics: Dict[str, DetectionMetric], **kwargs):
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
self.metrics = metrics
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
clip_evaluations = [
|
|
||||||
self.match_clip(clip_annotation, preds)
|
|
||||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
|
||||||
]
|
|
||||||
scores = {
|
|
||||||
name: metric(clip_evaluations)
|
|
||||||
for name, metric in self.metrics.items()
|
|
||||||
}
|
|
||||||
return {
|
|
||||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
|
||||||
}
|
|
||||||
|
|
||||||
def match_clip(
|
|
||||||
self,
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
predictions: Sequence[RawPrediction],
|
|
||||||
) -> ClipMatches:
|
|
||||||
clip = clip_annotation.clip
|
|
||||||
|
|
||||||
gts = [
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if self.filter_sound_event_annotations(sound_event, clip)
|
|
||||||
]
|
|
||||||
preds = [
|
|
||||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
|
||||||
]
|
|
||||||
# Use score of top class for matching
|
|
||||||
scores = [pred.class_scores.max() for pred in preds]
|
|
||||||
|
|
||||||
return match(
|
|
||||||
gts,
|
|
||||||
preds,
|
|
||||||
scores=scores,
|
|
||||||
clip=clip,
|
|
||||||
targets=self.targets,
|
|
||||||
matcher=self.matcher,
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_config(
|
|
||||||
cls,
|
|
||||||
config: TopClassEvaluatorConfig,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
):
|
|
||||||
metrics = {
|
|
||||||
name: top_class_metrics.get(name) for name in config.metrics
|
|
||||||
}
|
|
||||||
return super().build(
|
|
||||||
config=config,
|
|
||||||
metrics=metrics,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(TopClassEvaluatorConfig, TopClassEvaluator)
|
|
||||||
253
src/batdetect2/evaluate/metrics/classification.py
Normal file
253
src/batdetect2/evaluate/metrics/classification.py
Normal file
@ -0,0 +1,253 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Mapping,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
from batdetect2.typing import RawPrediction
|
||||||
|
|
||||||
|
__all__ = []
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchEval:
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
is_prediction: bool
|
||||||
|
is_ground_truth: bool
|
||||||
|
is_generic: bool
|
||||||
|
true_class: Optional[str]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
clip: data.Clip
|
||||||
|
matches: Mapping[str, List[MatchEval]]
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
|
classification_metrics: Registry[ClassificationMetric, []] = Registry(
|
||||||
|
"classification_metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClassificationConfig(BaseConfig):
|
||||||
|
include: Optional[List[str]] = None
|
||||||
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
|
class BaseClassificationMetric:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.include = include
|
||||||
|
self.exclude = exclude
|
||||||
|
|
||||||
|
def include_class(self, class_name: str) -> bool:
|
||||||
|
if self.include is not None:
|
||||||
|
return class_name in self.include
|
||||||
|
|
||||||
|
if self.exclude is not None:
|
||||||
|
return class_name not in self.exclude
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
label: str = "average_precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
label: str = "average_precision",
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
super().__init__(include=include, exclude=exclude)
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEval]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
num_positives = defaultdict(lambda: 0)
|
||||||
|
|
||||||
|
class_names = set()
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, matches in clip_eval.matches.items():
|
||||||
|
class_names.add(class_name)
|
||||||
|
|
||||||
|
for m in matches:
|
||||||
|
# Exclude matches with ground truth sounds where the class
|
||||||
|
# is unknown
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
continue
|
||||||
|
|
||||||
|
is_class = m.true_class == class_name
|
||||||
|
|
||||||
|
if is_class:
|
||||||
|
num_positives[class_name] += 1
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true[class_name].append(is_class)
|
||||||
|
y_score[class_name].append(m.score)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: average_precision(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
num_positives=num_positives[class_name],
|
||||||
|
)
|
||||||
|
for class_name in class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
mean_score = float(
|
||||||
|
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": mean_score,
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if self.include_class(class_name)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClassificationAveragePrecisionConfig):
|
||||||
|
return ClassificationAveragePrecision(
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
include=config.include,
|
||||||
|
exclude=config.exclude,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUCConfig(BaseClassificationConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
ignore_generic: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationROCAUC(BaseClassificationMetric):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
label: str = "roc_auc",
|
||||||
|
include: Optional[List[str]] = None,
|
||||||
|
exclude: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.label = label
|
||||||
|
self.include = include
|
||||||
|
self.exclude = exclude
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEval]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
|
||||||
|
class_names = set()
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, matches in clip_eval.matches.items():
|
||||||
|
class_names.add(class_name)
|
||||||
|
|
||||||
|
for m in matches:
|
||||||
|
# Exclude matches with ground truth sounds where the class
|
||||||
|
# is unknown
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true[class_name].append(m.true_class == class_name)
|
||||||
|
y_score[class_name].append(m.score)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: float(
|
||||||
|
metrics.roc_auc_score(
|
||||||
|
y_true[class_name],
|
||||||
|
y_score[class_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for class_name in class_names
|
||||||
|
}
|
||||||
|
|
||||||
|
mean_score = float(
|
||||||
|
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": mean_score,
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if self.include_class(class_name)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@classification_metrics.register(ClassificationROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClassificationROCAUCConfig):
|
||||||
|
return ClassificationROCAUC(
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClassificationMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClassificationAveragePrecisionConfig,
|
||||||
|
ClassificationROCAUCConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_classification_metrics(config: ClassificationMetricConfig):
|
||||||
|
return classification_metrics.build(config)
|
||||||
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
true_classes: Set[str]
|
||||||
|
class_scores: Dict[str, float]
|
||||||
|
|
||||||
|
|
||||||
|
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
||||||
|
"clip_classification_metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationAveragePrecision:
|
||||||
|
def __init__(self, label: str = "average_precision"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, score in clip_eval.class_scores.items():
|
||||||
|
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||||
|
y_score[class_name].append(score)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: float(
|
||||||
|
average_precision(
|
||||||
|
y_true=y_true[class_name],
|
||||||
|
y_score=y_score[class_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for class_name in y_true
|
||||||
|
}
|
||||||
|
|
||||||
|
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": float(mean),
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if not np.isnan(score)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@clip_classification_metrics.register(
|
||||||
|
ClipClassificationAveragePrecisionConfig
|
||||||
|
)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipClassificationAveragePrecisionConfig):
|
||||||
|
return ClipClassificationAveragePrecision(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationROCAUC:
|
||||||
|
def __init__(self, label: str = "roc_auc"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = defaultdict(list)
|
||||||
|
y_score = defaultdict(list)
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for class_name, score in clip_eval.class_scores.items():
|
||||||
|
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||||
|
y_score[class_name].append(score)
|
||||||
|
|
||||||
|
class_scores = {
|
||||||
|
class_name: float(
|
||||||
|
metrics.roc_auc_score(
|
||||||
|
y_true=y_true[class_name],
|
||||||
|
y_score=y_score[class_name],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
for class_name in y_true
|
||||||
|
}
|
||||||
|
|
||||||
|
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"mean_{self.label}": float(mean),
|
||||||
|
**{
|
||||||
|
f"{self.label}/{class_name}": score
|
||||||
|
for class_name, score in class_scores.items()
|
||||||
|
if not np.isnan(score)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipClassificationROCAUCConfig):
|
||||||
|
return ClipClassificationROCAUC(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
ClipClassificationMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClipClassificationAveragePrecisionConfig,
|
||||||
|
ClipClassificationROCAUCConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_metric(config: ClipClassificationMetricConfig):
|
||||||
|
return clip_classification_metrics.build(config)
|
||||||
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
gt_det: bool
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
||||||
|
"clip_detection_metric"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionAveragePrecision:
|
||||||
|
def __init__(self, label: str = "average_precision"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
y_true.append(clip_eval.gt_det)
|
||||||
|
y_score.append(clip_eval.score)
|
||||||
|
|
||||||
|
score = average_precision(y_true=y_true, y_score=y_score)
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionAveragePrecisionConfig):
|
||||||
|
return ClipDetectionAveragePrecision(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUC:
|
||||||
|
def __init__(self, label: str = "roc_auc"):
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
y_true.append(clip_eval.gt_det)
|
||||||
|
y_score.append(clip_eval.score)
|
||||||
|
|
||||||
|
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionROCAUCConfig):
|
||||||
|
return ClipDetectionROCAUC(label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionRecallConfig(BaseConfig):
|
||||||
|
name: Literal["recall"] = "recall"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "recall"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionRecall:
|
||||||
|
def __init__(self, threshold: float, label: str = "recall"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_positives = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
if clip_eval.gt_det:
|
||||||
|
num_positives += 1
|
||||||
|
|
||||||
|
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_positives == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_positives
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionRecallConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionRecallConfig):
|
||||||
|
return ClipDetectionRecall(
|
||||||
|
threshold=config.threshold, label=config.label
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionPrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["precision"] = "precision"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "precision"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionPrecision:
|
||||||
|
def __init__(self, threshold: float, label: str = "precision"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_detections = 0
|
||||||
|
true_positives = 0
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
if clip_eval.score >= self.threshold:
|
||||||
|
num_detections += 1
|
||||||
|
|
||||||
|
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_detections == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_detections
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: ClipDetectionPrecisionConfig):
|
||||||
|
return ClipDetectionPrecision(
|
||||||
|
threshold=config.threshold, label=config.label
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
ClipDetectionMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClipDetectionAveragePrecisionConfig,
|
||||||
|
ClipDetectionROCAUCConfig,
|
||||||
|
ClipDetectionRecallConfig,
|
||||||
|
ClipDetectionPrecisionConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_clip_metric(config: ClipDetectionMetricConfig):
|
||||||
|
return clip_detection_metrics.build(config)
|
||||||
@ -1,24 +1,24 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"compute_precision_recall",
|
||||||
|
"average_precision",
|
||||||
|
]
|
||||||
|
|
||||||
def average_precision(
|
|
||||||
|
def compute_precision_recall(
|
||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
num_positives: Optional[int] = None,
|
num_positives: Optional[int] = None,
|
||||||
) -> float:
|
) -> Tuple[np.ndarray, np.ndarray]:
|
||||||
y_true = np.array(y_true)
|
y_true = np.array(y_true)
|
||||||
y_score = np.array(y_score)
|
y_score = np.array(y_score)
|
||||||
|
|
||||||
if num_positives is None:
|
if num_positives is None:
|
||||||
num_positives = y_true.sum()
|
num_positives = y_true.sum()
|
||||||
|
|
||||||
# Remove non-detections
|
|
||||||
valid_inds = y_score > 0
|
|
||||||
y_true = y_true[valid_inds]
|
|
||||||
y_score = y_score[valid_inds]
|
|
||||||
|
|
||||||
# Sort by score
|
# Sort by score
|
||||||
sort_ind = np.argsort(y_score)[::-1]
|
sort_ind = np.argsort(y_score)[::-1]
|
||||||
y_true_sorted = y_true[sort_ind]
|
y_true_sorted = y_true[sort_ind]
|
||||||
@ -34,6 +34,19 @@ def average_precision(
|
|||||||
|
|
||||||
precision[np.isnan(precision)] = 0
|
precision[np.isnan(precision)] = 0
|
||||||
recall[np.isnan(recall)] = 0
|
recall[np.isnan(recall)] = 0
|
||||||
|
return precision, recall
|
||||||
|
|
||||||
|
|
||||||
|
def average_precision(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives: Optional[int] = None,
|
||||||
|
) -> float:
|
||||||
|
precision, recall = compute_precision_recall(
|
||||||
|
y_true,
|
||||||
|
y_score,
|
||||||
|
num_positives=num_positives,
|
||||||
|
)
|
||||||
|
|
||||||
# pascal 12 way
|
# pascal 12 way
|
||||||
mprec = np.hstack((0, precision, 0))
|
mprec = np.hstack((0, precision, 0))
|
||||||
|
|||||||
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
@ -0,0 +1,226 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
from batdetect2.typing import RawPrediction
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DetectionMetricConfig",
|
||||||
|
"DetectionMetric",
|
||||||
|
"build_detection_metric",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchEval:
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
is_prediction: bool
|
||||||
|
is_ground_truth: bool
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
clip: data.Clip
|
||||||
|
matches: List[MatchEval]
|
||||||
|
|
||||||
|
|
||||||
|
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
|
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionAveragePrecision:
|
||||||
|
def __init__(self, label: str, ignore_non_predictions: bool = True):
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
||||||
|
return {self.label: ap}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionAveragePrecisionConfig):
|
||||||
|
return DetectionAveragePrecision(
|
||||||
|
label=config.label,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
label: str = "roc_auc"
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionROCAUC:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
label: str = "roc_auc",
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
):
|
||||||
|
self.label = label
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
|
||||||
|
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||||
|
y_true: List[bool] = []
|
||||||
|
y_score: List[float] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore matches that don't correspond to a prediction
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.is_ground_truth)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionROCAUCConfig):
|
||||||
|
return DetectionROCAUC(
|
||||||
|
label=config.label,
|
||||||
|
ignore_non_predictions=config.ignore_non_predictions,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionRecallConfig(BaseConfig):
|
||||||
|
name: Literal["recall"] = "recall"
|
||||||
|
label: str = "recall"
|
||||||
|
threshold: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionRecall:
|
||||||
|
def __init__(self, threshold: float, label: str = "recall"):
|
||||||
|
self.label = label
|
||||||
|
self.threshold = threshold
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_positives = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_ground_truth:
|
||||||
|
num_positives += 1
|
||||||
|
|
||||||
|
if m.score >= self.threshold and m.is_ground_truth:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_positives == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_positives
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionRecallConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionRecallConfig):
|
||||||
|
return DetectionRecall(threshold=config.threshold, label=config.label)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["precision"] = "precision"
|
||||||
|
label: str = "precision"
|
||||||
|
threshold: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionPrecision:
|
||||||
|
def __init__(self, threshold: float, label: str = "precision"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_detections = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
is_detection = m.score >= self.threshold
|
||||||
|
|
||||||
|
if is_detection:
|
||||||
|
num_detections += 1
|
||||||
|
|
||||||
|
if is_detection and m.is_ground_truth:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_detections == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_detections
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@detection_metrics.register(DetectionPrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DetectionPrecisionConfig):
|
||||||
|
return DetectionPrecision(
|
||||||
|
threshold=config.threshold,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
DetectionMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
DetectionAveragePrecisionConfig,
|
||||||
|
DetectionROCAUCConfig,
|
||||||
|
DetectionRecallConfig,
|
||||||
|
DetectionPrecisionConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_detection_metric(config: DetectionMetricConfig):
|
||||||
|
return detection_metrics.build(config)
|
||||||
@ -1,235 +0,0 @@
|
|||||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
from batdetect2.typing import (
|
|
||||||
ClipMatches,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MatchMetricConfig",
|
|
||||||
"MatchesMetric",
|
|
||||||
"build_match_metric",
|
|
||||||
]
|
|
||||||
|
|
||||||
MatchesMetric = Callable[[Sequence[ClipMatches]], float]
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry: Registry[MatchesMetric, []] = Registry("match_metric")
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["detection_average_precision"] = (
|
|
||||||
"detection_average_precision"
|
|
||||||
)
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionAveragePrecision:
|
|
||||||
def __init__(self, ignore_non_predictions: bool = True):
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
num_positives += int(m.gt_det)
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.gt_det)
|
|
||||||
y_score.append(m.pred_score)
|
|
||||||
|
|
||||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
|
||||||
|
|
||||||
@metrics_registry.register(DetectionAveragePrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionAveragePrecisionConfig):
|
|
||||||
return DetectionAveragePrecision(
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["top_class_average_precision"] = (
|
|
||||||
"top_class_average_precision"
|
|
||||||
)
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class TopClassAveragePrecision:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
):
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
# Ignore ground truth sounds with unknown class
|
|
||||||
continue
|
|
||||||
|
|
||||||
num_positives += int(m.gt_det)
|
|
||||||
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.gt_det & (m.top_class == m.gt_class))
|
|
||||||
y_score.append(m.top_class_score)
|
|
||||||
|
|
||||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
|
||||||
|
|
||||||
@metrics_registry.register(TopClassAveragePrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: TopClassAveragePrecisionConfig):
|
|
||||||
return TopClassAveragePrecision(
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionROCAUC:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
):
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
|
|
||||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.gt_det)
|
|
||||||
y_score.append(m.pred_score)
|
|
||||||
|
|
||||||
return float(metrics.roc_auc_score(y_true, y_score))
|
|
||||||
|
|
||||||
@metrics_registry.register(DetectionROCAUCConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionROCAUCConfig):
|
|
||||||
return DetectionROCAUC(
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionRecallConfig(BaseConfig):
|
|
||||||
name: Literal["detection_recall"] = "detection_recall"
|
|
||||||
threshold: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionRecall:
|
|
||||||
def __init__(self, threshold: float):
|
|
||||||
self.threshold = threshold
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
) -> float:
|
|
||||||
num_positives = 0
|
|
||||||
true_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
if m.gt_det:
|
|
||||||
num_positives += 1
|
|
||||||
|
|
||||||
if m.pred_score >= self.threshold and m.gt_det:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_positives == 0:
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return true_positives / num_positives
|
|
||||||
|
|
||||||
@metrics_registry.register(DetectionRecallConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionRecallConfig):
|
|
||||||
return DetectionRecall(threshold=config.threshold)
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["detection_precision"] = "detection_precision"
|
|
||||||
threshold: float = 0.5
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionPrecision:
|
|
||||||
def __init__(self, threshold: float):
|
|
||||||
self.threshold = threshold
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
) -> float:
|
|
||||||
num_detections = 0
|
|
||||||
true_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
is_detection = m.pred_score >= self.threshold
|
|
||||||
|
|
||||||
if is_detection:
|
|
||||||
num_detections += 1
|
|
||||||
|
|
||||||
if is_detection and m.gt_det:
|
|
||||||
true_positives += 1
|
|
||||||
|
|
||||||
if num_detections == 0:
|
|
||||||
return np.nan
|
|
||||||
|
|
||||||
return true_positives / num_detections
|
|
||||||
|
|
||||||
@metrics_registry.register(DetectionPrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: DetectionPrecisionConfig):
|
|
||||||
return DetectionPrecision(threshold=config.threshold)
|
|
||||||
|
|
||||||
|
|
||||||
MatchMetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
DetectionAveragePrecisionConfig,
|
|
||||||
DetectionROCAUCConfig,
|
|
||||||
DetectionRecallConfig,
|
|
||||||
DetectionPrecisionConfig,
|
|
||||||
TopClassAveragePrecisionConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_match_metric(config: MatchMetricConfig):
|
|
||||||
return metrics_registry.build(config)
|
|
||||||
@ -1,136 +0,0 @@
|
|||||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from sklearn import metrics
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
|
||||||
from batdetect2.evaluate.metrics.common import average_precision
|
|
||||||
from batdetect2.typing import (
|
|
||||||
ClipMatches,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = []
|
|
||||||
|
|
||||||
PerClassMatchMetric = Callable[[Sequence[ClipMatches], str], float]
|
|
||||||
|
|
||||||
|
|
||||||
metrics_registry: Registry[PerClassMatchMetric, []] = Registry(
|
|
||||||
"match_metric"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAveragePrecisionConfig(BaseConfig):
|
|
||||||
name: Literal["classification_average_precision"] = (
|
|
||||||
"classification_average_precision"
|
|
||||||
)
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationAveragePrecision:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
):
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
class_name: str,
|
|
||||||
) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
num_positives = 0
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
is_class = m.gt_class == class_name
|
|
||||||
|
|
||||||
if is_class:
|
|
||||||
num_positives += 1
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Exclude matches with ground truth sounds where the class is
|
|
||||||
# unknown
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(is_class)
|
|
||||||
y_score.append(m.pred_class_scores.get(class_name, 0))
|
|
||||||
|
|
||||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
|
||||||
|
|
||||||
@metrics_registry.register(ClassificationAveragePrecisionConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClassificationAveragePrecisionConfig):
|
|
||||||
return ClassificationAveragePrecision(
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCAUCConfig(BaseConfig):
|
|
||||||
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
|
||||||
ignore_non_predictions: bool = True
|
|
||||||
ignore_generic: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationROCAUC:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
ignore_non_predictions: bool = True,
|
|
||||||
ignore_generic: bool = True,
|
|
||||||
):
|
|
||||||
self.ignore_non_predictions = ignore_non_predictions
|
|
||||||
self.ignore_generic = ignore_generic
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
clip_evaluations: Sequence[ClipMatches],
|
|
||||||
class_name: str,
|
|
||||||
) -> float:
|
|
||||||
y_true = []
|
|
||||||
y_score = []
|
|
||||||
|
|
||||||
for clip_eval in clip_evaluations:
|
|
||||||
for m in clip_eval.matches:
|
|
||||||
# Exclude matches with ground truth sounds where the class is
|
|
||||||
# unknown
|
|
||||||
if m.is_generic and self.ignore_generic:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Ignore matches that don't correspond to a prediction
|
|
||||||
if not m.is_prediction and self.ignore_non_predictions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
y_true.append(m.gt_class == class_name)
|
|
||||||
y_score.append(m.pred_class_scores.get(class_name, 0))
|
|
||||||
|
|
||||||
return float(metrics.roc_auc_score(y_true, y_score))
|
|
||||||
|
|
||||||
@metrics_registry.register(ClassificationROCAUCConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: ClassificationROCAUCConfig):
|
|
||||||
return ClassificationROCAUC(
|
|
||||||
ignore_non_predictions=config.ignore_non_predictions,
|
|
||||||
ignore_generic=config.ignore_generic,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
PerClassMatchMetricConfig = Annotated[
|
|
||||||
Union[
|
|
||||||
ClassificationAveragePrecisionConfig,
|
|
||||||
ClassificationROCAUCConfig,
|
|
||||||
],
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_per_class_matches_metric(config: PerClassMatchMetricConfig):
|
|
||||||
return metrics_registry.build(config)
|
|
||||||
313
src/batdetect2/evaluate/metrics/top_class.py
Normal file
313
src/batdetect2/evaluate/metrics/top_class.py
Normal file
@ -0,0 +1,313 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
|
from sklearn import metrics, preprocessing
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.evaluate.metrics.common import average_precision
|
||||||
|
from batdetect2.typing import RawPrediction
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TopClassMetricConfig",
|
||||||
|
"TopClassMetric",
|
||||||
|
"build_top_class_metric",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MatchEval:
|
||||||
|
gt: Optional[data.SoundEventAnnotation]
|
||||||
|
pred: Optional[RawPrediction]
|
||||||
|
|
||||||
|
is_ground_truth: bool
|
||||||
|
is_generic: bool
|
||||||
|
is_prediction: bool
|
||||||
|
pred_class: Optional[str]
|
||||||
|
true_class: Optional[str]
|
||||||
|
score: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ClipEval:
|
||||||
|
clip: data.Clip
|
||||||
|
matches: List[MatchEval]
|
||||||
|
|
||||||
|
|
||||||
|
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||||
|
|
||||||
|
|
||||||
|
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["average_precision"] = "average_precision"
|
||||||
|
label: str = "average_precision"
|
||||||
|
ignore_generic: bool = True
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassAveragePrecision:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
label: str = "average_precision",
|
||||||
|
):
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evals: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
num_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
num_positives += int(m.is_ground_truth)
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
score = average_precision(y_true, y_score, num_positives=num_positives)
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassAveragePrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassAveragePrecisionConfig):
|
||||||
|
return TopClassAveragePrecision(
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["roc_auc"] = "roc_auc"
|
||||||
|
ignore_generic: bool = True
|
||||||
|
ignore_non_predictions: bool = True
|
||||||
|
label: str = "roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassROCAUC:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
ignore_generic: bool = True,
|
||||||
|
ignore_non_predictions: bool = True,
|
||||||
|
label: str = "roc_auc",
|
||||||
|
):
|
||||||
|
self.ignore_generic = ignore_generic
|
||||||
|
self.ignore_non_predictions = ignore_non_predictions
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||||
|
y_true: List[bool] = []
|
||||||
|
y_score: List[float] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evals:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic and self.ignore_generic:
|
||||||
|
# Ignore gt sounds with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_prediction and self.ignore_non_predictions:
|
||||||
|
# Ignore non predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.pred_class == m.true_class)
|
||||||
|
y_score.append(m.score)
|
||||||
|
|
||||||
|
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassROCAUCConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassROCAUCConfig):
|
||||||
|
return TopClassROCAUC(
|
||||||
|
ignore_generic=config.ignore_generic,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassRecallConfig(BaseConfig):
|
||||||
|
name: Literal["recall"] = "recall"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "recall"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassRecall:
|
||||||
|
def __init__(self, threshold: float, label: str = "recall"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_positives = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_ground_truth:
|
||||||
|
num_positives += 1
|
||||||
|
|
||||||
|
if m.score >= self.threshold and m.pred_class == m.true_class:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_positives == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_positives
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassRecallConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassRecallConfig):
|
||||||
|
return TopClassRecall(
|
||||||
|
threshold=config.threshold,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassPrecisionConfig(BaseConfig):
|
||||||
|
name: Literal["precision"] = "precision"
|
||||||
|
threshold: float = 0.5
|
||||||
|
label: str = "precision"
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassPrecision:
|
||||||
|
def __init__(self, threshold: float, label: str = "precision"):
|
||||||
|
self.threshold = threshold
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
num_detections = 0
|
||||||
|
true_positives = 0
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
is_detection = m.score >= self.threshold
|
||||||
|
|
||||||
|
if is_detection:
|
||||||
|
num_detections += 1
|
||||||
|
|
||||||
|
if is_detection and m.pred_class == m.true_class:
|
||||||
|
true_positives += 1
|
||||||
|
|
||||||
|
if num_detections == 0:
|
||||||
|
return {self.label: np.nan}
|
||||||
|
|
||||||
|
score = true_positives / num_detections
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(TopClassPrecisionConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: TopClassPrecisionConfig):
|
||||||
|
return TopClassPrecision(
|
||||||
|
threshold=config.threshold,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BalancedAccuracyConfig(BaseConfig):
|
||||||
|
name: Literal["balanced_accuracy"] = "balanced_accuracy"
|
||||||
|
label: str = "balanced_accuracy"
|
||||||
|
exclude_noise: bool = False
|
||||||
|
noise_class: str = "noise"
|
||||||
|
|
||||||
|
|
||||||
|
class BalancedAccuracy:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
exclude_noise: bool = True,
|
||||||
|
noise_class: str = "noise",
|
||||||
|
label: str = "balanced_accuracy",
|
||||||
|
):
|
||||||
|
self.exclude_noise = exclude_noise
|
||||||
|
self.noise_class = noise_class
|
||||||
|
self.label = label
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_evaluations: Sequence[ClipEval],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true: List[str] = []
|
||||||
|
y_pred: List[str] = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for m in clip_eval.matches:
|
||||||
|
if m.is_generic:
|
||||||
|
# Ignore matches that correspond to a sound event
|
||||||
|
# with unknown class
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not m.is_ground_truth and self.exclude_noise:
|
||||||
|
# Ignore predictions that were not matched to a
|
||||||
|
# ground truth
|
||||||
|
continue
|
||||||
|
|
||||||
|
if m.pred_class is None and self.exclude_noise:
|
||||||
|
# Ignore non-predictions
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(m.true_class or self.noise_class)
|
||||||
|
y_pred.append(m.pred_class or self.noise_class)
|
||||||
|
|
||||||
|
encoder = preprocessing.LabelEncoder()
|
||||||
|
encoder.fit(list(set(y_true) | set(y_pred)))
|
||||||
|
|
||||||
|
y_true = encoder.transform(y_true)
|
||||||
|
y_pred = encoder.transform(y_pred)
|
||||||
|
score = metrics.balanced_accuracy_score(y_true, y_pred)
|
||||||
|
return {self.label: score}
|
||||||
|
|
||||||
|
@top_class_metrics.register(BalancedAccuracyConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: BalancedAccuracyConfig):
|
||||||
|
return BalancedAccuracy(
|
||||||
|
exclude_noise=config.exclude_noise,
|
||||||
|
noise_class=config.noise_class,
|
||||||
|
label=config.label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
TopClassMetricConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
TopClassAveragePrecisionConfig,
|
||||||
|
TopClassROCAUCConfig,
|
||||||
|
TopClassRecallConfig,
|
||||||
|
TopClassPrecisionConfig,
|
||||||
|
BalancedAccuracyConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_top_class_metric(config: TopClassMetricConfig):
|
||||||
|
return top_class_metrics.build(config)
|
||||||
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.evaluate.tasks.base import tasks_registry
|
||||||
|
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.clip_classification import (
|
||||||
|
ClipClassificationTaskConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||||
|
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TaskConfig",
|
||||||
|
"build_task",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
TaskConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
ClassificationTaskConfig,
|
||||||
|
DetectionTaskConfig,
|
||||||
|
ClipDetectionTaskConfig,
|
||||||
|
ClipClassificationTaskConfig,
|
||||||
|
TopClassDetectionTaskConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_task(
|
||||||
|
config: TaskConfig,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
) -> EvaluatorProtocol:
|
||||||
|
targets = targets or build_targets()
|
||||||
|
return tasks_registry.build(config, targets)
|
||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Callable, Dict, Generic, List, Sequence, TypeVar
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
@ -14,14 +16,19 @@ from batdetect2.typing.postprocess import RawPrediction
|
|||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseEvaluatorConfig",
|
"BaseTaskConfig",
|
||||||
"BaseEvaluator",
|
"BaseTask",
|
||||||
]
|
]
|
||||||
|
|
||||||
evaluators: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry("metric")
|
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||||
|
"tasks"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseEvaluatorConfig(BaseConfig):
|
T_Output = TypeVar("T_Output")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTaskConfig(BaseConfig):
|
||||||
prefix: str
|
prefix: str
|
||||||
ignore_start_end: float = 0.01
|
ignore_start_end: float = 0.01
|
||||||
matching_strategy: MatchConfig = Field(
|
matching_strategy: MatchConfig = Field(
|
||||||
@ -29,11 +36,13 @@ class BaseEvaluatorConfig(BaseConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseEvaluator(EvaluatorProtocol):
|
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
|
|
||||||
matcher: MatcherProtocol
|
matcher: MatcherProtocol
|
||||||
|
|
||||||
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||||
|
|
||||||
ignore_start_end: float
|
ignore_start_end: float
|
||||||
|
|
||||||
prefix: str
|
prefix: str
|
||||||
@ -42,15 +51,44 @@ class BaseEvaluator(EvaluatorProtocol):
|
|||||||
self,
|
self,
|
||||||
matcher: MatcherProtocol,
|
matcher: MatcherProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
ignore_start_end: float = 0.01,
|
ignore_start_end: float = 0.01,
|
||||||
):
|
):
|
||||||
self.matcher = matcher
|
self.matcher = matcher
|
||||||
|
self.metrics = metrics
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.ignore_start_end = ignore_start_end
|
self.ignore_start_end = ignore_start_end
|
||||||
|
|
||||||
def filter_sound_event_annotations(
|
def compute_metrics(
|
||||||
|
self,
|
||||||
|
eval_outputs: List[T_Output],
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
scores = [metric(eval_outputs) for metric in self.metrics]
|
||||||
|
return {
|
||||||
|
f"{self.prefix}/{name}": score
|
||||||
|
for metric_output in scores
|
||||||
|
for name, score in metric_output.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
|
) -> List[T_Output]:
|
||||||
|
return [
|
||||||
|
self.evaluate_clip(clip_annotation, preds)
|
||||||
|
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||||
|
]
|
||||||
|
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> T_Output: ...
|
||||||
|
|
||||||
|
def include_sound_event_annotation(
|
||||||
self,
|
self,
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
@ -68,7 +106,7 @@ class BaseEvaluator(EvaluatorProtocol):
|
|||||||
self.ignore_start_end,
|
self.ignore_start_end,
|
||||||
)
|
)
|
||||||
|
|
||||||
def filter_predictions(
|
def include_prediction(
|
||||||
self,
|
self,
|
||||||
prediction: RawPrediction,
|
prediction: RawPrediction,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
@ -82,14 +120,16 @@ class BaseEvaluator(EvaluatorProtocol):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def build(
|
def build(
|
||||||
cls,
|
cls,
|
||||||
config: BaseEvaluatorConfig,
|
config: BaseTaskConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
matcher = build_matcher(config.matching_strategy)
|
matcher = build_matcher(config.matching_strategy)
|
||||||
return cls(
|
return cls(
|
||||||
matcher=matcher,
|
matcher=matcher,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
|
metrics=metrics,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
ignore_start_end=config.ignore_start_end,
|
ignore_start_end=config.ignore_start_end,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
137
src/batdetect2/evaluate/tasks/classification.py
Normal file
137
src/batdetect2/evaluate/tasks/classification.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
from typing import (
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Sequence,
|
||||||
|
)
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
|
ClassificationAveragePrecisionConfig,
|
||||||
|
ClassificationMetricConfig,
|
||||||
|
ClipEval,
|
||||||
|
MatchEval,
|
||||||
|
build_classification_metrics,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||||
|
prefix: str = "classification"
|
||||||
|
metrics: List[ClassificationMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||||
|
)
|
||||||
|
include_generics: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationTask(BaseTask[ClipEval]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*args,
|
||||||
|
include_generics: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.include_generics = include_generics
|
||||||
|
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
preds = [
|
||||||
|
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||||
|
]
|
||||||
|
|
||||||
|
all_gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
]
|
||||||
|
|
||||||
|
per_class_matches = {}
|
||||||
|
|
||||||
|
for class_name in self.targets.class_names:
|
||||||
|
class_idx = self.targets.class_names.index(class_name)
|
||||||
|
|
||||||
|
# Only match to targets of the given class
|
||||||
|
gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in all_gts
|
||||||
|
if self.is_class(sound_event, class_name)
|
||||||
|
]
|
||||||
|
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
|
predictions=[pred.geometry for pred in preds],
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
|
true_class = (
|
||||||
|
self.targets.encode_class(gt) if gt is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
score = (
|
||||||
|
float(pred.class_scores[class_idx])
|
||||||
|
if pred is not None
|
||||||
|
else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEval(
|
||||||
|
gt=gt,
|
||||||
|
pred=pred,
|
||||||
|
is_prediction=pred is not None,
|
||||||
|
is_ground_truth=gt is not None,
|
||||||
|
is_generic=gt is not None and true_class is None,
|
||||||
|
true_class=true_class,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
per_class_matches[class_name] = matches
|
||||||
|
|
||||||
|
return ClipEval(clip=clip, matches=per_class_matches)
|
||||||
|
|
||||||
|
def is_class(
|
||||||
|
self,
|
||||||
|
sound_event: data.SoundEventAnnotation,
|
||||||
|
class_name: str,
|
||||||
|
) -> bool:
|
||||||
|
sound_event_class = self.targets.encode_class(sound_event)
|
||||||
|
|
||||||
|
if sound_event_class is None and self.include_generics:
|
||||||
|
# Sound events that are generic could be of the given
|
||||||
|
# class
|
||||||
|
return True
|
||||||
|
|
||||||
|
return sound_event_class == class_name
|
||||||
|
|
||||||
|
@tasks_registry.register(ClassificationTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClassificationTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [
|
||||||
|
build_classification_metrics(metric) for metric in config.metrics
|
||||||
|
]
|
||||||
|
return ClassificationTask.build(
|
||||||
|
config=config,
|
||||||
|
targets=targets,
|
||||||
|
metrics=metrics,
|
||||||
|
)
|
||||||
75
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
75
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.clip_classification import (
|
||||||
|
ClipClassificationAveragePrecisionConfig,
|
||||||
|
ClipClassificationMetricConfig,
|
||||||
|
ClipEval,
|
||||||
|
build_clip_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["clip_classification"] = "clip_classification"
|
||||||
|
prefix: str = "clip_classification"
|
||||||
|
metrics: List[ClipClassificationMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
ClipClassificationAveragePrecisionConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gt_classes = set()
|
||||||
|
for sound_event in clip_annotation.sound_events:
|
||||||
|
if not self.include_sound_event_annotation(sound_event, clip):
|
||||||
|
continue
|
||||||
|
|
||||||
|
class_name = self.targets.encode_class(sound_event)
|
||||||
|
|
||||||
|
if class_name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
gt_classes.add(class_name)
|
||||||
|
|
||||||
|
pred_scores = defaultdict(float)
|
||||||
|
for pred in predictions:
|
||||||
|
if not self.include_prediction(pred, clip):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for class_idx, class_name in enumerate(self.targets.class_names):
|
||||||
|
pred_scores[class_name] = max(
|
||||||
|
float(pred.class_scores[class_idx]),
|
||||||
|
pred_scores[class_name],
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
|
||||||
|
|
||||||
|
@tasks_registry.register(ClipClassificationTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClipClassificationTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||||
|
return ClipClassificationTask.build(
|
||||||
|
config=config,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
66
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
66
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.clip_detection import (
|
||||||
|
ClipDetectionAveragePrecisionConfig,
|
||||||
|
ClipDetectionMetricConfig,
|
||||||
|
ClipEval,
|
||||||
|
build_clip_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["clip_detection"] = "clip_detection"
|
||||||
|
prefix: str = "clip_detection"
|
||||||
|
metrics: List[ClipDetectionMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [
|
||||||
|
ClipDetectionAveragePrecisionConfig(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gt_det = any(
|
||||||
|
self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_score = 0
|
||||||
|
for pred in predictions:
|
||||||
|
if not self.include_prediction(pred, clip):
|
||||||
|
continue
|
||||||
|
|
||||||
|
pred_score = max(pred_score, pred.detection_score)
|
||||||
|
|
||||||
|
return ClipEval(
|
||||||
|
gt_det=gt_det,
|
||||||
|
score=pred_score,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tasks_registry.register(ClipDetectionTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: ClipDetectionTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||||
|
return ClipDetectionTask.build(
|
||||||
|
config=config,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
79
src/batdetect2/evaluate/tasks/detection.py
Normal file
79
src/batdetect2/evaluate/tasks/detection.py
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.detection import (
|
||||||
|
ClipEval,
|
||||||
|
DetectionAveragePrecisionConfig,
|
||||||
|
DetectionMetricConfig,
|
||||||
|
MatchEval,
|
||||||
|
build_detection_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||||
|
prefix: str = "detection"
|
||||||
|
metrics: List[DetectionMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
]
|
||||||
|
preds = [
|
||||||
|
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||||
|
]
|
||||||
|
scores = [pred.detection_score for pred in preds]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
|
predictions=[pred.geometry for pred in preds],
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEval(
|
||||||
|
gt=gt,
|
||||||
|
pred=pred,
|
||||||
|
is_prediction=pred is not None,
|
||||||
|
is_ground_truth=gt is not None,
|
||||||
|
score=pred.detection_score if pred is not None else 0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClipEval(clip=clip, matches=matches)
|
||||||
|
|
||||||
|
@tasks_registry.register(DetectionTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: DetectionTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
||||||
|
return DetectionTask.build(
|
||||||
|
config=config,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
101
src/batdetect2/evaluate/tasks/top_class.py
Normal file
101
src/batdetect2/evaluate/tasks/top_class.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
from typing import List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
|
ClipEval,
|
||||||
|
MatchEval,
|
||||||
|
TopClassAveragePrecisionConfig,
|
||||||
|
TopClassMetricConfig,
|
||||||
|
build_top_class_metric,
|
||||||
|
)
|
||||||
|
from batdetect2.evaluate.tasks.base import (
|
||||||
|
BaseTask,
|
||||||
|
BaseTaskConfig,
|
||||||
|
tasks_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||||
|
name: Literal["top_class_detection"] = "top_class_detection"
|
||||||
|
prefix: str = "top_class"
|
||||||
|
metrics: List[TopClassMetricConfig] = Field(
|
||||||
|
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||||
|
def evaluate_clip(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
predictions: Sequence[RawPrediction],
|
||||||
|
) -> ClipEval:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
gts = [
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
|
]
|
||||||
|
preds = [
|
||||||
|
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||||
|
]
|
||||||
|
# Take the highest score for each prediction
|
||||||
|
scores = [pred.class_scores.max() for pred in preds]
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for pred_idx, gt_idx, _ in self.matcher(
|
||||||
|
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||||
|
predictions=[pred.geometry for pred in preds],
|
||||||
|
scores=scores,
|
||||||
|
):
|
||||||
|
gt = gts[gt_idx] if gt_idx is not None else None
|
||||||
|
pred = preds[pred_idx] if pred_idx is not None else None
|
||||||
|
|
||||||
|
true_class = (
|
||||||
|
self.targets.encode_class(gt) if gt is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
class_idx = (
|
||||||
|
pred.class_scores.argmax() if pred is not None else None
|
||||||
|
)
|
||||||
|
|
||||||
|
score = (
|
||||||
|
float(pred.class_scores[class_idx]) if pred is not None else 0
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_class = (
|
||||||
|
self.targets.class_names[class_idx]
|
||||||
|
if class_idx is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
matches.append(
|
||||||
|
MatchEval(
|
||||||
|
gt=gt,
|
||||||
|
pred=pred,
|
||||||
|
is_ground_truth=gt is not None,
|
||||||
|
is_prediction=pred is not None,
|
||||||
|
true_class=true_class,
|
||||||
|
is_generic=gt is not None and true_class is None,
|
||||||
|
pred_class=pred_class,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ClipEval(clip=clip, matches=matches)
|
||||||
|
|
||||||
|
@tasks_registry.register(TopClassDetectionTaskConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(
|
||||||
|
config: TopClassDetectionTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
||||||
|
return TopClassDetectionTask.build(
|
||||||
|
config=config,
|
||||||
|
metrics=metrics,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
@ -8,7 +8,7 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
@ -106,7 +106,7 @@ def train(
|
|||||||
config,
|
config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
evaluator=build_evaluator(
|
evaluator=build_evaluator(
|
||||||
config.train.validation.evaluator,
|
config.train.validation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
),
|
),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user