diff --git a/src/batdetect2/data/predictions/raw.py b/src/batdetect2/data/predictions/raw.py index 08cd3d3..814784d 100644 --- a/src/batdetect2/data/predictions/raw.py +++ b/src/batdetect2/data/predictions/raw.py @@ -194,10 +194,10 @@ class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]): detection_data.geometry.item() ) else: - start_time = detection_data.start_time - end_time = detection_data.end_time - low_freq = detection_data.low_freq - high_freq = detection_data.high_freq + start_time = detection_data.start_time.item() + end_time = detection_data.end_time.item() + low_freq = detection_data.low_freq.item() + high_freq = detection_data.high_freq.item() geometry = data.BoundingBox.model_construct( coordinates=[start_time, low_freq, end_time, high_freq] ) diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index d55d27b..c18a9b1 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -29,6 +29,7 @@ __all__ = [ "ClassificationMetric", "ClassificationMetricConfig", "build_classification_metric", + "compute_precision_recall_curves", ] diff --git a/src/batdetect2/evaluate/metrics/common.py b/src/batdetect2/evaluate/metrics/common.py index 0aa632d..dfc47bc 100644 --- a/src/batdetect2/evaluate/metrics/common.py +++ b/src/batdetect2/evaluate/metrics/common.py @@ -52,6 +52,14 @@ def average_precision( num_positives=num_positives, ) + # pascal 12 way + return _average_precision(recall, precision) + + +def _average_precision( + recall: np.ndarray, + precision: np.ndarray, +) -> float: # pascal 12 way mprec = np.hstack((0, precision, 0)) mrec = np.hstack((0, recall, 1)) @@ -59,5 +67,4 @@ def average_precision( mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1]) inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1 ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum() - return ave_prec diff --git a/src/batdetect2/evaluate/metrics/top_class.py b/src/batdetect2/evaluate/metrics/top_class.py index 16b0fd7..d16575b 100644 --- a/src/batdetect2/evaluate/metrics/top_class.py +++ b/src/batdetect2/evaluate/metrics/top_class.py @@ -1,3 +1,4 @@ +from collections import defaultdict from dataclasses import dataclass from typing import ( Annotated, diff --git a/src/batdetect2/evaluate/tasks/__init__.py b/src/batdetect2/evaluate/tasks/__init__.py index 4b62c16..5f4947d 100644 --- a/src/batdetect2/evaluate/tasks/__init__.py +++ b/src/batdetect2/evaluate/tasks/__init__.py @@ -1,8 +1,9 @@ -from typing import Annotated, Optional, Union +from typing import Annotated, Optional, Sequence, Union from pydantic import Field +from soundevent import data -from batdetect2.evaluate.tasks.base import tasks_registry +from batdetect2.evaluate.tasks.base import BaseTaskConfig, tasks_registry from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig from batdetect2.evaluate.tasks.clip_classification import ( ClipClassificationTaskConfig, @@ -11,11 +12,16 @@ from batdetect2.evaluate.tasks.clip_detection import ClipDetectionTaskConfig from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.top_class import TopClassDetectionTaskConfig from batdetect2.targets import build_targets -from batdetect2.typing import EvaluatorProtocol, TargetProtocol +from batdetect2.typing import ( + BatDetect2Prediction, + EvaluatorProtocol, + TargetProtocol, +) __all__ = [ "TaskConfig", "build_task", + "evaluate_task", ] @@ -37,3 +43,25 @@ def build_task( ) -> EvaluatorProtocol: targets = targets or build_targets() return tasks_registry.build(config, targets) + + +def evaluate_task( + clip_annotations: Sequence[data.ClipAnnotation], + predictions: Sequence[BatDetect2Prediction], + task: Optional["str"] = None, + targets: Optional[TargetProtocol] = None, + config: Optional[Union[TaskConfig, dict]] = None, +): + if isinstance(config, BaseTaskConfig): + task_obj = build_task(config, targets) + return task_obj.evaluate(clip_annotations, predictions) + + if task is None: + raise ValueError( + "Task must be specified if a full config is not provided.", + ) + + config_class = tasks_registry.get_config_type(task) + config = config_class.model_validate(config or {}) # type: ignore + task_obj = build_task(config, targets) # type: ignore + return task_obj.evaluate(clip_annotations, predictions)