mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
68 lines
1.8 KiB
Python
68 lines
1.8 KiB
Python
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
|
|
|
from matplotlib.figure import Figure
|
|
from soundevent import data
|
|
|
|
from batdetect2.evaluate.config import EvaluationConfig
|
|
from batdetect2.evaluate.tasks import build_task
|
|
from batdetect2.targets import build_targets
|
|
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
|
|
|
__all__ = [
|
|
"Evaluator",
|
|
"build_evaluator",
|
|
]
|
|
|
|
|
|
class Evaluator:
|
|
def __init__(
|
|
self,
|
|
targets: TargetProtocol,
|
|
tasks: Sequence[EvaluatorProtocol],
|
|
):
|
|
self.targets = targets
|
|
self.tasks = tasks
|
|
|
|
def evaluate(
|
|
self,
|
|
clip_annotations: Sequence[data.ClipAnnotation],
|
|
predictions: Sequence[Sequence[RawPrediction]],
|
|
) -> List[Any]:
|
|
return [
|
|
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
|
]
|
|
|
|
def compute_metrics(self, eval_outputs: List[Any]) -> Dict[str, float]:
|
|
results = {}
|
|
|
|
for task, outputs in zip(self.tasks, eval_outputs):
|
|
results.update(task.compute_metrics(outputs))
|
|
|
|
return results
|
|
|
|
def generate_plots(
|
|
self,
|
|
eval_outputs: List[Any],
|
|
) -> Iterable[Tuple[str, Figure]]:
|
|
for task, outputs in zip(self.tasks, eval_outputs):
|
|
for name, fig in task.generate_plots(outputs):
|
|
yield name, fig
|
|
|
|
|
|
def build_evaluator(
|
|
config: Optional[Union[EvaluationConfig, dict]] = None,
|
|
targets: Optional[TargetProtocol] = None,
|
|
) -> EvaluatorProtocol:
|
|
targets = targets or build_targets()
|
|
|
|
if config is None:
|
|
config = EvaluationConfig()
|
|
|
|
if not isinstance(config, EvaluationConfig):
|
|
config = EvaluationConfig.model_validate(config)
|
|
|
|
return Evaluator(
|
|
targets=targets,
|
|
tasks=[build_task(task, targets=targets) for task in config.tasks],
|
|
)
|