From c1945ebdb7e9b69d5413f53937aefa7e68944ce1 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 16 Aug 2025 18:49:30 +0100 Subject: [PATCH] Minor difference in loss fixed --- src/batdetect2/train/losses.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index a93b4ea..56bc092 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -210,7 +210,7 @@ class FocalLoss(nn.Module): pos_loss = pos_loss * torch.tensor(self.class_weights) if self.mask_zero: - valid_mask = gt.any(dim=1, keepdim=True).float() + valid_mask = (gt.sum(1) > 0).float().unsqueeze(1) pos_loss = pos_loss * valid_mask neg_loss = neg_loss * valid_mask @@ -476,7 +476,7 @@ def build_loss( size_loss_fn = BBoxLoss() - return LossFunction( + return LossFunction( # type: ignore size_loss=size_loss_fn, classification_loss=classification_loss_fn, detection_loss=detection_loss_fn,