diff --git a/models/embed.py b/models/embed.py index 32797af..5b3e392 100644 --- a/models/embed.py +++ b/models/embed.py @@ -1,110 +1,107 @@ -import -import torch -import torch.nn as nn -import torch.nn.functional as F - +import mindspore.nn as nn +import mindspore.ops.operations as ops +import mindspore.common.dtype as mstype +import mindspore.common.initializer as init +import mindspore.tensor as Tensor import math -class PositionalEmbedding(nn.Module): +class PositionalEmbedding(nn.Cell): def __init__(self, d_model, max_len=5000): super(PositionalEmbedding, self).__init__() - # Compute the positional encodings once in log space. - pe = torch.zeros(max_len, d_model).float() - pe.require_grad = False - - position = torch.arange(0, max_len).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + pe = Tensor(torch.zeros(max_len, d_model).float(), mstype.float32) + position = Tensor(torch.arange(0, max_len).float().unsqueeze(1), mstype.float32) + div_term = Tensor((torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp(), mstype.float32) - pe[:, 0::2] = torch.sin(position * div_term) - pe[:, 1::2] = torch.cos(position * div_term) + pe[:, 0::2] = ops.sin(position * div_term) + pe[:, 1::2] = ops.cos(position * div_term) pe = pe.unsqueeze(0) - self.register_buffer('pe', pe) + self.pe = nn.Parameter(pe, requires_grad=False) - def forward(self, x): - return self.pe[:, :x.size(1)] + def construct(self, x): + return self.pe[:, :x.shape[1]] -class TokenEmbedding(nn.Module): +class TokenEmbedding(nn.Cell): def __init__(self, c_in, d_model): super(TokenEmbedding, self).__init__() - padding = 1 if torch.__version__>='1.5.0' else 2 - self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, - kernel_size=3, padding=padding, padding_mode='circular') - for m in self.modules(): + padding = 1 if torch.__version__ >= '1.5.0' else 2 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular') + for _, m in self.cells_and_names(): if isinstance(m, nn.Conv1d): - nn.init.kaiming_normal_(m.weight,mode='fan_in',nonlinearity='leaky_relu') + m.weight.set_data(init.initializer(init.KaimingNormal(), m.weight.shape, m.weight.dtype)) + m.bias.set_data(init.initializer(init.Zero(), m.bias.shape, m.bias.dtype)) - def forward(self, x): - x = self.tokenConv(x.permute(0, 2, 1)).transpose(1,2) + def construct(self, x): + x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2) return x -class FixedEmbedding(nn.Module): +class FixedEmbedding(nn.Cell): def __init__(self, c_in, d_model): super(FixedEmbedding, self).__init__() + w = Tensor(torch.zeros(c_in, d_model).float(), mstype.float32) + position = Tensor(torch.arange(0, c_in).float().unsqueeze(1), mstype.float32) + div_term = Tensor((torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp(), mstype.float32) - w = torch.zeros(c_in, d_model).float() - w.require_grad = False - - position = torch.arange(0, c_in).float().unsqueeze(1) - div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp() + w[:, 0::2] = ops.sin(position * div_term) + w[:, 1::2] = ops.cos(position * div_term) - w[:, 0::2] = torch.sin(position * div_term) - w[:, 1::2] = torch.cos(position * div_term) + self.emb = nn.Embedding(c_in, d_model, embedding_table=w, embedding_size=(c_in, d_model)) + self.emb.embedding_table.requires_grad = False - self.emb = nn.Embedding(c_in, d_model) - self.emb.weight = nn.Parameter(w, requires_grad=False) - - def forward(self, x): + def construct(self, x): return self.emb(x).detach() -class TemporalEmbedding(nn.Module): +class TemporalEmbedding(nn.Cell): def __init__(self, d_model, embed_type='fixed', freq='h'): super(TemporalEmbedding, self).__init__() - minute_size = 4; hour_size = 24 - weekday_size = 7; day_size = 32; month_size = 13 + minute_size = 4 + hour_size = 24 + weekday_size = 7 + day_size = 32 + month_size = 13 - Embed = FixedEmbedding if embed_type=='fixed' else nn.Embedding - if freq=='t': + Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding + if freq == 't': self.minute_embed = Embed(minute_size, d_model) self.hour_embed = Embed(hour_size, d_model) self.weekday_embed = Embed(weekday_size, d_model) self.day_embed = Embed(day_size, d_model) self.month_embed = Embed(month_size, d_model) - - def forward(self, x): - x = x.long() - - minute_x = self.minute_embed(x[:,:,4]) if hasattr(self, 'minute_embed') else 0. - hour_x = self.hour_embed(x[:,:,3]) - weekday_x = self.weekday_embed(x[:,:,2]) - day_x = self.day_embed(x[:,:,1]) - month_x = self.month_embed(x[:,:,0]) - + + def construct(self, x): + x = x.astype(mstype.int32) + + minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0. + hour_x = self.hour_embed(x[:, :, 3]) + weekday_x = self.weekday_embed(x[:, :, 2]) + day_x = self.day_embed(x[:, :, 1]) + month_x = self.month_embed(x[:, :, 0]) + return hour_x + weekday_x + day_x + month_x + minute_x -class TimeFeatureEmbedding(nn.Module): +class TimeFeatureEmbedding(nn.Cell): def __init__(self, d_model, embed_type='timeF', freq='h'): super(TimeFeatureEmbedding, self).__init__() - freq_map = {'h':4, 't':5, 's':6, 'm':1, 'a':1, 'w':2, 'd':3, 'b':3} + freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3} d_inp = freq_map[freq] - self.embed = nn.Linear(d_inp, d_model) - - def forward(self, x): + self.embed = nn.Dense(d_inp,d_model) + + def construct(self, x): return self.embed(x) -class DataEmbedding(nn.Module): +class DataEmbedding(nn.Cell): def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1): super(DataEmbedding, self).__init__() self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model) self.position_embedding = PositionalEmbedding(d_model=d_model) - self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type!='timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) + self.temporal_embedding = TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) if embed_type != 'timeF' else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq) self.dropout = nn.Dropout(p=dropout) - def forward(self, x, x_mark): + def construct(self, x, x_mark): x = self.value_embedding(x) + self.position_embedding(x) + self.temporal_embedding(x_mark) - + return self.dropout(x) \ No newline at end of file