Minor fixes

This commit is contained in:
Santiago Martinez Balvanera 2026-04-24 16:21:47 +01:00
parent 4303d4e42d
commit a5fdf438e2
3 changed files with 25 additions and 18 deletions

View File

@ -1,4 +1,6 @@
import operator
from collections.abc import Callable, Sequence
from functools import partial
from typing import Annotated, Literal
from pydantic import Field
@ -78,25 +80,23 @@ class DurationConfig(BaseConfig):
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
if op == "gt":
return partial(operator.gt, value)
if operator == "gte":
return lambda x: x >= value
if op == "gte":
return partial(operator.ge, value)
if operator == "lt":
return lambda x: x < value
if op == "lt":
return partial(operator.lt, value)
if operator == "lte":
return lambda x: x <= value
if op == "lte":
return partial(operator.le, value)
if operator == "eq":
return lambda x: x == value
if op == "eq":
return partial(operator.eq, value)
raise ValueError(f"Invalid operator {operator}")
raise ValueError(f"Invalid operator {op}")
class Duration:

View File

@ -204,15 +204,19 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_generic=self.ignore_generic,
)
class_scores = {
class_name: float(
class_scores = {}
for class_name in self.targets.class_names:
if len(y_true[class_name]) == 0:
class_scores[class_name] = np.nan
continue
class_scores[class_name] = float(
metrics.roc_auc_score(
y_true[class_name],
y_score[class_name],
)
)
for class_name in self.targets.class_names
}
mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan])

View File

@ -133,6 +133,9 @@ class DetectionROCAUC:
y_true.append(m.is_ground_truth)
y_score.append(m.score)
if len(y_true) == 0:
return {self.label: np.nan}
score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score}