From 750f9e43c439a574cfe11e5671b9a6972cbdbf73 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 12 Dec 2025 19:53:15 +0000 Subject: [PATCH] Make sure threshold is used --- .../evaluate/plots/classification.py | 2 -- src/batdetect2/evaluate/tasks/base.py | 26 ++++++++++++++++++- .../evaluate/tasks/classification.py | 10 +++---- src/batdetect2/evaluate/tasks/detection.py | 9 +++---- src/batdetect2/evaluate/tasks/top_class.py | 9 +++---- 5 files changed, 35 insertions(+), 21 deletions(-) diff --git a/src/batdetect2/evaluate/plots/classification.py b/src/batdetect2/evaluate/plots/classification.py index a9d8a6e..a84d08c 100644 --- a/src/batdetect2/evaluate/plots/classification.py +++ b/src/batdetect2/evaluate/plots/classification.py @@ -86,9 +86,7 @@ class PRCurve(BasePlot): ax = plot_pr_curve(precision, recall, thresholds, ax=ax) ax.set_title(class_name) - yield f"{self.label}/{class_name}", fig - plt.close(fig) @classification_plots.register(PRCurveConfig) diff --git a/src/batdetect2/evaluate/tasks/base.py b/src/batdetect2/evaluate/tasks/base.py index 64fd5b9..9523ff5 100644 --- a/src/batdetect2/evaluate/tasks/base.py +++ b/src/batdetect2/evaluate/tasks/base.py @@ -15,7 +15,11 @@ from soundevent import data from soundevent.geometry import compute_bounds from batdetect2.core import BaseConfig, Registry -from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig +from batdetect2.evaluate.affinity import ( + AffinityConfig, + TimeAffinityConfig, + build_affinity_function, +) from batdetect2.typing import ( AffinityFunction, BatDetect2Prediction, @@ -40,6 +44,8 @@ T_Output = TypeVar("T_Output") class BaseTaskConfig(BaseConfig): prefix: str + ignore_start_end: float = 0.01 + class BaseTask(EvaluatorProtocol, Generic[T_Output]): targets: TargetProtocol @@ -166,6 +172,24 @@ class BaseSEDTask(BaseTask[T_Output]): self.affinity_threshold = affinity_threshold self.strict_match = strict_match + @classmethod + def build( + cls, + config: BaseSEDTaskConfig, + targets: TargetProtocol, + **kwargs, + ): + affinity = build_affinity_function(config.affinity) + return cls( + affinity=affinity, + affinity_threshold=config.affinity_threshold, + prefix=config.prefix, + ignore_start_end=config.ignore_start_end, + strict_match=config.strict_match, + targets=targets, + **kwargs, + ) + def is_in_bounds( geometry: data.Geometry, diff --git a/src/batdetect2/evaluate/tasks/classification.py b/src/batdetect2/evaluate/tasks/classification.py index 1a1fc6a..ca7185b 100644 --- a/src/batdetect2/evaluate/tasks/classification.py +++ b/src/batdetect2/evaluate/tasks/classification.py @@ -5,7 +5,6 @@ from pydantic import Field from soundevent import data from soundevent.evaluation import match_detections_and_gts -from batdetect2.evaluate.affinity import build_affinity_function from batdetect2.evaluate.metrics.classification import ( ClassificationAveragePrecisionConfig, ClassificationMetricConfig, @@ -93,6 +92,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]): affinity=self.affinity, score=partial(get_class_score, class_idx=class_idx), strict_match=self.strict_match, + affinity_threshold=self.affinity_threshold, ): true_class = ( self.targets.encode_class(match.annotation) @@ -131,14 +131,12 @@ class ClassificationTask(BaseSEDTask[ClipEval]): build_classification_plotter(plot, targets) for plot in config.plots ] - affinity = build_affinity_function(config.affinity) - return ClassificationTask( - affinity=affinity, - prefix=config.prefix, + return ClassificationTask.build( + config=config, plots=plots, targets=targets, metrics=metrics, - strict_match=config.strict_match, + include_generics=config.include_generics, ) diff --git a/src/batdetect2/evaluate/tasks/detection.py b/src/batdetect2/evaluate/tasks/detection.py index 3ba034e..9ac9535 100644 --- a/src/batdetect2/evaluate/tasks/detection.py +++ b/src/batdetect2/evaluate/tasks/detection.py @@ -4,7 +4,6 @@ from pydantic import Field from soundevent import data from soundevent.evaluation import match_detections_and_gts -from batdetect2.evaluate.affinity import build_affinity_function from batdetect2.evaluate.metrics.detection import ( ClipEval, DetectionAveragePrecisionConfig, @@ -60,6 +59,7 @@ class DetectionTask(BaseSEDTask[ClipEval]): affinity=self.affinity, score=lambda pred: pred.detection_score, strict_match=self.strict_match, + affinity_threshold=self.affinity_threshold, ): matches.append( MatchEval( @@ -83,12 +83,9 @@ class DetectionTask(BaseSEDTask[ClipEval]): plots = [ build_detection_plotter(plot, targets) for plot in config.plots ] - affinity = build_affinity_function(config.affinity) - return DetectionTask( - prefix=config.prefix, - affinity=affinity, + return DetectionTask.build( + config=config, metrics=metrics, targets=targets, plots=plots, - strict_match=config.strict_match, ) diff --git a/src/batdetect2/evaluate/tasks/top_class.py b/src/batdetect2/evaluate/tasks/top_class.py index a625ecc..00b1a73 100644 --- a/src/batdetect2/evaluate/tasks/top_class.py +++ b/src/batdetect2/evaluate/tasks/top_class.py @@ -4,7 +4,6 @@ from pydantic import Field from soundevent import data from soundevent.evaluation import match_detections_and_gts -from batdetect2.evaluate.affinity import build_affinity_function from batdetect2.evaluate.metrics.top_class import ( ClipEval, MatchEval, @@ -59,6 +58,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]): affinity=self.affinity, score=lambda pred: pred.class_scores.max(), strict_match=self.strict_match, + affinity_threshold=self.affinity_threshold, ): gt = match.annotation pred = match.prediction @@ -101,12 +101,9 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]): plots = [ build_top_class_plotter(plot, targets) for plot in config.plots ] - affinity = build_affinity_function(config.affinity) - return TopClassDetectionTask( - prefix=config.prefix, + return TopClassDetectionTask.build( + config=config, plots=plots, metrics=metrics, targets=targets, - affinity=affinity, - strict_match=config.strict_match, )