From 7bd0ce0b5ea7f9ba76c93642c5639ea3c19b82f0 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sun, 10 Apr 2022 23:15:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20device=20?= =?UTF-8?q?=E4=B8=BA=20-1=20=E6=97=B6=20=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../drivers/torch_driver/initialize_torch_driver.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 35b20b72..9cd1ac01 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -39,11 +39,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - if device < 0 and device != -1: - raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") - if device >= _could_use_device_num: + if device < 0: + if device != -1: + raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") + device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] + elif device >= _could_use_device_num: raise ValueError("The gpu device that parameter `device` specifies is not existed.") - device = torch.device(f"cuda:{device}") + else: + device = torch.device(f"cuda:{device}") elif isinstance(device, Sequence): device = list(set(device)) for each in device: