diff --git a/src/batdetect2/models/detectors.py b/src/batdetect2/models/detectors.py index c5ab691..854f71e 100644 --- a/src/batdetect2/models/detectors.py +++ b/src/batdetect2/models/detectors.py @@ -114,8 +114,8 @@ class Detector(DetectionModel): `(B, C_out, H, W)`. """ features = self.backbone(spec) - detection = self.detector_head(features) classification = self.classifier_head(features) + detection = classification.sum(dim=1, keep_dim=True) size_preds = self.bbox_head(features) return ModelOutput( detection_probs=detection,