Browse Source

AFT Attention

pull/3/head
Huyf9 2 years ago
parent
commit
9f8d55fbf1
1 changed files with 44 additions and 0 deletions
  1. +44
    -0
      model/attention/AFT.py

+ 44
- 0
model/attention/AFT.py View File

@@ -0,0 +1,44 @@
"""
MindSpore implementation of 'AFT Attention'
Refer to "An Attention Free Transformer"
"""
import mindspore as ms
from mindspore import nn
from mindspore import ops


class AFT_FULL(nn.Cell):
""" AFT Attention """
def __init__(self, d_model, n=49, simple=False):
super().__init__()
self.fc_q = nn.Dense(d_model, d_model)
self.fc_k = nn.Dense(d_model, d_model)
self.fc_v = nn.Dense(d_model, d_model)
if simple:
self.position_biases = ms.ops.zeros((n, n))
else:
self.position_biases = ms.Parameter(ms.ops.ones((n, n)))
self.d_model = d_model
self.n = n
self.sigmoid = nn.Sigmoid()

def construct(self, x):
B, N, D = x.shape

q = self.fc_q(x)
k = self.fc_k(x).view(1, B, N, D)
v = self.fc_v(x).view(1, B, N, D)

numerator = ops.sum(ops.exp(k + self.position_biases.view(N, 1, -1, 1)) * v, dim=2)
denominator = ops.sum(ops.exp(k + self.position_biases.view(N, 1, -1, 1)), dim=2)

out = numerator / denominator
out = self.sigmoid(q) * (out.permute(1, 0, 2))
return out


if __name__ == "__main__":
dummy_input = ms.ops.randn((50, 49, 512))
aft_full = AFT_FULL(d_model=512, n=49)
output = aft_full(dummy_input)
print(output.shape)

Loading…
Cancel
Save