1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
| import torch import torch.nn as nn import torch.nn.functional as F from math import sqrt
class SelfAttention(nn.Module): def __init__(self, input_dim, dim_k, dim_v): super(SelfAttention, self).__init__() self.q = nn.Linear(input_dim, dim_k) self.k = nn.Linear(input_dim, dim_k) self.v = nn.Linear(input_dim, dim_v) self.norm = sqrt(dim_k) def forward(self, x): Q = self.q(x) K = self.k(x) V = self.v(x) atten = nn.Softmax(dim=-1)(torch.bmm(Q, K.permute(0, 2, 1)) / self.norm) output = torch.bmm(atten, V) return output
if __name__ == "__main__": X = torch.randn(4, 3, 2) self_attention = SelfAttention(2, 4, 5) res = self_attention(X) print(res)
|