You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embed.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import math
  5. class PositionalEmbedding(nn.Module):
  6. def __init__(self, d_model, max_len=5000):
  7. super(PositionalEmbedding, self).__init__()
  8. # Compute the positional encodings once in log space.
  9. pe = torch.zeros(max_len, d_model).float()
  10. pe.require_grad = False
  11. position = torch.arange(0, max_len).float().unsqueeze(1)
  12. div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
  13. pe[:, 0::2] = torch.sin(position * div_term)
  14. pe[:, 1::2] = torch.cos(position * div_term)
  15. pe = pe.unsqueeze(0)
  16. self.register_buffer('pe', pe)
  17. def forward(self, x):
  18. return self.pe[:, :x.size(1)]
  19. class TokenEmbedding(nn.Module):
  20. def __init__(self, c_in, d_model):
  21. super(TokenEmbedding, self).__init__()
  22. padding = 1 if torch.__version__>='1.5.0' else 2
  23. self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
  24. kernel_size=3, padding=padding, padding_mode='circular')
  25. for m in self.modules():
  26. if isinstance(m, nn.Conv1d):
  27. nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
  28. def forward(self, x):
  29. x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
  30. return x
  31. class FixedEmbedding(nn.Module):
  32. def __init__(self, c_in, d_model):
  33. super(FixedEmbedding, self).__init__()
  34. w = torch.zeros(c_in, d_model).float()
  35. w.require_grad = False
  36. position = torch.arange(0, c_in).float().unsqueeze(1)
  37. div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
  38. w[:, 0::2] = torch.sin(position * div_term)
  39. w[:, 1::2] = torch.cos(position * div_term)
  40. self.emb = nn.Embedding(c_in, d_model)
  41. self.emb.weight = nn.Parameter(w, requires_grad=False)
  42. def forward(self, x):
  43. return self.emb(x).detach()
  44. class TemporalEmbedding(nn.Module):
  45. def __init__(self, d_model, embed_type='fixed', freq='h'):
  46. super(TemporalEmbedding, self).__init__()
  47. minute_size = 4; hour_size = 24
  48. weekday_size = 7; day_size = 32; month_size = 13
  49. Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
  50. if freq=='t':
  51. self.minute_embed = Embed(minute_size, d_model)
  52. self.hour_embed = Embed(hour_size, d_model)
  53. self.weekday_embed = Embed(weekday_size, d_model)
  54. self.day_embed = Embed(day_size, d_model)
  55. self.month_embed = Embed(month_size, d_model)
  56. def forward(self, x):
  57. x = x.long()
  58. minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0.
  59. hour_x = self.hour_embed(x[:,:,3])
  60. weekday_x = self.weekday_embed(x[:,:,2])
  61. day_x = self.day_embed(x[:,:,1])
  62. month_x = self.month_embed(x[:,:,0])
  63. return hour_x + weekday_x + day_x + month_x + minute_x
  64. class TimeFeatureEmbedding(nn.Module):
  65. def __init__(self, d_model, embed_type='timeF', freq='h'):
  66. super(TimeFeatureEmbedding, self).__init__()
  67. freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3}
  68. d_inp = freq_map[freq]
  69. self.embed = nn.Linear(d_inp, d_model)
  70. def forward(self, x):
  71. return self.embed(x)
  72. class DataEmbedding(nn.Module):
  73. def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
  74. super(DataEmbedding, self).__init__()
  75. self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
  76. self.position_embedding = PositionalEmbedding(d_model=d_model)
  77. self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
  78. self.dropout = nn.Dropout(p=dropout)
  79. def forward(self, x, x_mark):
  80. x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
  81. return self.dropout(x)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN