You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exp_basic.py 875 B

123456789101112131415161718192021222324252627282930313233343536
  1. import os
  2. import torch
  3. import numpy as np
  4. class Exp_Basic(object):
  5. def __init__(self, args):
  6. self.args = args
  7. self.device = self._acquire_device()
  8. self.model = self._build_model().to(self.device)
  9. def _build_model(self):
  10. raise NotImplementedError
  11. return None
  12. def _acquire_device(self):
  13. if self.args.use_gpu:
  14. os.environ["CUDA_VISIBLE_DEVICES"] = str(self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
  15. device = torch.device('cuda:{}'.format(self.args.gpu))
  16. print('Use GPU: cuda:{}'.format(self.args.gpu))
  17. else:
  18. device = torch.device('cpu')
  19. print('Use CPU')
  20. return device
  21. def _get_data(self):
  22. pass
  23. def vali(self):
  24. pass
  25. def train(self):
  26. pass
  27. def test(self):
  28. pass

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN