You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

exp_basic.py 1.1 kB

12345678910111213141516171819202122232425262728293031323334353637
  1. import mindspore
  2. import mindspore.nn as nn
  3. from mindspore import context
  4. 1
  5. class Exp_Basic(nn.Cell):
  6. def __init__(self, args):
  7. super(Exp_Basic, self).__init__()
  8. self.args = args
  9. self.device = self._acquire_device()
  10. self.model = self._build_model().to(self.device)
  11. def _build_model(self):
  12. raise NotImplementedError
  13. return None
  14. def _acquire_device(self):
  15. if self.args.use_gpu:
  16. context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=str(self.args.gpu) if not self.args.use_multi_gpu else self.args.devices)
  17. device = mindspore.cuda_device('cuda:{}'.format(self.args.gpu))
  18. print('Use GPU: cuda:{}'.format(self.args.gpu))
  19. else:
  20. context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
  21. device = mindspore.cpu_device()
  22. print('Use CPU')
  23. return device
  24. def _get_data(self):
  25. pass
  26. def vali(self):
  27. pass
  28. def train(self):
  29. pass
  30. def test(self):
  31. pass

基于MindSpore的多模态股票价格预测系统研究 Informer,LSTM,RNN