Browse Source

Delete 'model/attention/SimplifiedSelfAttention.py'

v1
limingjuan 2 years ago
parent
commit
7b6c11b316
1 changed files with 0 additions and 49 deletions
  1. +0
    -49
      model/attention/SimplifiedSelfAttention.py

+ 0
- 49
model/attention/SimplifiedSelfAttention.py View File

@@ -1,49 +0,0 @@
"""
MindSpore implementation of 'SimplifiedSelfAttention'
"""
import math
import mindspore as ms
from mindspore import nn


class SimplifiedScaledDotProductAttention(nn.Cell):
""" Simplified Scale Dot Product Attention """
def __init__(self, d_model, h=1, drop_rate=0.1):
super().__init__()
self.d_model = d_model
self.d_k = d_model // h
self.d_v = d_model // h
self.h = h

self.fc_o = nn.Dense(h * self.d_v, d_model)
self.dropout = nn.Dropout(p=drop_rate)

def construct(self, queries, keys=None, values=None, attention_mask=None, attention_weights=None):
if keys is None:
keys = queries
if values is None:
values = queries

B, Nq = queries.shape[:2]
Nk = keys.shape[1]
q = queries.view(B, Nq, self.h, self.d_k).permute(0, 2, 1, 3)
k = keys.view(B, Nk, self.h, self.d_k).permute(0, 2, 3, 1)
v = values.view(B, Nk, self.h, self.d_v).permute(0, 2, 1, 3)

att = ms.ops.matmul(q, k) / math.sqrt(self.d_k)
if attention_weights is not None:
att = att * attention_weights
if attention_mask is not None:
att = att.masked_fill(attention_mask, -float('inf'))

att = ms.ops.softmax(att, axis=-1)
att = self.dropout(att)
att = ms.ops.matmul(att, v).permute(0, 2, 1, 3).view(B, Nq, self.h * self.d_v)
return self.fc_o(att)


if __name__ == '__main__':
dummy_input = ms.ops.randn(50, 49, 512)
ssa = SimplifiedScaledDotProductAttention(d_model=512, h=8)
output = ssa(dummy_input, dummy_input, dummy_input)
print(output.shape)

Loading…
Cancel
Save