mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add function to facilitate task evaluation
This commit is contained in:
parent
dbd2d30ead
commit
4ecbc2b734
@ -194,10 +194,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
detection_data.geometry.item()
|
||||
)
|
||||
else:
|
||||
start_time = detection_data.start_time
|
||||
end_time = detection_data.end_time
|
||||
low_freq = detection_data.low_freq
|
||||
high_freq = detection_data.high_freq
|
||||
start_time = detection_data.start_time.item()
|
||||
end_time = detection_data.end_time.item()
|
||||
low_freq = detection_data.low_freq.item()
|
||||
high_freq = detection_data.high_freq.item()
|
||||
geometry = data.BoundingBox.model_construct(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
@ -29,6 +29,7 @@ __all__ = [
|
||||
"ClassificationMetric",
|
||||
"ClassificationMetricConfig",
|
||||
"build_classification_metric",
|
||||
"compute_precision_recall_curves",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -52,6 +52,14 @@ def average_precision(
|
||||
num_positives=num_positives,
|
||||
)
|
||||
|
||||
# pascal 12 way
|
||||
return _average_precision(recall, precision)
|
||||
|
||||
|
||||
def _average_precision(
|
||||
recall: np.ndarray,
|
||||
precision: np.ndarray,
|
||||
) -> float:
|
||||
# pascal 12 way
|
||||
mprec = np.hstack((0, precision, 0))
|
||||
mrec = np.hstack((0, recall, 1))
|
||||
@ -59,5 +67,4 @@ def average_precision(
|
||||
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||
|
||||
return ave_prec
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Annotated,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from typing import Annotated, Optional, Union
|
||||
from typing import Annotated, Optional, Sequence, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.tasks.base import tasks_registry
|
||||
from batdetect2.evaluate.tasks.base import BaseTaskConfig, tasks_registry
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.evaluate.tasks.clip_classification import (
|
||||
ClipClassificationTaskConfig,
|
||||
@ -11,11 +12,16 @@ from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
EvaluatorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TaskConfig",
|
||||
"build_task",
|
||||
"evaluate_task",
|
||||
]
|
||||
|
||||
|
||||
@ -37,3 +43,25 @@ def build_task(
|
||||
) -> EvaluatorProtocol:
|
||||
targets = targets or build_targets()
|
||||
return tasks_registry.build(config, targets)
|
||||
|
||||
|
||||
def evaluate_task(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
task: Optional["str"] = None,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[Union[TaskConfig, dict]] = None,
|
||||
):
|
||||
if isinstance(config, BaseTaskConfig):
|
||||
task_obj = build_task(config, targets)
|
||||
return task_obj.evaluate(clip_annotations, predictions)
|
||||
|
||||
if task is None:
|
||||
raise ValueError(
|
||||
"Task must be specified if a full config is not provided.",
|
||||
)
|
||||
|
||||
config_class = tasks_registry.get_config_type(task)
|
||||
config = config_class.model_validate(config or {}) # type: ignore
|
||||
task_obj = build_task(config, targets) # type: ignore
|
||||
return task_obj.evaluate(clip_annotations, predictions)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user