Browse Source

small

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
8cda30c426
1 changed files with 7 additions and 6 deletions
  1. +7
    -6
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 7
- 6
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -255,13 +255,14 @@ class TorchDriver(Driver):
logger.debug("Load model...")

# 3. 加载fp16的状态
if 'grad_scaler_state_dict' in states:
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
if not isinstance(self.grad_scaler, DummyGradScaler):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
self.auto_cast = torch.cuda.amp.autocast
if "grad_scaler_state_dict" in states:
grad_scaler_state_dict = states.pop("grad_scaler_state_dict")
if isinstance(self.grad_scaler, DummyGradScaler):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False)
self.grad_scaler = _grad_scaler()
self.fp16 = True
logger.debug("Load grad_scaler state dict...")
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "
f"the training process may be unstable.")


Loading…
Cancel
Save