diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b6c282b4..253ae46d 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -444,7 +444,6 @@ class Trainer(object): self.n_steps = (len(self.train_data) // self.batch_size + int( len(self.train_data) % self.batch_size != 0)) * self.n_epochs - # 是否一开始就是DataParallel的。 self.model = _move_model_to_device(self.model, device=device) if isinstance(optimizer, torch.optim.Optimizer): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index efb4faa7..cc9e8164 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -193,13 +193,14 @@ def _move_model_to_device(model, device): if isinstance(model, torch.nn.parallel.DistributedDataParallel): raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") - if not torch.cuda.is_available() and (device!='cpu' or (isinstance(device, torch.device) and device.type!='cpu')): - raise ValueError("There is no usable gpu. set `device` as `cpu`.") - if device is None: if isinstance(model, torch.nn.DataParallel): model.cuda() return model + else: + if not torch.cuda.is_available() and ( + device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): + raise ValueError("There is no usable gpu. set `device` as `cpu`.") if isinstance(model, torch.nn.DataParallel): raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")