Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
28aabb5a2f
1 changed files with 7 additions and 4 deletions
  1. +7
    -4
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py

+ 7
- 4
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -39,11 +39,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
if isinstance(device, str):
device = torch.device(device)
elif isinstance(device, int):
if device < 0 and device != -1:
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
if device >= _could_use_device_num:
if device < 0:
if device != -1:
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.")
device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)]
elif device >= _could_use_device_num:
raise ValueError("The gpu device that parameter `device` specifies is not existed.")
device = torch.device(f"cuda:{device}")
else:
device = torch.device(f"cuda:{device}")
elif isinstance(device, Sequence):
device = list(set(device))
for each in device:


Loading…
Cancel
Save