mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add method to compute attention weights
This commit is contained in:
parent
4ecbc2b734
commit
16c401b1da
@ -174,6 +174,22 @@ class SelfAttention(nn.Module):
|
||||
|
||||
return op
|
||||
|
||||
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
|
||||
key = torch.matmul(
|
||||
x, self.key_fun.weight.T
|
||||
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
query = torch.matmul(
|
||||
x, self.query_fun.weight.T
|
||||
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||
self.temperature * self.att_dim
|
||||
)
|
||||
att_weights = F.softmax(kk_qq, 1)
|
||||
return att_weights
|
||||
|
||||
|
||||
class ConvConfig(BaseConfig):
|
||||
"""Configuration for a basic ConvBlock."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user