mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Task/Metrics restructure
This commit is contained in:
parent
d6ddc4514c
commit
df2abff654
@ -123,10 +123,7 @@ class BatDetect2API:
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation.evaluator,
|
||||
targets=targets,
|
||||
)
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
# NOTE: Better to have a separate instance of
|
||||
# preprocessor and postprocessor as these may be moved
|
||||
@ -178,10 +175,7 @@ class BatDetect2API:
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation.evaluator,
|
||||
targets=targets,
|
||||
)
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
|
||||
@ -1,11 +1,14 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.evaluate.evaluator import MultipleEvaluator, build_evaluator
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
"evaluate",
|
||||
"MultipleEvaluator",
|
||||
"Evaluator",
|
||||
"TaskConfig",
|
||||
"build_evaluator",
|
||||
"build_task",
|
||||
"evaluate",
|
||||
"load_evaluation_config",
|
||||
]
|
||||
|
||||
@ -1,13 +1,14 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.evaluator import (
|
||||
EvaluatorConfig,
|
||||
MultipleEvaluatorConfig,
|
||||
from batdetect2.evaluate.tasks import (
|
||||
TaskConfig,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||
|
||||
__all__ = [
|
||||
@ -17,7 +18,12 @@ __all__ = [
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
evaluator: EvaluatorConfig = Field(default_factory=MultipleEvaluatorConfig)
|
||||
tasks: List[TaskConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
DetectionTaskConfig(),
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
|
||||
@ -55,10 +55,7 @@ def evaluate(
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation.evaluator,
|
||||
targets=targets,
|
||||
)
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
logger = build_logger(
|
||||
config.evaluation.logger,
|
||||
|
||||
67
src/batdetect2/evaluate/evaluator.py
Normal file
67
src/batdetect2/evaluate/evaluator.py
Normal file
@ -0,0 +1,67 @@
|
||||
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
tasks: Sequence[EvaluatorProtocol],
|
||||
):
|
||||
self.targets = targets
|
||||
self.tasks = tasks
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[Any]:
|
||||
return [
|
||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||
]
|
||||
|
||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
results.update(task.compute_metrics(outputs))
|
||||
|
||||
return results
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: List[Any],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for task, outputs in zip(self.tasks, eval_outputs):
|
||||
for name, fig in task.generate_plots(outputs):
|
||||
yield name, fig
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[Union[EvaluationConfig, dict]] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
if config is None:
|
||||
config = EvaluationConfig()
|
||||
|
||||
if not isinstance(config, EvaluationConfig):
|
||||
config = EvaluationConfig.model_validate(config)
|
||||
|
||||
return Evaluator(
|
||||
targets=targets,
|
||||
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
||||
)
|
||||
@ -1,114 +0,0 @@
|
||||
from typing import (
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.evaluate.evaluator.base import evaluators
|
||||
from batdetect2.evaluate.evaluator.clip import ClipMetricsConfig
|
||||
from batdetect2.evaluate.evaluator.per_class import ClassificationMetricsConfig
|
||||
from batdetect2.evaluate.evaluator.single import GlobalEvaluatorConfig
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import (
|
||||
EvaluatorProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorConfig",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
|
||||
EvaluatorConfig = Annotated[
|
||||
Union[
|
||||
ClassificationMetricsConfig,
|
||||
GlobalEvaluatorConfig,
|
||||
ClipMetricsConfig,
|
||||
"MultipleEvaluatorConfig",
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
class MultipleEvaluatorConfig(BaseConfig):
|
||||
name: Literal["multiple_evaluations"] = "multiple_evaluations"
|
||||
evaluations: List[EvaluatorConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClassificationMetricsConfig(),
|
||||
GlobalEvaluatorConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class MultipleEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
evaluators: Sequence[EvaluatorProtocol],
|
||||
):
|
||||
self.targets = targets
|
||||
self.evaluators = evaluators
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[Any]:
|
||||
return [
|
||||
evaluator.evaluate(
|
||||
clip_annotations,
|
||||
predictions,
|
||||
)
|
||||
for evaluator in self.evaluators
|
||||
]
|
||||
|
||||
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for evaluator, outputs in zip(self.evaluators, eval_outputs):
|
||||
results.update(evaluator.compute_metrics(outputs))
|
||||
|
||||
return results
|
||||
|
||||
def generate_plots(
|
||||
self,
|
||||
eval_outputs: List[Any],
|
||||
) -> Iterable[Tuple[str, Figure]]:
|
||||
for evaluator, outputs in zip(self.evaluators, eval_outputs):
|
||||
for name, fig in evaluator.generate_plots(outputs):
|
||||
yield name, fig
|
||||
|
||||
@evaluators.register(MultipleEvaluatorConfig)
|
||||
@staticmethod
|
||||
def from_config(config: MultipleEvaluatorConfig, targets: TargetProtocol):
|
||||
return MultipleEvaluator(
|
||||
evaluators=[
|
||||
build_evaluator(conf, targets=targets)
|
||||
for conf in config.evaluations
|
||||
],
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
config: Optional[EvaluatorConfig] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
|
||||
config = config or MultipleEvaluatorConfig()
|
||||
return evaluators.build(config, targets)
|
||||
@ -1,163 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Literal, Sequence, Set
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.evaluator.base import (
|
||||
BaseEvaluator,
|
||||
BaseEvaluatorConfig,
|
||||
evaluators,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipInfo:
|
||||
gt_det: bool
|
||||
gt_classes: Set[str]
|
||||
pred_score: float
|
||||
pred_class_scores: Dict[str, float]
|
||||
|
||||
|
||||
ClipMetric = Callable[[Sequence[ClipInfo]], float]
|
||||
|
||||
|
||||
def clip_detection_average_precision(
|
||||
clip_evaluations: Sequence[ClipInfo],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.pred_score)
|
||||
|
||||
return average_precision(y_true=y_true, y_score=y_score)
|
||||
|
||||
|
||||
def clip_detection_roc_auc(
|
||||
clip_evaluations: Sequence[ClipInfo],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.pred_score)
|
||||
|
||||
return float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
||||
|
||||
|
||||
clip_metrics = {
|
||||
"average_precision": clip_detection_average_precision,
|
||||
"roc_auc": clip_detection_roc_auc,
|
||||
}
|
||||
|
||||
|
||||
class ClipMetricsConfig(BaseEvaluatorConfig):
|
||||
name: Literal["clip"] = "clip"
|
||||
prefix: str = "clip"
|
||||
metrics: List[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"average_precision",
|
||||
"roc_auc",
|
||||
]
|
||||
)
|
||||
|
||||
@field_validator("metrics", mode="after")
|
||||
@classmethod
|
||||
def validate_metrics(cls, v: List[str]) -> List[str]:
|
||||
for metric_name in v:
|
||||
if metric_name not in clip_metrics:
|
||||
raise ValueError(f"Unknown metric {metric_name}")
|
||||
return v
|
||||
|
||||
|
||||
class ClipEvaluator(BaseEvaluator):
|
||||
def __init__(self, *args, metrics: Dict[str, ClipMetric], **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.metrics = metrics
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipInfo]:
|
||||
return [
|
||||
self.match_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: List[ClipInfo],
|
||||
) -> Dict[str, float]:
|
||||
scores = {
|
||||
name: metric(eval_outputs) for name, metric in self.metrics.items()
|
||||
}
|
||||
return {
|
||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
||||
}
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipInfo:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gt_det = False
|
||||
gt_classes = set()
|
||||
for sound_event in clip_annotation.sound_events:
|
||||
if self.filter_sound_event_annotations(sound_event, clip):
|
||||
continue
|
||||
|
||||
gt_det = True
|
||||
class_name = self.targets.encode_class(sound_event)
|
||||
|
||||
if class_name is None:
|
||||
continue
|
||||
|
||||
gt_classes.add(class_name)
|
||||
|
||||
pred_score = 0
|
||||
pred_class_scores: defaultdict[str, float] = defaultdict(lambda: 0)
|
||||
for pred in predictions:
|
||||
if self.filter_predictions(pred, clip):
|
||||
continue
|
||||
|
||||
pred_score = max(pred_score, pred.detection_score)
|
||||
|
||||
for class_name, class_score in zip(
|
||||
self.targets.class_names,
|
||||
pred.class_scores,
|
||||
):
|
||||
pred_class_scores[class_name] = max(
|
||||
pred_class_scores[class_name],
|
||||
class_score,
|
||||
)
|
||||
|
||||
return ClipInfo(
|
||||
gt_det=gt_det,
|
||||
gt_classes=gt_classes,
|
||||
pred_score=pred_score,
|
||||
pred_class_scores=pred_class_scores,
|
||||
)
|
||||
|
||||
@evaluators.register(ClipMetricsConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClipMetricsConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {name: clip_metrics.get(name) for name in config.metrics}
|
||||
return ClipEvaluator.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
@ -1,219 +0,0 @@
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.evaluator.base import (
|
||||
BaseEvaluator,
|
||||
BaseEvaluatorConfig,
|
||||
evaluators,
|
||||
)
|
||||
from batdetect2.evaluate.match import match
|
||||
from batdetect2.evaluate.metrics.per_class_matches import (
|
||||
ClassificationAveragePrecisionConfig,
|
||||
PerClassMatchMetric,
|
||||
PerClassMatchMetricConfig,
|
||||
build_per_class_matches_metric,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipMatches,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
ScoreFn = Callable[[RawPrediction, int], float]
|
||||
|
||||
|
||||
def score_by_class_score(pred: RawPrediction, class_index: int) -> float:
|
||||
return float(pred.class_scores[class_index])
|
||||
|
||||
|
||||
def score_by_adjusted_class_score(
|
||||
pred: RawPrediction,
|
||||
class_index: int,
|
||||
) -> float:
|
||||
return float(pred.class_scores[class_index]) * pred.detection_score
|
||||
|
||||
|
||||
ScoreFunctionOption = Literal["class_score", "adjusted_class_score"]
|
||||
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
|
||||
"class_score": score_by_class_score,
|
||||
"adjusted_class_score": score_by_adjusted_class_score,
|
||||
}
|
||||
|
||||
|
||||
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
|
||||
return score_functions[name]
|
||||
|
||||
|
||||
class ClassificationMetricsConfig(BaseEvaluatorConfig):
|
||||
name: Literal["classification"] = "classification"
|
||||
prefix: str = "classification"
|
||||
include_generics: bool = True
|
||||
score_by: ScoreFunctionOption = "class_score"
|
||||
metrics: List[PerClassMatchMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
)
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class PerClassEvaluator(BaseEvaluator):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
metrics: Dict[str, PerClassMatchMetric],
|
||||
score_fn: ScoreFn,
|
||||
include_generics: bool = True,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.score_fn = score_fn
|
||||
self.metrics = metrics
|
||||
|
||||
self.include_generics = include_generics
|
||||
|
||||
self.include = include
|
||||
self.exclude = exclude
|
||||
|
||||
self.selected = self.targets.class_names
|
||||
if include is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name in include
|
||||
]
|
||||
|
||||
if exclude is not None:
|
||||
self.selected = [
|
||||
class_name
|
||||
for class_name in self.selected
|
||||
if class_name not in exclude
|
||||
]
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> Dict[str, List[ClipMatches]]:
|
||||
ret = defaultdict(list)
|
||||
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions):
|
||||
matches = self.match_clip(clip_annotation, preds)
|
||||
for class_name, clip_eval in matches.items():
|
||||
ret[class_name].append(clip_eval)
|
||||
|
||||
return ret
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: Dict[str, List[ClipMatches]],
|
||||
) -> Dict[str, float]:
|
||||
results = {}
|
||||
|
||||
for metric_name, metric in self.metrics.items():
|
||||
class_scores = {
|
||||
class_name: metric(eval_outputs[class_name], class_name)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
mean = float(
|
||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||
)
|
||||
|
||||
results[f"{self.prefix}/mean_{metric_name}"] = mean
|
||||
|
||||
for class_name, value in class_scores.items():
|
||||
if class_name not in self.selected:
|
||||
continue
|
||||
|
||||
results[f"{self.prefix}/{metric_name}/{class_name}"] = value
|
||||
|
||||
return results
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> Dict[str, ClipMatches]:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
preds = [
|
||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
||||
]
|
||||
|
||||
all_gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
|
||||
ret = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
class_idx = self.targets.class_names.index(class_name)
|
||||
|
||||
# Only match to targets of the given class
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in all_gts
|
||||
if self.is_class(sound_event, class_name)
|
||||
]
|
||||
scores = [self.score_fn(pred, class_idx) for pred in preds]
|
||||
|
||||
ret[class_name] = match(
|
||||
gts,
|
||||
preds,
|
||||
clip=clip,
|
||||
scores=scores,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
)
|
||||
|
||||
return ret
|
||||
|
||||
def is_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
) -> bool:
|
||||
sound_event_class = self.targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and self.include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@evaluators.register(ClassificationMetricsConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClassificationMetricsConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {
|
||||
metric.name: build_per_class_matches_metric(metric)
|
||||
for metric in config.metrics
|
||||
}
|
||||
return PerClassEvaluator.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
score_fn=get_score_fn(config.score_by),
|
||||
include_generics=config.include_generics,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
@ -1,126 +0,0 @@
|
||||
from typing import Callable, Dict, List, Literal, Mapping, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.evaluator.base import (
|
||||
BaseEvaluator,
|
||||
BaseEvaluatorConfig,
|
||||
evaluators,
|
||||
)
|
||||
from batdetect2.evaluate.match import match
|
||||
from batdetect2.evaluate.metrics.matches import (
|
||||
DetectionAveragePrecisionConfig,
|
||||
MatchesMetric,
|
||||
MatchMetricConfig,
|
||||
build_match_metric,
|
||||
)
|
||||
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
|
||||
|
||||
ScoreFn = Callable[[RawPrediction], float]
|
||||
|
||||
|
||||
def score_by_detection_score(pred: RawPrediction) -> float:
|
||||
return pred.detection_score
|
||||
|
||||
|
||||
def score_by_top_class_score(pred: RawPrediction) -> float:
|
||||
return pred.class_scores.max()
|
||||
|
||||
|
||||
ScoreFunctionOption = Literal["detection_score", "top_class_score"]
|
||||
score_functions: Mapping[ScoreFunctionOption, ScoreFn] = {
|
||||
"detection_score": score_by_detection_score,
|
||||
"top_class_score": score_by_top_class_score,
|
||||
}
|
||||
|
||||
|
||||
def get_score_fn(name: ScoreFunctionOption) -> ScoreFn:
|
||||
return score_functions[name]
|
||||
|
||||
|
||||
class GlobalEvaluatorConfig(BaseEvaluatorConfig):
|
||||
name: Literal["detection"] = "detection"
|
||||
prefix: str = "detection"
|
||||
score_by: ScoreFunctionOption = "detection_score"
|
||||
metrics: List[MatchMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
)
|
||||
|
||||
|
||||
class GlobalEvaluator(BaseEvaluator):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
score_fn: ScoreFn,
|
||||
metrics: Dict[str, MatchesMetric],
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.metrics = metrics
|
||||
self.score_fn = score_fn
|
||||
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: List[ClipMatches],
|
||||
) -> Dict[str, float]:
|
||||
scores = {
|
||||
name: metric(eval_outputs) for name, metric in self.metrics.items()
|
||||
}
|
||||
return {
|
||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
||||
}
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipMatches]:
|
||||
return [
|
||||
self.match_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipMatches:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
||||
]
|
||||
scores = [self.score_fn(pred) for pred in preds]
|
||||
|
||||
return match(
|
||||
gts,
|
||||
preds,
|
||||
scores=scores,
|
||||
clip=clip,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
)
|
||||
|
||||
@evaluators.register(GlobalEvaluatorConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: GlobalEvaluatorConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {
|
||||
metric.name: build_match_metric(metric)
|
||||
for metric in config.metrics
|
||||
}
|
||||
score_fn = get_score_fn(config.score_by)
|
||||
return GlobalEvaluator.build(
|
||||
config=config,
|
||||
score_fn=score_fn,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
@ -1,133 +0,0 @@
|
||||
from typing import Dict, List, Literal, Sequence
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.match import match
|
||||
from batdetect2.evaluate.metrics.base import (
|
||||
BaseMetric,
|
||||
BaseMetricConfig,
|
||||
metrics_registry,
|
||||
)
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.evaluate.metrics.detection import DetectionMetric
|
||||
from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TopClassEvaluator",
|
||||
"TopClassEvaluatorConfig",
|
||||
]
|
||||
|
||||
|
||||
def top_class_average_precision(
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_generic = m.gt_det and (m.gt_class is None)
|
||||
|
||||
# Ignore ground truth sounds with unknown class
|
||||
if is_generic:
|
||||
continue
|
||||
|
||||
num_positives += int(m.gt_det)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if m.pred_geometry is None:
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det & (m.top_class == m.gt_class))
|
||||
y_score.append(m.top_class_score)
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
|
||||
top_class_metrics = {
|
||||
"average_precision": top_class_average_precision,
|
||||
}
|
||||
|
||||
|
||||
class TopClassEvaluatorConfig(BaseMetricConfig):
|
||||
name: Literal["top_class"] = "top_class"
|
||||
prefix: str = "top_class"
|
||||
metrics: List[str] = Field(default_factory=lambda: ["average_precision"])
|
||||
|
||||
@field_validator("metrics", mode="after")
|
||||
@classmethod
|
||||
def validate_metrics(cls, v: List[str]) -> List[str]:
|
||||
for metric_name in v:
|
||||
if metric_name not in top_class_metrics:
|
||||
raise ValueError(f"Unknown metric {metric_name}")
|
||||
return v
|
||||
|
||||
|
||||
class TopClassEvaluator(BaseMetric):
|
||||
def __init__(self, *args, metrics: Dict[str, DetectionMetric], **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.metrics = metrics
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> Dict[str, float]:
|
||||
clip_evaluations = [
|
||||
self.match_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
scores = {
|
||||
name: metric(clip_evaluations)
|
||||
for name, metric in self.metrics.items()
|
||||
}
|
||||
return {
|
||||
f"{self.prefix}/{name}": score for name, score in scores.items()
|
||||
}
|
||||
|
||||
def match_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipMatches:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.filter_sound_event_annotations(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.filter_predictions(pred, clip)
|
||||
]
|
||||
# Use score of top class for matching
|
||||
scores = [pred.class_scores.max() for pred in preds]
|
||||
|
||||
return match(
|
||||
gts,
|
||||
preds,
|
||||
scores=scores,
|
||||
clip=clip,
|
||||
targets=self.targets,
|
||||
matcher=self.matcher,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: TopClassEvaluatorConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = {
|
||||
name: top_class_metrics.get(name) for name in config.metrics
|
||||
}
|
||||
return super().build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(TopClassEvaluatorConfig, TopClassEvaluator)
|
||||
253
src/batdetect2/evaluate/metrics/classification.py
Normal file
253
src/batdetect2/evaluate/metrics/classification.py
Normal file
@ -0,0 +1,253 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
|
||||
__all__ = []
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
true_class: Optional[str]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
clip: data.Clip
|
||||
matches: Mapping[str, List[MatchEval]]
|
||||
|
||||
|
||||
ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
|
||||
classification_metrics: Registry[ClassificationMetric, []] = Registry(
|
||||
"classification_metric"
|
||||
)
|
||||
|
||||
|
||||
class BaseClassificationConfig(BaseConfig):
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class BaseClassificationMetric:
|
||||
def __init__(
|
||||
self,
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.include = include
|
||||
self.exclude = exclude
|
||||
|
||||
def include_class(self, class_name: str) -> bool:
|
||||
if self.include is not None:
|
||||
return class_name in self.include
|
||||
|
||||
if self.exclude is not None:
|
||||
return class_name not in self.exclude
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class ClassificationAveragePrecisionConfig(BaseClassificationConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
label: str = "average_precision"
|
||||
|
||||
|
||||
class ClassificationAveragePrecision(BaseClassificationMetric):
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "average_precision",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
super().__init__(include=include, exclude=exclude)
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEval]
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
num_positives = defaultdict(lambda: 0)
|
||||
|
||||
class_names = set()
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, matches in clip_eval.matches.items():
|
||||
class_names.add(class_name)
|
||||
|
||||
for m in matches:
|
||||
# Exclude matches with ground truth sounds where the class
|
||||
# is unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
is_class = m.true_class == class_name
|
||||
|
||||
if is_class:
|
||||
num_positives[class_name] += 1
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true[class_name].append(is_class)
|
||||
y_score[class_name].append(m.score)
|
||||
|
||||
class_scores = {
|
||||
class_name: average_precision(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
num_positives=num_positives[class_name],
|
||||
)
|
||||
for class_name in class_names
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||
)
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": mean_score,
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if self.include_class(class_name)
|
||||
},
|
||||
}
|
||||
|
||||
@classification_metrics.register(ClassificationAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationAveragePrecisionConfig):
|
||||
return ClassificationAveragePrecision(
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
include=config.include,
|
||||
exclude=config.exclude,
|
||||
)
|
||||
|
||||
|
||||
class ClassificationROCAUCConfig(BaseClassificationConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ClassificationROCAUC(BaseClassificationMetric):
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
label: str = "roc_auc",
|
||||
include: Optional[List[str]] = None,
|
||||
exclude: Optional[List[str]] = None,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
self.label = label
|
||||
self.include = include
|
||||
self.exclude = exclude
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEval]
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
|
||||
class_names = set()
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, matches in clip_eval.matches.items():
|
||||
class_names.add(class_name)
|
||||
|
||||
for m in matches:
|
||||
# Exclude matches with ground truth sounds where the class
|
||||
# is unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true[class_name].append(m.true_class == class_name)
|
||||
y_score[class_name].append(m.score)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
metrics.roc_auc_score(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in class_names
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||
)
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": mean_score,
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if self.include_class(class_name)
|
||||
},
|
||||
}
|
||||
|
||||
@classification_metrics.register(ClassificationROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationROCAUCConfig):
|
||||
return ClassificationROCAUC(
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
ClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_classification_metrics(config: ClassificationMetricConfig):
|
||||
return classification_metrics.build(config)
|
||||
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
135
src/batdetect2/evaluate/metrics/clip_classification.py
Normal file
@ -0,0 +1,135 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
true_classes: Set[str]
|
||||
class_scores: Dict[str, float]
|
||||
|
||||
|
||||
ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry(
|
||||
"clip_classification_metric"
|
||||
)
|
||||
|
||||
|
||||
class ClipClassificationAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
|
||||
class ClipClassificationAveragePrecision:
|
||||
def __init__(self, label: str = "average_precision"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, score in clip_eval.class_scores.items():
|
||||
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||
y_score[class_name].append(score)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
average_precision(
|
||||
y_true=y_true[class_name],
|
||||
y_score=y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in y_true
|
||||
}
|
||||
|
||||
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": float(mean),
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if not np.isnan(score)
|
||||
},
|
||||
}
|
||||
|
||||
@clip_classification_metrics.register(
|
||||
ClipClassificationAveragePrecisionConfig
|
||||
)
|
||||
@staticmethod
|
||||
def from_config(config: ClipClassificationAveragePrecisionConfig):
|
||||
return ClipClassificationAveragePrecision(label=config.label)
|
||||
|
||||
|
||||
class ClipClassificationROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
|
||||
|
||||
class ClipClassificationROCAUC:
|
||||
def __init__(self, label: str = "roc_auc"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = defaultdict(list)
|
||||
y_score = defaultdict(list)
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for class_name, score in clip_eval.class_scores.items():
|
||||
y_true[class_name].append(class_name in clip_eval.true_classes)
|
||||
y_score[class_name].append(score)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
metrics.roc_auc_score(
|
||||
y_true=y_true[class_name],
|
||||
y_score=y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in y_true
|
||||
}
|
||||
|
||||
mean = np.mean([v for v in class_scores.values() if not np.isnan(v)])
|
||||
|
||||
return {
|
||||
f"mean_{self.label}": float(mean),
|
||||
**{
|
||||
f"{self.label}/{class_name}": score
|
||||
for class_name, score in class_scores.items()
|
||||
if not np.isnan(score)
|
||||
},
|
||||
}
|
||||
|
||||
@clip_classification_metrics.register(ClipClassificationROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipClassificationROCAUCConfig):
|
||||
return ClipClassificationROCAUC(label=config.label)
|
||||
|
||||
|
||||
ClipClassificationMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipClassificationAveragePrecisionConfig,
|
||||
ClipClassificationROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_metric(config: ClipClassificationMetricConfig):
|
||||
return clip_classification_metrics.build(config)
|
||||
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
173
src/batdetect2/evaluate/metrics/clip_detection.py
Normal file
@ -0,0 +1,173 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Callable, Dict, Literal, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
gt_det: bool
|
||||
score: float
|
||||
|
||||
|
||||
ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry(
|
||||
"clip_detection_metric"
|
||||
)
|
||||
|
||||
|
||||
class ClipDetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
|
||||
|
||||
class ClipDetectionAveragePrecision:
|
||||
def __init__(self, label: str = "average_precision"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.score)
|
||||
|
||||
score = average_precision(y_true=y_true, y_score=y_score)
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionAveragePrecisionConfig):
|
||||
return ClipDetectionAveragePrecision(label=config.label)
|
||||
|
||||
|
||||
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
|
||||
|
||||
class ClipDetectionROCAUC:
|
||||
def __init__(self, label: str = "roc_auc"):
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
y_true.append(clip_eval.gt_det)
|
||||
y_score.append(clip_eval.score)
|
||||
|
||||
score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score))
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionROCAUCConfig):
|
||||
return ClipDetectionROCAUC(label=config.label)
|
||||
|
||||
|
||||
class ClipDetectionRecallConfig(BaseConfig):
|
||||
name: Literal["recall"] = "recall"
|
||||
threshold: float = 0.5
|
||||
label: str = "recall"
|
||||
|
||||
|
||||
class ClipDetectionRecall:
|
||||
def __init__(self, threshold: float, label: str = "recall"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
if clip_eval.gt_det:
|
||||
num_positives += 1
|
||||
|
||||
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_positives
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionRecallConfig):
|
||||
return ClipDetectionRecall(
|
||||
threshold=config.threshold, label=config.label
|
||||
)
|
||||
|
||||
|
||||
class ClipDetectionPrecisionConfig(BaseConfig):
|
||||
name: Literal["precision"] = "precision"
|
||||
threshold: float = 0.5
|
||||
label: str = "precision"
|
||||
|
||||
|
||||
class ClipDetectionPrecision:
|
||||
def __init__(self, threshold: float, label: str = "precision"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
for clip_eval in clip_evaluations:
|
||||
if clip_eval.score >= self.threshold:
|
||||
num_detections += 1
|
||||
|
||||
if clip_eval.score >= self.threshold and clip_eval.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_detections
|
||||
return {self.label: score}
|
||||
|
||||
@clip_detection_metrics.register(ClipDetectionPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClipDetectionPrecisionConfig):
|
||||
return ClipDetectionPrecision(
|
||||
threshold=config.threshold, label=config.label
|
||||
)
|
||||
|
||||
|
||||
ClipDetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
ClipDetectionAveragePrecisionConfig,
|
||||
ClipDetectionROCAUCConfig,
|
||||
ClipDetectionRecallConfig,
|
||||
ClipDetectionPrecisionConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_metric(config: ClipDetectionMetricConfig):
|
||||
return clip_detection_metrics.build(config)
|
||||
@ -1,24 +1,24 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
__all__ = [
|
||||
"compute_precision_recall",
|
||||
"average_precision",
|
||||
]
|
||||
|
||||
def average_precision(
|
||||
|
||||
def compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> float:
|
||||
) -> Tuple[np.ndarray, np.ndarray]:
|
||||
y_true = np.array(y_true)
|
||||
y_score = np.array(y_score)
|
||||
|
||||
if num_positives is None:
|
||||
num_positives = y_true.sum()
|
||||
|
||||
# Remove non-detections
|
||||
valid_inds = y_score > 0
|
||||
y_true = y_true[valid_inds]
|
||||
y_score = y_score[valid_inds]
|
||||
|
||||
# Sort by score
|
||||
sort_ind = np.argsort(y_score)[::-1]
|
||||
y_true_sorted = y_true[sort_ind]
|
||||
@ -34,6 +34,19 @@ def average_precision(
|
||||
|
||||
precision[np.isnan(precision)] = 0
|
||||
recall[np.isnan(recall)] = 0
|
||||
return precision, recall
|
||||
|
||||
|
||||
def average_precision(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> float:
|
||||
precision, recall = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
|
||||
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
226
src/batdetect2/evaluate/metrics/detection.py
Normal file
@ -0,0 +1,226 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
|
||||
__all__ = [
|
||||
"DetectionMetricConfig",
|
||||
"DetectionMetric",
|
||||
"build_detection_metric",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
is_prediction: bool
|
||||
is_ground_truth: bool
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
clip: data.Clip
|
||||
matches: List[MatchEval]
|
||||
|
||||
|
||||
DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
|
||||
detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric")
|
||||
|
||||
|
||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionAveragePrecision:
|
||||
def __init__(self, label: str, ignore_non_predictions: bool = True):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evals: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
ap = average_precision(y_true, y_score, num_positives=num_positives)
|
||||
return {self.label: ap}
|
||||
|
||||
@detection_metrics.register(DetectionAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionAveragePrecisionConfig):
|
||||
return DetectionAveragePrecision(
|
||||
label=config.label,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
)
|
||||
|
||||
|
||||
class DetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
label: str = "roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
label: str = "roc_auc",
|
||||
ignore_non_predictions: bool = True,
|
||||
):
|
||||
self.label = label
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
|
||||
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||
y_true: List[bool] = []
|
||||
y_score: List[float] = []
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {self.label: score}
|
||||
|
||||
@detection_metrics.register(DetectionROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionROCAUCConfig):
|
||||
return DetectionROCAUC(
|
||||
label=config.label,
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
)
|
||||
|
||||
|
||||
class DetectionRecallConfig(BaseConfig):
|
||||
name: Literal["recall"] = "recall"
|
||||
label: str = "recall"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionRecall:
|
||||
def __init__(self, threshold: float, label: str = "recall"):
|
||||
self.label = label
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_ground_truth:
|
||||
num_positives += 1
|
||||
|
||||
if m.score >= self.threshold and m.is_ground_truth:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_positives
|
||||
return {self.label: score}
|
||||
|
||||
@detection_metrics.register(DetectionRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionRecallConfig):
|
||||
return DetectionRecall(threshold=config.threshold, label=config.label)
|
||||
|
||||
|
||||
class DetectionPrecisionConfig(BaseConfig):
|
||||
name: Literal["precision"] = "precision"
|
||||
label: str = "precision"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionPrecision:
|
||||
def __init__(self, threshold: float, label: str = "precision"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_detection = m.score >= self.threshold
|
||||
|
||||
if is_detection:
|
||||
num_detections += 1
|
||||
|
||||
if is_detection and m.is_ground_truth:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_detections
|
||||
return {self.label: score}
|
||||
|
||||
@detection_metrics.register(DetectionPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionPrecisionConfig):
|
||||
return DetectionPrecision(
|
||||
threshold=config.threshold,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
DetectionMetricConfig = Annotated[
|
||||
Union[
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionROCAUCConfig,
|
||||
DetectionRecallConfig,
|
||||
DetectionPrecisionConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_detection_metric(config: DetectionMetricConfig):
|
||||
return detection_metrics.build(config)
|
||||
@ -1,235 +0,0 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import (
|
||||
ClipMatches,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"MatchMetricConfig",
|
||||
"MatchesMetric",
|
||||
"build_match_metric",
|
||||
]
|
||||
|
||||
MatchesMetric = Callable[[Sequence[ClipMatches]], float]
|
||||
|
||||
|
||||
metrics_registry: Registry[MatchesMetric, []] = Registry("match_metric")
|
||||
|
||||
|
||||
class DetectionAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["detection_average_precision"] = (
|
||||
"detection_average_precision"
|
||||
)
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionAveragePrecision:
|
||||
def __init__(self, ignore_non_predictions: bool = True):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
num_positives += int(m.gt_det)
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det)
|
||||
y_score.append(m.pred_score)
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
@metrics_registry.register(DetectionAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionAveragePrecisionConfig):
|
||||
return DetectionAveragePrecision(
|
||||
ignore_non_predictions=config.ignore_non_predictions
|
||||
)
|
||||
|
||||
|
||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["top_class_average_precision"] = (
|
||||
"top_class_average_precision"
|
||||
)
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class TopClassAveragePrecision:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore ground truth sounds with unknown class
|
||||
continue
|
||||
|
||||
num_positives += int(m.gt_det)
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det & (m.top_class == m.gt_class))
|
||||
y_score.append(m.top_class_score)
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
@metrics_registry.register(TopClassAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassAveragePrecisionConfig):
|
||||
return TopClassAveragePrecision(
|
||||
ignore_non_predictions=config.ignore_non_predictions
|
||||
)
|
||||
|
||||
|
||||
class DetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["detection_roc_auc"] = "detection_roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class DetectionROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipMatches]) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_det)
|
||||
y_score.append(m.pred_score)
|
||||
|
||||
return float(metrics.roc_auc_score(y_true, y_score))
|
||||
|
||||
@metrics_registry.register(DetectionROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionROCAUCConfig):
|
||||
return DetectionROCAUC(
|
||||
ignore_non_predictions=config.ignore_non_predictions
|
||||
)
|
||||
|
||||
|
||||
class DetectionRecallConfig(BaseConfig):
|
||||
name: Literal["detection_recall"] = "detection_recall"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionRecall:
|
||||
def __init__(self, threshold: float):
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.gt_det:
|
||||
num_positives += 1
|
||||
|
||||
if m.pred_score >= self.threshold and m.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return 1
|
||||
|
||||
return true_positives / num_positives
|
||||
|
||||
@metrics_registry.register(DetectionRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionRecallConfig):
|
||||
return DetectionRecall(threshold=config.threshold)
|
||||
|
||||
|
||||
class DetectionPrecisionConfig(BaseConfig):
|
||||
name: Literal["detection_precision"] = "detection_precision"
|
||||
threshold: float = 0.5
|
||||
|
||||
|
||||
class DetectionPrecision:
|
||||
def __init__(self, threshold: float):
|
||||
self.threshold = threshold
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
) -> float:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_detection = m.pred_score >= self.threshold
|
||||
|
||||
if is_detection:
|
||||
num_detections += 1
|
||||
|
||||
if is_detection and m.gt_det:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return np.nan
|
||||
|
||||
return true_positives / num_detections
|
||||
|
||||
@metrics_registry.register(DetectionPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DetectionPrecisionConfig):
|
||||
return DetectionPrecision(threshold=config.threshold)
|
||||
|
||||
|
||||
MatchMetricConfig = Annotated[
|
||||
Union[
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionROCAUCConfig,
|
||||
DetectionRecallConfig,
|
||||
DetectionPrecisionConfig,
|
||||
TopClassAveragePrecisionConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_match_metric(config: MatchMetricConfig):
|
||||
return metrics_registry.build(config)
|
||||
@ -1,136 +0,0 @@
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
from sklearn import metrics
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import (
|
||||
ClipMatches,
|
||||
)
|
||||
|
||||
__all__ = []
|
||||
|
||||
PerClassMatchMetric = Callable[[Sequence[ClipMatches], str], float]
|
||||
|
||||
|
||||
metrics_registry: Registry[PerClassMatchMetric, []] = Registry(
|
||||
"match_metric"
|
||||
)
|
||||
|
||||
|
||||
class ClassificationAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["classification_average_precision"] = (
|
||||
"classification_average_precision"
|
||||
)
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ClassificationAveragePrecision:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
class_name: str,
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_class = m.gt_class == class_name
|
||||
|
||||
if is_class:
|
||||
num_positives += 1
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
# Exclude matches with ground truth sounds where the class is
|
||||
# unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
y_true.append(is_class)
|
||||
y_score.append(m.pred_class_scores.get(class_name, 0))
|
||||
|
||||
return average_precision(y_true, y_score, num_positives=num_positives)
|
||||
|
||||
@metrics_registry.register(ClassificationAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationAveragePrecisionConfig):
|
||||
return ClassificationAveragePrecision(
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
class ClassificationROCAUCConfig(BaseConfig):
|
||||
name: Literal["classification_roc_auc"] = "classification_roc_auc"
|
||||
ignore_non_predictions: bool = True
|
||||
ignore_generic: bool = True
|
||||
|
||||
|
||||
class ClassificationROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_non_predictions: bool = True,
|
||||
ignore_generic: bool = True,
|
||||
):
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.ignore_generic = ignore_generic
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipMatches],
|
||||
class_name: str,
|
||||
) -> float:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
# Exclude matches with ground truth sounds where the class is
|
||||
# unknown
|
||||
if m.is_generic and self.ignore_generic:
|
||||
continue
|
||||
|
||||
# Ignore matches that don't correspond to a prediction
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
continue
|
||||
|
||||
y_true.append(m.gt_class == class_name)
|
||||
y_score.append(m.pred_class_scores.get(class_name, 0))
|
||||
|
||||
return float(metrics.roc_auc_score(y_true, y_score))
|
||||
|
||||
@metrics_registry.register(ClassificationROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ClassificationROCAUCConfig):
|
||||
return ClassificationROCAUC(
|
||||
ignore_non_predictions=config.ignore_non_predictions,
|
||||
ignore_generic=config.ignore_generic,
|
||||
)
|
||||
|
||||
|
||||
PerClassMatchMetricConfig = Annotated[
|
||||
Union[
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_per_class_matches_metric(config: PerClassMatchMetricConfig):
|
||||
return metrics_registry.build(config)
|
||||
313
src/batdetect2/evaluate/metrics/top_class.py
Normal file
313
src/batdetect2/evaluate/metrics/top_class.py
Normal file
@ -0,0 +1,313 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Sequence,
|
||||
Union,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics, preprocessing
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.metrics.common import average_precision
|
||||
from batdetect2.typing import RawPrediction
|
||||
|
||||
__all__ = [
|
||||
"TopClassMetricConfig",
|
||||
"TopClassMetric",
|
||||
"build_top_class_metric",
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchEval:
|
||||
gt: Optional[data.SoundEventAnnotation]
|
||||
pred: Optional[RawPrediction]
|
||||
|
||||
is_ground_truth: bool
|
||||
is_generic: bool
|
||||
is_prediction: bool
|
||||
pred_class: Optional[str]
|
||||
true_class: Optional[str]
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClipEval:
|
||||
clip: data.Clip
|
||||
matches: List[MatchEval]
|
||||
|
||||
|
||||
TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]]
|
||||
|
||||
|
||||
top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric")
|
||||
|
||||
|
||||
class TopClassAveragePrecisionConfig(BaseConfig):
|
||||
name: Literal["average_precision"] = "average_precision"
|
||||
label: str = "average_precision"
|
||||
ignore_generic: bool = True
|
||||
ignore_non_predictions: bool = True
|
||||
|
||||
|
||||
class TopClassAveragePrecision:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_generic: bool = True,
|
||||
ignore_non_predictions: bool = True,
|
||||
label: str = "average_precision",
|
||||
):
|
||||
self.ignore_generic = ignore_generic
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evals: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
num_positives = 0
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
num_positives += int(m.is_ground_truth)
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
score = average_precision(y_true, y_score, num_positives=num_positives)
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassAveragePrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassAveragePrecisionConfig):
|
||||
return TopClassAveragePrecision(
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class TopClassROCAUCConfig(BaseConfig):
|
||||
name: Literal["roc_auc"] = "roc_auc"
|
||||
ignore_generic: bool = True
|
||||
ignore_non_predictions: bool = True
|
||||
label: str = "roc_auc"
|
||||
|
||||
|
||||
class TopClassROCAUC:
|
||||
def __init__(
|
||||
self,
|
||||
ignore_generic: bool = True,
|
||||
ignore_non_predictions: bool = True,
|
||||
label: str = "roc_auc",
|
||||
):
|
||||
self.ignore_generic = ignore_generic
|
||||
self.ignore_non_predictions = ignore_non_predictions
|
||||
self.label = label
|
||||
|
||||
def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]:
|
||||
y_true: List[bool] = []
|
||||
y_score: List[float] = []
|
||||
|
||||
for clip_eval in clip_evals:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic and self.ignore_generic:
|
||||
# Ignore gt sounds with unknown class
|
||||
continue
|
||||
|
||||
if not m.is_prediction and self.ignore_non_predictions:
|
||||
# Ignore non predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.pred_class == m.true_class)
|
||||
y_score.append(m.score)
|
||||
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassROCAUCConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassROCAUCConfig):
|
||||
return TopClassROCAUC(
|
||||
ignore_generic=config.ignore_generic,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class TopClassRecallConfig(BaseConfig):
|
||||
name: Literal["recall"] = "recall"
|
||||
threshold: float = 0.5
|
||||
label: str = "recall"
|
||||
|
||||
|
||||
class TopClassRecall:
|
||||
def __init__(self, threshold: float, label: str = "recall"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_positives = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_ground_truth:
|
||||
num_positives += 1
|
||||
|
||||
if m.score >= self.threshold and m.pred_class == m.true_class:
|
||||
true_positives += 1
|
||||
|
||||
if num_positives == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_positives
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassRecallConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassRecallConfig):
|
||||
return TopClassRecall(
|
||||
threshold=config.threshold,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class TopClassPrecisionConfig(BaseConfig):
|
||||
name: Literal["precision"] = "precision"
|
||||
threshold: float = 0.5
|
||||
label: str = "precision"
|
||||
|
||||
|
||||
class TopClassPrecision:
|
||||
def __init__(self, threshold: float, label: str = "precision"):
|
||||
self.threshold = threshold
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
num_detections = 0
|
||||
true_positives = 0
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
is_detection = m.score >= self.threshold
|
||||
|
||||
if is_detection:
|
||||
num_detections += 1
|
||||
|
||||
if is_detection and m.pred_class == m.true_class:
|
||||
true_positives += 1
|
||||
|
||||
if num_detections == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = true_positives / num_detections
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(TopClassPrecisionConfig)
|
||||
@staticmethod
|
||||
def from_config(config: TopClassPrecisionConfig):
|
||||
return TopClassPrecision(
|
||||
threshold=config.threshold,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
class BalancedAccuracyConfig(BaseConfig):
|
||||
name: Literal["balanced_accuracy"] = "balanced_accuracy"
|
||||
label: str = "balanced_accuracy"
|
||||
exclude_noise: bool = False
|
||||
noise_class: str = "noise"
|
||||
|
||||
|
||||
class BalancedAccuracy:
|
||||
def __init__(
|
||||
self,
|
||||
exclude_noise: bool = True,
|
||||
noise_class: str = "noise",
|
||||
label: str = "balanced_accuracy",
|
||||
):
|
||||
self.exclude_noise = exclude_noise
|
||||
self.noise_class = noise_class
|
||||
self.label = label
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_evaluations: Sequence[ClipEval],
|
||||
) -> Dict[str, float]:
|
||||
y_true: List[str] = []
|
||||
y_pred: List[str] = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for m in clip_eval.matches:
|
||||
if m.is_generic:
|
||||
# Ignore matches that correspond to a sound event
|
||||
# with unknown class
|
||||
continue
|
||||
|
||||
if not m.is_ground_truth and self.exclude_noise:
|
||||
# Ignore predictions that were not matched to a
|
||||
# ground truth
|
||||
continue
|
||||
|
||||
if m.pred_class is None and self.exclude_noise:
|
||||
# Ignore non-predictions
|
||||
continue
|
||||
|
||||
y_true.append(m.true_class or self.noise_class)
|
||||
y_pred.append(m.pred_class or self.noise_class)
|
||||
|
||||
encoder = preprocessing.LabelEncoder()
|
||||
encoder.fit(list(set(y_true) | set(y_pred)))
|
||||
|
||||
y_true = encoder.transform(y_true)
|
||||
y_pred = encoder.transform(y_pred)
|
||||
score = metrics.balanced_accuracy_score(y_true, y_pred)
|
||||
return {self.label: score}
|
||||
|
||||
@top_class_metrics.register(BalancedAccuracyConfig)
|
||||
@staticmethod
|
||||
def from_config(config: BalancedAccuracyConfig):
|
||||
return BalancedAccuracy(
|
||||
exclude_noise=config.exclude_noise,
|
||||
noise_class=config.noise_class,
|
||||
label=config.label,
|
||||
)
|
||||
|
||||
|
||||
TopClassMetricConfig = Annotated[
|
||||
Union[
|
||||
TopClassAveragePrecisionConfig,
|
||||
TopClassROCAUCConfig,
|
||||
TopClassRecallConfig,
|
||||
TopClassPrecisionConfig,
|
||||
BalancedAccuracyConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_top_class_metric(config: TopClassMetricConfig):
|
||||
return top_class_metrics.build(config)
|
||||
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
39
src/batdetect2/evaluate/tasks/__init__.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.evaluate.tasks.base import tasks_registry
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.evaluate.tasks.clip_classification import (
|
||||
ClipClassificationTaskConfig,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"TaskConfig",
|
||||
"build_task",
|
||||
]
|
||||
|
||||
|
||||
TaskConfig = Annotated[
|
||||
Union[
|
||||
ClassificationTaskConfig,
|
||||
DetectionTaskConfig,
|
||||
ClipDetectionTaskConfig,
|
||||
ClipClassificationTaskConfig,
|
||||
TopClassDetectionTaskConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_task(
|
||||
config: TaskConfig,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
return tasks_registry.build(config, targets)
|
||||
@ -1,3 +1,5 @@
|
||||
from typing import Callable, Dict, Generic, List, Sequence, TypeVar
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
@ -14,14 +16,19 @@ from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BaseEvaluatorConfig",
|
||||
"BaseEvaluator",
|
||||
"BaseTaskConfig",
|
||||
"BaseTask",
|
||||
]
|
||||
|
||||
evaluators: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry("metric")
|
||||
tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry(
|
||||
"tasks"
|
||||
)
|
||||
|
||||
|
||||
class BaseEvaluatorConfig(BaseConfig):
|
||||
T_Output = TypeVar("T_Output")
|
||||
|
||||
|
||||
class BaseTaskConfig(BaseConfig):
|
||||
prefix: str
|
||||
ignore_start_end: float = 0.01
|
||||
matching_strategy: MatchConfig = Field(
|
||||
@ -29,11 +36,13 @@ class BaseEvaluatorConfig(BaseConfig):
|
||||
)
|
||||
|
||||
|
||||
class BaseEvaluator(EvaluatorProtocol):
|
||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
targets: TargetProtocol
|
||||
|
||||
matcher: MatcherProtocol
|
||||
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]]
|
||||
|
||||
ignore_start_end: float
|
||||
|
||||
prefix: str
|
||||
@ -42,15 +51,44 @@ class BaseEvaluator(EvaluatorProtocol):
|
||||
self,
|
||||
matcher: MatcherProtocol,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
prefix: str,
|
||||
ignore_start_end: float = 0.01,
|
||||
):
|
||||
self.matcher = matcher
|
||||
self.metrics = metrics
|
||||
self.targets = targets
|
||||
self.prefix = prefix
|
||||
self.ignore_start_end = ignore_start_end
|
||||
|
||||
def filter_sound_event_annotations(
|
||||
def compute_metrics(
|
||||
self,
|
||||
eval_outputs: List[T_Output],
|
||||
) -> Dict[str, float]:
|
||||
scores = [metric(eval_outputs) for metric in self.metrics]
|
||||
return {
|
||||
f"{self.prefix}/{name}": score
|
||||
for metric_output in scores
|
||||
for name, score in metric_output.items()
|
||||
}
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[T_Output]:
|
||||
return [
|
||||
self.evaluate_clip(clip_annotation, preds)
|
||||
for clip_annotation, preds in zip(clip_annotations, predictions)
|
||||
]
|
||||
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> T_Output: ...
|
||||
|
||||
def include_sound_event_annotation(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
clip: data.Clip,
|
||||
@ -68,7 +106,7 @@ class BaseEvaluator(EvaluatorProtocol):
|
||||
self.ignore_start_end,
|
||||
)
|
||||
|
||||
def filter_predictions(
|
||||
def include_prediction(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
clip: data.Clip,
|
||||
@ -82,14 +120,16 @@ class BaseEvaluator(EvaluatorProtocol):
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
config: BaseEvaluatorConfig,
|
||||
config: BaseTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]],
|
||||
**kwargs,
|
||||
):
|
||||
matcher = build_matcher(config.matching_strategy)
|
||||
return cls(
|
||||
matcher=matcher,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
prefix=config.prefix,
|
||||
ignore_start_end=config.ignore_start_end,
|
||||
**kwargs,
|
||||
137
src/batdetect2/evaluate/tasks/classification.py
Normal file
137
src/batdetect2/evaluate/tasks/classification.py
Normal file
@ -0,0 +1,137 @@
|
||||
from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationMetricConfig,
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
build_classification_metrics,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["sound_event_classification"] = "sound_event_classification"
|
||||
prefix: str = "classification"
|
||||
metrics: List[ClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [ClassificationAveragePrecisionConfig()]
|
||||
)
|
||||
include_generics: bool = True
|
||||
|
||||
|
||||
class ClassificationTask(BaseTask[ClipEval]):
|
||||
def __init__(
|
||||
self,
|
||||
*args,
|
||||
include_generics: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.include_generics = include_generics
|
||||
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
]
|
||||
|
||||
all_gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
|
||||
per_class_matches = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
class_idx = self.targets.class_names.index(class_name)
|
||||
|
||||
# Only match to targets of the given class
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in all_gts
|
||||
if self.is_class(sound_event, class_name)
|
||||
]
|
||||
scores = [float(pred.class_scores[class_idx]) for pred in preds]
|
||||
|
||||
matches = []
|
||||
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx])
|
||||
if pred is not None
|
||||
else 0
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
true_class=true_class,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
per_class_matches[class_name] = matches
|
||||
|
||||
return ClipEval(clip=clip, matches=per_class_matches)
|
||||
|
||||
def is_class(
|
||||
self,
|
||||
sound_event: data.SoundEventAnnotation,
|
||||
class_name: str,
|
||||
) -> bool:
|
||||
sound_event_class = self.targets.encode_class(sound_event)
|
||||
|
||||
if sound_event_class is None and self.include_generics:
|
||||
# Sound events that are generic could be of the given
|
||||
# class
|
||||
return True
|
||||
|
||||
return sound_event_class == class_name
|
||||
|
||||
@tasks_registry.register(ClassificationTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClassificationTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [
|
||||
build_classification_metrics(metric) for metric in config.metrics
|
||||
]
|
||||
return ClassificationTask.build(
|
||||
config=config,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
)
|
||||
75
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
75
src/batdetect2/evaluate/tasks/clip_classification.py
Normal file
@ -0,0 +1,75 @@
|
||||
from collections import defaultdict
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.clip_classification import (
|
||||
ClipClassificationAveragePrecisionConfig,
|
||||
ClipClassificationMetricConfig,
|
||||
ClipEval,
|
||||
build_clip_metric,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_classification"] = "clip_classification"
|
||||
prefix: str = "clip_classification"
|
||||
metrics: List[ClipClassificationMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipClassificationAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gt_classes = set()
|
||||
for sound_event in clip_annotation.sound_events:
|
||||
if not self.include_sound_event_annotation(sound_event, clip):
|
||||
continue
|
||||
|
||||
class_name = self.targets.encode_class(sound_event)
|
||||
|
||||
if class_name is None:
|
||||
continue
|
||||
|
||||
gt_classes.add(class_name)
|
||||
|
||||
pred_scores = defaultdict(float)
|
||||
for pred in predictions:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
for class_idx, class_name in enumerate(self.targets.class_names):
|
||||
pred_scores[class_name] = max(
|
||||
float(pred.class_scores[class_idx]),
|
||||
pred_scores[class_name],
|
||||
)
|
||||
|
||||
return ClipEval(true_classes=gt_classes, class_scores=pred_scores)
|
||||
|
||||
@tasks_registry.register(ClipClassificationTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClipClassificationTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||
return ClipClassificationTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
66
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
66
src/batdetect2/evaluate/tasks/clip_detection.py
Normal file
@ -0,0 +1,66 @@
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.clip_detection import (
|
||||
ClipDetectionAveragePrecisionConfig,
|
||||
ClipDetectionMetricConfig,
|
||||
ClipEval,
|
||||
build_clip_metric,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["clip_detection"] = "clip_detection"
|
||||
prefix: str = "clip_detection"
|
||||
metrics: List[ClipDetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [
|
||||
ClipDetectionAveragePrecisionConfig(),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gt_det = any(
|
||||
self.include_sound_event_annotation(sound_event, clip)
|
||||
for sound_event in clip_annotation.sound_events
|
||||
)
|
||||
|
||||
pred_score = 0
|
||||
for pred in predictions:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
pred_score = max(pred_score, pred.detection_score)
|
||||
|
||||
return ClipEval(
|
||||
gt_det=gt_det,
|
||||
score=pred_score,
|
||||
)
|
||||
|
||||
@tasks_registry.register(ClipDetectionTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: ClipDetectionTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_clip_metric(metric) for metric in config.metrics]
|
||||
return ClipDetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
79
src/batdetect2/evaluate/tasks/detection.py
Normal file
79
src/batdetect2/evaluate/tasks/detection.py
Normal file
@ -0,0 +1,79 @@
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.detection import (
|
||||
ClipEval,
|
||||
DetectionAveragePrecisionConfig,
|
||||
DetectionMetricConfig,
|
||||
MatchEval,
|
||||
build_detection_metric,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["sound_event_detection"] = "sound_event_detection"
|
||||
prefix: str = "detection"
|
||||
metrics: List[DetectionMetricConfig] = Field(
|
||||
default_factory=lambda: [DetectionAveragePrecisionConfig()]
|
||||
)
|
||||
|
||||
|
||||
class DetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
]
|
||||
scores = [pred.detection_score for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_prediction=pred is not None,
|
||||
is_ground_truth=gt is not None,
|
||||
score=pred.detection_score if pred is not None else 0,
|
||||
)
|
||||
)
|
||||
|
||||
return ClipEval(clip=clip, matches=matches)
|
||||
|
||||
@tasks_registry.register(DetectionTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: DetectionTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_detection_metric(metric) for metric in config.metrics]
|
||||
return DetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
101
src/batdetect2/evaluate/tasks/top_class.py
Normal file
101
src/batdetect2/evaluate/tasks/top_class.py
Normal file
@ -0,0 +1,101 @@
|
||||
from typing import List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
TopClassAveragePrecisionConfig,
|
||||
TopClassMetricConfig,
|
||||
build_top_class_metric,
|
||||
)
|
||||
from batdetect2.evaluate.tasks.base import (
|
||||
BaseTask,
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||
name: Literal["top_class_detection"] = "top_class_detection"
|
||||
prefix: str = "top_class"
|
||||
metrics: List[TopClassMetricConfig] = Field(
|
||||
default_factory=lambda: [TopClassAveragePrecisionConfig()]
|
||||
)
|
||||
|
||||
|
||||
class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
gts = [
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
]
|
||||
# Take the highest score for each prediction
|
||||
scores = [pred.class_scores.max() for pred in preds]
|
||||
|
||||
matches = []
|
||||
for pred_idx, gt_idx, _ in self.matcher(
|
||||
ground_truth=[se.sound_event.geometry for se in gts], # type: ignore
|
||||
predictions=[pred.geometry for pred in preds],
|
||||
scores=scores,
|
||||
):
|
||||
gt = gts[gt_idx] if gt_idx is not None else None
|
||||
pred = preds[pred_idx] if pred_idx is not None else None
|
||||
|
||||
true_class = (
|
||||
self.targets.encode_class(gt) if gt is not None else None
|
||||
)
|
||||
|
||||
class_idx = (
|
||||
pred.class_scores.argmax() if pred is not None else None
|
||||
)
|
||||
|
||||
score = (
|
||||
float(pred.class_scores[class_idx]) if pred is not None else 0
|
||||
)
|
||||
|
||||
pred_class = (
|
||||
self.targets.class_names[class_idx]
|
||||
if class_idx is not None
|
||||
else None
|
||||
)
|
||||
|
||||
matches.append(
|
||||
MatchEval(
|
||||
gt=gt,
|
||||
pred=pred,
|
||||
is_ground_truth=gt is not None,
|
||||
is_prediction=pred is not None,
|
||||
true_class=true_class,
|
||||
is_generic=gt is not None and true_class is None,
|
||||
pred_class=pred_class,
|
||||
score=score,
|
||||
)
|
||||
)
|
||||
|
||||
return ClipEval(clip=clip, matches=matches)
|
||||
|
||||
@tasks_registry.register(TopClassDetectionTaskConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: TopClassDetectionTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
):
|
||||
metrics = [build_top_class_metric(metric) for metric in config.metrics]
|
||||
return TopClassDetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
)
|
||||
@ -8,7 +8,7 @@ from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
@ -106,7 +106,7 @@ def train(
|
||||
config,
|
||||
targets=targets,
|
||||
evaluator=build_evaluator(
|
||||
config.train.validation.evaluator,
|
||||
config.train.validation,
|
||||
targets=targets,
|
||||
),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user