You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

attn.py 6.0 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import numpy as np
  2. import mindspore.nn as nn
  3. import mindspore.ops as ops
  4. from mindspore import Tensor
  5. from mindspore.common import dtype as mstype
  6. class FullAttention(nn.Module):
  7. def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
  8. super(FullAttention, self).__init__()
  9. self.scale = scale
  10. self.mask_flag = mask_flag
  11. self.output_attention = output_attention
  12. self.dropout = nn.Dropout(attention_dropout)
  13. def construct(self, queries, keys, values, attn_mask): # def forward(self, queries, keys, values, attn_mask):
  14. B, L, H, E = queries.shape
  15. _, S, _, D = values.shape
  16. scale = self.scale or 1.0 / np.sqrt(E)
  17. scores = self.matmul(queries, keys.transpose(0, 1, 3, 2))
  18. if self.mask_flag:
  19. if attn_mask is None:
  20. attn_mask = TriangularCausalMask(B, L, device=queries.device)
  21. scores = scores + (1 - attn_mask.mask) * (-np.inf)
  22. A = self.dropout(self.softmax(scale * scores))
  23. V = self.matmul(A, values)
  24. if self.output_attention:
  25. return V.contiguous(), A
  26. else:
  27. return V.contiguous(), None
  28. class ProbAttention(nn.Module):
  29. def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
  30. super(ProbAttention, self).__init__()
  31. self.factor = factor
  32. self.scale = scale
  33. self.mask_flag = mask_flag
  34. self.output_attention = output_attention
  35. self.dropout = nn.Dropout(attention_dropout)
  36. def _prob_QK(self, Q, K, sample_k, n_top):
  37. B, H, L_K, E = K.shape
  38. _, _, L_Q, D = Q.shape
  39. K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
  40. index_sample = ops.random.randint(0, L_K, (L_Q, sample_k), dtype=mstype.int32)
  41. K_sample = K_expand[:, :, ops.arange(L_Q).unsqueeze(1), index_sample, :]
  42. Q_K_sample = ops.matmul(Q.unsqueeze(-2), K_sample.transpose((0, 1, 3, 2))).squeeze(-2)
  43. M = Q_K_sample.max(-1)[0] - ops.div(Q_K_sample.sum(-1), L_K)
  44. _, M_top = ops.top_k(M, n_top, sorted=False)
  45. Q_reduce = Q[ops.arange(B).unsqueeze(1).unsqueeze(2),
  46. ops.arange(H).unsqueeze(0).unsqueeze(2),
  47. M_top, :]
  48. Q_K = ops.matmul(Q_reduce, K.transpose((0, 1, 3, 2)))
  49. return Q_K, M_top
  50. def _get_initial_context(self, V, L_Q):
  51. B, H, L_V, D = V.shape
  52. if not self.mask_flag:
  53. # V_sum = V.sum(dim=-2)
  54. V_sum = V.mean(dim=-2)
  55. contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
  56. else: # use mask
  57. assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
  58. contex = V.cumsum(dim=-2)
  59. return contex
  60. def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
  61. B, H, L_V, D = V.shape
  62. if self.mask_flag:
  63. attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
  64. scores.masked_fill_(attn_mask.mask, -np.inf)
  65. attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
  66. context_in[torch.arange(B)[:, None, None],
  67. torch.arange(H)[None, :, None],
  68. index, :] = torch.matmul(attn, V).type_as(context_in)
  69. if self.output_attention:
  70. attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
  71. attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
  72. return (context_in, attns)
  73. else:
  74. return (context_in, None)
  75. def forward(self, queries, keys, values, attn_mask):
  76. B, L_Q, H, D = queries.shape
  77. _, L_K, _, _ = keys.shape
  78. queries = queries.transpose(2,1)
  79. keys = keys.transpose(2,1)
  80. values = values.transpose(2,1)
  81. U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
  82. u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
  83. U_part = U_part if U_part<L_K else L_K
  84. u = u if u<L_Q else L_Q
  85. scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
  86. # add scale factor
  87. scale = self.scale or 1./sqrt(D)
  88. if scale is not None:
  89. scores_top = scores_top * scale
  90. # get the context
  91. context = self._get_initial_context(values, L_Q)
  92. # update the context with selected top_k queries
  93. context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
  94. return context.transpose(2,1).contiguous(), attn
  95. class AttentionLayer(nn.Module):
  96. def __init__(self, attention, d_model, n_heads,
  97. d_keys=None, d_values=None, mix=False):
  98. super(AttentionLayer, self).__init__()
  99. d_keys = d_keys or (d_model//n_heads)
  100. d_values = d_values or (d_model//n_heads)
  101. self.inner_attention = attention
  102. self.query_projection = nn.Linear(d_model, d_keys * n_heads)
  103. self.key_projection = nn.Linear(d_model, d_keys * n_heads)
  104. self.value_projection = nn.Linear(d_model, d_values * n_heads)
  105. self.out_projection = nn.Linear(d_values * n_heads, d_model)
  106. self.n_heads = n_heads
  107. self.mix = mix
  108. def forward(self, queries, keys, values, attn_mask):
  109. B, L, _ = queries.shape
  110. _, S, _ = keys.shape
  111. H = self.n_heads
  112. queries = self.query_projection(queries).view(B, L, H, -1)
  113. keys = self.key_projection(keys).view(B, S, H, -1)
  114. values = self.value_projection(values).view(B, S, H, -1)
  115. out, attn = self.inner_attention(
  116. queries,
  117. keys,
  118. values,
  119. attn_mask
  120. )
  121. if self.mix:
  122. out = out.transpose(2,1).contiguous()
  123. out = out.view(B, L, -1)
  124. return self.out_projection(out), attn

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN