You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

embed.py 4.4 kB

2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import mindspore.nn as nn
  2. import mindspore.ops.operations as ops
  3. import mindspore.common.dtype as mstype
  4. import mindspore.common.initializer as init
  5. import mindspore.tensor as Tensor
  6. import math
  7. class PositionalEmbedding(nn.Cell):
  8. def __init__(self, d_model, max_len=5000):
  9. super(PositionalEmbedding, self).__init__()
  10. pe = Tensor(torch.zeros(max_len, d_model).float(), mstype.float32)
  11. position = Tensor(torch.arange(0, max_len).float().unsqueeze(1), mstype.float32)
  12. div_term = Tensor((torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp(), mstype.float32)
  13. pe[:, 0::2] = ops.sin(position * div_term)
  14. pe[:, 1::2] = ops.cos(position * div_term)
  15. pe = pe.unsqueeze(0)
  16. self.pe = nn.Parameter(pe, requires_grad=False)
  17. def construct(self, x):
  18. return self.pe[:, :x.shape[1]]
  19. class TokenEmbedding(nn.Cell):
  20. def __init__(self, c_in, d_model):
  21. super(TokenEmbedding, self).__init__()
  22. padding = 1 if torch.__version__ >= '1.5.0' else 2
  23. self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular')
  24. for _, m in self.cells_and_names():
  25. if isinstance(m, nn.Conv1d):
  26. m.weight.set_data(init.initializer(init.KaimingNormal(), m.weight.shape, m.weight.dtype))
  27. m.bias.set_data(init.initializer(init.Zero(), m.bias.shape, m.bias.dtype))
  28. def construct(self, x):
  29. x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
  30. return x
  31. class FixedEmbedding(nn.Cell):
  32. def __init__(self, c_in, d_model):
  33. super(FixedEmbedding, self).__init__()
  34. w = Tensor(torch.zeros(c_in, d_model).float(), mstype.float32)
  35. position = Tensor(torch.arange(0, c_in).float().unsqueeze(1), mstype.float32)
  36. div_term = Tensor((torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp(), mstype.float32)
  37. w[:, 0::2] = ops.sin(position * div_term)
  38. w[:, 1::2] = ops.cos(position * div_term)
  39. self.emb = nn.Embedding(c_in, d_model, embedding_table=w, embedding_size=(c_in, d_model))
  40. self.emb.embedding_table.requires_grad = False
  41. def construct(self, x):
  42. return self.emb(x).detach()
  43. class TemporalEmbedding(nn.Cell):
  44. def __init__(self, d_model, embed_type='fixed', freq='h'):
  45. super(TemporalEmbedding, self).__init__()
  46. minute_size = 4
  47. hour_size = 24
  48. weekday_size = 7
  49. day_size = 32
  50. month_size = 13
  51. Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
  52. if freq == 't':
  53. self.minute_embed = Embed(minute_size, d_model)
  54. self.hour_embed = Embed(hour_size, d_model)
  55. self.weekday_embed = Embed(weekday_size, d_model)
  56. self.day_embed = Embed(day_size, d_model)
  57. self.month_embed = Embed(month_size, d_model)
  58. def construct(self, x):
  59. x = x.astype(mstype.int32)
  60. minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0.
  61. hour_x = self.hour_embed(x[:, :, 3])
  62. weekday_x = self.weekday_embed(x[:, :, 2])
  63. day_x = self.day_embed(x[:, :, 1])
  64. month_x = self.month_embed(x[:, :, 0])
  65. return hour_x + weekday_x + day_x + month_x + minute_x
  66. class TimeFeatureEmbedding(nn.Cell):
  67. def __init__(self, d_model, embed_type='timeF', freq='h'):
  68. super(TimeFeatureEmbedding, self).__init__()
  69. freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
  70. d_inp = freq_map[freq]
  71. self.embed = nn.Dense(d_inp,d_model)
  72. def construct(self, x):
  73. return self.embed(x)
  74. class DataEmbedding(nn.Cell):
  75. def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
  76. super(DataEmbedding, self).__init__()
  77. self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
  78. self.position_embedding = PositionalEmbedding(d_model=d_model)
  79. self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
  80. self.dropout = nn.Dropout(p=dropout)
  81. def construct(self, x, x_mark):
  82. x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark)
  83. return self.dropout(x)

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN