diff --git a/model/attention/AFT.py b/model/attention/AFT.py new file mode 100644 index 0000000..dc8419f --- /dev/null +++ b/model/attention/AFT.py @@ -0,0 +1,44 @@ +""" +MindSpore implementation of 'AFT Attention' +Refer to "An Attention Free Transformer" +""" +import mindspore as ms +from mindspore import nn +from mindspore import ops + + +class AFT_FULL(nn.Cell): + """ AFT Attention """ + def __init__(self, d_model, n=49, simple=False): + super().__init__() + self.fc_q = nn.Dense(d_model, d_model) + self.fc_k = nn.Dense(d_model, d_model) + self.fc_v = nn.Dense(d_model, d_model) + if simple: + self.position_biases = ms.ops.zeros((n, n)) + else: + self.position_biases = ms.Parameter(ms.ops.ones((n, n))) + self.d_model = d_model + self.n = n + self.sigmoid = nn.Sigmoid() + + def construct(self, x): + B, N, D = x.shape + + q = self.fc_q(x) + k = self.fc_k(x).view(1, B, N, D) + v = self.fc_v(x).view(1, B, N, D) + + numerator = ops.sum(ops.exp(k + self.position_biases.view(N, 1, -1, 1)) * v, dim=2) + denominator = ops.sum(ops.exp(k + self.position_biases.view(N, 1, -1, 1)), dim=2) + + out = numerator / denominator + out = self.sigmoid(q) * (out.permute(1, 0, 2)) + return out + + +if __name__ == "__main__": + dummy_input = ms.ops.randn((50, 49, 512)) + aft_full = AFT_FULL(d_model=512, n=49) + output = aft_full(dummy_input) + print(output.shape)