mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Remove unnecessary extra conv layer
This commit is contained in:
parent
78a0975864
commit
13da5d2814
3
.gitignore
vendored
3
.gitignore
vendored
@ -120,3 +120,6 @@ notebooks/lightning_logs
|
|||||||
|
|
||||||
# Intermediate artifacts
|
# Intermediate artifacts
|
||||||
example_data/preprocessed
|
example_data/preprocessed
|
||||||
|
|
||||||
|
# Dev notebooks
|
||||||
|
notebooks/tmp
|
||||||
|
|||||||
@ -94,7 +94,10 @@ class Net2DFast(nn.Module):
|
|||||||
num_filts // 4, 2, kernel_size=1, padding=0
|
num_filts // 4, 2, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
self.conv_classes_op = nn.Conv2d(
|
self.conv_classes_op = nn.Conv2d(
|
||||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
num_filts // 4,
|
||||||
|
self.num_classes + 1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.emb_dim > 0:
|
if self.emb_dim > 0:
|
||||||
|
|||||||
@ -26,7 +26,6 @@ from soundevent import data
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.blocks import ConvBlock
|
|
||||||
from batdetect2.models.bottleneck import (
|
from batdetect2.models.bottleneck import (
|
||||||
DEFAULT_BOTTLENECK_CONFIG,
|
DEFAULT_BOTTLENECK_CONFIG,
|
||||||
BottleneckConfig,
|
BottleneckConfig,
|
||||||
@ -89,7 +88,6 @@ class Backbone(BackboneModel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
input_height: int,
|
input_height: int,
|
||||||
out_channels: int,
|
|
||||||
encoder: Encoder,
|
encoder: Encoder,
|
||||||
decoder: Decoder,
|
decoder: Decoder,
|
||||||
bottleneck: nn.Module,
|
bottleneck: nn.Module,
|
||||||
@ -116,16 +114,12 @@ class Backbone(BackboneModel):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_height = input_height
|
self.input_height = input_height
|
||||||
self.out_channels = out_channels
|
|
||||||
|
|
||||||
self.encoder = encoder
|
self.encoder = encoder
|
||||||
self.decoder = decoder
|
self.decoder = decoder
|
||||||
self.bottleneck = bottleneck
|
self.bottleneck = bottleneck
|
||||||
|
|
||||||
self.final_conv = ConvBlock(
|
self.out_channels = decoder.out_channels
|
||||||
in_channels=decoder.out_channels,
|
|
||||||
out_channels=out_channels,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Down/Up scaling factor. Need to ensure inputs are divisible by
|
# Down/Up scaling factor. Need to ensure inputs are divisible by
|
||||||
# this factor in order to be processed by the down/up scaling layers
|
# this factor in order to be processed by the down/up scaling layers
|
||||||
@ -164,7 +158,7 @@ class Backbone(BackboneModel):
|
|||||||
# Restore original size
|
# Restore original size
|
||||||
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
x = _restore_pad(x, h_pad=h_pad, w_pad=w_pad)
|
||||||
|
|
||||||
return self.final_conv(x)
|
return x
|
||||||
|
|
||||||
|
|
||||||
class BackboneConfig(BaseConfig):
|
class BackboneConfig(BaseConfig):
|
||||||
@ -299,7 +293,6 @@ def build_backbone(config: BackboneConfig) -> BackboneModel:
|
|||||||
|
|
||||||
return Backbone(
|
return Backbone(
|
||||||
input_height=config.input_height,
|
input_height=config.input_height,
|
||||||
out_channels=config.out_channels,
|
|
||||||
encoder=encoder,
|
encoder=encoder,
|
||||||
decoder=decoder,
|
decoder=decoder,
|
||||||
bottleneck=bottleneck,
|
bottleneck=bottleneck,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user