You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.py 7.1 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from utils.masking import TriangularCausalMask, ProbMask
  5. from models.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack
  6. from models.decoder import Decoder, DecoderLayer
  7. from models.attn import FullAttention, ProbAttention, AttentionLayer
  8. from models.embed import DataEmbedding
  9. class Informer(nn.Module):
  10. def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
  11. factor=5, d_model=512, n_heads=8, e_layers=3, d_layers=2, d_ff=512,
  12. dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu',
  13. output_attention = False, distil=True, mix=True,
  14. device=torch.device('cuda:0')):
  15. super(Informer, self).__init__()
  16. self.pred_len = out_len
  17. self.attn = attn
  18. self.output_attention = output_attention
  19. # Encoding
  20. self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
  21. self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
  22. # Attention
  23. Attn = ProbAttention if attn=='prob' else FullAttention
  24. # Encoder
  25. self.encoder = Encoder(
  26. [
  27. EncoderLayer(
  28. AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention),
  29. d_model, n_heads, mix=False),
  30. d_model,
  31. d_ff,
  32. dropout=dropout,
  33. activation=activation
  34. ) for l in range(e_layers)
  35. ],
  36. [
  37. ConvLayer(
  38. d_model
  39. ) for l in range(e_layers-1)
  40. ] if distil else None,
  41. norm_layer=torch.nn.LayerNorm(d_model)
  42. )
  43. # Decoder
  44. self.decoder = Decoder(
  45. [
  46. DecoderLayer(
  47. AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False),
  48. d_model, n_heads, mix=mix),
  49. AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False),
  50. d_model, n_heads, mix=False),
  51. d_model,
  52. d_ff,
  53. dropout=dropout,
  54. activation=activation,
  55. )
  56. for l in range(d_layers)
  57. ],
  58. norm_layer=torch.nn.LayerNorm(d_model)
  59. )
  60. # self.end_conv1 = nn.Conv1d(in_channels=label_len+out_len, out_channels=out_len, kernel_size=1, bias=True)
  61. # self.end_conv2 = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=1, bias=True)
  62. self.projection = nn.Linear(d_model, c_out, bias=True)
  63. def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
  64. enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
  65. enc_out = self.enc_embedding(x_enc, x_mark_enc)
  66. enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
  67. dec_out = self.dec_embedding(x_dec, x_mark_dec)
  68. dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
  69. dec_out = self.projection(dec_out)
  70. # dec_out = self.end_conv1(dec_out)
  71. # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2)
  72. if self.output_attention:
  73. return dec_out[:,-self.pred_len:,:], attns
  74. else:
  75. return dec_out[:,-self.pred_len:,:] # [B, L, D]
  76. class InformerStack(nn.Module):
  77. def __init__(self, enc_in, dec_in, c_out, seq_len, label_len, out_len,
  78. factor=5, d_model=512, n_heads=8, e_layers=[3,2,1], d_layers=2, d_ff=512,
  79. dropout=0.0, attn='prob', embed='fixed', freq='h', activation='gelu',
  80. output_attention = False, distil=True, mix=True,
  81. device=torch.device('cuda:0')):
  82. super(InformerStack, self).__init__()
  83. self.pred_len = out_len
  84. self.attn = attn
  85. self.output_attention = output_attention
  86. # Encoding
  87. self.enc_embedding = DataEmbedding(enc_in, d_model, embed, freq, dropout)
  88. self.dec_embedding = DataEmbedding(dec_in, d_model, embed, freq, dropout)
  89. # Attention
  90. Attn = ProbAttention if attn=='prob' else FullAttention
  91. # Encoder
  92. inp_lens = list(range(len(e_layers))) # [0,1,2,...] you can customize here
  93. encoders = [
  94. Encoder(
  95. [
  96. EncoderLayer(
  97. AttentionLayer(Attn(False, factor, attention_dropout=dropout, output_attention=output_attention),
  98. d_model, n_heads, mix=False),
  99. d_model,
  100. d_ff,
  101. dropout=dropout,
  102. activation=activation
  103. ) for l in range(el)
  104. ],
  105. [
  106. ConvLayer(
  107. d_model
  108. ) for l in range(el-1)
  109. ] if distil else None,
  110. norm_layer=torch.nn.LayerNorm(d_model)
  111. ) for el in e_layers]
  112. self.encoder = EncoderStack(encoders, inp_lens)
  113. # Decoder
  114. self.decoder = Decoder(
  115. [
  116. DecoderLayer(
  117. AttentionLayer(Attn(True, factor, attention_dropout=dropout, output_attention=False),
  118. d_model, n_heads, mix=mix),
  119. AttentionLayer(FullAttention(False, factor, attention_dropout=dropout, output_attention=False),
  120. d_model, n_heads, mix=False),
  121. d_model,
  122. d_ff,
  123. dropout=dropout,
  124. activation=activation,
  125. )
  126. for l in range(d_layers)
  127. ],
  128. norm_layer=torch.nn.LayerNorm(d_model)
  129. )
  130. # self.end_conv1 = nn.Conv1d(in_channels=label_len+out_len, out_channels=out_len, kernel_size=1, bias=True)
  131. # self.end_conv2 = nn.Conv1d(in_channels=d_model, out_channels=c_out, kernel_size=1, bias=True)
  132. self.projection = nn.Linear(d_model, c_out, bias=True)
  133. def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
  134. enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
  135. enc_out = self.enc_embedding(x_enc, x_mark_enc)
  136. enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
  137. dec_out = self.dec_embedding(x_dec, x_mark_dec)
  138. dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
  139. dec_out = self.projection(dec_out)
  140. # dec_out = self.end_conv1(dec_out)
  141. # dec_out = self.end_conv2(dec_out.transpose(2,1)).transpose(1,2)
  142. if self.output_attention:
  143. return dec_out[:,-self.pred_len:,:], attns
  144. else:
  145. return dec_out[:,-self.pred_len:,:] # [B, L, D]

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN