mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Add method to compute attention weights
This commit is contained in:
parent
4ecbc2b734
commit
16c401b1da
@ -174,6 +174,22 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
return op
|
return op
|
||||||
|
|
||||||
|
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
|
|
||||||
|
key = torch.matmul(
|
||||||
|
x, self.key_fun.weight.T
|
||||||
|
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
query = torch.matmul(
|
||||||
|
x, self.query_fun.weight.T
|
||||||
|
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||||
|
self.temperature * self.att_dim
|
||||||
|
)
|
||||||
|
att_weights = F.softmax(kk_qq, 1)
|
||||||
|
return att_weights
|
||||||
|
|
||||||
|
|
||||||
class ConvConfig(BaseConfig):
|
class ConvConfig(BaseConfig):
|
||||||
"""Configuration for a basic ConvBlock."""
|
"""Configuration for a basic ConvBlock."""
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user