Make sure threshold is used

This commit is contained in:
mbsantiago 2025-12-12 19:53:15 +00:00
parent f71fe0c2e2
commit 750f9e43c4
5 changed files with 35 additions and 21 deletions

View File

@ -86,9 +86,7 @@ class PRCurve(BasePlot):
ax = plot_pr_curve(precision, recall, thresholds, ax=ax) ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
ax.set_title(class_name) ax.set_title(class_name)
yield f"{self.label}/{class_name}", fig yield f"{self.label}/{class_name}", fig
plt.close(fig) plt.close(fig)
@classification_plots.register(PRCurveConfig) @classification_plots.register(PRCurveConfig)

View File

@ -15,7 +15,11 @@ from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig, Registry from batdetect2.core import BaseConfig, Registry
from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig from batdetect2.evaluate.affinity import (
AffinityConfig,
TimeAffinityConfig,
build_affinity_function,
)
from batdetect2.typing import ( from batdetect2.typing import (
AffinityFunction, AffinityFunction,
BatDetect2Prediction, BatDetect2Prediction,
@ -40,6 +44,8 @@ T_Output = TypeVar("T_Output")
class BaseTaskConfig(BaseConfig): class BaseTaskConfig(BaseConfig):
prefix: str prefix: str
ignore_start_end: float = 0.01
class BaseTask(EvaluatorProtocol, Generic[T_Output]): class BaseTask(EvaluatorProtocol, Generic[T_Output]):
targets: TargetProtocol targets: TargetProtocol
@ -166,6 +172,24 @@ class BaseSEDTask(BaseTask[T_Output]):
self.affinity_threshold = affinity_threshold self.affinity_threshold = affinity_threshold
self.strict_match = strict_match self.strict_match = strict_match
@classmethod
def build(
cls,
config: BaseSEDTaskConfig,
targets: TargetProtocol,
**kwargs,
):
affinity = build_affinity_function(config.affinity)
return cls(
affinity=affinity,
affinity_threshold=config.affinity_threshold,
prefix=config.prefix,
ignore_start_end=config.ignore_start_end,
strict_match=config.strict_match,
targets=targets,
**kwargs,
)
def is_in_bounds( def is_in_bounds(
geometry: data.Geometry, geometry: data.Geometry,

View File

@ -5,7 +5,6 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.classification import ( from batdetect2.evaluate.metrics.classification import (
ClassificationAveragePrecisionConfig, ClassificationAveragePrecisionConfig,
ClassificationMetricConfig, ClassificationMetricConfig,
@ -93,6 +92,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
affinity=self.affinity, affinity=self.affinity,
score=partial(get_class_score, class_idx=class_idx), score=partial(get_class_score, class_idx=class_idx),
strict_match=self.strict_match, strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
): ):
true_class = ( true_class = (
self.targets.encode_class(match.annotation) self.targets.encode_class(match.annotation)
@ -131,14 +131,12 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
build_classification_plotter(plot, targets) build_classification_plotter(plot, targets)
for plot in config.plots for plot in config.plots
] ]
affinity = build_affinity_function(config.affinity) return ClassificationTask.build(
return ClassificationTask( config=config,
affinity=affinity,
prefix=config.prefix,
plots=plots, plots=plots,
targets=targets, targets=targets,
metrics=metrics, metrics=metrics,
strict_match=config.strict_match, include_generics=config.include_generics,
) )

View File

@ -4,7 +4,6 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.detection import ( from batdetect2.evaluate.metrics.detection import (
ClipEval, ClipEval,
DetectionAveragePrecisionConfig, DetectionAveragePrecisionConfig,
@ -60,6 +59,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
affinity=self.affinity, affinity=self.affinity,
score=lambda pred: pred.detection_score, score=lambda pred: pred.detection_score,
strict_match=self.strict_match, strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
): ):
matches.append( matches.append(
MatchEval( MatchEval(
@ -83,12 +83,9 @@ class DetectionTask(BaseSEDTask[ClipEval]):
plots = [ plots = [
build_detection_plotter(plot, targets) for plot in config.plots build_detection_plotter(plot, targets) for plot in config.plots
] ]
affinity = build_affinity_function(config.affinity) return DetectionTask.build(
return DetectionTask( config=config,
prefix=config.prefix,
affinity=affinity,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
plots=plots, plots=plots,
strict_match=config.strict_match,
) )

View File

@ -4,7 +4,6 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.evaluation import match_detections_and_gts from soundevent.evaluation import match_detections_and_gts
from batdetect2.evaluate.affinity import build_affinity_function
from batdetect2.evaluate.metrics.top_class import ( from batdetect2.evaluate.metrics.top_class import (
ClipEval, ClipEval,
MatchEval, MatchEval,
@ -59,6 +58,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
affinity=self.affinity, affinity=self.affinity,
score=lambda pred: pred.class_scores.max(), score=lambda pred: pred.class_scores.max(),
strict_match=self.strict_match, strict_match=self.strict_match,
affinity_threshold=self.affinity_threshold,
): ):
gt = match.annotation gt = match.annotation
pred = match.prediction pred = match.prediction
@ -101,12 +101,9 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
plots = [ plots = [
build_top_class_plotter(plot, targets) for plot in config.plots build_top_class_plotter(plot, targets) for plot in config.plots
] ]
affinity = build_affinity_function(config.affinity) return TopClassDetectionTask.build(
return TopClassDetectionTask( config=config,
prefix=config.prefix,
plots=plots, plots=plots,
metrics=metrics, metrics=metrics,
targets=targets, targets=targets,
affinity=affinity,
strict_match=config.strict_match,
) )