Compare commits

...

2 Commits

Author SHA1 Message Date
mbsantiago
c1945ebdb7 Minor difference in loss fixed 2025-08-16 18:49:30 +01:00
mbsantiago
13da5d2814 Remove unnecessary extra conv layer 2025-08-16 18:32:35 +01:00
4 changed files with 11 additions and 12 deletions

3
.gitignore vendored
View File

@ -120,3 +120,6 @@ notebooks/lightning_logs
# Intermediate artifacts
example_data/preprocessed
# Dev notebooks
notebooks/tmp

View File

@ -94,7 +94,10 @@ class Net2DFast(nn.Module):
num_filts // 4, 2, kernel_size=1, padding=0
)
self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
num_filts // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
)
if self.emb_dim > 0:

View File

@ -26,7 +26,6 @@ from soundevent import data
from torch import nn
from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.blocks import ConvBlock
from batdetect2.models.bottleneck import (
DEFAULT_BOTTLENECK_CONFIG,
BottleneckConfig,
@ -89,7 +88,6 @@ class Backbone(BackboneModel):
def __init__(
self,
input_height: int,
out_channels: int,
encoder: Encoder,
decoder: Decoder,
bottleneck: nn.Module,
@ -116,16 +114,12 @@ class Backbone(BackboneModel):
"""
super().__init__()
self.input_height = input_height
self.out_channels = out_channels
self.encoder = encoder
self.decoder = decoder
self.bottleneck = bottleneck
self.final_conv = ConvBlock(
in_channels=decoder.out_channels,
out_channels=out_channels,
)
self.out_channels = decoder.out_channels
# Down/Up scaling factor. Need to ensure inputs are divisible by
# this factor in order to be processed by the down/up scaling layers
@ -164,7 +158,7 @@ class Backbone(BackboneModel):
# Restore original size
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
return self.final_conv(x)
return x
class BackboneConfig(BaseConfig):
@ -299,7 +293,6 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
return Backbone(
input_height=config.input_height,
out_channels=config.out_channels,
encoder=encoder,
decoder=decoder,
bottleneck=bottleneck,

View File

@ -210,7 +210,7 @@ class FocalLoss(nn.Module):
pos_loss = pos_loss * torch.tensor(self.class_weights)
if self.mask_zero:
valid_mask = gt.any(dim=1, keepdim=True).float()
valid_mask = (gt.sum(1) > 0).float().unsqueeze(1)
pos_loss = pos_loss * valid_mask
neg_loss = neg_loss * valid_mask
@ -476,7 +476,7 @@ def build_loss(
size_loss_fn = BBoxLoss()
return LossFunction(
return LossFunction( # type: ignore
size_loss=size_loss_fn,
classification_loss=classification_loss_fn,
detection_loss=detection_loss_fn,