You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

attn.py 6.2 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. import numpy as np
  2. import mindspore.nn as nn
  3. import mindspore.ops as ops
  4. from mindspore import Tensor
  5. from mindspore.common import dtype as mstype
  6. class FullAttention(nn.Module):
  7. def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
  8. super(FullAttention, self).__init__()
  9. self.scale = scale
  10. self.mask_flag = mask_flag
  11. self.output_attention = output_attention
  12. self.dropout = nn.Dropout(attention_dropout)
  13. def construct(self, queries, keys, values, attn_mask): # def forward(self, queries, keys, values, attn_mask):
  14. B, L, H, E = queries.shape
  15. _, S, _, D = values.shape
  16. scale = self.scale or 1.0 / np.sqrt(E)
  17. scores = self.matmul(queries, keys.transpose(0, 1, 3, 2))
  18. if self.mask_flag:
  19. if attn_mask is None:
  20. attn_mask = TriangularCausalMask(B, L, device=queries.device)
  21. scores = scores + (1 - attn_mask.mask) * (-np.inf)
  22. A = self.dropout(self.softmax(scale * scores))
  23. V = self.matmul(A, values)
  24. if self.output_attention:
  25. return V.contiguous(), A
  26. else:
  27. return V.contiguous(), None
  28. class ProbAttention(nn.Module):
  29. def __init__(self, mask_flag=True, factor=5, scale=None, attention_dropout=0.1, output_attention=False):
  30. super(ProbAttention, self).__init__()
  31. self.factor = factor
  32. self.scale = scale
  33. self.mask_flag = mask_flag
  34. self.output_attention = output_attention
  35. self.dropout = nn.Dropout(attention_dropout)
  36. def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
  37. # Q [B, H, L, D]
  38. B, H, L_K, E = K.shape
  39. _, _, L_Q, _ = Q.shape
  40. # calculate the sampled Q_K
  41. K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
  42. index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
  43. K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
  44. Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
  45. # find the Top_k query with sparisty measurement
  46. M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
  47. M_top = M.topk(n_top, sorted=False)[1]
  48. # use the reduced Q to calculate Q_K
  49. Q_reduce = Q[torch.arange(B)[:, None, None],
  50. torch.arange(H)[None, :, None],
  51. M_top, :] # factor*ln(L_q)
  52. Q_K = torch.matmul(Q_reduce, K.transpose(-2, -1)) # factor*ln(L_q)*L_k
  53. return Q_K, M_top
  54. def _get_initial_context(self, V, L_Q):
  55. B, H, L_V, D = V.shape
  56. if not self.mask_flag:
  57. # V_sum = V.sum(dim=-2)
  58. V_sum = V.mean(dim=-2)
  59. contex = V_sum.unsqueeze(-2).expand(B, H, L_Q, V_sum.shape[-1]).clone()
  60. else: # use mask
  61. assert(L_Q == L_V) # requires that L_Q == L_V, i.e. for self-attention only
  62. contex = V.cumsum(dim=-2)
  63. return contex
  64. def _update_context(self, context_in, V, scores, index, L_Q, attn_mask):
  65. B, H, L_V, D = V.shape
  66. if self.mask_flag:
  67. attn_mask = ProbMask(B, H, L_Q, index, scores, device=V.device)
  68. scores.masked_fill_(attn_mask.mask, -np.inf)
  69. attn = torch.softmax(scores, dim=-1) # nn.Softmax(dim=-1)(scores)
  70. context_in[torch.arange(B)[:, None, None],
  71. torch.arange(H)[None, :, None],
  72. index, :] = torch.matmul(attn, V).type_as(context_in)
  73. if self.output_attention:
  74. attns = (torch.ones([B, H, L_V, L_V])/L_V).type_as(attn).to(attn.device)
  75. attns[torch.arange(B)[:, None, None], torch.arange(H)[None, :, None], index, :] = attn
  76. return (context_in, attns)
  77. else:
  78. return (context_in, None)
  79. def forward(self, queries, keys, values, attn_mask):
  80. B, L_Q, H, D = queries.shape
  81. _, L_K, _, _ = keys.shape
  82. queries = queries.transpose(2,1)
  83. keys = keys.transpose(2,1)
  84. values = values.transpose(2,1)
  85. U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
  86. u = self.factor * np.ceil(np.log(L_Q)).astype('int').item() # c*ln(L_q)
  87. U_part = U_part if U_part<L_K else L_K
  88. u = u if u<L_Q else L_Q
  89. scores_top, index = self._prob_QK(queries, keys, sample_k=U_part, n_top=u)
  90. # add scale factor
  91. scale = self.scale or 1./sqrt(D)
  92. if scale is not None:
  93. scores_top = scores_top * scale
  94. # get the context
  95. context = self._get_initial_context(values, L_Q)
  96. # update the context with selected top_k queries
  97. context, attn = self._update_context(context, values, scores_top, index, L_Q, attn_mask)
  98. return context.transpose(2,1).contiguous(), attn
  99. class AttentionLayer(nn.Module):
  100. def __init__(self, attention, d_model, n_heads,
  101. d_keys=None, d_values=None, mix=False):
  102. super(AttentionLayer, self).__init__()
  103. d_keys = d_keys or (d_model//n_heads)
  104. d_values = d_values or (d_model//n_heads)
  105. self.inner_attention = attention
  106. self.query_projection = nn.Linear(d_model, d_keys * n_heads)
  107. self.key_projection = nn.Linear(d_model, d_keys * n_heads)
  108. self.value_projection = nn.Linear(d_model, d_values * n_heads)
  109. self.out_projection = nn.Linear(d_values * n_heads, d_model)
  110. self.n_heads = n_heads
  111. self.mix = mix
  112. def forward(self, queries, keys, values, attn_mask):
  113. B, L, _ = queries.shape
  114. _, S, _ = keys.shape
  115. H = self.n_heads
  116. queries = self.query_projection(queries).view(B, L, H, -1)
  117. keys = self.key_projection(keys).view(B, S, H, -1)
  118. values = self.value_projection(values).view(B, S, H, -1)
  119. out, attn = self.inner_attention(
  120. queries,
  121. keys,
  122. values,
  123. attn_mask
  124. )
  125. if self.mix:
  126. out = out.transpose(2,1).contiguous()
  127. out = out.view(B, L, -1)
  128. return self.out_projection(out), attn

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN