|
|
@@ -200,13 +200,13 @@ def _move_model_to_device(model, device): |
|
|
|
else: |
|
|
|
if not torch.cuda.is_available() and ( |
|
|
|
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): |
|
|
|
raise ValueError("There is no usable gpu. set `device` as `cpu`.") |
|
|
|
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") |
|
|
|
|
|
|
|
if isinstance(model, torch.nn.DataParallel): |
|
|
|
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") |
|
|
|
|
|
|
|
if isinstance(device, int): |
|
|
|
assert device>-1, "device can only be positive integer" |
|
|
|
assert device>-1, "device can only be non-negative integer" |
|
|
|
assert torch.cuda.device_count()>device, "Only has {} gpus, cannot use device {}.".format(torch.cuda.device_count(), |
|
|
|
device) |
|
|
|
device = torch.device('cuda:{}'.format(device)) |
|
|
@@ -227,7 +227,7 @@ def _move_model_to_device(model, device): |
|
|
|
assert list(types)[0] == int, "Only int supported for multiple devices." |
|
|
|
assert len(set(device))==len(device), "Duplicated device id found in device." |
|
|
|
for d in device: |
|
|
|
assert d>-1, "Only positive device id allowed." |
|
|
|
assert d>-1, "Only non-negative device id allowed." |
|
|
|
if len(device)>1: |
|
|
|
output_device = device[0] |
|
|
|
model = nn.DataParallel(model, device_ids=device, output_device=output_device) |
|
|
|