From d7ddf72c731550855e2bf60805af7462341805c7 Mon Sep 17 00:00:00 2001 From: Santiago Martinez Date: Wed, 22 Feb 2023 14:55:08 +0000 Subject: [PATCH] Made explicit imports in detector/models.py --- bat_detect/detector/model_helpers.py | 14 ++++++++++---- bat_detect/detector/models.py | 18 ++++++++++++++---- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/bat_detect/detector/model_helpers.py b/bat_detect/detector/model_helpers.py index 94657be..b05f361 100644 --- a/bat_detect/detector/model_helpers.py +++ b/bat_detect/detector/model_helpers.py @@ -1,11 +1,17 @@ -import math - -import numpy as np import torch -import torch.nn as nn +from torch import nn import torch.nn.functional as F +__all__ = [ + "SelfAttention", + "ConvBlockDownCoordF", + "ConvBlockDownStandard", + "ConvBlockUpF", + "ConvBlockUpStandard", +] + + class SelfAttention(nn.Module): def __init__(self, ip_dim, att_dim): super(SelfAttention, self).__init__() diff --git a/bat_detect/detector/models.py b/bat_detect/detector/models.py index e9f8941..4f76fe6 100644 --- a/bat_detect/detector/models.py +++ b/bat_detect/detector/models.py @@ -1,12 +1,22 @@ -import numpy as np import torch import torch.fft -import torch.nn as nn import torch.nn.functional as F -import torchvision from torch import nn -from .model_helpers import * +from .model_helpers import ( + SelfAttention, + ConvBlockDownCoordF, + ConvBlockDownStandard, + ConvBlockUpF, + ConvBlockUpStandard, +) + + +__all__ = [ + "Net2DFast", + "Net2DFastNoAttn", + "Net2DFastNoCoordConv", +] class Net2DFast(nn.Module):