Add detection threshold override option to postprocessor

This commit is contained in:
mbsantiago 2026-03-28 14:25:18 +00:00
parent e4bbde9995
commit dfe19c68f0
2 changed files with 12 additions and 2 deletions

View File

@ -63,7 +63,14 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def forward( def forward(
self, self,
output: ModelOutput, output: ModelOutput,
detection_threshold: float | None = None,
) -> list[ClipDetectionsTensor]: ) -> list[ClipDetectionsTensor]:
threshold = (
self.detection_threshold
if detection_threshold is None
else detection_threshold
)
detection_heatmap = non_max_suppression( detection_heatmap = non_max_suppression(
output.detection_probs.detach(), output.detection_probs.detach(),
kernel_size=self.nms_kernel_size, kernel_size=self.nms_kernel_size,
@ -78,7 +85,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
feature_heatmap=output.features, feature_heatmap=output.features,
classification_heatmap=output.class_probs, classification_heatmap=output.class_probs,
max_detections=max_detections, max_detections=max_detections,
threshold=self.detection_threshold, threshold=threshold,
) )
return [ return [

View File

@ -81,5 +81,8 @@ class ClipPrediction:
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
def __call__( def __call__(
self, output: "ModelOutput" self,
output: "ModelOutput",
*,
detection_threshold: float | None = None,
) -> list[ClipDetectionsTensor]: ... ) -> list[ClipDetectionsTensor]: ...