You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools.py 2.6 kB

2 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. import numpy as np
  2. import mindspore.numpy as mnp
  3. import mindspore
  4. from mindspore import Tensor, Parameter
  5. def adjust_learning_rate(optimizer, epoch, args):
  6. if args.lradj == 'type1':
  7. lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch-1) // 1))}
  8. elif args.lradj == 'type2':
  9. lr_adjust = {
  10. 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
  11. 10: 5e-7, 15: 1e-7, 20: 5e-8
  12. }
  13. if epoch in lr_adjust.keys():
  14. lr = lr_adjust[epoch]
  15. for param_group in optimizer.parameters():
  16. param_group.set_lr(lr)
  17. print('Updating learning rate to {}'.format(lr))
  18. class EarlyStopping:
  19. def __init__(self, patience=7, verbose=False, delta=0):
  20. self.patience = patience
  21. self.verbose = verbose
  22. self.counter = 0
  23. self.best_score = None
  24. self.early_stop = False
  25. self.val_loss_min = np.Inf
  26. self.delta = delta
  27. def __call__(self, val_loss, model, path):
  28. score = -val_loss
  29. if self.best_score is None:
  30. self.best_score = score
  31. self.save_checkpoint(val_loss, model, path)
  32. elif score < self.best_score + self.delta:
  33. self.counter += 1
  34. print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
  35. if self.counter >= self.patience:
  36. self.early_stop = True
  37. else:
  38. self.best_score = score
  39. self.save_checkpoint(val_loss, model, path)
  40. self.counter = 0
  41. def save_checkpoint(self, val_loss, model, path):
  42. if self.verbose:
  43. print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
  44. model.save_checkpoint(path + '/' + 'checkpoint.ckpt')
  45. self.val_loss_min = val_loss
  46. class dotdict(dict):
  47. """dot.notation access to dictionary attributes"""
  48. __getattr__ = dict.get
  49. __setattr__ = dict.__setitem__
  50. __delattr__ = dict.__delitem__
  51. class StandardScaler():
  52. def __init__(self):
  53. self.mean = 0.
  54. self.std = 1.
  55. def fit(self, data):
  56. self.mean = mnp.mean(data, 0)
  57. self.std = mnp.std(data, 0)
  58. def transform(self, data):
  59. mean = Tensor(self.mean, mindspore.float32)
  60. std = Tensor(self.std, mindspore.float32)
  61. return (data - mean) / std
  62. def inverse_transform(self, data):
  63. mean = Tensor(self.mean, mindspore.float32)
  64. std = Tensor(self.std, mindspore.float32)
  65. if data.shape[-1] != mean.shape[-1]:
  66. mean = mean[-1:]
  67. std = std[-1:]
  68. return (data * std) + mean

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN