You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tools.py 2.8 kB

2 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import numpy as np
  2. import torch
  3. 1
  4. def adjust_learning_rate(optimizer, epoch, args):
  5. # lr = args.learning_rate * (0.2 ** (epoch // 2))
  6. if args.lradj=='type1':
  7. lr_adjust = {epoch: args.learning_rate * (0.5 ** ((epoch-1) // 1))}
  8. elif args.lradj=='type2':
  9. lr_adjust = {
  10. 2: 5e-5, 4: 1e-5, 6: 5e-6, 8: 1e-6,
  11. 10: 5e-7, 15: 1e-7, 20: 5e-8
  12. }
  13. if epoch in lr_adjust.keys():
  14. lr = lr_adjust[epoch]
  15. for param_group in optimizer.param_groups:
  16. param_group['lr'] = lr
  17. print('Updating learning rate to {}'.format(lr))
  18. class EarlyStopping:
  19. def __init__(self, patience=7, verbose=False, delta=0):
  20. self.patience = patience
  21. self.verbose = verbose
  22. self.counter = 0
  23. self.best_score = None
  24. self.early_stop = False
  25. self.val_loss_min = np.Inf
  26. self.delta = delta
  27. def __call__(self, val_loss, model, path):
  28. score = -val_loss
  29. if self.best_score is None:
  30. self.best_score = score
  31. self.save_checkpoint(val_loss, model, path)
  32. elif score < self.best_score + self.delta:
  33. self.counter += 1
  34. print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
  35. if self.counter >= self.patience:
  36. self.early_stop = True
  37. else:
  38. self.best_score = score
  39. self.save_checkpoint(val_loss, model, path)
  40. self.counter = 0
  41. def save_checkpoint(self, val_loss, model, path):
  42. if self.verbose:
  43. print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
  44. torch.save(model.state_dict(), path+'/'+'checkpoint.pth')
  45. self.val_loss_min = val_loss
  46. class dotdict(dict):
  47. """dot.notation access to dictionary attributes"""
  48. __getattr__ = dict.get
  49. __setattr__ = dict.__setitem__
  50. __delattr__ = dict.__delitem__
  51. class StandardScaler():
  52. def __init__(self):
  53. self.mean = 0.
  54. self.std = 1.
  55. def fit(self, data):
  56. self.mean = data.mean(0)
  57. self.std = data.std(0)
  58. def transform(self, data):
  59. mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
  60. std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
  61. return (data - mean) / std
  62. def inverse_transform(self, data):
  63. mean = torch.from_numpy(self.mean).type_as(data).to(data.device) if torch.is_tensor(data) else self.mean
  64. std = torch.from_numpy(self.std).type_as(data).to(data.device) if torch.is_tensor(data) else self.std
  65. if data.shape[-1] != mean.shape[-1]:
  66. mean = mean[-1:]
  67. std = std[-1:]
  68. return (data * std) + mean

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN