mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Minor difference in loss fixed
This commit is contained in:
parent
13da5d2814
commit
c1945ebdb7
@ -210,7 +210,7 @@ class FocalLoss(nn.Module):
|
||||
pos_loss = pos_loss * torch.tensor(self.class_weights)
|
||||
|
||||
if self.mask_zero:
|
||||
valid_mask = gt.any(dim=1, keepdim=True).float()
|
||||
valid_mask = (gt.sum(1) > 0).float().unsqueeze(1)
|
||||
pos_loss = pos_loss * valid_mask
|
||||
neg_loss = neg_loss * valid_mask
|
||||
|
||||
@ -476,7 +476,7 @@ def build_loss(
|
||||
|
||||
size_loss_fn = BBoxLoss()
|
||||
|
||||
return LossFunction(
|
||||
return LossFunction( # type: ignore
|
||||
size_loss=size_loss_fn,
|
||||
classification_loss=classification_loss_fn,
|
||||
detection_loss=detection_loss_fn,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user