mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Minor difference in loss fixed
This commit is contained in:
parent
13da5d2814
commit
c1945ebdb7
@ -210,7 +210,7 @@ class FocalLoss(nn.Module):
|
|||||||
pos_loss = pos_loss * torch.tensor(self.class_weights)
|
pos_loss = pos_loss * torch.tensor(self.class_weights)
|
||||||
|
|
||||||
if self.mask_zero:
|
if self.mask_zero:
|
||||||
valid_mask = gt.any(dim=1, keepdim=True).float()
|
valid_mask = (gt.sum(1) > 0).float().unsqueeze(1)
|
||||||
pos_loss = pos_loss * valid_mask
|
pos_loss = pos_loss * valid_mask
|
||||||
neg_loss = neg_loss * valid_mask
|
neg_loss = neg_loss * valid_mask
|
||||||
|
|
||||||
@ -476,7 +476,7 @@ def build_loss(
|
|||||||
|
|
||||||
size_loss_fn = BBoxLoss()
|
size_loss_fn = BBoxLoss()
|
||||||
|
|
||||||
return LossFunction(
|
return LossFunction( # type: ignore
|
||||||
size_loss=size_loss_fn,
|
size_loss=size_loss_fn,
|
||||||
classification_loss=classification_loss_fn,
|
classification_loss=classification_loss_fn,
|
||||||
detection_loss=detection_loss_fn,
|
detection_loss=detection_loss_fn,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user