mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add detection threshold override option to postprocessor
This commit is contained in:
parent
e4bbde9995
commit
dfe19c68f0
@ -63,7 +63,14 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
def forward(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetectionsTensor]:
|
||||
threshold = (
|
||||
self.detection_threshold
|
||||
if detection_threshold is None
|
||||
else detection_threshold
|
||||
)
|
||||
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=self.nms_kernel_size,
|
||||
@ -78,7 +85,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
feature_heatmap=output.features,
|
||||
classification_heatmap=output.class_probs,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
threshold=threshold,
|
||||
)
|
||||
|
||||
return [
|
||||
|
||||
@ -81,5 +81,8 @@ class ClipPrediction:
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
def __call__(
|
||||
self, output: "ModelOutput"
|
||||
self,
|
||||
output: "ModelOutput",
|
||||
*,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetectionsTensor]: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user