|
|
@@ -193,13 +193,14 @@ def _move_model_to_device(model, device): |
|
|
|
if isinstance(model, torch.nn.parallel.DistributedDataParallel): |
|
|
|
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") |
|
|
|
|
|
|
|
if not torch.cuda.is_available() and (device!='cpu' or (isinstance(device, torch.device) and device.type!='cpu')): |
|
|
|
raise ValueError("There is no usable gpu. set `device` as `cpu`.") |
|
|
|
|
|
|
|
if device is None: |
|
|
|
if isinstance(model, torch.nn.DataParallel): |
|
|
|
model.cuda() |
|
|
|
return model |
|
|
|
else: |
|
|
|
if not torch.cuda.is_available() and ( |
|
|
|
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): |
|
|
|
raise ValueError("There is no usable gpu. set `device` as `cpu`.") |
|
|
|
|
|
|
|
if isinstance(model, torch.nn.DataParallel): |
|
|
|
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") |
|
|
|