diff --git a/src/batdetect2/postprocess/postprocessor.py b/src/batdetect2/postprocess/postprocessor.py index d30e6d1..b6864ea 100644 --- a/src/batdetect2/postprocess/postprocessor.py +++ b/src/batdetect2/postprocess/postprocessor.py @@ -63,7 +63,14 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): def forward( self, output: ModelOutput, + detection_threshold: float | None = None, ) -> list[ClipDetectionsTensor]: + threshold = ( + self.detection_threshold + if detection_threshold is None + else detection_threshold + ) + detection_heatmap = non_max_suppression( output.detection_probs.detach(), kernel_size=self.nms_kernel_size, @@ -78,7 +85,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): feature_heatmap=output.features, classification_heatmap=output.class_probs, max_detections=max_detections, - threshold=self.detection_threshold, + threshold=threshold, ) return [ diff --git a/src/batdetect2/postprocess/types.py b/src/batdetect2/postprocess/types.py index 61e18e3..10d1511 100644 --- a/src/batdetect2/postprocess/types.py +++ b/src/batdetect2/postprocess/types.py @@ -81,5 +81,8 @@ class ClipPrediction: class PostprocessorProtocol(Protocol): def __call__( - self, output: "ModelOutput" + self, + output: "ModelOutput", + *, + detection_threshold: float | None = None, ) -> list[ClipDetectionsTensor]: ...