From 13da5d28144285a904289cf68c5fe39c730b794f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Sat, 16 Aug 2025 18:32:35 +0100 Subject: [PATCH] Remove unnecessary extra conv layer --- .gitignore | 3 +++ src/batdetect2/detector/models.py | 5 ++++- src/batdetect2/models/backbones.py | 11 ++--------- 3 files changed, 9 insertions(+), 10 deletions(-) diff --git a/.gitignore b/.gitignore index 7787be9..15f192a 100644 --- a/.gitignore +++ b/.gitignore @@ -120,3 +120,6 @@ notebooks/lightning_logs # Intermediate artifacts example_data/preprocessed + +# Dev notebooks +notebooks/tmp diff --git a/src/batdetect2/detector/models.py b/src/batdetect2/detector/models.py index 105ddaf..2bffeab 100644 --- a/src/batdetect2/detector/models.py +++ b/src/batdetect2/detector/models.py @@ -94,7 +94,10 @@ class Net2DFast(nn.Module): num_filts // 4, 2, kernel_size=1, padding=0 ) self.conv_classes_op = nn.Conv2d( - num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0 + num_filts // 4, + self.num_classes + 1, + kernel_size=1, + padding=0, ) if self.emb_dim > 0: diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index e7f1ada..be55932 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -26,7 +26,6 @@ from soundevent import data from torch import nn from batdetect2.configs import BaseConfig, load_config -from batdetect2.models.blocks import ConvBlock from batdetect2.models.bottleneck import ( DEFAULT_BOTTLENECK_CONFIG, BottleneckConfig, @@ -89,7 +88,6 @@ class Backbone(BackboneModel): def __init__( self, input_height: int, - out_channels: int, encoder: Encoder, decoder: Decoder, bottleneck: nn.Module, @@ -116,16 +114,12 @@ class Backbone(BackboneModel): """ super().__init__() self.input_height = input_height - self.out_channels = out_channels self.encoder = encoder self.decoder = decoder self.bottleneck = bottleneck - self.final_conv = ConvBlock( - in_channels=decoder.out_channels, - out_channels=out_channels, - ) + self.out_channels = decoder.out_channels # Down/Up scaling factor. Need to ensure inputs are divisible by # this factor in order to be processed by the down/up scaling layers @@ -164,7 +158,7 @@ class Backbone(BackboneModel): # Restore original size x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad) - return self.final_conv(x) + return x class BackboneConfig(BaseConfig): @@ -299,7 +293,6 @@ def build_backbone(config: BackboneConfig) -> BackboneModel: return Backbone( input_height=config.input_height, - out_channels=config.out_channels, encoder=encoder, decoder=decoder, bottleneck=bottleneck,