|
@@ -4,6 +4,7 @@ from typing import Union, Optional, Dict, Any |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
from functools import partial |
|
|
from functools import partial |
|
|
from dataclasses import dataclass |
|
|
from dataclasses import dataclass |
|
|
|
|
|
from jittor import grad |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
@@ -317,17 +318,14 @@ class PaddleDriver(Driver): |
|
|
logger.debug("Load model...") |
|
|
logger.debug("Load model...") |
|
|
|
|
|
|
|
|
# 3. 加载fp16的状态; |
|
|
# 3. 加载fp16的状态; |
|
|
if "grad_scaler_state_dict" in states: |
|
|
|
|
|
grad_scaler_state_dict = states.pop("grad_scaler_state_dict") |
|
|
|
|
|
if isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
|
|
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=False) |
|
|
|
|
|
self.grad_scaler = _grad_scaler() |
|
|
|
|
|
self.fp16 = True |
|
|
|
|
|
self.grad_scaler.load_state_dict(grad_scaler_state_dict) |
|
|
|
|
|
logger.debug("Load grad_scaler state dict...") |
|
|
|
|
|
elif not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
|
|
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " |
|
|
|
|
|
f"the training process may be unstable.") |
|
|
|
|
|
|
|
|
grad_scaler_state_dict = states.pop("grad_scaler_state_dict", None) |
|
|
|
|
|
if self.fp16: |
|
|
|
|
|
if grad_scaler_state_dict: |
|
|
|
|
|
self.grad_scaler.load_state_dict(grad_scaler_state_dict) |
|
|
|
|
|
logger.debug("Load grad_scaler state dict...") |
|
|
|
|
|
else: |
|
|
|
|
|
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " |
|
|
|
|
|
f"the training process may be unstable.") |
|
|
|
|
|
|
|
|
# 4. 恢复 sampler 的状态; |
|
|
# 4. 恢复 sampler 的状态; |
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|