diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 74c812a2..d324af72 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -409,7 +409,7 @@ def _move_model_to_device(model, device): if device is None: if isinstance(model, torch.nn.DataParallel): - model.cuda() + model.cuda(model.device_ids[0]) return model else: if not torch.cuda.is_available() and ((isinstance(device, str) and device!='cpu') or