|
12345678910111213141516171819202122232425262728293031323334353637 |
- import mindspore
- import mindspore.nn as nn
- from mindspore import context
-
- class Exp_Basic(nn.Cell):
- def __init__(self, args):
- super(Exp_Basic, self).__init__()
- self.args = args
- self.device = self._acquire_device()
- self.model = self._build_model().to(self.device)
-
- def _build_model(self):
- raise NotImplementedError
- return None
-
- def _acquire_device(self):
- if self.args.use_gpu:
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=str(self.args.gpu) if not self.args.use_multi_gpu else self.args.devices)
- device = mindspore.cuda_device('cuda:{}'.format(self.args.gpu))
- print('Use GPU: cuda:{}'.format(self.args.gpu))
- else:
- context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
- device = mindspore.cpu_device()
- print('Use CPU')
- return device
-
- def _get_data(self):
- pass
-
- def vali(self):
- pass
-
- def train(self):
- pass
-
- def test(self):
- pass
|