Browse Source

加载fp16时同时设置auto_cast和fp16属性

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
5d1ac72ec9
1 changed files with 2 additions and 0 deletions
  1. +2
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 2
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -259,6 +259,8 @@ class TorchDriver(Driver):
grad_scaler_state_dict = states.pop('grad_scaler_state_dict')
if not isinstance(self.grad_scaler, DummyGradScaler):
self.grad_scaler.load_state_dict(grad_scaler_state_dict)
self.auto_cast = torch.cuda.amp.autocast
self.fp16 = True
logger.debug("Load grad_scaler state dict...")
elif not isinstance(self.grad_scaler, DummyGradScaler):
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, "


Loading…
Cancel
Save