diff --git a/bat_detect/detector/model_helpers.py b/bat_detect/detector/model_helpers.py index 94657be..b05f361 100644 --- a/bat_detect/detector/model_helpers.py +++ b/bat_detect/detector/model_helpers.py @@ -1,11 +1,17 @@ -import math - -import numpy as np import torch -import torch.nn as nn +from torch import nn import torch.nn.functional as F +__all__ = [ + "SelfAttention", + "ConvBlockDownCoordF", + "ConvBlockDownStandard", + "ConvBlockUpF", + "ConvBlockUpStandard", +] + + class SelfAttention(nn.Module): def __init__(self, ip_dim, att_dim): super(SelfAttention, self).__init__() diff --git a/bat_detect/detector/models.py b/bat_detect/detector/models.py index e9f8941..4f76fe6 100644 --- a/bat_detect/detector/models.py +++ b/bat_detect/detector/models.py @@ -1,12 +1,22 @@ -import numpy as np import torch import torch.fft -import torch.nn as nn import torch.nn.functional as F -import torchvision from torch import nn -from .model_helpers import * +from .model_helpers import ( + SelfAttention, + ConvBlockDownCoordF, + ConvBlockDownStandard, + ConvBlockUpF, + ConvBlockUpStandard, +) + + +__all__ = [ + "Net2DFast", + "Net2DFastNoAttn", + "Net2DFastNoCoordConv", +] class Net2DFast(nn.Module):