mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Make sure threshold is used
This commit is contained in:
parent
f71fe0c2e2
commit
750f9e43c4
@ -86,9 +86,7 @@ class PRCurve(BasePlot):
|
||||
|
||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||
ax.set_title(class_name)
|
||||
|
||||
yield f"{self.label}/{class_name}", fig
|
||||
|
||||
plt.close(fig)
|
||||
|
||||
@classification_plots.register(PRCurveConfig)
|
||||
|
||||
@ -15,7 +15,11 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig
|
||||
from batdetect2.evaluate.affinity import (
|
||||
AffinityConfig,
|
||||
TimeAffinityConfig,
|
||||
build_affinity_function,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
AffinityFunction,
|
||||
BatDetect2Prediction,
|
||||
@ -40,6 +44,8 @@ T_Output = TypeVar("T_Output")
|
||||
class BaseTaskConfig(BaseConfig):
|
||||
prefix: str
|
||||
|
||||
ignore_start_end: float = 0.01
|
||||
|
||||
|
||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
targets: TargetProtocol
|
||||
@ -166,6 +172,24 @@ class BaseSEDTask(BaseTask[T_Output]):
|
||||
self.affinity_threshold = affinity_threshold
|
||||
self.strict_match = strict_match
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls,
|
||||
config: BaseSEDTaskConfig,
|
||||
targets: TargetProtocol,
|
||||
**kwargs,
|
||||
):
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return cls(
|
||||
affinity=affinity,
|
||||
affinity_threshold=config.affinity_threshold,
|
||||
prefix=config.prefix,
|
||||
ignore_start_end=config.ignore_start_end,
|
||||
strict_match=config.strict_match,
|
||||
targets=targets,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def is_in_bounds(
|
||||
geometry: data.Geometry,
|
||||
|
||||
@ -5,7 +5,6 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.affinity import build_affinity_function
|
||||
from batdetect2.evaluate.metrics.classification import (
|
||||
ClassificationAveragePrecisionConfig,
|
||||
ClassificationMetricConfig,
|
||||
@ -93,6 +92,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
||||
affinity=self.affinity,
|
||||
score=partial(get_class_score, class_idx=class_idx),
|
||||
strict_match=self.strict_match,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
):
|
||||
true_class = (
|
||||
self.targets.encode_class(match.annotation)
|
||||
@ -131,14 +131,12 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
||||
build_classification_plotter(plot, targets)
|
||||
for plot in config.plots
|
||||
]
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return ClassificationTask(
|
||||
affinity=affinity,
|
||||
prefix=config.prefix,
|
||||
return ClassificationTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
targets=targets,
|
||||
metrics=metrics,
|
||||
strict_match=config.strict_match,
|
||||
include_generics=config.include_generics,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.affinity import build_affinity_function
|
||||
from batdetect2.evaluate.metrics.detection import (
|
||||
ClipEval,
|
||||
DetectionAveragePrecisionConfig,
|
||||
@ -60,6 +59,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
||||
affinity=self.affinity,
|
||||
score=lambda pred: pred.detection_score,
|
||||
strict_match=self.strict_match,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
):
|
||||
matches.append(
|
||||
MatchEval(
|
||||
@ -83,12 +83,9 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
||||
plots = [
|
||||
build_detection_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return DetectionTask(
|
||||
prefix=config.prefix,
|
||||
affinity=affinity,
|
||||
return DetectionTask.build(
|
||||
config=config,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
plots=plots,
|
||||
strict_match=config.strict_match,
|
||||
)
|
||||
|
||||
@ -4,7 +4,6 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.evaluation import match_detections_and_gts
|
||||
|
||||
from batdetect2.evaluate.affinity import build_affinity_function
|
||||
from batdetect2.evaluate.metrics.top_class import (
|
||||
ClipEval,
|
||||
MatchEval,
|
||||
@ -59,6 +58,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
||||
affinity=self.affinity,
|
||||
score=lambda pred: pred.class_scores.max(),
|
||||
strict_match=self.strict_match,
|
||||
affinity_threshold=self.affinity_threshold,
|
||||
):
|
||||
gt = match.annotation
|
||||
pred = match.prediction
|
||||
@ -101,12 +101,9 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
||||
plots = [
|
||||
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||
]
|
||||
affinity = build_affinity_function(config.affinity)
|
||||
return TopClassDetectionTask(
|
||||
prefix=config.prefix,
|
||||
return TopClassDetectionTask.build(
|
||||
config=config,
|
||||
plots=plots,
|
||||
metrics=metrics,
|
||||
targets=targets,
|
||||
affinity=affinity,
|
||||
strict_match=config.strict_match,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user