You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embed.py 4.1 kB

2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. import
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import math
  6. class PositionalEmbedding(nn.Module):
  7. def __init__(self, d_model, max_len=5000):
  8. super(PositionalEmbedding, self).__init__()
  9. # Compute the positional encodings once in log space.
  10. pe = torch.zeros(max_len, d_model).float()
  11. pe.require_grad = False
  12. position = torch.arange(0, max_len).float().unsqueeze(1)
  13. div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
  14. pe[:, 0::2] = torch.sin(position * div_term)
  15. pe[:, 1::2] = torch.cos(position * div_term)
  16. pe = pe.unsqueeze(0)
  17. self.register_buffer('pe', pe)
  18. def forward(self, x):
  19. return self.pe[:, :x.size(1)]
  20. class TokenEmbedding(nn.Module):
  21. def __init__(self, c_in, d_model):
  22. super(TokenEmbedding, self).__init__()
  23. padding = 1 if torch.__version__>='1.5.0' else 2
  24. self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
  25. kernel_size=3, padding=padding, padding_mode='circular')
  26. for m in self.modules():
  27. if isinstance(m, nn.Conv1d):
  28. nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu')
  29. def forward(self, x):
  30. x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2)
  31. return x
  32. class FixedEmbedding(nn.Module):
  33. def __init__(self, c_in, d_model):
  34. super(FixedEmbedding, self).__init__()
  35. w = torch.zeros(c_in, d_model).float()
  36. w.require_grad = False
  37. position = torch.arange(0, c_in).float().unsqueeze(1)
  38. div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()
  39. w[:, 0::2] = torch.sin(position * div_term)
  40. w[:, 1::2] = torch.cos(position * div_term)
  41. self.emb = nn.Embedding(c_in, d_model)
  42. self.emb.weight = nn.Parameter(w, requires_grad=False)
  43. def forward(self, x):
  44. return self.emb(x).detach()
  45. class TemporalEmbedding(nn.Module):
  46. def __init__(self, d_model, embed_type='fixed', freq='h'):
  47. super(TemporalEmbedding, self).__init__()
  48. minute_size = 4; hour_size = 24
  49. weekday_size = 7; day_size = 32; month_size = 13
  50. Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding
  51. if freq=='t':
  52. self.minute_embed = Embed(minute_size, d_model)
  53. self.hour_embed = Embed(hour_size, d_model)
  54. self.weekday_embed = Embed(weekday_size, d_model)
  55. self.day_embed = Embed(day_size, d_model)
  56. self.month_embed = Embed(month_size, d_model)
  57. def forward(self, x):
  58. x = x.long()
  59. minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0.
  60. hour_x = self.hour_embed(x[:,:,3])
  61. weekday_x = self.weekday_embed(x[:,:,2])
  62. day_x = self.day_embed(x[:,:,1])
  63. month_x = self.month_embed(x[:,:,0])
  64. return hour_x + weekday_x + day_x + month_x + minute_x
  65. class TimeFeatureEmbedding(nn.Module):
  66. def __init__(self, d_model, embed_type='timeF', freq='h'):
  67. super(TimeFeatureEmbedding, self).__init__()
  68. freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3}
  69. d_inp = freq_map[freq]
  70. self.embed = nn.Linear(d_inp, d_model)
  71. def forward(self, x):
  72. return self.embed(x)
  73. class DataEmbedding(nn.Module):
  74. def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
  75. super(DataEmbedding, self).__init__()
  76. self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
  77. self.position_embedding = PositionalEmbedding(d_model=d_model)
  78. self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
  79. self.dropout = nn.Dropout(p=dropout)
  80. def forward(self, x, x_mark):
  81. x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
  82. return self.dropout(x)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN