Minor difference in loss fixed

This commit is contained in:
mbsantiago 2025-08-16 18:49:30 +01:00
parent 13da5d2814
commit c1945ebdb7

View File

@ -210,7 +210,7 @@ class FocalLoss(nn.Module):
pos_loss = pos_loss * torch.tensor(self.class_weights) pos_loss = pos_loss * torch.tensor(self.class_weights)
if self.mask_zero: if self.mask_zero:
valid_mask = gt.any(dim=1, keepdim=True).float() valid_mask = (gt.sum(1) > 0).float().unsqueeze(1)
pos_loss = pos_loss * valid_mask pos_loss = pos_loss * valid_mask
neg_loss = neg_loss * valid_mask neg_loss = neg_loss * valid_mask
@ -476,7 +476,7 @@ def build_loss(
size_loss_fn = BBoxLoss() size_loss_fn = BBoxLoss()
return LossFunction( return LossFunction( # type: ignore
size_loss=size_loss_fn, size_loss=size_loss_fn,
classification_loss=classification_loss_fn, classification_loss=classification_loss_fn,
detection_loss=detection_loss_fn, detection_loss=detection_loss_fn,