|
|
@@ -255,13 +255,14 @@ class TorchDriver(Driver): |
|
|
|
logger.debug("Load model...") |
|
|
|
|
|
|
|
# 3. 加载fp16的状态 |
|
|
|
if 'grad_scaler_state_dict' in states: |
|
|
|
grad_scaler_state_dict = states.pop('grad_scaler_state_dict') |
|
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
self.grad_scaler.load_state_dict(grad_scaler_state_dict) |
|
|
|
self.auto_cast = torch.cuda.amp.autocast |
|
|
|
if "grad_scaler_state_dict" in states: |
|
|
|
grad_scaler_state_dict = states.pop("grad_scaler_state_dict") |
|
|
|
if isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) |
|
|
|
self.grad_scaler = _grad_scaler() |
|
|
|
self.fp16 = True |
|
|
|
logger.debug("Load grad_scaler state dict...") |
|
|
|
self.grad_scaler.load_state_dict(grad_scaler_state_dict) |
|
|
|
logger.debug("Load grad_scaler state dict...") |
|
|
|
elif not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
logger.warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " |
|
|
|
f"the training process may be unstable.") |
|
|
|