|
|
@@ -39,11 +39,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic |
|
|
|
if isinstance(device, str): |
|
|
|
device = torch.device(device) |
|
|
|
elif isinstance(device, int): |
|
|
|
if device < 0 and device != -1: |
|
|
|
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") |
|
|
|
if device >= _could_use_device_num: |
|
|
|
if device < 0: |
|
|
|
if device != -1: |
|
|
|
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") |
|
|
|
device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] |
|
|
|
elif device >= _could_use_device_num: |
|
|
|
raise ValueError("The gpu device that parameter `device` specifies is not existed.") |
|
|
|
device = torch.device(f"cuda:{device}") |
|
|
|
else: |
|
|
|
device = torch.device(f"cuda:{device}") |
|
|
|
elif isinstance(device, Sequence): |
|
|
|
device = list(set(device)) |
|
|
|
for each in device: |
|
|
|