From dfe19c68f0df8eecc820831662aee502b87e9c47 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 28 Mar 2026 14:25:18 +0000 Subject: [PATCH] Add detection threshold override option to postprocessor --- src/batdetect2/postprocess/postprocessor.py | 9 ++++++++- src/batdetect2/postprocess/types.py | 5 ++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index d30e6d1..b6864ea 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -63,7 +63,14 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): def forward( self, output: ModelOutput, + detection_threshold: float | None = None, ) -> list[ClipDetectionsTensor]: + threshold = ( + self.detection_threshold + if detection_threshold is None + else detection_threshold + ) + detection_heatmap = non_max_suppression( output.detection_probs.detach(), kernel_size=self.nms_kernel_size, @@ -78,7 +85,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): feature_heatmap=output.features, classification_heatmap=output.class_probs, max_detections=max_detections, - threshold=self.detection_threshold, + threshold=threshold, ) return [ diff --git a/src/batdetect2/postprocess/types.py b/src/batdetect2/postprocess/types.py index 61e18e3..10d1511 100644 --- a/src/batdetect2/postprocess/types.py +++ b/src/batdetect2/postprocess/types.py @@ -81,5 +81,8 @@ class ClipPrediction: class PostprocessorProtocol(Protocol): def __call__( - self, output: "ModelOutput" + self, + output: "ModelOutput", + *, + detection_threshold: float | None = None, ) -> list[ClipDetectionsTensor]: ...