You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

encoder.py 3.6 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class ConvLayer(nn.Module):
  5. def __init__(self, c_in):
  6. super(ConvLayer, self).__init__()
  7. padding = 1 if torch.__version__>='1.5.0' else 2
  8. self.downConv = nn.Conv1d(in_channels=c_in,
  9. out_channels=c_in,
  10. kernel_size=3,
  11. padding=padding,
  12. padding_mode='circular')
  13. self.norm = nn.BatchNorm1d(c_in)
  14. self.activation = nn.ELU()
  15. self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
  16. def forward(self, x):
  17. x = self.downConv(x.permute(0, 2, 1))
  18. x = self.norm(x)
  19. x = self.activation(x)
  20. x = self.maxPool(x)
  21. x = x.transpose(1,2)
  22. return x
  23. class EncoderLayer(nn.Module):
  24. def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
  25. super(EncoderLayer, self).__init__()
  26. d_ff = d_ff or 4*d_model
  27. self.attention = attention
  28. self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
  29. self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
  30. self.norm1 = nn.LayerNorm(d_model)
  31. self.norm2 = nn.LayerNorm(d_model)
  32. self.dropout = nn.Dropout(dropout)
  33. self.activation = F.relu if activation == "relu" else F.gelu
  34. def forward(self, x, attn_mask=None):
  35. # x [B, L, D]
  36. # x = x + self.dropout(self.attention(
  37. # x, x, x,
  38. # attn_mask = attn_mask
  39. # ))
  40. new_x, attn = self.attention(
  41. x, x, x,
  42. attn_mask = attn_mask
  43. )
  44. x = x + self.dropout(new_x)
  45. y = x = self.norm1(x)
  46. y = self.dropout(self.activation(self.conv1(y.transpose(-1,1))))
  47. y = self.dropout(self.conv2(y).transpose(-1,1))
  48. return self.norm2(x+y), attn
  49. class Encoder(nn.Module):
  50. def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
  51. super(Encoder, self).__init__()
  52. self.attn_layers = nn.ModuleList(attn_layers)
  53. self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
  54. self.norm = norm_layer
  55. def forward(self, x, attn_mask=None):
  56. # x [B, L, D]
  57. attns = []
  58. if self.conv_layers is not None:
  59. for attn_layer, conv_layer in zip(self.attn_layers, self.conv_layers):
  60. x, attn = attn_layer(x, attn_mask=attn_mask)
  61. x = conv_layer(x)
  62. attns.append(attn)
  63. x, attn = self.attn_layers[-1](x, attn_mask=attn_mask)
  64. attns.append(attn)
  65. else:
  66. for attn_layer in self.attn_layers:
  67. x, attn = attn_layer(x, attn_mask=attn_mask)
  68. attns.append(attn)
  69. if self.norm is not None:
  70. x = self.norm(x)
  71. return x, attns
  72. class EncoderStack(nn.Module):
  73. def __init__(self, encoders, inp_lens):
  74. super(EncoderStack, self).__init__()
  75. self.encoders = nn.ModuleList(encoders)
  76. self.inp_lens = inp_lens
  77. def forward(self, x, attn_mask=None):
  78. # x [B, L, D]
  79. x_stack = []; attns = []
  80. for i_len, encoder in zip(self.inp_lens, self.encoders):
  81. inp_len = x.shape[1]//(2**i_len)
  82. x_s, attn = encoder(x[:, -inp_len:, :])
  83. x_stack.append(x_s); attns.append(attn)
  84. x_stack = torch.cat(x_stack, -2)
  85. return x_stack, attns

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN