mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Remove unnecessary extra conv layer
This commit is contained in:
parent
78a0975864
commit
13da5d2814
3
.gitignore
vendored
3
.gitignore
vendored
@ -120,3 +120,6 @@ notebooks/lightning_logs
|
||||
|
||||
# Intermediate artifacts
|
||||
example_data/preprocessed
|
||||
|
||||
# Dev notebooks
|
||||
notebooks/tmp
|
||||
|
||||
@ -94,7 +94,10 @@ class Net2DFast(nn.Module):
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
||||
num_filts // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
|
||||
@ -26,7 +26,6 @@ from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.models.blocks import ConvBlock
|
||||
from batdetect2.models.bottleneck import (
|
||||
DEFAULT_BOTTLENECK_CONFIG,
|
||||
BottleneckConfig,
|
||||
@ -89,7 +88,6 @@ class Backbone(BackboneModel):
|
||||
def __init__(
|
||||
self,
|
||||
input_height: int,
|
||||
out_channels: int,
|
||||
encoder: Encoder,
|
||||
decoder: Decoder,
|
||||
bottleneck: nn.Module,
|
||||
@ -116,16 +114,12 @@ class Backbone(BackboneModel):
|
||||
"""
|
||||
super().__init__()
|
||||
self.input_height = input_height
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.encoder = encoder
|
||||
self.decoder = decoder
|
||||
self.bottleneck = bottleneck
|
||||
|
||||
self.final_conv = ConvBlock(
|
||||
in_channels=decoder.out_channels,
|
||||
out_channels=out_channels,
|
||||
)
|
||||
self.out_channels = decoder.out_channels
|
||||
|
||||
# Down/Up scaling factor. Need to ensure inputs are divisible by
|
||||
# this factor in order to be processed by the down/up scaling layers
|
||||
@ -164,7 +158,7 @@ class Backbone(BackboneModel):
|
||||
# Restore original size
|
||||
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||
|
||||
return self.final_conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class BackboneConfig(BaseConfig):
|
||||
@ -299,7 +293,6 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
|
||||
|
||||
return Backbone(
|
||||
input_height=config.input_height,
|
||||
out_channels=config.out_channels,
|
||||
encoder=encoder,
|
||||
decoder=decoder,
|
||||
bottleneck=bottleneck,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user