Add method to compute attention weights

This commit is contained in:
mbsantiago 2025-11-22 00:34:31 +00:00
parent 4ecbc2b734
commit 16c401b1da

View File

@ -174,6 +174,22 @@ class SelfAttention(nn.Module):
return op
def compute_attention_weights(self, x: torch.Tensor) -> torch.Tensor:
x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul(
x, self.key_fun.weight.T
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
query = torch.matmul(
x, self.query_fun.weight.T
) + self.query_fun.bias.unsqueeze(0).unsqueeze(0)
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
self.temperature * self.att_dim
)
att_weights = F.softmax(kk_qq, 1)
return att_weights
class ConvConfig(BaseConfig):
"""Configuration for a basic ConvBlock."""