You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

encoder.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. im
  5. class ConvLayer(nn.Module):
  6. def __init__(self, c_in):
  7. super(ConvLayer, self).__init__()
  8. padding = 1 if torch.__version__>='1.5.0' else 2
  9. self.downConv = nn.Conv1d(in_channels=c_in,
  10. out_channels=c_in,
  11. kernel_size=3,
  12. padding=padding,
  13. padding_mode='circular')
  14. self.norm = nn.BatchNorm1d(c_in)
  15. self.activation = nn.ELU()
  16. self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
  17. def forward(self, x):
  18. x = self.downConv(x.permute(0, 2, 1))
  19. x = self.norm(x)
  20. x = self.activation(x)
  21. x = self.maxPool(x)
  22. x = x.transpose(1,2)
  23. return x
  24. class EncoderLayer(nn.Module):
  25. def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
  26. super(EncoderLayer, self).__init__()
  27. d_ff = d_ff or 4*d_model
  28. self.attention = attention
  29. self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
  30. self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
  31. self.norm1 = nn.LayerNorm(d_model)
  32. self.norm2 = nn.LayerNorm(d_model)
  33. self.dropout = nn.Dropout(dropout)
  34. self.activation = F.relu if activation == "relu" else F.gelu
  35. def forward(self, x, attn_mask=None):
  36. # x [B, L, D]
  37. # x = x + self.dropout(self.attention(
  38. # x, x, x,
  39. # attn_mask = attn_mask
  40. # ))
  41. new_x, attn = self.attention(
  42. x, x, x,
  43. attn_mask = attn_mask
  44. )
  45. x = x + self.dropout(new_x)
  46. y = x = self.norm1(x)
  47. y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
  48. y = self.dropout(self.conv2(y).transpose(-1,1))
  49. return self.norm2(x+y), attn
  50. class Encoder(nn.Module):
  51. def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
  52. super(Encoder, self).__init__()
  53. self.attn_layers = nn.ModuleList(attn_layers)
  54. self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
  55. self.norm = norm_layer
  56. def forward(self, x, attn_mask=None):
  57. # x [B, L, D]
  58. attns = []
  59. if self.conv_layers is not None:
  60. for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
  61. x, attn = attn_layer(x, attn_mask=attn_mask)
  62. x = conv_layer(x)
  63. attns.append(attn)
  64. x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
  65. attns.append(attn)
  66. else:
  67. for attn_layer in self.attn_layers:
  68. x, attn = attn_layer(x, attn_mask=attn_mask)
  69. attns.append(attn)
  70. if self.norm is not None:
  71. x = self.norm(x)
  72. return x, attns
  73. class EncoderStack(nn.Module):
  74. def __init__(self, encoders, inp_lens):
  75. super(EncoderStack, self).__init__()
  76. self.encoders = nn.ModuleList(encoders)
  77. self.inp_lens = inp_lens
  78. def forward(self, x, attn_mask=None):
  79. # x [B, L, D]
  80. x_stack = []; attns = []
  81. for i_len, encoder in zip(self.inp_lens, self.encoders):
  82. inp_len = x.shape[1]//(2**i_len)
  83. x_s, attn = encoder(x[:, -inp_len:, :])
  84. x_stack.append(x_s); attns.append(attn)
  85. x_stack = torch.cat(x_stack, -2)
  86. return x_stack, attns

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN