diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 35b20b72..9cd1ac01 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -39,11 +39,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - if device < 0 and device != -1: - raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") - if device >= _could_use_device_num: + if device < 0: + if device != -1: + raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") + device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] + elif device >= _could_use_device_num: raise ValueError("The gpu device that parameter `device` specifies is not existed.") - device = torch.device(f"cuda:{device}") + else: + device = torch.device(f"cuda:{device}") elif isinstance(device, Sequence): device = list(set(device)) for each in device: