From df2abff654a01db9a2323ba24f56b741d873ce53 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 26 Sep 2025 15:23:38 +0100 Subject: [PATCH] Task/Metrics restructure --- src/batdetect2/api/base.py | 10 +- src/batdetect2/evaluate/__init__.py | 11 +- src/batdetect2/evaluate/config.py | 16 +- src/batdetect2/evaluate/evaluate.py | 5 +- src/batdetect2/evaluate/evaluator.py | 67 ++++ src/batdetect2/evaluate/evaluator/__init__.py | 114 ------- src/batdetect2/evaluate/evaluator/clip.py | 163 --------- src/batdetect2/evaluate/evaluator/multiple.py | 0 .../evaluate/evaluator/per_class.py | 219 ------------ src/batdetect2/evaluate/evaluator/single.py | 126 ------- .../evaluate/evaluator/top_class.py | 133 -------- .../evaluate/metrics/classification.py | 253 ++++++++++++++ .../evaluate/metrics/clip_classification.py | 135 ++++++++ .../evaluate/metrics/clip_detection.py | 173 ++++++++++ src/batdetect2/evaluate/metrics/common.py | 29 +- src/batdetect2/evaluate/metrics/detection.py | 226 +++++++++++++ src/batdetect2/evaluate/metrics/matches.py | 235 ------------- .../evaluate/metrics/per_class_matches.py | 136 -------- src/batdetect2/evaluate/metrics/top_class.py | 313 ++++++++++++++++++ src/batdetect2/evaluate/tasks/__init__.py | 39 +++ .../evaluate/{evaluator => tasks}/base.py | 56 +++- .../evaluate/tasks/classification.py | 137 ++++++++ .../evaluate/tasks/clip_classification.py | 75 +++++ .../evaluate/tasks/clip_detection.py | 66 ++++ src/batdetect2/evaluate/tasks/detection.py | 79 +++++ src/batdetect2/evaluate/tasks/top_class.py | 101 ++++++ src/batdetect2/train/train.py | 4 +- 27 files changed, 1756 insertions(+), 1165 deletions(-) create mode 100644 src/batdetect2/evaluate/evaluator.py delete mode 100644 src/batdetect2/evaluate/evaluator/__init__.py delete mode 100644 src/batdetect2/evaluate/evaluator/clip.py delete mode 100644 src/batdetect2/evaluate/evaluator/multiple.py delete mode 100644 src/batdetect2/evaluate/evaluator/per_class.py delete mode 100644 src/batdetect2/evaluate/evaluator/single.py delete mode 100644 src/batdetect2/evaluate/evaluator/top_class.py create mode 100644 src/batdetect2/evaluate/metrics/classification.py create mode 100644 src/batdetect2/evaluate/metrics/clip_classification.py create mode 100644 src/batdetect2/evaluate/metrics/clip_detection.py create mode 100644 src/batdetect2/evaluate/metrics/detection.py delete mode 100644 src/batdetect2/evaluate/metrics/matches.py delete mode 100644 src/batdetect2/evaluate/metrics/per_class_matches.py create mode 100644 src/batdetect2/evaluate/metrics/top_class.py create mode 100644 src/batdetect2/evaluate/tasks/__init__.py rename src/batdetect2/evaluate/{evaluator => tasks}/base.py (60%) create mode 100644 src/batdetect2/evaluate/tasks/classification.py create mode 100644 src/batdetect2/evaluate/tasks/clip_classification.py create mode 100644 src/batdetect2/evaluate/tasks/clip_detection.py create mode 100644 src/batdetect2/evaluate/tasks/detection.py create mode 100644 src/batdetect2/evaluate/tasks/top_class.py diff --git a/src/batdetect2/api/base.py b/src/batdetect2/api/base.py index 556db7d..11992a8 100644 --- a/src/batdetect2/api/base.py +++ b/src/batdetect2/api/base.py @@ -123,10 +123,7 @@ class BatDetect2API: config=config.postprocess, ) - evaluator = build_evaluator( - config=config.evaluation.evaluator, - targets=targets, - ) + evaluator = build_evaluator(config=config.evaluation, targets=targets) # NOTE: Better to have a separate instance of # preprocessor and postprocessor as these may be moved @@ -178,10 +175,7 @@ class BatDetect2API: config=config.postprocess, ) - evaluator = build_evaluator( - config=config.evaluation.evaluator, - targets=targets, - ) + evaluator = build_evaluator(config=config.evaluation, targets=targets) return cls( config=config, diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index 03d31db..07fa19e 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,11 +1,14 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.evaluate import evaluate -from batdetect2.evaluate.evaluator import MultipleEvaluator, build_evaluator +from batdetect2.evaluate.evaluator import Evaluator, build_evaluator +from batdetect2.evaluate.tasks import TaskConfig, build_task __all__ = [ "EvaluationConfig", - "load_evaluation_config", - "evaluate", - "MultipleEvaluator", + "Evaluator", + "TaskConfig", "build_evaluator", + "build_task", + "evaluate", + "load_evaluation_config", ] diff --git a/src/batdetect2/evaluate/config.py b/src/batdetect2/evaluate/config.py index de1ffae..4d5510c 100644 --- a/src/batdetect2/evaluate/config.py +++ b/src/batdetect2/evaluate/config.py @@ -1,13 +1,14 @@ -from typing import Optional +from typing import List, Optional from pydantic import Field from soundevent import data from batdetect2.core.configs import BaseConfig, load_config -from batdetect2.evaluate.evaluator import ( - EvaluatorConfig, - MultipleEvaluatorConfig, +from batdetect2.evaluate.tasks import ( + TaskConfig, ) +from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig +from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.logging import CSVLoggerConfig, LoggerConfig __all__ = [ @@ -17,7 +18,12 @@ __all__ = [ class EvaluationConfig(BaseConfig): - evaluator: EvaluatorConfig = Field(default_factory=MultipleEvaluatorConfig) + tasks: List[TaskConfig] = Field( + default_factory=lambda: [ + DetectionTaskConfig(), + ClassificationTaskConfig(), + ] + ) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index a151107..2fd723f 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -55,10 +55,7 @@ def evaluate( num_workers=num_workers, ) - evaluator = build_evaluator( - config=config.evaluation.evaluator, - targets=targets, - ) + evaluator = build_evaluator(config=config.evaluation, targets=targets) logger = build_logger( config.evaluation.logger, diff --git a/src/batdetect2/evaluate/evaluator.py b/src/batdetect2/evaluate/evaluator.py new file mode 100644 index 0000000..8126dda --- /dev/null +++ b/src/batdetect2/evaluate/evaluator.py @@ -0,0 +1,67 @@ +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union + +from matplotlib.figure import Figure +from soundevent import data + +from batdetect2.evaluate.config import EvaluationConfig +from batdetect2.evaluate.tasks import build_task +from batdetect2.targets import build_targets +from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol + +__all__ = [ + "Evaluator", + "build_evaluator", +] + + +class Evaluator: + def __init__( + self, + targets: TargetProtocol, + tasks: Sequence[EvaluatorProtocol], + ): + self.targets = targets + self.tasks = tasks + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> List[Any]: + return [ + task.evaluate(clip_annotations, predictions) for task in self.tasks + ] + + def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]: + results = {} + + for task, outputs in zip(self.tasks, eval_outputs): + results.update(task.compute_metrics(outputs)) + + return results + + def generate_plots( + self, + eval_outputs: List[Any], + ) -> Iterable[Tuple[str, Figure]]: + for task, outputs in zip(self.tasks, eval_outputs): + for name, fig in task.generate_plots(outputs): + yield name, fig + + +def build_evaluator( + config: Optional[Union[EvaluationConfig, dict]] = None, + targets: Optional[TargetProtocol] = None, +) -> EvaluatorProtocol: + targets = targets or build_targets() + + if config is None: + config = EvaluationConfig() + + if not isinstance(config, EvaluationConfig): + config = EvaluationConfig.model_validate(config) + + return Evaluator( + targets=targets, + tasks=[build_task(task, targets=targets) for task in config.tasks], + ) diff --git a/src/batdetect2/evaluate/evaluator/__init__.py b/src/batdetect2/evaluate/evaluator/__init__.py deleted file mode 100644 index 92b31a9..0000000 --- a/src/batdetect2/evaluate/evaluator/__init__.py +++ /dev/null @@ -1,114 +0,0 @@ -from typing import ( - Annotated, - Any, - Dict, - Iterable, - List, - Literal, - Optional, - Sequence, - Tuple, - Union, -) - -from matplotlib.figure import Figure -from pydantic import Field -from soundevent import data - -from batdetect2.core.configs import BaseConfig -from batdetect2.evaluate.evaluator.base import evaluators -from batdetect2.evaluate.evaluator.clip import ClipMetricsConfig -from batdetect2.evaluate.evaluator.per_class import ClassificationMetricsConfig -from batdetect2.evaluate.evaluator.single import GlobalEvaluatorConfig -from batdetect2.targets import build_targets -from batdetect2.typing import ( - EvaluatorProtocol, - RawPrediction, - TargetProtocol, -) - -__all__ = [ - "EvaluatorConfig", - "build_evaluator", -] - - -EvaluatorConfig = Annotated[ - Union[ - ClassificationMetricsConfig, - GlobalEvaluatorConfig, - ClipMetricsConfig, - "MultipleEvaluatorConfig", - ], - Field(discriminator="name"), -] - - -class MultipleEvaluatorConfig(BaseConfig): - name: Literal["multiple_evaluations"] = "multiple_evaluations" - evaluations: List[EvaluatorConfig] = Field( - default_factory=lambda: [ - ClassificationMetricsConfig(), - GlobalEvaluatorConfig(), - ] - ) - - -class MultipleEvaluator: - def __init__( - self, - targets: TargetProtocol, - evaluators: Sequence[EvaluatorProtocol], - ): - self.targets = targets - self.evaluators = evaluators - - def evaluate( - self, - clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], - ) -> List[Any]: - return [ - evaluator.evaluate( - clip_annotations, - predictions, - ) - for evaluator in self.evaluators - ] - - def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]: - results = {} - - for evaluator, outputs in zip(self.evaluators, eval_outputs): - results.update(evaluator.compute_metrics(outputs)) - - return results - - def generate_plots( - self, - eval_outputs: List[Any], - ) -> Iterable[Tuple[str, Figure]]: - for evaluator, outputs in zip(self.evaluators, eval_outputs): - for name, fig in evaluator.generate_plots(outputs): - yield name, fig - - @evaluators.register(MultipleEvaluatorConfig) - @staticmethod - def from_config(config: MultipleEvaluatorConfig, targets: TargetProtocol): - return MultipleEvaluator( - evaluators=[ - build_evaluator(conf, targets=targets) - for conf in config.evaluations - ], - targets=targets, - ) - - -def build_evaluator( - config: Optional[EvaluatorConfig] = None, - targets: Optional[TargetProtocol] = None, -) -> EvaluatorProtocol: - targets = targets or build_targets() - - config = config or MultipleEvaluatorConfig() - return evaluators.build(config, targets) diff --git a/src/batdetect2/evaluate/evaluator/clip.py b/src/batdetect2/evaluate/evaluator/clip.py deleted file mode 100644 index 1556bc5..0000000 --- a/src/batdetect2/evaluate/evaluator/clip.py +++ /dev/null @@ -1,163 +0,0 @@ -from collections import defaultdict -from dataclasses import dataclass -from typing import Callable, Dict, List, Literal, Sequence, Set - -from pydantic import Field, field_validator -from sklearn import metrics -from soundevent import data - -from batdetect2.evaluate.evaluator.base import ( - BaseEvaluator, - BaseEvaluatorConfig, - evaluators, -) -from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing.postprocess import RawPrediction -from batdetect2.typing.targets import TargetProtocol - - -@dataclass -class ClipInfo: - gt_det: bool - gt_classes: Set[str] - pred_score: float - pred_class_scores: Dict[str, float] - - -ClipMetric = Callable[[Sequence[ClipInfo]], float] - - -def clip_detection_average_precision( - clip_evaluations: Sequence[ClipInfo], -) -> float: - y_true = [] - y_score = [] - - for clip_eval in clip_evaluations: - y_true.append(clip_eval.gt_det) - y_score.append(clip_eval.pred_score) - - return average_precision(y_true=y_true, y_score=y_score) - - -def clip_detection_roc_auc( - clip_evaluations: Sequence[ClipInfo], -) -> float: - y_true = [] - y_score = [] - - for clip_eval in clip_evaluations: - y_true.append(clip_eval.gt_det) - y_score.append(clip_eval.pred_score) - - return float(metrics.roc_auc_score(y_true=y_true, y_score=y_score)) - - -clip_metrics = { - "average_precision": clip_detection_average_precision, - "roc_auc": clip_detection_roc_auc, -} - - -class ClipMetricsConfig(BaseEvaluatorConfig): - name: Literal["clip"] = "clip" - prefix: str = "clip" - metrics: List[str] = Field( - default_factory=lambda: [ - "average_precision", - "roc_auc", - ] - ) - - @field_validator("metrics", mode="after") - @classmethod - def validate_metrics(cls, v: List[str]) -> List[str]: - for metric_name in v: - if metric_name not in clip_metrics: - raise ValueError(f"Unknown metric {metric_name}") - return v - - -class ClipEvaluator(BaseEvaluator): - def __init__(self, *args, metrics: Dict[str, ClipMetric], **kwargs): - super().__init__(*args, **kwargs) - self.metrics = metrics - - def evaluate( - self, - clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], - ) -> List[ClipInfo]: - return [ - self.match_clip(clip_annotation, preds) - for clip_annotation, preds in zip(clip_annotations, predictions) - ] - - def compute_metrics( - self, - eval_outputs: List[ClipInfo], - ) -> Dict[str, float]: - scores = { - name: metric(eval_outputs) for name, metric in self.metrics.items() - } - return { - f"{self.prefix}/{name}": score for name, score in scores.items() - } - - def match_clip( - self, - clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], - ) -> ClipInfo: - clip = clip_annotation.clip - - gt_det = False - gt_classes = set() - for sound_event in clip_annotation.sound_events: - if self.filter_sound_event_annotations(sound_event, clip): - continue - - gt_det = True - class_name = self.targets.encode_class(sound_event) - - if class_name is None: - continue - - gt_classes.add(class_name) - - pred_score = 0 - pred_class_scores: defaultdict[str, float] = defaultdict(lambda: 0) - for pred in predictions: - if self.filter_predictions(pred, clip): - continue - - pred_score = max(pred_score, pred.detection_score) - - for class_name, class_score in zip( - self.targets.class_names, - pred.class_scores, - ): - pred_class_scores[class_name] = max( - pred_class_scores[class_name], - class_score, - ) - - return ClipInfo( - gt_det=gt_det, - gt_classes=gt_classes, - pred_score=pred_score, - pred_class_scores=pred_class_scores, - ) - - @evaluators.register(ClipMetricsConfig) - @staticmethod - def from_config( - config: ClipMetricsConfig, - targets: TargetProtocol, - ): - metrics = {name: clip_metrics.get(name) for name in config.metrics} - return ClipEvaluator.build( - config=config, - metrics=metrics, - targets=targets, - ) diff --git a/src/batdetect2/evaluate/evaluator/multiple.py b/src/batdetect2/evaluate/evaluator/multiple.py deleted file mode 100644 index e69de29..0000000 diff --git a/src/batdetect2/evaluate/evaluator/per_class.py b/src/batdetect2/evaluate/evaluator/per_class.py deleted file mode 100644 index c5177d8..0000000 --- a/src/batdetect2/evaluate/evaluator/per_class.py +++ /dev/null @@ -1,219 +0,0 @@ -from collections import defaultdict -from typing import ( - Callable, - Dict, - List, - Literal, - Mapping, - Optional, - Sequence, -) - -import numpy as np -from pydantic import Field -from soundevent import data - -from batdetect2.evaluate.evaluator.base import ( - BaseEvaluator, - BaseEvaluatorConfig, - evaluators, -) -from batdetect2.evaluate.match import match -from batdetect2.evaluate.metrics.per_class_matches import ( - ClassificationAveragePrecisionConfig, - PerClassMatchMetric, - PerClassMatchMetricConfig, - build_per_class_matches_metric, -) -from batdetect2.typing import ( - ClipMatches, - RawPrediction, - TargetProtocol, -) - -ScoreFn = Callable[[RawPrediction, int], float] - - -def score_by_class_score(pred: RawPrediction, class_index: int) -> float: - return float(pred.class_scores[class_index]) - - -def score_by_adjusted_class_score( - pred: RawPrediction, - class_index: int, -) -> float: - return float(pred.class_scores[class_index]) * pred.detection_score - - -ScoreFunctionOption = Literal["class_score", "adjusted_class_score"] -score_functions: Mapping[ScoreFunctionOption, ScoreFn] = { - "class_score": score_by_class_score, - "adjusted_class_score": score_by_adjusted_class_score, -} - - -def get_score_fn(name: ScoreFunctionOption) -> ScoreFn: - return score_functions[name] - - -class ClassificationMetricsConfig(BaseEvaluatorConfig): - name: Literal["classification"] = "classification" - prefix: str = "classification" - include_generics: bool = True - score_by: ScoreFunctionOption = "class_score" - metrics: List[PerClassMatchMetricConfig] = Field( - default_factory=lambda: [ClassificationAveragePrecisionConfig()] - ) - include: Optional[List[str]] = None - exclude: Optional[List[str]] = None - - -class PerClassEvaluator(BaseEvaluator): - def __init__( - self, - *args, - metrics: Dict[str, PerClassMatchMetric], - score_fn: ScoreFn, - include_generics: bool = True, - include: Optional[List[str]] = None, - exclude: Optional[List[str]] = None, - **kwargs, - ): - super().__init__(*args, **kwargs) - - self.score_fn = score_fn - self.metrics = metrics - - self.include_generics = include_generics - - self.include = include - self.exclude = exclude - - self.selected = self.targets.class_names - if include is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name in include - ] - - if exclude is not None: - self.selected = [ - class_name - for class_name in self.selected - if class_name not in exclude - ] - - def evaluate( - self, - clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], - ) -> Dict[str, List[ClipMatches]]: - ret = defaultdict(list) - - for clip_annotation, preds in zip(clip_annotations, predictions): - matches = self.match_clip(clip_annotation, preds) - for class_name, clip_eval in matches.items(): - ret[class_name].append(clip_eval) - - return ret - - def compute_metrics( - self, - eval_outputs: Dict[str, List[ClipMatches]], - ) -> Dict[str, float]: - results = {} - - for metric_name, metric in self.metrics.items(): - class_scores = { - class_name: metric(eval_outputs[class_name], class_name) - for class_name in self.targets.class_names - } - mean = float( - np.mean([v for v in class_scores.values() if v != np.nan]) - ) - - results[f"{self.prefix}/mean_{metric_name}"] = mean - - for class_name, value in class_scores.items(): - if class_name not in self.selected: - continue - - results[f"{self.prefix}/{metric_name}/{class_name}"] = value - - return results - - def match_clip( - self, - clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], - ) -> Dict[str, ClipMatches]: - clip = clip_annotation.clip - - preds = [ - pred for pred in predictions if self.filter_predictions(pred, clip) - ] - - all_gts = [ - sound_event - for sound_event in clip_annotation.sound_events - if self.filter_sound_event_annotations(sound_event, clip) - ] - - ret = {} - - for class_name in self.targets.class_names: - class_idx = self.targets.class_names.index(class_name) - - # Only match to targets of the given class - gts = [ - sound_event - for sound_event in all_gts - if self.is_class(sound_event, class_name) - ] - scores = [self.score_fn(pred, class_idx) for pred in preds] - - ret[class_name] = match( - gts, - preds, - clip=clip, - scores=scores, - targets=self.targets, - matcher=self.matcher, - ) - - return ret - - def is_class( - self, - sound_event: data.SoundEventAnnotation, - class_name: str, - ) -> bool: - sound_event_class = self.targets.encode_class(sound_event) - - if sound_event_class is None and self.include_generics: - # Sound events that are generic could be of the given - # class - return True - - return sound_event_class == class_name - - @evaluators.register(ClassificationMetricsConfig) - @staticmethod - def from_config( - config: ClassificationMetricsConfig, - targets: TargetProtocol, - ): - metrics = { - metric.name: build_per_class_matches_metric(metric) - for metric in config.metrics - } - return PerClassEvaluator.build( - config=config, - targets=targets, - metrics=metrics, - score_fn=get_score_fn(config.score_by), - include_generics=config.include_generics, - include=config.include, - exclude=config.exclude, - ) diff --git a/src/batdetect2/evaluate/evaluator/single.py b/src/batdetect2/evaluate/evaluator/single.py deleted file mode 100644 index 4e3d22b..0000000 --- a/src/batdetect2/evaluate/evaluator/single.py +++ /dev/null @@ -1,126 +0,0 @@ -from typing import Callable, Dict, List, Literal, Mapping, Sequence - -from pydantic import Field -from soundevent import data - -from batdetect2.evaluate.evaluator.base import ( - BaseEvaluator, - BaseEvaluatorConfig, - evaluators, -) -from batdetect2.evaluate.match import match -from batdetect2.evaluate.metrics.matches import ( - DetectionAveragePrecisionConfig, - MatchesMetric, - MatchMetricConfig, - build_match_metric, -) -from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol - -ScoreFn = Callable[[RawPrediction], float] - - -def score_by_detection_score(pred: RawPrediction) -> float: - return pred.detection_score - - -def score_by_top_class_score(pred: RawPrediction) -> float: - return pred.class_scores.max() - - -ScoreFunctionOption = Literal["detection_score", "top_class_score"] -score_functions: Mapping[ScoreFunctionOption, ScoreFn] = { - "detection_score": score_by_detection_score, - "top_class_score": score_by_top_class_score, -} - - -def get_score_fn(name: ScoreFunctionOption) -> ScoreFn: - return score_functions[name] - - -class GlobalEvaluatorConfig(BaseEvaluatorConfig): - name: Literal["detection"] = "detection" - prefix: str = "detection" - score_by: ScoreFunctionOption = "detection_score" - metrics: List[MatchMetricConfig] = Field( - default_factory=lambda: [DetectionAveragePrecisionConfig()] - ) - - -class GlobalEvaluator(BaseEvaluator): - def __init__( - self, - *args, - score_fn: ScoreFn, - metrics: Dict[str, MatchesMetric], - **kwargs, - ): - super().__init__(*args, **kwargs) - self.metrics = metrics - self.score_fn = score_fn - - def compute_metrics( - self, - eval_outputs: List[ClipMatches], - ) -> Dict[str, float]: - scores = { - name: metric(eval_outputs) for name, metric in self.metrics.items() - } - return { - f"{self.prefix}/{name}": score for name, score in scores.items() - } - - def evaluate( - self, - clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], - ) -> List[ClipMatches]: - return [ - self.match_clip(clip_annotation, preds) - for clip_annotation, preds in zip(clip_annotations, predictions) - ] - - def match_clip( - self, - clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], - ) -> ClipMatches: - clip = clip_annotation.clip - - gts = [ - sound_event - for sound_event in clip_annotation.sound_events - if self.filter_sound_event_annotations(sound_event, clip) - ] - preds = [ - pred for pred in predictions if self.filter_predictions(pred, clip) - ] - scores = [self.score_fn(pred) for pred in preds] - - return match( - gts, - preds, - scores=scores, - clip=clip, - targets=self.targets, - matcher=self.matcher, - ) - - @evaluators.register(GlobalEvaluatorConfig) - @staticmethod - def from_config( - config: GlobalEvaluatorConfig, - targets: TargetProtocol, - ): - metrics = { - metric.name: build_match_metric(metric) - for metric in config.metrics - } - score_fn = get_score_fn(config.score_by) - return GlobalEvaluator.build( - config=config, - score_fn=score_fn, - metrics=metrics, - targets=targets, - ) diff --git a/src/batdetect2/evaluate/evaluator/top_class.py b/src/batdetect2/evaluate/evaluator/top_class.py deleted file mode 100644 index 149447e..0000000 --- a/src/batdetect2/evaluate/evaluator/top_class.py +++ /dev/null @@ -1,133 +0,0 @@ -from typing import Dict, List, Literal, Sequence - -from pydantic import Field, field_validator -from soundevent import data - -from batdetect2.evaluate.match import match -from batdetect2.evaluate.metrics.base import ( - BaseMetric, - BaseMetricConfig, - metrics_registry, -) -from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.evaluate.metrics.detection import DetectionMetric -from batdetect2.typing import ClipMatches, RawPrediction, TargetProtocol - -__all__ = [ - "TopClassEvaluator", - "TopClassEvaluatorConfig", -] - - -def top_class_average_precision( - clip_evaluations: Sequence[ClipMatches], -) -> float: - y_true = [] - y_score = [] - num_positives = 0 - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - is_generic = m.gt_det and (m.gt_class is None) - - # Ignore ground truth sounds with unknown class - if is_generic: - continue - - num_positives += int(m.gt_det) - - # Ignore matches that don't correspond to a prediction - if m.pred_geometry is None: - continue - - y_true.append(m.gt_det & (m.top_class == m.gt_class)) - y_score.append(m.top_class_score) - - return average_precision(y_true, y_score, num_positives=num_positives) - - -top_class_metrics = { - "average_precision": top_class_average_precision, -} - - -class TopClassEvaluatorConfig(BaseMetricConfig): - name: Literal["top_class"] = "top_class" - prefix: str = "top_class" - metrics: List[str] = Field(default_factory=lambda: ["average_precision"]) - - @field_validator("metrics", mode="after") - @classmethod - def validate_metrics(cls, v: List[str]) -> List[str]: - for metric_name in v: - if metric_name not in top_class_metrics: - raise ValueError(f"Unknown metric {metric_name}") - return v - - -class TopClassEvaluator(BaseMetric): - def __init__(self, *args, metrics: Dict[str, DetectionMetric], **kwargs): - super().__init__(*args, **kwargs) - self.metrics = metrics - - def __call__( - self, - clip_annotations: Sequence[data.ClipAnnotation], - predictions: Sequence[Sequence[RawPrediction]], - ) -> Dict[str, float]: - clip_evaluations = [ - self.match_clip(clip_annotation, preds) - for clip_annotation, preds in zip(clip_annotations, predictions) - ] - scores = { - name: metric(clip_evaluations) - for name, metric in self.metrics.items() - } - return { - f"{self.prefix}/{name}": score for name, score in scores.items() - } - - def match_clip( - self, - clip_annotation: data.ClipAnnotation, - predictions: Sequence[RawPrediction], - ) -> ClipMatches: - clip = clip_annotation.clip - - gts = [ - sound_event - for sound_event in clip_annotation.sound_events - if self.filter_sound_event_annotations(sound_event, clip) - ] - preds = [ - pred for pred in predictions if self.filter_predictions(pred, clip) - ] - # Use score of top class for matching - scores = [pred.class_scores.max() for pred in preds] - - return match( - gts, - preds, - scores=scores, - clip=clip, - targets=self.targets, - matcher=self.matcher, - ) - - @classmethod - def from_config( - cls, - config: TopClassEvaluatorConfig, - targets: TargetProtocol, - ): - metrics = { - name: top_class_metrics.get(name) for name in config.metrics - } - return super().build( - config=config, - metrics=metrics, - targets=targets, - ) - - -metrics_registry.register(TopClassEvaluatorConfig, TopClassEvaluator) diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py new file mode 100644 index 0000000..399c852 --- /dev/null +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -0,0 +1,253 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import ( + Annotated, + Callable, + Dict, + List, + Literal, + Mapping, + Optional, + Sequence, + Union, +) + +import numpy as np +from pydantic import Field +from sklearn import metrics +from soundevent import data + +from batdetect2.core import BaseConfig, Registry +from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.typing import RawPrediction + +__all__ = [] + + +@dataclass +class MatchEval: + gt: Optional[data.SoundEventAnnotation] + pred: Optional[RawPrediction] + + is_prediction: bool + is_ground_truth: bool + is_generic: bool + true_class: Optional[str] + score: float + + +@dataclass +class ClipEval: + clip: data.Clip + matches: Mapping[str, List[MatchEval]] + + +ClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] + + +classification_metrics: Registry[ClassificationMetric, []] = Registry( + "classification_metric" +) + + +class BaseClassificationConfig(BaseConfig): + include: Optional[List[str]] = None + exclude: Optional[List[str]] = None + + +class BaseClassificationMetric: + def __init__( + self, + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ): + self.include = include + self.exclude = exclude + + def include_class(self, class_name: str) -> bool: + if self.include is not None: + return class_name in self.include + + if self.exclude is not None: + return class_name not in self.exclude + + return True + + +class ClassificationAveragePrecisionConfig(BaseClassificationConfig): + name: Literal["average_precision"] = "average_precision" + ignore_non_predictions: bool = True + ignore_generic: bool = True + label: str = "average_precision" + + +class ClassificationAveragePrecision(BaseClassificationMetric): + def __init__( + self, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + label: str = "average_precision", + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ): + super().__init__(include=include, exclude=exclude) + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + self.label = label + + def __call__( + self, clip_evaluations: Sequence[ClipEval] + ) -> Dict[str, float]: + y_true = defaultdict(list) + y_score = defaultdict(list) + num_positives = defaultdict(lambda: 0) + + class_names = set() + + for clip_eval in clip_evaluations: + for class_name, matches in clip_eval.matches.items(): + class_names.add(class_name) + + for m in matches: + # Exclude matches with ground truth sounds where the class + # is unknown + if m.is_generic and self.ignore_generic: + continue + + is_class = m.true_class == class_name + + if is_class: + num_positives[class_name] += 1 + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and self.ignore_non_predictions: + continue + + y_true[class_name].append(is_class) + y_score[class_name].append(m.score) + + class_scores = { + class_name: average_precision( + y_true[class_name], + y_score[class_name], + num_positives=num_positives[class_name], + ) + for class_name in class_names + } + + mean_score = float( + np.mean([v for v in class_scores.values() if v != np.nan]) + ) + + return { + f"mean_{self.label}": mean_score, + **{ + f"{self.label}/{class_name}": score + for class_name, score in class_scores.items() + if self.include_class(class_name) + }, + } + + @classification_metrics.register(ClassificationAveragePrecisionConfig) + @staticmethod + def from_config(config: ClassificationAveragePrecisionConfig): + return ClassificationAveragePrecision( + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + label=config.label, + include=config.include, + exclude=config.exclude, + ) + + +class ClassificationROCAUCConfig(BaseClassificationConfig): + name: Literal["roc_auc"] = "roc_auc" + label: str = "roc_auc" + ignore_non_predictions: bool = True + ignore_generic: bool = True + + +class ClassificationROCAUC(BaseClassificationMetric): + def __init__( + self, + ignore_non_predictions: bool = True, + ignore_generic: bool = True, + label: str = "roc_auc", + include: Optional[List[str]] = None, + exclude: Optional[List[str]] = None, + ): + self.ignore_non_predictions = ignore_non_predictions + self.ignore_generic = ignore_generic + self.label = label + self.include = include + self.exclude = exclude + + def __call__( + self, clip_evaluations: Sequence[ClipEval] + ) -> Dict[str, float]: + y_true = defaultdict(list) + y_score = defaultdict(list) + + class_names = set() + + for clip_eval in clip_evaluations: + for class_name, matches in clip_eval.matches.items(): + class_names.add(class_name) + + for m in matches: + # Exclude matches with ground truth sounds where the class + # is unknown + if m.is_generic and self.ignore_generic: + continue + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and self.ignore_non_predictions: + continue + + y_true[class_name].append(m.true_class == class_name) + y_score[class_name].append(m.score) + + class_scores = { + class_name: float( + metrics.roc_auc_score( + y_true[class_name], + y_score[class_name], + ) + ) + for class_name in class_names + } + + mean_score = float( + np.mean([v for v in class_scores.values() if v != np.nan]) + ) + + return { + f"mean_{self.label}": mean_score, + **{ + f"{self.label}/{class_name}": score + for class_name, score in class_scores.items() + if self.include_class(class_name) + }, + } + + @classification_metrics.register(ClassificationROCAUCConfig) + @staticmethod + def from_config(config: ClassificationROCAUCConfig): + return ClassificationROCAUC( + ignore_non_predictions=config.ignore_non_predictions, + ignore_generic=config.ignore_generic, + label=config.label, + ) + + +ClassificationMetricConfig = Annotated[ + Union[ + ClassificationAveragePrecisionConfig, + ClassificationROCAUCConfig, + ], + Field(discriminator="name"), +] + + +def build_classification_metrics(config: ClassificationMetricConfig): + return classification_metrics.build(config) diff --git a/src/batdetect2/evaluate/metrics/clip_classification.py b/src/batdetect2/evaluate/metrics/clip_classification.py new file mode 100644 index 0000000..5554b9c --- /dev/null +++ b/src/batdetect2/evaluate/metrics/clip_classification.py @@ -0,0 +1,135 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Annotated, Callable, Dict, Literal, Sequence, Set, Union + +import numpy as np +from pydantic import Field +from sklearn import metrics + +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry +from batdetect2.evaluate.metrics.common import average_precision + + +@dataclass +class ClipEval: + true_classes: Set[str] + class_scores: Dict[str, float] + + +ClipClassificationMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] + +clip_classification_metrics: Registry[ClipClassificationMetric, []] = Registry( + "clip_classification_metric" +) + + +class ClipClassificationAveragePrecisionConfig(BaseConfig): + name: Literal["average_precision"] = "average_precision" + label: str = "average_precision" + + +class ClipClassificationAveragePrecision: + def __init__(self, label: str = "average_precision"): + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + y_true = defaultdict(list) + y_score = defaultdict(list) + + for clip_eval in clip_evaluations: + for class_name, score in clip_eval.class_scores.items(): + y_true[class_name].append(class_name in clip_eval.true_classes) + y_score[class_name].append(score) + + class_scores = { + class_name: float( + average_precision( + y_true=y_true[class_name], + y_score=y_score[class_name], + ) + ) + for class_name in y_true + } + + mean = np.mean([v for v in class_scores.values() if not np.isnan(v)]) + + return { + f"mean_{self.label}": float(mean), + **{ + f"{self.label}/{class_name}": score + for class_name, score in class_scores.items() + if not np.isnan(score) + }, + } + + @clip_classification_metrics.register( + ClipClassificationAveragePrecisionConfig + ) + @staticmethod + def from_config(config: ClipClassificationAveragePrecisionConfig): + return ClipClassificationAveragePrecision(label=config.label) + + +class ClipClassificationROCAUCConfig(BaseConfig): + name: Literal["roc_auc"] = "roc_auc" + label: str = "roc_auc" + + +class ClipClassificationROCAUC: + def __init__(self, label: str = "roc_auc"): + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + y_true = defaultdict(list) + y_score = defaultdict(list) + + for clip_eval in clip_evaluations: + for class_name, score in clip_eval.class_scores.items(): + y_true[class_name].append(class_name in clip_eval.true_classes) + y_score[class_name].append(score) + + class_scores = { + class_name: float( + metrics.roc_auc_score( + y_true=y_true[class_name], + y_score=y_score[class_name], + ) + ) + for class_name in y_true + } + + mean = np.mean([v for v in class_scores.values() if not np.isnan(v)]) + + return { + f"mean_{self.label}": float(mean), + **{ + f"{self.label}/{class_name}": score + for class_name, score in class_scores.items() + if not np.isnan(score) + }, + } + + @clip_classification_metrics.register(ClipClassificationROCAUCConfig) + @staticmethod + def from_config(config: ClipClassificationROCAUCConfig): + return ClipClassificationROCAUC(label=config.label) + + +ClipClassificationMetricConfig = Annotated[ + Union[ + ClipClassificationAveragePrecisionConfig, + ClipClassificationROCAUCConfig, + ], + Field(discriminator="name"), +] + + +def build_clip_metric(config: ClipClassificationMetricConfig): + return clip_classification_metrics.build(config) diff --git a/src/batdetect2/evaluate/metrics/clip_detection.py b/src/batdetect2/evaluate/metrics/clip_detection.py new file mode 100644 index 0000000..df4b99d --- /dev/null +++ b/src/batdetect2/evaluate/metrics/clip_detection.py @@ -0,0 +1,173 @@ +from dataclasses import dataclass +from typing import Annotated, Callable, Dict, Literal, Sequence, Union + +import numpy as np +from pydantic import Field +from sklearn import metrics + +from batdetect2.core.configs import BaseConfig +from batdetect2.core.registries import Registry +from batdetect2.evaluate.metrics.common import average_precision + + +@dataclass +class ClipEval: + gt_det: bool + score: float + + +ClipDetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] + +clip_detection_metrics: Registry[ClipDetectionMetric, []] = Registry( + "clip_detection_metric" +) + + +class ClipDetectionAveragePrecisionConfig(BaseConfig): + name: Literal["average_precision"] = "average_precision" + label: str = "average_precision" + + +class ClipDetectionAveragePrecision: + def __init__(self, label: str = "average_precision"): + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + y_true.append(clip_eval.gt_det) + y_score.append(clip_eval.score) + + score = average_precision(y_true=y_true, y_score=y_score) + return {self.label: score} + + @clip_detection_metrics.register(ClipDetectionAveragePrecisionConfig) + @staticmethod + def from_config(config: ClipDetectionAveragePrecisionConfig): + return ClipDetectionAveragePrecision(label=config.label) + + +class ClipDetectionROCAUCConfig(BaseConfig): + name: Literal["roc_auc"] = "roc_auc" + label: str = "roc_auc" + + +class ClipDetectionROCAUC: + def __init__(self, label: str = "roc_auc"): + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + y_true = [] + y_score = [] + + for clip_eval in clip_evaluations: + y_true.append(clip_eval.gt_det) + y_score.append(clip_eval.score) + + score = float(metrics.roc_auc_score(y_true=y_true, y_score=y_score)) + return {self.label: score} + + @clip_detection_metrics.register(ClipDetectionROCAUCConfig) + @staticmethod + def from_config(config: ClipDetectionROCAUCConfig): + return ClipDetectionROCAUC(label=config.label) + + +class ClipDetectionRecallConfig(BaseConfig): + name: Literal["recall"] = "recall" + threshold: float = 0.5 + label: str = "recall" + + +class ClipDetectionRecall: + def __init__(self, threshold: float, label: str = "recall"): + self.threshold = threshold + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + num_positives = 0 + true_positives = 0 + + for clip_eval in clip_evaluations: + if clip_eval.gt_det: + num_positives += 1 + + if clip_eval.score >= self.threshold and clip_eval.gt_det: + true_positives += 1 + + if num_positives == 0: + return {self.label: np.nan} + + score = true_positives / num_positives + return {self.label: score} + + @clip_detection_metrics.register(ClipDetectionRecallConfig) + @staticmethod + def from_config(config: ClipDetectionRecallConfig): + return ClipDetectionRecall( + threshold=config.threshold, label=config.label + ) + + +class ClipDetectionPrecisionConfig(BaseConfig): + name: Literal["precision"] = "precision" + threshold: float = 0.5 + label: str = "precision" + + +class ClipDetectionPrecision: + def __init__(self, threshold: float, label: str = "precision"): + self.threshold = threshold + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + num_detections = 0 + true_positives = 0 + for clip_eval in clip_evaluations: + if clip_eval.score >= self.threshold: + num_detections += 1 + + if clip_eval.score >= self.threshold and clip_eval.gt_det: + true_positives += 1 + + if num_detections == 0: + return {self.label: np.nan} + + score = true_positives / num_detections + return {self.label: score} + + @clip_detection_metrics.register(ClipDetectionPrecisionConfig) + @staticmethod + def from_config(config: ClipDetectionPrecisionConfig): + return ClipDetectionPrecision( + threshold=config.threshold, label=config.label + ) + + +ClipDetectionMetricConfig = Annotated[ + Union[ + ClipDetectionAveragePrecisionConfig, + ClipDetectionROCAUCConfig, + ClipDetectionRecallConfig, + ClipDetectionPrecisionConfig, + ], + Field(discriminator="name"), +] + + +def build_clip_metric(config: ClipDetectionMetricConfig): + return clip_detection_metrics.build(config) diff --git a/src/batdetect2/evaluate/metrics/common.py b/src/batdetect2/evaluate/metrics/common.py index 4375477..44ce045 100644 --- a/src/batdetect2/evaluate/metrics/common.py +++ b/src/batdetect2/evaluate/metrics/common.py @@ -1,24 +1,24 @@ -from typing import Optional +from typing import Optional, Tuple import numpy as np +__all__ = [ + "compute_precision_recall", + "average_precision", +] -def average_precision( + +def compute_precision_recall( y_true, y_score, num_positives: Optional[int] = None, -) -> float: +) -> Tuple[np.ndarray, np.ndarray]: y_true = np.array(y_true) y_score = np.array(y_score) if num_positives is None: num_positives = y_true.sum() - # Remove non-detections - valid_inds = y_score > 0 - y_true = y_true[valid_inds] - y_score = y_score[valid_inds] - # Sort by score sort_ind = np.argsort(y_score)[::-1] y_true_sorted = y_true[sort_ind] @@ -34,6 +34,19 @@ def average_precision( precision[np.isnan(precision)] = 0 recall[np.isnan(recall)] = 0 + return precision, recall + + +def average_precision( + y_true, + y_score, + num_positives: Optional[int] = None, +) -> float: + precision, recall = compute_precision_recall( + y_true, + y_score, + num_positives=num_positives, + ) # pascal 12 way mprec = np.hstack((0, precision, 0)) diff --git a/src/batdetect2/evaluate/metrics/detection.py b/src/batdetect2/evaluate/metrics/detection.py new file mode 100644 index 0000000..f687392 --- /dev/null +++ b/src/batdetect2/evaluate/metrics/detection.py @@ -0,0 +1,226 @@ +from dataclasses import dataclass +from typing import ( + Annotated, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) + +import numpy as np +from pydantic import Field +from sklearn import metrics +from soundevent import data + +from batdetect2.core import BaseConfig, Registry +from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.typing import RawPrediction + +__all__ = [ + "DetectionMetricConfig", + "DetectionMetric", + "build_detection_metric", +] + + +@dataclass +class MatchEval: + gt: Optional[data.SoundEventAnnotation] + pred: Optional[RawPrediction] + + is_prediction: bool + is_ground_truth: bool + score: float + + +@dataclass +class ClipEval: + clip: data.Clip + matches: List[MatchEval] + + +DetectionMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] + + +detection_metrics: Registry[DetectionMetric, []] = Registry("detection_metric") + + +class DetectionAveragePrecisionConfig(BaseConfig): + name: Literal["average_precision"] = "average_precision" + label: str = "average_precision" + ignore_non_predictions: bool = True + + +class DetectionAveragePrecision: + def __init__(self, label: str, ignore_non_predictions: bool = True): + self.ignore_non_predictions = ignore_non_predictions + self.label = label + + def __call__( + self, + clip_evals: Sequence[ClipEval], + ) -> Dict[str, float]: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evals: + for m in clip_eval.matches: + num_positives += int(m.is_ground_truth) + + # Ignore matches that don't correspond to a prediction + if not m.is_prediction and self.ignore_non_predictions: + continue + + y_true.append(m.is_ground_truth) + y_score.append(m.score) + + ap = average_precision(y_true, y_score, num_positives=num_positives) + return {self.label: ap} + + @detection_metrics.register(DetectionAveragePrecisionConfig) + @staticmethod + def from_config(config: DetectionAveragePrecisionConfig): + return DetectionAveragePrecision( + label=config.label, + ignore_non_predictions=config.ignore_non_predictions, + ) + + +class DetectionROCAUCConfig(BaseConfig): + name: Literal["roc_auc"] = "roc_auc" + label: str = "roc_auc" + ignore_non_predictions: bool = True + + +class DetectionROCAUC: + def __init__( + self, + label: str = "roc_auc", + ignore_non_predictions: bool = True, + ): + self.label = label + self.ignore_non_predictions = ignore_non_predictions + + def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]: + y_true: List[bool] = [] + y_score: List[float] = [] + + for clip_eval in clip_evals: + for m in clip_eval.matches: + if not m.is_prediction and self.ignore_non_predictions: + # Ignore matches that don't correspond to a prediction + continue + + y_true.append(m.is_ground_truth) + y_score.append(m.score) + + score = float(metrics.roc_auc_score(y_true, y_score)) + return {self.label: score} + + @detection_metrics.register(DetectionROCAUCConfig) + @staticmethod + def from_config(config: DetectionROCAUCConfig): + return DetectionROCAUC( + label=config.label, + ignore_non_predictions=config.ignore_non_predictions, + ) + + +class DetectionRecallConfig(BaseConfig): + name: Literal["recall"] = "recall" + label: str = "recall" + threshold: float = 0.5 + + +class DetectionRecall: + def __init__(self, threshold: float, label: str = "recall"): + self.label = label + self.threshold = threshold + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + num_positives = 0 + true_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if m.is_ground_truth: + num_positives += 1 + + if m.score >= self.threshold and m.is_ground_truth: + true_positives += 1 + + if num_positives == 0: + return {self.label: np.nan} + + score = true_positives / num_positives + return {self.label: score} + + @detection_metrics.register(DetectionRecallConfig) + @staticmethod + def from_config(config: DetectionRecallConfig): + return DetectionRecall(threshold=config.threshold, label=config.label) + + +class DetectionPrecisionConfig(BaseConfig): + name: Literal["precision"] = "precision" + label: str = "precision" + threshold: float = 0.5 + + +class DetectionPrecision: + def __init__(self, threshold: float, label: str = "precision"): + self.threshold = threshold + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + num_detections = 0 + true_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + is_detection = m.score >= self.threshold + + if is_detection: + num_detections += 1 + + if is_detection and m.is_ground_truth: + true_positives += 1 + + if num_detections == 0: + return {self.label: np.nan} + + score = true_positives / num_detections + return {self.label: score} + + @detection_metrics.register(DetectionPrecisionConfig) + @staticmethod + def from_config(config: DetectionPrecisionConfig): + return DetectionPrecision( + threshold=config.threshold, + label=config.label, + ) + + +DetectionMetricConfig = Annotated[ + Union[ + DetectionAveragePrecisionConfig, + DetectionROCAUCConfig, + DetectionRecallConfig, + DetectionPrecisionConfig, + ], + Field(discriminator="name"), +] + + +def build_detection_metric(config: DetectionMetricConfig): + return detection_metrics.build(config) diff --git a/src/batdetect2/evaluate/metrics/matches.py b/src/batdetect2/evaluate/metrics/matches.py deleted file mode 100644 index 0c3ec12..0000000 --- a/src/batdetect2/evaluate/metrics/matches.py +++ /dev/null @@ -1,235 +0,0 @@ -from typing import Annotated, Callable, Literal, Sequence, Union - -import numpy as np -from pydantic import Field -from sklearn import metrics - -from batdetect2.core import BaseConfig, Registry -from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing import ( - ClipMatches, -) - -__all__ = [ - "MatchMetricConfig", - "MatchesMetric", - "build_match_metric", -] - -MatchesMetric = Callable[[Sequence[ClipMatches]], float] - - -metrics_registry: Registry[MatchesMetric, []] = Registry("match_metric") - - -class DetectionAveragePrecisionConfig(BaseConfig): - name: Literal["detection_average_precision"] = ( - "detection_average_precision" - ) - ignore_non_predictions: bool = True - - -class DetectionAveragePrecision: - def __init__(self, ignore_non_predictions: bool = True): - self.ignore_non_predictions = ignore_non_predictions - - def __call__( - self, - clip_evaluations: Sequence[ClipMatches], - ) -> float: - y_true = [] - y_score = [] - num_positives = 0 - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - num_positives += int(m.gt_det) - - # Ignore matches that don't correspond to a prediction - if not m.is_prediction and self.ignore_non_predictions: - continue - - y_true.append(m.gt_det) - y_score.append(m.pred_score) - - return average_precision(y_true, y_score, num_positives=num_positives) - - @metrics_registry.register(DetectionAveragePrecisionConfig) - @staticmethod - def from_config(config: DetectionAveragePrecisionConfig): - return DetectionAveragePrecision( - ignore_non_predictions=config.ignore_non_predictions - ) - - -class TopClassAveragePrecisionConfig(BaseConfig): - name: Literal["top_class_average_precision"] = ( - "top_class_average_precision" - ) - ignore_non_predictions: bool = True - ignore_generic: bool = True - - -class TopClassAveragePrecision: - def __init__( - self, - ignore_non_predictions: bool = True, - ignore_generic: bool = True, - ): - self.ignore_non_predictions = ignore_non_predictions - self.ignore_generic = ignore_generic - - def __call__( - self, - clip_evaluations: Sequence[ClipMatches], - ) -> float: - y_true = [] - y_score = [] - num_positives = 0 - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - if m.is_generic and self.ignore_generic: - # Ignore ground truth sounds with unknown class - continue - - num_positives += int(m.gt_det) - - if not m.is_prediction and self.ignore_non_predictions: - # Ignore matches that don't correspond to a prediction - continue - - y_true.append(m.gt_det & (m.top_class == m.gt_class)) - y_score.append(m.top_class_score) - - return average_precision(y_true, y_score, num_positives=num_positives) - - @metrics_registry.register(TopClassAveragePrecisionConfig) - @staticmethod - def from_config(config: TopClassAveragePrecisionConfig): - return TopClassAveragePrecision( - ignore_non_predictions=config.ignore_non_predictions - ) - - -class DetectionROCAUCConfig(BaseConfig): - name: Literal["detection_roc_auc"] = "detection_roc_auc" - ignore_non_predictions: bool = True - - -class DetectionROCAUC: - def __init__( - self, - ignore_non_predictions: bool = True, - ): - self.ignore_non_predictions = ignore_non_predictions - - def __call__(self, clip_evaluations: Sequence[ClipMatches]) -> float: - y_true = [] - y_score = [] - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - if not m.is_prediction and self.ignore_non_predictions: - # Ignore matches that don't correspond to a prediction - continue - - y_true.append(m.gt_det) - y_score.append(m.pred_score) - - return float(metrics.roc_auc_score(y_true, y_score)) - - @metrics_registry.register(DetectionROCAUCConfig) - @staticmethod - def from_config(config: DetectionROCAUCConfig): - return DetectionROCAUC( - ignore_non_predictions=config.ignore_non_predictions - ) - - -class DetectionRecallConfig(BaseConfig): - name: Literal["detection_recall"] = "detection_recall" - threshold: float = 0.5 - - -class DetectionRecall: - def __init__(self, threshold: float): - self.threshold = threshold - - def __call__( - self, - clip_evaluations: Sequence[ClipMatches], - ) -> float: - num_positives = 0 - true_positives = 0 - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - if m.gt_det: - num_positives += 1 - - if m.pred_score >= self.threshold and m.gt_det: - true_positives += 1 - - if num_positives == 0: - return 1 - - return true_positives / num_positives - - @metrics_registry.register(DetectionRecallConfig) - @staticmethod - def from_config(config: DetectionRecallConfig): - return DetectionRecall(threshold=config.threshold) - - -class DetectionPrecisionConfig(BaseConfig): - name: Literal["detection_precision"] = "detection_precision" - threshold: float = 0.5 - - -class DetectionPrecision: - def __init__(self, threshold: float): - self.threshold = threshold - - def __call__( - self, - clip_evaluations: Sequence[ClipMatches], - ) -> float: - num_detections = 0 - true_positives = 0 - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - is_detection = m.pred_score >= self.threshold - - if is_detection: - num_detections += 1 - - if is_detection and m.gt_det: - true_positives += 1 - - if num_detections == 0: - return np.nan - - return true_positives / num_detections - - @metrics_registry.register(DetectionPrecisionConfig) - @staticmethod - def from_config(config: DetectionPrecisionConfig): - return DetectionPrecision(threshold=config.threshold) - - -MatchMetricConfig = Annotated[ - Union[ - DetectionAveragePrecisionConfig, - DetectionROCAUCConfig, - DetectionRecallConfig, - DetectionPrecisionConfig, - TopClassAveragePrecisionConfig, - ], - Field(discriminator="name"), -] - - -def build_match_metric(config: MatchMetricConfig): - return metrics_registry.build(config) diff --git a/src/batdetect2/evaluate/metrics/per_class_matches.py b/src/batdetect2/evaluate/metrics/per_class_matches.py deleted file mode 100644 index 51e0a8a..0000000 --- a/src/batdetect2/evaluate/metrics/per_class_matches.py +++ /dev/null @@ -1,136 +0,0 @@ -from typing import Annotated, Callable, Literal, Sequence, Union - -from pydantic import Field -from sklearn import metrics - -from batdetect2.core import BaseConfig, Registry -from batdetect2.evaluate.metrics.common import average_precision -from batdetect2.typing import ( - ClipMatches, -) - -__all__ = [] - -PerClassMatchMetric = Callable[[Sequence[ClipMatches], str], float] - - -metrics_registry: Registry[PerClassMatchMetric, []] = Registry( - "match_metric" -) - - -class ClassificationAveragePrecisionConfig(BaseConfig): - name: Literal["classification_average_precision"] = ( - "classification_average_precision" - ) - ignore_non_predictions: bool = True - ignore_generic: bool = True - - -class ClassificationAveragePrecision: - def __init__( - self, - ignore_non_predictions: bool = True, - ignore_generic: bool = True, - ): - self.ignore_non_predictions = ignore_non_predictions - self.ignore_generic = ignore_generic - - def __call__( - self, - clip_evaluations: Sequence[ClipMatches], - class_name: str, - ) -> float: - y_true = [] - y_score = [] - num_positives = 0 - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - is_class = m.gt_class == class_name - - if is_class: - num_positives += 1 - - # Ignore matches that don't correspond to a prediction - if not m.is_prediction and self.ignore_non_predictions: - continue - - # Exclude matches with ground truth sounds where the class is - # unknown - if m.is_generic and self.ignore_generic: - continue - - y_true.append(is_class) - y_score.append(m.pred_class_scores.get(class_name, 0)) - - return average_precision(y_true, y_score, num_positives=num_positives) - - @metrics_registry.register(ClassificationAveragePrecisionConfig) - @staticmethod - def from_config(config: ClassificationAveragePrecisionConfig): - return ClassificationAveragePrecision( - ignore_non_predictions=config.ignore_non_predictions, - ignore_generic=config.ignore_generic, - ) - - -class ClassificationROCAUCConfig(BaseConfig): - name: Literal["classification_roc_auc"] = "classification_roc_auc" - ignore_non_predictions: bool = True - ignore_generic: bool = True - - -class ClassificationROCAUC: - def __init__( - self, - ignore_non_predictions: bool = True, - ignore_generic: bool = True, - ): - self.ignore_non_predictions = ignore_non_predictions - self.ignore_generic = ignore_generic - - def __call__( - self, - clip_evaluations: Sequence[ClipMatches], - class_name: str, - ) -> float: - y_true = [] - y_score = [] - - for clip_eval in clip_evaluations: - for m in clip_eval.matches: - # Exclude matches with ground truth sounds where the class is - # unknown - if m.is_generic and self.ignore_generic: - continue - - # Ignore matches that don't correspond to a prediction - if not m.is_prediction and self.ignore_non_predictions: - continue - - y_true.append(m.gt_class == class_name) - y_score.append(m.pred_class_scores.get(class_name, 0)) - - return float(metrics.roc_auc_score(y_true, y_score)) - - @metrics_registry.register(ClassificationROCAUCConfig) - @staticmethod - def from_config(config: ClassificationROCAUCConfig): - return ClassificationROCAUC( - ignore_non_predictions=config.ignore_non_predictions, - ignore_generic=config.ignore_generic, - ) - - -PerClassMatchMetricConfig = Annotated[ - Union[ - ClassificationAveragePrecisionConfig, - ClassificationROCAUCConfig, - ], - Field(discriminator="name"), -] - - -def build_per_class_matches_metric(config: PerClassMatchMetricConfig): - return metrics_registry.build(config) diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py new file mode 100644 index 0000000..ee837c8 --- /dev/null +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -0,0 +1,313 @@ +from dataclasses import dataclass +from typing import ( + Annotated, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Union, +) + +import numpy as np +from pydantic import Field +from sklearn import metrics, preprocessing +from soundevent import data + +from batdetect2.core import BaseConfig, Registry +from batdetect2.evaluate.metrics.common import average_precision +from batdetect2.typing import RawPrediction + +__all__ = [ + "TopClassMetricConfig", + "TopClassMetric", + "build_top_class_metric", +] + + +@dataclass +class MatchEval: + gt: Optional[data.SoundEventAnnotation] + pred: Optional[RawPrediction] + + is_ground_truth: bool + is_generic: bool + is_prediction: bool + pred_class: Optional[str] + true_class: Optional[str] + score: float + + +@dataclass +class ClipEval: + clip: data.Clip + matches: List[MatchEval] + + +TopClassMetric = Callable[[Sequence[ClipEval]], Dict[str, float]] + + +top_class_metrics: Registry[TopClassMetric, []] = Registry("top_class_metric") + + +class TopClassAveragePrecisionConfig(BaseConfig): + name: Literal["average_precision"] = "average_precision" + label: str = "average_precision" + ignore_generic: bool = True + ignore_non_predictions: bool = True + + +class TopClassAveragePrecision: + def __init__( + self, + ignore_generic: bool = True, + ignore_non_predictions: bool = True, + label: str = "average_precision", + ): + self.ignore_generic = ignore_generic + self.ignore_non_predictions = ignore_non_predictions + self.label = label + + def __call__( + self, + clip_evals: Sequence[ClipEval], + ) -> Dict[str, float]: + y_true = [] + y_score = [] + num_positives = 0 + + for clip_eval in clip_evals: + for m in clip_eval.matches: + if m.is_generic and self.ignore_generic: + # Ignore gt sounds with unknown class + continue + + num_positives += int(m.is_ground_truth) + + if not m.is_prediction and self.ignore_non_predictions: + # Ignore non predictions + continue + + y_true.append(m.pred_class == m.true_class) + y_score.append(m.score) + + score = average_precision(y_true, y_score, num_positives=num_positives) + return {self.label: score} + + @top_class_metrics.register(TopClassAveragePrecisionConfig) + @staticmethod + def from_config(config: TopClassAveragePrecisionConfig): + return TopClassAveragePrecision( + ignore_generic=config.ignore_generic, + label=config.label, + ) + + +class TopClassROCAUCConfig(BaseConfig): + name: Literal["roc_auc"] = "roc_auc" + ignore_generic: bool = True + ignore_non_predictions: bool = True + label: str = "roc_auc" + + +class TopClassROCAUC: + def __init__( + self, + ignore_generic: bool = True, + ignore_non_predictions: bool = True, + label: str = "roc_auc", + ): + self.ignore_generic = ignore_generic + self.ignore_non_predictions = ignore_non_predictions + self.label = label + + def __call__(self, clip_evals: Sequence[ClipEval]) -> Dict[str, float]: + y_true: List[bool] = [] + y_score: List[float] = [] + + for clip_eval in clip_evals: + for m in clip_eval.matches: + if m.is_generic and self.ignore_generic: + # Ignore gt sounds with unknown class + continue + + if not m.is_prediction and self.ignore_non_predictions: + # Ignore non predictions + continue + + y_true.append(m.pred_class == m.true_class) + y_score.append(m.score) + + score = float(metrics.roc_auc_score(y_true, y_score)) + return {self.label: score} + + @top_class_metrics.register(TopClassROCAUCConfig) + @staticmethod + def from_config(config: TopClassROCAUCConfig): + return TopClassROCAUC( + ignore_generic=config.ignore_generic, + label=config.label, + ) + + +class TopClassRecallConfig(BaseConfig): + name: Literal["recall"] = "recall" + threshold: float = 0.5 + label: str = "recall" + + +class TopClassRecall: + def __init__(self, threshold: float, label: str = "recall"): + self.threshold = threshold + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + num_positives = 0 + true_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if m.is_ground_truth: + num_positives += 1 + + if m.score >= self.threshold and m.pred_class == m.true_class: + true_positives += 1 + + if num_positives == 0: + return {self.label: np.nan} + + score = true_positives / num_positives + return {self.label: score} + + @top_class_metrics.register(TopClassRecallConfig) + @staticmethod + def from_config(config: TopClassRecallConfig): + return TopClassRecall( + threshold=config.threshold, + label=config.label, + ) + + +class TopClassPrecisionConfig(BaseConfig): + name: Literal["precision"] = "precision" + threshold: float = 0.5 + label: str = "precision" + + +class TopClassPrecision: + def __init__(self, threshold: float, label: str = "precision"): + self.threshold = threshold + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + num_detections = 0 + true_positives = 0 + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + is_detection = m.score >= self.threshold + + if is_detection: + num_detections += 1 + + if is_detection and m.pred_class == m.true_class: + true_positives += 1 + + if num_detections == 0: + return {self.label: np.nan} + + score = true_positives / num_detections + return {self.label: score} + + @top_class_metrics.register(TopClassPrecisionConfig) + @staticmethod + def from_config(config: TopClassPrecisionConfig): + return TopClassPrecision( + threshold=config.threshold, + label=config.label, + ) + + +class BalancedAccuracyConfig(BaseConfig): + name: Literal["balanced_accuracy"] = "balanced_accuracy" + label: str = "balanced_accuracy" + exclude_noise: bool = False + noise_class: str = "noise" + + +class BalancedAccuracy: + def __init__( + self, + exclude_noise: bool = True, + noise_class: str = "noise", + label: str = "balanced_accuracy", + ): + self.exclude_noise = exclude_noise + self.noise_class = noise_class + self.label = label + + def __call__( + self, + clip_evaluations: Sequence[ClipEval], + ) -> Dict[str, float]: + y_true: List[str] = [] + y_pred: List[str] = [] + + for clip_eval in clip_evaluations: + for m in clip_eval.matches: + if m.is_generic: + # Ignore matches that correspond to a sound event + # with unknown class + continue + + if not m.is_ground_truth and self.exclude_noise: + # Ignore predictions that were not matched to a + # ground truth + continue + + if m.pred_class is None and self.exclude_noise: + # Ignore non-predictions + continue + + y_true.append(m.true_class or self.noise_class) + y_pred.append(m.pred_class or self.noise_class) + + encoder = preprocessing.LabelEncoder() + encoder.fit(list(set(y_true) | set(y_pred))) + + y_true = encoder.transform(y_true) + y_pred = encoder.transform(y_pred) + score = metrics.balanced_accuracy_score(y_true, y_pred) + return {self.label: score} + + @top_class_metrics.register(BalancedAccuracyConfig) + @staticmethod + def from_config(config: BalancedAccuracyConfig): + return BalancedAccuracy( + exclude_noise=config.exclude_noise, + noise_class=config.noise_class, + label=config.label, + ) + + +TopClassMetricConfig = Annotated[ + Union[ + TopClassAveragePrecisionConfig, + TopClassROCAUCConfig, + TopClassRecallConfig, + TopClassPrecisionConfig, + BalancedAccuracyConfig, + ], + Field(discriminator="name"), +] + + +def build_top_class_metric(config: TopClassMetricConfig): + return top_class_metrics.build(config) diff --git a/src/batdetect2/evaluate/tasks/__init__.py b/src/batdetect2/evaluate/tasks/__init__.py new file mode 100644 index 0000000..4b62c16 --- /dev/null +++ b/src/batdetect2/evaluate/tasks/__init__.py @@ -0,0 +1,39 @@ +from typing import Annotated, Optional, Union + +from pydantic import Field + +from batdetect2.evaluate.tasks.base import tasks_registry +from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig +from batdetect2.evaluate.tasks.clip_classification import ( + ClipClassificationTaskConfig, +) +from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig +from batdetect2.evaluate.tasks.detection import DetectionTaskConfig +from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig +from batdetect2.targets import build_targets +from batdetect2.typing import EvaluatorProtocol, TargetProtocol + +__all__ = [ + "TaskConfig", + "build_task", +] + + +TaskConfig = Annotated[ + Union[ + ClassificationTaskConfig, + DetectionTaskConfig, + ClipDetectionTaskConfig, + ClipClassificationTaskConfig, + TopClassDetectionTaskConfig, + ], + Field(discriminator="name"), +] + + +def build_task( + config: TaskConfig, + targets: Optional[TargetProtocol] = None, +) -> EvaluatorProtocol: + targets = targets or build_targets() + return tasks_registry.build(config, targets) diff --git a/src/batdetect2/evaluate/evaluator/base.py b/src/batdetect2/evaluate/tasks/base.py similarity index 60% rename from src/batdetect2/evaluate/evaluator/base.py rename to src/batdetect2/evaluate/tasks/base.py index 8248ee5..29b61ae 100644 --- a/src/batdetect2/evaluate/evaluator/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -1,3 +1,5 @@ +from typing import Callable, Dict, Generic, List, Sequence, TypeVar + from pydantic import Field from soundevent import data from soundevent.geometry import compute_bounds @@ -14,14 +16,19 @@ from batdetect2.typing.postprocess import RawPrediction from batdetect2.typing.targets import TargetProtocol __all__ = [ - "BaseEvaluatorConfig", - "BaseEvaluator", + "BaseTaskConfig", + "BaseTask", ] -evaluators: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry("metric") +tasks_registry: Registry[EvaluatorProtocol, [TargetProtocol]] = Registry( + "tasks" +) -class BaseEvaluatorConfig(BaseConfig): +T_Output = TypeVar("T_Output") + + +class BaseTaskConfig(BaseConfig): prefix: str ignore_start_end: float = 0.01 matching_strategy: MatchConfig = Field( @@ -29,11 +36,13 @@ class BaseEvaluatorConfig(BaseConfig): ) -class BaseEvaluator(EvaluatorProtocol): +class BaseTask(EvaluatorProtocol, Generic[T_Output]): targets: TargetProtocol matcher: MatcherProtocol + metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]] + ignore_start_end: float prefix: str @@ -42,15 +51,44 @@ class BaseEvaluator(EvaluatorProtocol): self, matcher: MatcherProtocol, targets: TargetProtocol, + metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], prefix: str, ignore_start_end: float = 0.01, ): self.matcher = matcher + self.metrics = metrics self.targets = targets self.prefix = prefix self.ignore_start_end = ignore_start_end - def filter_sound_event_annotations( + def compute_metrics( + self, + eval_outputs: List[T_Output], + ) -> Dict[str, float]: + scores = [metric(eval_outputs) for metric in self.metrics] + return { + f"{self.prefix}/{name}": score + for metric_output in scores + for name, score in metric_output.items() + } + + def evaluate( + self, + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[Sequence[RawPrediction]], + ) -> List[T_Output]: + return [ + self.evaluate_clip(clip_annotation, preds) + for clip_annotation, preds in zip(clip_annotations, predictions) + ] + + def evaluate_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> T_Output: ... + + def include_sound_event_annotation( self, sound_event_annotation: data.SoundEventAnnotation, clip: data.Clip, @@ -68,7 +106,7 @@ class BaseEvaluator(EvaluatorProtocol): self.ignore_start_end, ) - def filter_predictions( + def include_prediction( self, prediction: RawPrediction, clip: data.Clip, @@ -82,14 +120,16 @@ class BaseEvaluator(EvaluatorProtocol): @classmethod def build( cls, - config: BaseEvaluatorConfig, + config: BaseTaskConfig, targets: TargetProtocol, + metrics: List[Callable[[Sequence[T_Output]], Dict[str, float]]], **kwargs, ): matcher = build_matcher(config.matching_strategy) return cls( matcher=matcher, targets=targets, + metrics=metrics, prefix=config.prefix, ignore_start_end=config.ignore_start_end, **kwargs, diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py new file mode 100644 index 0000000..3481bd1 --- /dev/null +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -0,0 +1,137 @@ +from typing import ( + List, + Literal, + Sequence, +) + +from pydantic import Field +from soundevent import data + +from batdetect2.evaluate.metrics.classification import ( + ClassificationAveragePrecisionConfig, + ClassificationMetricConfig, + ClipEval, + MatchEval, + build_classification_metrics, +) +from batdetect2.evaluate.tasks.base import ( + BaseTask, + BaseTaskConfig, + tasks_registry, +) +from batdetect2.typing import RawPrediction, TargetProtocol + + +class ClassificationTaskConfig(BaseTaskConfig): + name: Literal["sound_event_classification"] = "sound_event_classification" + prefix: str = "classification" + metrics: List[ClassificationMetricConfig] = Field( + default_factory=lambda: [ClassificationAveragePrecisionConfig()] + ) + include_generics: bool = True + + +class ClassificationTask(BaseTask[ClipEval]): + def __init__( + self, + *args, + include_generics: bool = True, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.include_generics = include_generics + + def evaluate_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipEval: + clip = clip_annotation.clip + + preds = [ + pred for pred in predictions if self.include_prediction(pred, clip) + ] + + all_gts = [ + sound_event + for sound_event in clip_annotation.sound_events + if self.include_sound_event_annotation(sound_event, clip) + ] + + per_class_matches = {} + + for class_name in self.targets.class_names: + class_idx = self.targets.class_names.index(class_name) + + # Only match to targets of the given class + gts = [ + sound_event + for sound_event in all_gts + if self.is_class(sound_event, class_name) + ] + scores = [float(pred.class_scores[class_idx]) for pred in preds] + + matches = [] + + for pred_idx, gt_idx, _ in self.matcher( + ground_truth=[se.sound_event.geometry for se in gts], # type: ignore + predictions=[pred.geometry for pred in preds], + scores=scores, + ): + gt = gts[gt_idx] if gt_idx is not None else None + pred = preds[pred_idx] if pred_idx is not None else None + + true_class = ( + self.targets.encode_class(gt) if gt is not None else None + ) + + score = ( + float(pred.class_scores[class_idx]) + if pred is not None + else 0 + ) + + matches.append( + MatchEval( + gt=gt, + pred=pred, + is_prediction=pred is not None, + is_ground_truth=gt is not None, + is_generic=gt is not None and true_class is None, + true_class=true_class, + score=score, + ) + ) + + per_class_matches[class_name] = matches + + return ClipEval(clip=clip, matches=per_class_matches) + + def is_class( + self, + sound_event: data.SoundEventAnnotation, + class_name: str, + ) -> bool: + sound_event_class = self.targets.encode_class(sound_event) + + if sound_event_class is None and self.include_generics: + # Sound events that are generic could be of the given + # class + return True + + return sound_event_class == class_name + + @tasks_registry.register(ClassificationTaskConfig) + @staticmethod + def from_config( + config: ClassificationTaskConfig, + targets: TargetProtocol, + ): + metrics = [ + build_classification_metrics(metric) for metric in config.metrics + ] + return ClassificationTask.build( + config=config, + targets=targets, + metrics=metrics, + ) diff --git a/src/batdetect2/evaluate/tasks/clip_classification.py b/src/batdetect2/evaluate/tasks/clip_classification.py new file mode 100644 index 0000000..91392ee --- /dev/null +++ b/src/batdetect2/evaluate/tasks/clip_classification.py @@ -0,0 +1,75 @@ +from collections import defaultdict +from typing import List, Literal, Sequence + +from pydantic import Field +from soundevent import data + +from batdetect2.evaluate.metrics.clip_classification import ( + ClipClassificationAveragePrecisionConfig, + ClipClassificationMetricConfig, + ClipEval, + build_clip_metric, +) +from batdetect2.evaluate.tasks.base import ( + BaseTask, + BaseTaskConfig, + tasks_registry, +) +from batdetect2.typing import RawPrediction, TargetProtocol + + +class ClipClassificationTaskConfig(BaseTaskConfig): + name: Literal["clip_classification"] = "clip_classification" + prefix: str = "clip_classification" + metrics: List[ClipClassificationMetricConfig] = Field( + default_factory=lambda: [ + ClipClassificationAveragePrecisionConfig(), + ] + ) + + +class ClipClassificationTask(BaseTask[ClipEval]): + def evaluate_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipEval: + clip = clip_annotation.clip + + gt_classes = set() + for sound_event in clip_annotation.sound_events: + if not self.include_sound_event_annotation(sound_event, clip): + continue + + class_name = self.targets.encode_class(sound_event) + + if class_name is None: + continue + + gt_classes.add(class_name) + + pred_scores = defaultdict(float) + for pred in predictions: + if not self.include_prediction(pred, clip): + continue + + for class_idx, class_name in enumerate(self.targets.class_names): + pred_scores[class_name] = max( + float(pred.class_scores[class_idx]), + pred_scores[class_name], + ) + + return ClipEval(true_classes=gt_classes, class_scores=pred_scores) + + @tasks_registry.register(ClipClassificationTaskConfig) + @staticmethod + def from_config( + config: ClipClassificationTaskConfig, + targets: TargetProtocol, + ): + metrics = [build_clip_metric(metric) for metric in config.metrics] + return ClipClassificationTask.build( + config=config, + metrics=metrics, + targets=targets, + ) diff --git a/src/batdetect2/evaluate/tasks/clip_detection.py b/src/batdetect2/evaluate/tasks/clip_detection.py new file mode 100644 index 0000000..15f2311 --- /dev/null +++ b/src/batdetect2/evaluate/tasks/clip_detection.py @@ -0,0 +1,66 @@ +from typing import List, Literal, Sequence + +from pydantic import Field +from soundevent import data + +from batdetect2.evaluate.metrics.clip_detection import ( + ClipDetectionAveragePrecisionConfig, + ClipDetectionMetricConfig, + ClipEval, + build_clip_metric, +) +from batdetect2.evaluate.tasks.base import ( + BaseTask, + BaseTaskConfig, + tasks_registry, +) +from batdetect2.typing import RawPrediction, TargetProtocol + + +class ClipDetectionTaskConfig(BaseTaskConfig): + name: Literal["clip_detection"] = "clip_detection" + prefix: str = "clip_detection" + metrics: List[ClipDetectionMetricConfig] = Field( + default_factory=lambda: [ + ClipDetectionAveragePrecisionConfig(), + ] + ) + + +class ClipDetectionTask(BaseTask[ClipEval]): + def evaluate_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipEval: + clip = clip_annotation.clip + + gt_det = any( + self.include_sound_event_annotation(sound_event, clip) + for sound_event in clip_annotation.sound_events + ) + + pred_score = 0 + for pred in predictions: + if not self.include_prediction(pred, clip): + continue + + pred_score = max(pred_score, pred.detection_score) + + return ClipEval( + gt_det=gt_det, + score=pred_score, + ) + + @tasks_registry.register(ClipDetectionTaskConfig) + @staticmethod + def from_config( + config: ClipDetectionTaskConfig, + targets: TargetProtocol, + ): + metrics = [build_clip_metric(metric) for metric in config.metrics] + return ClipDetectionTask.build( + config=config, + metrics=metrics, + targets=targets, + ) diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py new file mode 100644 index 0000000..1308d43 --- /dev/null +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -0,0 +1,79 @@ +from typing import List, Literal, Sequence + +from pydantic import Field +from soundevent import data + +from batdetect2.evaluate.metrics.detection import ( + ClipEval, + DetectionAveragePrecisionConfig, + DetectionMetricConfig, + MatchEval, + build_detection_metric, +) +from batdetect2.evaluate.tasks.base import ( + BaseTask, + BaseTaskConfig, + tasks_registry, +) +from batdetect2.typing import RawPrediction, TargetProtocol + + +class DetectionTaskConfig(BaseTaskConfig): + name: Literal["sound_event_detection"] = "sound_event_detection" + prefix: str = "detection" + metrics: List[DetectionMetricConfig] = Field( + default_factory=lambda: [DetectionAveragePrecisionConfig()] + ) + + +class DetectionTask(BaseTask[ClipEval]): + def evaluate_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipEval: + clip = clip_annotation.clip + + gts = [ + sound_event + for sound_event in clip_annotation.sound_events + if self.include_sound_event_annotation(sound_event, clip) + ] + preds = [ + pred for pred in predictions if self.include_prediction(pred, clip) + ] + scores = [pred.detection_score for pred in preds] + + matches = [] + for pred_idx, gt_idx, _ in self.matcher( + ground_truth=[se.sound_event.geometry for se in gts], # type: ignore + predictions=[pred.geometry for pred in preds], + scores=scores, + ): + gt = gts[gt_idx] if gt_idx is not None else None + pred = preds[pred_idx] if pred_idx is not None else None + + matches.append( + MatchEval( + gt=gt, + pred=pred, + is_prediction=pred is not None, + is_ground_truth=gt is not None, + score=pred.detection_score if pred is not None else 0, + ) + ) + + return ClipEval(clip=clip, matches=matches) + + @tasks_registry.register(DetectionTaskConfig) + @staticmethod + def from_config( + config: DetectionTaskConfig, + targets: TargetProtocol, + ): + metrics = [build_detection_metric(metric) for metric in config.metrics] + return DetectionTask.build( + config=config, + metrics=metrics, + targets=targets, + ) diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py new file mode 100644 index 0000000..c2a215c --- /dev/null +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -0,0 +1,101 @@ +from typing import List, Literal, Sequence + +from pydantic import Field +from soundevent import data + +from batdetect2.evaluate.metrics.top_class import ( + ClipEval, + MatchEval, + TopClassAveragePrecisionConfig, + TopClassMetricConfig, + build_top_class_metric, +) +from batdetect2.evaluate.tasks.base import ( + BaseTask, + BaseTaskConfig, + tasks_registry, +) +from batdetect2.typing import RawPrediction, TargetProtocol + + +class TopClassDetectionTaskConfig(BaseTaskConfig): + name: Literal["top_class_detection"] = "top_class_detection" + prefix: str = "top_class" + metrics: List[TopClassMetricConfig] = Field( + default_factory=lambda: [TopClassAveragePrecisionConfig()] + ) + + +class TopClassDetectionTask(BaseTask[ClipEval]): + def evaluate_clip( + self, + clip_annotation: data.ClipAnnotation, + predictions: Sequence[RawPrediction], + ) -> ClipEval: + clip = clip_annotation.clip + + gts = [ + sound_event + for sound_event in clip_annotation.sound_events + if self.include_sound_event_annotation(sound_event, clip) + ] + preds = [ + pred for pred in predictions if self.include_prediction(pred, clip) + ] + # Take the highest score for each prediction + scores = [pred.class_scores.max() for pred in preds] + + matches = [] + for pred_idx, gt_idx, _ in self.matcher( + ground_truth=[se.sound_event.geometry for se in gts], # type: ignore + predictions=[pred.geometry for pred in preds], + scores=scores, + ): + gt = gts[gt_idx] if gt_idx is not None else None + pred = preds[pred_idx] if pred_idx is not None else None + + true_class = ( + self.targets.encode_class(gt) if gt is not None else None + ) + + class_idx = ( + pred.class_scores.argmax() if pred is not None else None + ) + + score = ( + float(pred.class_scores[class_idx]) if pred is not None else 0 + ) + + pred_class = ( + self.targets.class_names[class_idx] + if class_idx is not None + else None + ) + + matches.append( + MatchEval( + gt=gt, + pred=pred, + is_ground_truth=gt is not None, + is_prediction=pred is not None, + true_class=true_class, + is_generic=gt is not None and true_class is None, + pred_class=pred_class, + score=score, + ) + ) + + return ClipEval(clip=clip, matches=matches) + + @tasks_registry.register(TopClassDetectionTaskConfig) + @staticmethod + def from_config( + config: TopClassDetectionTaskConfig, + targets: TargetProtocol, + ): + metrics = [build_top_class_metric(metric) for metric in config.metrics] + return TopClassDetectionTask.build( + config=config, + metrics=metrics, + targets=targets, + ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 0dfc36c..ff030fe 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -8,7 +8,7 @@ from loguru import logger from soundevent import data from batdetect2.audio import build_audio_loader -from batdetect2.evaluate.evaluator import build_evaluator +from batdetect2.evaluate import build_evaluator from batdetect2.logging import build_logger from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets @@ -106,7 +106,7 @@ def train( config, targets=targets, evaluator=build_evaluator( - config.train.validation.evaluator, + config.train.validation, targets=targets, ), checkpoint_dir=checkpoint_dir,