From c077107555ffa9dee9be26138f47f223e99f2b76 Mon Sep 17 00:00:00 2001 From: yh Date: Mon, 29 Apr 2019 14:43:44 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dbug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 1 - fastNLP/core/utils.py | 7 ++++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b6c282b4..253ae46d 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -444,7 +444,6 @@ class Trainer(object): self.n_steps = (len(self.train_data) // self.batch_size + int( len(self.train_data) % self.batch_size != 0)) * self.n_epochs - # 是否一开始就是DataParallel的。 self.model = _move_model_to_device(self.model, device=device) if isinstance(optimizer, torch.optim.Optimizer): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index efb4faa7..cc9e8164 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -193,13 +193,14 @@ def _move_model_to_device(model, device): if isinstance(model, torch.nn.parallel.DistributedDataParallel): raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") - if not torch.cuda.is_available() and (device!='cpu' or (isinstance(device, torch.device) and device.type!='cpu')): - raise ValueError("There is no usable gpu. set `device` as `cpu`.") - if device is None: if isinstance(model, torch.nn.DataParallel): model.cuda() return model + else: + if not torch.cuda.is_available() and ( + device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): + raise ValueError("There is no usable gpu. set `device` as `cpu`.") if isinstance(model, torch.nn.DataParallel): raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")