mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Minor fixes
This commit is contained in:
parent
4303d4e42d
commit
a5fdf438e2
@ -1,4 +1,6 @@
|
||||
import operator
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
@ -78,25 +80,23 @@ class DurationConfig(BaseConfig):
|
||||
seconds: float
|
||||
|
||||
|
||||
def _build_comparator(
|
||||
operator: Operator, value: float
|
||||
) -> Callable[[float], bool]:
|
||||
if operator == "gt":
|
||||
return lambda x: x > value
|
||||
def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
|
||||
if op == "gt":
|
||||
return partial(operator.gt, value)
|
||||
|
||||
if operator == "gte":
|
||||
return lambda x: x >= value
|
||||
if op == "gte":
|
||||
return partial(operator.ge, value)
|
||||
|
||||
if operator == "lt":
|
||||
return lambda x: x < value
|
||||
if op == "lt":
|
||||
return partial(operator.lt, value)
|
||||
|
||||
if operator == "lte":
|
||||
return lambda x: x <= value
|
||||
if op == "lte":
|
||||
return partial(operator.le, value)
|
||||
|
||||
if operator == "eq":
|
||||
return lambda x: x == value
|
||||
if op == "eq":
|
||||
return partial(operator.eq, value)
|
||||
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
raise ValueError(f"Invalid operator {op}")
|
||||
|
||||
|
||||
class Duration:
|
||||
|
||||
@ -204,15 +204,19 @@ class ClassificationROCAUC(BaseClassificationMetric):
|
||||
ignore_generic=self.ignore_generic,
|
||||
)
|
||||
|
||||
class_scores = {
|
||||
class_name: float(
|
||||
class_scores = {}
|
||||
|
||||
for class_name in self.targets.class_names:
|
||||
if len(y_true[class_name]) == 0:
|
||||
class_scores[class_name] = np.nan
|
||||
continue
|
||||
|
||||
class_scores[class_name] = float(
|
||||
metrics.roc_auc_score(
|
||||
y_true[class_name],
|
||||
y_score[class_name],
|
||||
)
|
||||
)
|
||||
for class_name in self.targets.class_names
|
||||
}
|
||||
|
||||
mean_score = float(
|
||||
np.mean([v for v in class_scores.values() if v != np.nan])
|
||||
|
||||
@ -133,6 +133,9 @@ class DetectionROCAUC:
|
||||
y_true.append(m.is_ground_truth)
|
||||
y_score.append(m.score)
|
||||
|
||||
if len(y_true) == 0:
|
||||
return {self.label: np.nan}
|
||||
|
||||
score = float(metrics.roc_auc_score(y_true, y_score))
|
||||
return {self.label: score}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user