Small fix

This commit is contained in:
mbsantiago 2025-08-18 10:25:14 +01:00
parent 9d1497b3f4
commit 98f83e8b34

View File

@ -114,8 +114,8 @@ class Detector(DetectionModel):
`(B, C_out, H, W)`. `(B, C_out, H, W)`.
""" """
features = self.backbone(spec) features = self.backbone(spec)
detection = self.detector_head(features)
classification = self.classifier_head(features) classification = self.classifier_head(features)
detection = classification.sum(dim=1, keep_dim=True)
size_preds = self.bbox_head(features) size_preds = self.bbox_head(features)
return ModelOutput( return ModelOutput(
detection_probs=detection, detection_probs=detection,