mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Make sure threshold is used
This commit is contained in:
parent
f71fe0c2e2
commit
750f9e43c4
@ -86,9 +86,7 @@ class PRCurve(BasePlot):
|
|||||||
|
|
||||||
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
ax = plot_pr_curve(precision, recall, thresholds, ax=ax)
|
||||||
ax.set_title(class_name)
|
ax.set_title(class_name)
|
||||||
|
|
||||||
yield f"{self.label}/{class_name}", fig
|
yield f"{self.label}/{class_name}", fig
|
||||||
|
|
||||||
plt.close(fig)
|
plt.close(fig)
|
||||||
|
|
||||||
@classification_plots.register(PRCurveConfig)
|
@classification_plots.register(PRCurveConfig)
|
||||||
|
|||||||
@ -15,7 +15,11 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig, Registry
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.evaluate.affinity import AffinityConfig, TimeAffinityConfig
|
from batdetect2.evaluate.affinity import (
|
||||||
|
AffinityConfig,
|
||||||
|
TimeAffinityConfig,
|
||||||
|
build_affinity_function,
|
||||||
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AffinityFunction,
|
AffinityFunction,
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
@ -40,6 +44,8 @@ T_Output = TypeVar("T_Output")
|
|||||||
class BaseTaskConfig(BaseConfig):
|
class BaseTaskConfig(BaseConfig):
|
||||||
prefix: str
|
prefix: str
|
||||||
|
|
||||||
|
ignore_start_end: float = 0.01
|
||||||
|
|
||||||
|
|
||||||
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||||
targets: TargetProtocol
|
targets: TargetProtocol
|
||||||
@ -166,6 +172,24 @@ class BaseSEDTask(BaseTask[T_Output]):
|
|||||||
self.affinity_threshold = affinity_threshold
|
self.affinity_threshold = affinity_threshold
|
||||||
self.strict_match = strict_match
|
self.strict_match = strict_match
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build(
|
||||||
|
cls,
|
||||||
|
config: BaseSEDTaskConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
affinity = build_affinity_function(config.affinity)
|
||||||
|
return cls(
|
||||||
|
affinity=affinity,
|
||||||
|
affinity_threshold=config.affinity_threshold,
|
||||||
|
prefix=config.prefix,
|
||||||
|
ignore_start_end=config.ignore_start_end,
|
||||||
|
strict_match=config.strict_match,
|
||||||
|
targets=targets,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_in_bounds(
|
def is_in_bounds(
|
||||||
geometry: data.Geometry,
|
geometry: data.Geometry,
|
||||||
|
|||||||
@ -5,7 +5,6 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_detections_and_gts
|
from soundevent.evaluation import match_detections_and_gts
|
||||||
|
|
||||||
from batdetect2.evaluate.affinity import build_affinity_function
|
|
||||||
from batdetect2.evaluate.metrics.classification import (
|
from batdetect2.evaluate.metrics.classification import (
|
||||||
ClassificationAveragePrecisionConfig,
|
ClassificationAveragePrecisionConfig,
|
||||||
ClassificationMetricConfig,
|
ClassificationMetricConfig,
|
||||||
@ -93,6 +92,7 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
affinity=self.affinity,
|
affinity=self.affinity,
|
||||||
score=partial(get_class_score, class_idx=class_idx),
|
score=partial(get_class_score, class_idx=class_idx),
|
||||||
strict_match=self.strict_match,
|
strict_match=self.strict_match,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
):
|
):
|
||||||
true_class = (
|
true_class = (
|
||||||
self.targets.encode_class(match.annotation)
|
self.targets.encode_class(match.annotation)
|
||||||
@ -131,14 +131,12 @@ class ClassificationTask(BaseSEDTask[ClipEval]):
|
|||||||
build_classification_plotter(plot, targets)
|
build_classification_plotter(plot, targets)
|
||||||
for plot in config.plots
|
for plot in config.plots
|
||||||
]
|
]
|
||||||
affinity = build_affinity_function(config.affinity)
|
return ClassificationTask.build(
|
||||||
return ClassificationTask(
|
config=config,
|
||||||
affinity=affinity,
|
|
||||||
prefix=config.prefix,
|
|
||||||
plots=plots,
|
plots=plots,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
strict_match=config.strict_match,
|
include_generics=config.include_generics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_detections_and_gts
|
from soundevent.evaluation import match_detections_and_gts
|
||||||
|
|
||||||
from batdetect2.evaluate.affinity import build_affinity_function
|
|
||||||
from batdetect2.evaluate.metrics.detection import (
|
from batdetect2.evaluate.metrics.detection import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
DetectionAveragePrecisionConfig,
|
DetectionAveragePrecisionConfig,
|
||||||
@ -60,6 +59,7 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
affinity=self.affinity,
|
affinity=self.affinity,
|
||||||
score=lambda pred: pred.detection_score,
|
score=lambda pred: pred.detection_score,
|
||||||
strict_match=self.strict_match,
|
strict_match=self.strict_match,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
):
|
):
|
||||||
matches.append(
|
matches.append(
|
||||||
MatchEval(
|
MatchEval(
|
||||||
@ -83,12 +83,9 @@ class DetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
plots = [
|
plots = [
|
||||||
build_detection_plotter(plot, targets) for plot in config.plots
|
build_detection_plotter(plot, targets) for plot in config.plots
|
||||||
]
|
]
|
||||||
affinity = build_affinity_function(config.affinity)
|
return DetectionTask.build(
|
||||||
return DetectionTask(
|
config=config,
|
||||||
prefix=config.prefix,
|
|
||||||
affinity=affinity,
|
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
plots=plots,
|
plots=plots,
|
||||||
strict_match=config.strict_match,
|
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.evaluation import match_detections_and_gts
|
from soundevent.evaluation import match_detections_and_gts
|
||||||
|
|
||||||
from batdetect2.evaluate.affinity import build_affinity_function
|
|
||||||
from batdetect2.evaluate.metrics.top_class import (
|
from batdetect2.evaluate.metrics.top_class import (
|
||||||
ClipEval,
|
ClipEval,
|
||||||
MatchEval,
|
MatchEval,
|
||||||
@ -59,6 +58,7 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
affinity=self.affinity,
|
affinity=self.affinity,
|
||||||
score=lambda pred: pred.class_scores.max(),
|
score=lambda pred: pred.class_scores.max(),
|
||||||
strict_match=self.strict_match,
|
strict_match=self.strict_match,
|
||||||
|
affinity_threshold=self.affinity_threshold,
|
||||||
):
|
):
|
||||||
gt = match.annotation
|
gt = match.annotation
|
||||||
pred = match.prediction
|
pred = match.prediction
|
||||||
@ -101,12 +101,9 @@ class TopClassDetectionTask(BaseSEDTask[ClipEval]):
|
|||||||
plots = [
|
plots = [
|
||||||
build_top_class_plotter(plot, targets) for plot in config.plots
|
build_top_class_plotter(plot, targets) for plot in config.plots
|
||||||
]
|
]
|
||||||
affinity = build_affinity_function(config.affinity)
|
return TopClassDetectionTask.build(
|
||||||
return TopClassDetectionTask(
|
config=config,
|
||||||
prefix=config.prefix,
|
|
||||||
plots=plots,
|
plots=plots,
|
||||||
metrics=metrics,
|
metrics=metrics,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
affinity=affinity,
|
|
||||||
strict_match=config.strict_match,
|
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user