diff --git a/src/batdetect2/data/conditions/sound_events.py b/src/batdetect2/data/conditions/sound_events.py index 5786d16..b0096d8 100644 --- a/src/batdetect2/data/conditions/sound_events.py +++ b/src/batdetect2/data/conditions/sound_events.py @@ -1,4 +1,6 @@ +import operator from collections.abc import Callable, Sequence +from functools import partial from typing import Annotated, Literal from pydantic import Field @@ -78,25 +80,23 @@ class DurationConfig(BaseConfig): seconds: float -def _build_comparator( - operator: Operator, value: float -) -> Callable[[float], bool]: - if operator == "gt": - return lambda x: x > value +def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]: + if op == "gt": + return partial(operator.gt, value) - if operator == "gte": - return lambda x: x >= value + if op == "gte": + return partial(operator.ge, value) - if operator == "lt": - return lambda x: x < value + if op == "lt": + return partial(operator.lt, value) - if operator == "lte": - return lambda x: x <= value + if op == "lte": + return partial(operator.le, value) - if operator == "eq": - return lambda x: x == value + if op == "eq": + return partial(operator.eq, value) - raise ValueError(f"Invalid operator {operator}") + raise ValueError(f"Invalid operator {op}") class Duration: diff --git a/src/batdetect2/evaluate/metrics/classification.py b/src/batdetect2/evaluate/metrics/classification.py index daf3e31..8be33a5 100644 --- a/src/batdetect2/evaluate/metrics/classification.py +++ b/src/batdetect2/evaluate/metrics/classification.py @@ -204,15 +204,19 @@ class ClassificationROCAUC(BaseClassificationMetric): ignore_generic=self.ignore_generic, ) - class_scores = { - class_name: float( + class_scores = {} + + for class_name in self.targets.class_names: + if len(y_true[class_name]) == 0: + class_scores[class_name] = np.nan + continue + + class_scores[class_name] = float( metrics.roc_auc_score( y_true[class_name], y_score[class_name], ) ) - for class_name in self.targets.class_names - } mean_score = float( np.mean([v for v in class_scores.values() if v != np.nan]) diff --git a/src/batdetect2/evaluate/metrics/detection.py b/src/batdetect2/evaluate/metrics/detection.py index 59fa0b6..c0ae0c9 100644 --- a/src/batdetect2/evaluate/metrics/detection.py +++ b/src/batdetect2/evaluate/metrics/detection.py @@ -133,6 +133,9 @@ class DetectionROCAUC: y_true.append(m.is_ground_truth) y_score.append(m.score) + if len(y_true) == 0: + return {self.label: np.nan} + score = float(metrics.roc_auc_score(y_true, y_score)) return {self.label: score}