Minor fixes

This commit is contained in:
Santiago Martinez Balvanera 2026-04-24 16:21:47 +01:00
parent 4303d4e42d
commit a5fdf438e2
3 changed files with 25 additions and 18 deletions

View File

@ -1,4 +1,6 @@
import operator
from collections.abc import Callable, Sequence from collections.abc import Callable, Sequence
from functools import partial
from typing import Annotated, Literal from typing import Annotated, Literal
from pydantic import Field from pydantic import Field
@ -78,25 +80,23 @@ class DurationConfig(BaseConfig):
seconds: float seconds: float
def _build_comparator( def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
operator: Operator, value: float if op == "gt":
) -> Callable[[float], bool]: return partial(operator.gt, value)
if operator == "gt":
return lambda x: x > value
if operator == "gte": if op == "gte":
return lambda x: x >= value return partial(operator.ge, value)
if operator == "lt": if op == "lt":
return lambda x: x < value return partial(operator.lt, value)
if operator == "lte": if op == "lte":
return lambda x: x <= value return partial(operator.le, value)
if operator == "eq": if op == "eq":
return lambda x: x == value return partial(operator.eq, value)
raise ValueError(f"Invalid operator {operator}") raise ValueError(f"Invalid operator {op}")
class Duration: class Duration:

View File

@ -204,15 +204,19 @@ class ClassificationROCAUC(BaseClassificationMetric):
ignore_generic=self.ignore_generic, ignore_generic=self.ignore_generic,
) )
class_scores = { class_scores = {}
class_name: float(
for class_name in self.targets.class_names:
if len(y_true[class_name]) == 0:
class_scores[class_name] = np.nan
continue
class_scores[class_name] = float(
metrics.roc_auc_score( metrics.roc_auc_score(
y_true[class_name], y_true[class_name],
y_score[class_name], y_score[class_name],
) )
) )
for class_name in self.targets.class_names
}
mean_score = float( mean_score = float(
np.mean([v for v in class_scores.values() if v != np.nan]) np.mean([v for v in class_scores.values() if v != np.nan])

View File

@ -133,6 +133,9 @@ class DetectionROCAUC:
y_true.append(m.is_ground_truth) y_true.append(m.is_ground_truth)
y_score.append(m.score) y_score.append(m.score)
if len(y_true) == 0:
return {self.label: np.nan}
score = float(metrics.roc_auc_score(y_true, y_score)) score = float(metrics.roc_auc_score(y_true, y_score))
return {self.label: score} return {self.label: score}