diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index a93b4ea..56bc092 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -210,7 +210,7 @@ class FocalLoss(nn.Module): pos_loss = pos_loss * torch.tensor(self.class_weights) if self.mask_zero: - valid_mask = gt.any(dim=1, keepdim=True).float() + valid_mask = (gt.sum(1) > 0).float().unsqueeze(1) pos_loss = pos_loss * valid_mask neg_loss = neg_loss * valid_mask @@ -476,7 +476,7 @@ def build_loss( size_loss_fn = BBoxLoss() - return LossFunction( + return LossFunction( # type: ignore size_loss=size_loss_fn, classification_loss=classification_loss_fn, detection_loss=detection_loss_fn,