diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 3630b593..2a04e62f 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -255,13 +255,14 @@ class TorchDriver(Driver): logger.debug("Load model...") # 3. 加载fp16的状态 - if 'grad_scaler_state_dict' in states: - grad_scaler_state_dict = states.pop('grad_scaler_state_dict') - if not isinstance(self.grad_scaler, DummyGradScaler): - self.grad_scaler.load_state_dict(grad_scaler_state_dict) - self.auto_cast = torch.cuda.amp.autocast + if "grad_scaler_state_dict" in states: + grad_scaler_state_dict = states.pop("grad_scaler_state_dict") + if isinstance(self.grad_scaler, DummyGradScaler): + self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) + self.grad_scaler = _grad_scaler() self.fp16 = True - logger.debug("Load grad_scaler state dict...") + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") elif not isinstance(self.grad_scaler, DummyGradScaler): logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " f"the training process may be unstable.")