Make sure threshold is used

This commit is contained in:
mbsantiago 2025-12-12 19:53:15 +00:00
parent f71fe0c2e2
commit 750f9e43c4
5 changed files with 35 additions and 21 deletions

View File

@ -86,9 +86,7 @@ class PRCurve(BasePlot):
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig
plt.close(fig)
@classification_plots.register(PRCurveConfig)

View File

@ -15,7 +15,11 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig
from batdetect2.evaluate.affinity import (
AffinityConfig,
TimeAffinityConfig,
build_affinity_function,
)
from batdetect2.typing import (
AffinityFunction,
BatDetect2Prediction,
@ -40,6 +44,8 @@ T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig):
prefix: str
ignore_start_end: float = 0.01
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol
@ -166,6 +172,24 @@ class BaseSEDTask(BaseTask[T_Output]):
self.affinity_threshold = affinity_threshold
self.strict_match = strict_match
@classmethod
def build(
cls,
config: BaseSEDTaskConfig,
targets: TargetProtocol,
**kwargs,
):
affinity = build_affinity_function(config.affinity)
return cls(
affinity=affinity,
affinity_threshold=config.affinity_threshold,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
strict_match=config.strict_match,
targets=targets,
**kwargs,
)
def is_in_bounds(
geometry: data.Geometry,

View File

@ -5,7 +5,6 @@ from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig,
ClassificationMetricConfig,
@ -93,6 +92,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
affinity=self.affinity,
score=partial(get_class_score, class_idx=class_idx),
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
):
true_class = (
self.targets.encode_class(match.annotation)
@ -131,14 +131,12 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
build_classification_plotter(plot, targets)
for plot in config.plots
]
affinity = build_affinity_function(config.affinity)
return ClassificationTask(
affinity=affinity,
prefix=config.prefix,
return ClassificationTask.build(
config=config,
plots=plots,
targets=targets,
metrics=metrics,
strict_match=config.strict_match,
include_generics=config.include_generics,
)

View File

@ -4,7 +4,6 @@ from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.detection import (
ClipEval,
DetectionAveragePrecisionConfig,
@ -60,6 +59,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
affinity=self.affinity,
score=lambda pred: pred.detection_score,
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
):
matches.append(
MatchEval(
@ -83,12 +83,9 @@ class DetectionTask(BaseSEDTask[ClipEval]):
plots = [
build_detection_plotter(plot, targets) for plot in config.plots
]
affinity = build_affinity_function(config.affinity)
return DetectionTask(
prefix=config.prefix,
affinity=affinity,
return DetectionTask.build(
config=config,
metrics=metrics,
targets=targets,
plots=plots,
strict_match=config.strict_match,
)

View File

@ -4,7 +4,6 @@ from pydantic import Field
from soundevent import data
from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.top_class import (
ClipEval,
MatchEval,
@ -59,6 +58,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
affinity=self.affinity,
score=lambda pred: pred.class_scores.max(),
strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
):
gt = match.annotation
pred = match.prediction
@ -101,12 +101,9 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
plots = [
build_top_class_plotter(plot, targets) for plot in config.plots
]
affinity = build_affinity_function(config.affinity)
return TopClassDetectionTask(
prefix=config.prefix,
return TopClassDetectionTask.build(
config=config,
plots=plots,
metrics=metrics,
targets=targets,
affinity=affinity,
strict_match=config.strict_match,
)