You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import numpy as np
  2. import torch
  3. def adjust_learning_rate(optimizer, epoch, args):
  4. # lr = args.learning_rate * (0.2 ** (epoch // 2))
  5. if args.lradj=='type1':
  6. lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch-1) // 1))}
  7. elif args.lradj=='type2':
  8. lr_adjust = {
  9. 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
  10. 10: 5e-7, 15: 1e-7, 20: 5e-8
  11. }
  12. if epoch in lr_adjust.keys():
  13. lr = lr_adjust[epoch]
  14. for param_group in optimizer.param_groups:
  15. param_group['lr'] = lr
  16. print('Updating learning rate to {}'.format(lr))
  17. class EarlyStopping:
  18. def __init__(self, patience=7, verbose=False, delta=0):
  19. self.patience = patience
  20. self.verbose = verbose
  21. self.counter = 0
  22. self.best_score = None
  23. self.early_stop = False
  24. self.val_loss_min = np.Inf
  25. self.delta = delta
  26. def __call__(self, val_loss, model, path):
  27. score = -val_loss
  28. if self.best_score is None:
  29. self.best_score = score
  30. self.save_checkpoint(val_loss, model, path)
  31. elif score < self.best_score + self.delta:
  32. self.counter += 1
  33. print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
  34. if self.counter >= self.patience:
  35. self.early_stop = True
  36. else:
  37. self.best_score = score
  38. self.save_checkpoint(val_loss, model, path)
  39. self.counter = 0
  40. def save_checkpoint(self, val_loss, model, path):
  41. if self.verbose:
  42. print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
  43. torch.save(model.state_dict(), path+'/'+'checkpoint.pth')
  44. self.val_loss_min = val_loss
  45. class dotdict(dict):
  46. """dot.notation access to dictionary attributes"""
  47. __getattr__ = dict.get
  48. __setattr__ = dict.__setitem__
  49. __delattr__ = dict.__delitem__
  50. class StandardScaler():
  51. def __init__(self):
  52. self.mean = 0.
  53. self.std = 1.
  54. def fit(self, data):
  55. self.mean = data.mean(0)
  56. self.std = data.std(0)
  57. def transform(self, data):
  58. mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
  59. std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
  60. return (data - mean) / std
  61. def inverse_transform(self, data):
  62. mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
  63. std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
  64. if data.shape[-1] != mean.shape[-1]:
  65. mean = mean[-1:]
  66. std = std[-1:]
  67. return (data * std) + mean

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN