Browse Source

修复bug

tags/v0.4.10
yh 6 years ago
parent
commit
c077107555
2 changed files with 4 additions and 4 deletions
  1. +0
    -1
      fastNLP/core/trainer.py
  2. +4
    -3
      fastNLP/core/utils.py

+ 0
- 1
fastNLP/core/trainer.py View File

@@ -444,7 +444,6 @@ class Trainer(object):
self.n_steps = (len(self.train_data) // self.batch_size + int(
len(self.train_data) % self.batch_size != 0)) * self.n_epochs

# 是否一开始就是DataParallel的。
self.model = _move_model_to_device(self.model, device=device)

if isinstance(optimizer, torch.optim.Optimizer):


+ 4
- 3
fastNLP/core/utils.py View File

@@ -193,13 +193,14 @@ def _move_model_to_device(model, device):
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.")

if not torch.cuda.is_available() and (device!='cpu' or (isinstance(device, torch.device) and device.type!='cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu`.")

if device is None:
if isinstance(model, torch.nn.DataParallel):
model.cuda()
return model
else:
if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu`.")

if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")


Loading…
Cancel
Save