diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 23743ecf..33b69d7e 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -404,6 +404,9 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): :param args: :return: """ + if not torch.cuda.is_available(): + return + if not isinstance(device, torch.device): raise TypeError(f"device must be `torch.device`, got `{type(device)}`")