diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 9acd36d..1e39031 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -174,6 +174,22 @@ class SelfAttention(nn.Module): return op + def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor: + x = x.squeeze(2).permute(0, 2, 1) + + key = torch.matmul( + x, self.key_fun.weight.T + ) + self.key_fun.bias.unsqueeze(0).unsqueeze(0) + query = torch.matmul( + x, self.query_fun.weight.T + ) + self.query_fun.bias.unsqueeze(0).unsqueeze(0) + + kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / ( + self.temperature * self.att_dim + ) + att_weights = F.softmax(kk_qq, 1) + return att_weights + class ConvConfig(BaseConfig): """Configuration for a basic ConvBlock."""