mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add detection threshold override option to postprocessor
This commit is contained in:
parent
e4bbde9995
commit
dfe19c68f0
@ -63,7 +63,14 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetectionsTensor]:
|
) -> list[ClipDetectionsTensor]:
|
||||||
|
threshold = (
|
||||||
|
self.detection_threshold
|
||||||
|
if detection_threshold is None
|
||||||
|
else detection_threshold
|
||||||
|
)
|
||||||
|
|
||||||
detection_heatmap = non_max_suppression(
|
detection_heatmap = non_max_suppression(
|
||||||
output.detection_probs.detach(),
|
output.detection_probs.detach(),
|
||||||
kernel_size=self.nms_kernel_size,
|
kernel_size=self.nms_kernel_size,
|
||||||
@ -78,7 +85,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
feature_heatmap=output.features,
|
feature_heatmap=output.features,
|
||||||
classification_heatmap=output.class_probs,
|
classification_heatmap=output.class_probs,
|
||||||
max_detections=max_detections,
|
max_detections=max_detections,
|
||||||
threshold=self.detection_threshold,
|
threshold=threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -81,5 +81,8 @@ class ClipPrediction:
|
|||||||
|
|
||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, output: "ModelOutput"
|
self,
|
||||||
|
output: "ModelOutput",
|
||||||
|
*,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetectionsTensor]: ...
|
) -> list[ClipDetectionsTensor]: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user