diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d7092d48..27f266fa 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -223,8 +223,8 @@ def _move_model_to_device(model, device): model.cuda() return model else: - if not torch.cuda.is_available() and ( - device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): + if not torch.cuda.is_available() and ((isinstance(device, str) and device!='cpu') or + (isinstance(device, torch.device) and device.type != 'cpu')): raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") if isinstance(model, torch.nn.DataParallel):