You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exp_informer.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289
  1. from data.data_loader import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom, Dataset_Pred
  2. from exp.exp_basic import Exp_Basic
  3. from models.model import Informer, InformerStack
  4. from utils.tools import EarlyStopping, adjust_learning_rate
  5. from utils.metrics import metric
  6. import numpy as np
  7. import torch
  8. import torch.nn as nn
  9. from torch import optim
  10. from torch.utils.data import DataLoader
  11. import os
  12. import time
  13. import warnings
  14. warnings.filterwarnings('ignore')
  15. class Exp_Informer(Exp_Basic):
  16. def __init__(self, args):
  17. super(Exp_Informer, self).__init__(args)
  18. def _build_model(self):
  19. model_dict = {
  20. 'informer':Informer,
  21. 'informerstack':InformerStack,
  22. }
  23. if self.args.model=='informer' or self.args.model=='informerstack':
  24. e_layers = self.args.e_layers if self.args.model=='informer' else self.args.s_layers
  25. model = model_dict[self.args.model](
  26. self.args.enc_in,
  27. self.args.dec_in,
  28. self.args.c_out,
  29. self.args.seq_len,
  30. self.args.label_len,
  31. self.args.pred_len,
  32. self.args.factor,
  33. self.args.d_model,
  34. self.args.n_heads,
  35. e_layers, # self.args.e_layers,
  36. self.args.d_layers,
  37. self.args.d_ff,
  38. self.args.dropout,
  39. self.args.attn,
  40. self.args.embed,
  41. self.args.freq,
  42. self.args.activation,
  43. self.args.output_attention,
  44. self.args.distil,
  45. self.args.mix,
  46. self.device
  47. ).float()
  48. if self.args.use_multi_gpu and self.args.use_gpu:
  49. model = nn.DataParallel(model, device_ids=self.args.device_ids)
  50. return model
  51. def _get_data(self, flag):
  52. args = self.args
  53. data_dict = {
  54. 'ETTh1':Dataset_ETT_hour,
  55. 'ETTh2':Dataset_ETT_hour,
  56. 'ETTm1':Dataset_ETT_minute,
  57. 'ETTm2':Dataset_ETT_minute,
  58. 'WTH':Dataset_Custom,
  59. 'ECL':Dataset_Custom,
  60. 'Solar':Dataset_Custom,
  61. 'custom':Dataset_Custom,
  62. }
  63. Data = data_dict[self.args.data]
  64. timeenc = 0 if args.embed!='timeF' else 1
  65. if flag == 'test':
  66. shuffle_flag = False; drop_last = True; batch_size = args.batch_size; freq=args.freq
  67. elif flag=='pred':
  68. shuffle_flag = False; drop_last = False; batch_size = 1; freq=args.detail_freq
  69. Data = Dataset_Pred
  70. else:
  71. shuffle_flag = True; drop_last = True; batch_size = args.batch_size; freq=args.freq
  72. data_set = Data(
  73. root_path=args.root_path,
  74. data_path=args.data_path,
  75. flag=flag,
  76. size=[args.seq_len, args.label_len, args.pred_len],
  77. features=args.features,
  78. target=args.target,
  79. inverse=args.inverse,
  80. timeenc=timeenc,
  81. freq=freq,
  82. cols=args.cols
  83. )
  84. print(flag, len(data_set))
  85. data_loader = DataLoader(
  86. data_set,
  87. batch_size=batch_size,
  88. shuffle=shuffle_flag,
  89. num_workers=args.num_workers,
  90. drop_last=drop_last)
  91. return data_set, data_loader
  92. def _select_optimizer(self):
  93. model_optim = optim.Adam(self.model.parameters(), lr=self.args.learning_rate)
  94. return model_optim
  95. def _select_criterion(self):
  96. criterion = nn.MSELoss()
  97. return criterion
  98. def vali(self, vali_data, vali_loader, criterion):
  99. self.model.eval()
  100. total_loss = []
  101. for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(vali_loader):
  102. pred, true = self._process_one_batch(
  103. vali_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
  104. loss = criterion(pred.detach().cpu(), true.detach().cpu())
  105. total_loss.append(loss)
  106. total_loss = np.average(total_loss)
  107. self.model.train()
  108. return total_loss
  109. def train(self, setting):
  110. train_data, train_loader = self._get_data(flag = 'train')
  111. vali_data, vali_loader = self._get_data(flag = 'val')
  112. test_data, test_loader = self._get_data(flag = 'test')
  113. path = os.path.join(self.args.checkpoints, setting)
  114. if not os.path.exists(path):
  115. os.makedirs(path)
  116. time_now = time.time()
  117. train_steps = len(train_loader)
  118. early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
  119. model_optim = self._select_optimizer()
  120. criterion = self._select_criterion()
  121. if self.args.use_amp:
  122. scaler = torch.cuda.amp.GradScaler()
  123. for epoch in range(self.args.train_epochs):
  124. iter_count = 0
  125. train_loss = []
  126. self.model.train()
  127. epoch_time = time.time()
  128. for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(train_loader):
  129. iter_count += 1
  130. model_optim.zero_grad()
  131. pred, true = self._process_one_batch(
  132. train_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
  133. loss = criterion(pred, true)
  134. train_loss.append(loss.item())
  135. if (i+1) % 100==0:
  136. print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
  137. speed = (time.time()-time_now)/iter_count
  138. left_time = speed*((self.args.train_epochs - epoch)*train_steps - i)
  139. print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
  140. iter_count = 0
  141. time_now = time.time()
  142. if self.args.use_amp:
  143. scaler.scale(loss).backward()
  144. scaler.step(model_optim)
  145. scaler.update()
  146. else:
  147. loss.backward()
  148. model_optim.step()
  149. print("Epoch: {} cost time: {}".format(epoch+1, time.time()-epoch_time))
  150. train_loss = np.average(train_loss)
  151. vali_loss = self.vali(vali_data, vali_loader, criterion)
  152. test_loss = self.vali(test_data, test_loader, criterion)
  153. print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Vali Loss: {3:.7f} Test Loss: {4:.7f}".format(
  154. epoch + 1, train_steps, train_loss, vali_loss, test_loss))
  155. early_stopping(vali_loss, self.model, path)
  156. if early_stopping.early_stop:
  157. print("Early stopping")
  158. break
  159. adjust_learning_rate(model_optim, epoch+1, self.args)
  160. best_model_path = path+'/'+'checkpoint.pth'
  161. self.model.load_state_dict(torch.load(best_model_path))
  162. return self.model
  163. def test(self, setting):
  164. test_data, test_loader = self._get_data(flag='test')
  165. self.model.eval()
  166. preds = []
  167. trues = []
  168. for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(test_loader):
  169. pred, true = self._process_one_batch(
  170. test_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
  171. preds.append(pred.detach().cpu().numpy())
  172. trues.append(true.detach().cpu().numpy())
  173. preds = np.array(preds)
  174. trues = np.array(trues)
  175. print('test shape:', preds.shape, trues.shape)
  176. preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
  177. trues = trues.reshape(-1, trues.shape[-2], trues.shape[-1])
  178. print('test shape:', preds.shape, trues.shape)
  179. # result save
  180. folder_path = './results1/' + setting +'/'
  181. if not os.path.exists(folder_path):
  182. os.makedirs(folder_path)
  183. mae, mse, rmse, mape, mspe = metric(preds, trues)
  184. print('mse:{}, mae:{}'.format(mse, mae))
  185. np.save(folder_path+'metrics.npy', np.array([mae, mse, rmse, mape, mspe]))
  186. np.save(folder_path+'pred.npy', preds)
  187. np.save(folder_path+'true.npy', trues)
  188. return
  189. def predict(self, setting, load=False):
  190. pred_data, pred_loader = self._get_data(flag='pred')
  191. if load:
  192. path = os.path.join(self.args.checkpoints, setting)
  193. best_model_path = path+'/'+'checkpoint.pth'
  194. self.model.load_state_dict(torch.load(best_model_path))
  195. self.model.eval()
  196. preds = []
  197. for i, (batch_x,batch_y,batch_x_mark,batch_y_mark) in enumerate(pred_loader):
  198. pred, true = self._process_one_batch(
  199. pred_data, batch_x, batch_y, batch_x_mark, batch_y_mark)
  200. preds.append(pred.detach().cpu().numpy())
  201. preds = np.array(preds)
  202. preds = preds.reshape(-1, preds.shape[-2], preds.shape[-1])
  203. # result save
  204. folder_path = './results1/' + setting +'/'
  205. if not os.path.exists(folder_path):
  206. os.makedirs(folder_path)
  207. np.save(folder_path+'real_prediction.npy', preds)
  208. return
  209. def _process_one_batch(self, dataset_object, batch_x, batch_y, batch_x_mark, batch_y_mark):
  210. batch_x = batch_x.float().to(self.device)
  211. batch_y = batch_y.float()
  212. batch_x_mark = batch_x_mark.float().to(self.device)
  213. batch_y_mark = batch_y_mark.float().to(self.device)
  214. # decoder input
  215. if self.args.padding==0:
  216. dec_inp = torch.zeros([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
  217. elif self.args.padding==1:
  218. dec_inp = torch.ones([batch_y.shape[0], self.args.pred_len, batch_y.shape[-1]]).float()
  219. dec_inp = torch.cat([batch_y[:,:self.args.label_len,:], dec_inp], dim=1).float().to(self.device)
  220. # encoder - decoder
  221. if self.args.use_amp:
  222. with torch.cuda.amp.autocast():
  223. if self.args.output_attention:
  224. outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
  225. else:
  226. outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
  227. else:
  228. if self.args.output_attention:
  229. outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)[0]
  230. else:
  231. outputs = self.model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
  232. if self.args.inverse:
  233. outputs = dataset_object.inverse_transform(outputs)
  234. f_dim = -1 if self.args.features=='MS' else 0
  235. batch_y = batch_y[:,-self.args.pred_len:,f_dim:].to(self.device)
  236. return outputs, batch_y

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN