Browse Source

Merge branch 'dev0.8.0' of gitee.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
49eb1fcc6b
2 changed files with 2 additions and 2 deletions
  1. +1
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  2. +1
    -1
      fastNLP/core/drivers/torch_driver/utils.py

+ 1
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -55,7 +55,7 @@ class TorchDriver(Driver):
# 因为 ddp 和 single_device 的混合精度训练的设置是一样的,因此可以统一抽象到这里;
self.fp16 = fp16
if parse_version(torch.__version__) < parse_version('1.6'):
raise RuntimeError("Pytorch supports float16 after version 1.6, please upgrade your pytorch version.")
raise RuntimeError(f"Pytorch({torch.__version__}) need to be older than 1.6.")
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()



+ 1
- 1
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -160,7 +160,7 @@ def _build_fp16_env(dummy=False):
GradScaler = DummyGradScaler
else:
if not torch.cuda.is_available():
raise RuntimeError("No cuda")
raise RuntimeError("Pytorch is not installed in gpu version, please use device='cpu'.")
if torch.cuda.get_device_capability(0)[0] < 7:
logger.rank_zero_warning(
"NOTE: your device does NOT support faster training with fp16, "


Loading…
Cancel
Save