|
@@ -4,7 +4,6 @@ from typing import Union, Optional, Dict, Any |
|
|
from pathlib import Path |
|
|
from pathlib import Path |
|
|
from functools import partial |
|
|
from functools import partial |
|
|
from dataclasses import dataclass |
|
|
from dataclasses import dataclass |
|
|
from jittor import grad |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import numpy as np |
|
|
|
|
|
|
|
@@ -84,11 +83,6 @@ class PaddleDriver(Driver): |
|
|
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) |
|
|
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) |
|
|
|
|
|
|
|
|
def zero_grad(self): |
|
|
def zero_grad(self): |
|
|
r""" |
|
|
|
|
|
实现深度学习中的梯度的置零操作,应当直接通过优化器 ``optimizers`` 来将梯度置零; |
|
|
|
|
|
注意梯度累积不需要在这里实现,:class:`~fastNLP.core.Trainer` 已经在内部实现了梯度累积; |
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
for optimizer in self.optimizers: |
|
|
for optimizer in self.optimizers: |
|
|
optimizer.clear_grad() |
|
|
optimizer.clear_grad() |
|
|
|
|
|
|
|
@@ -194,7 +188,7 @@ class PaddleDriver(Driver): |
|
|
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") |
|
|
raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") |
|
|
paddle.jit.save(model, filepath, input_spec) |
|
|
paddle.jit.save(model, filepath, input_spec) |
|
|
|
|
|
|
|
|
def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): |
|
|
|
|
|
|
|
|
def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): |
|
|
model = self.unwrap_model() |
|
|
model = self.unwrap_model() |
|
|
if isinstance(filepath, Path): |
|
|
if isinstance(filepath, Path): |
|
|
filepath = str(filepath) |
|
|
filepath = str(filepath) |
|
@@ -274,21 +268,10 @@ class PaddleDriver(Driver): |
|
|
# 2. 保存模型的状态; |
|
|
# 2. 保存模型的状态; |
|
|
if should_save_model: |
|
|
if should_save_model: |
|
|
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) |
|
|
self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, **kwargs) |
|
|
if only_state_dict: |
|
|
|
|
|
logger.debug("Save model state dict.") |
|
|
|
|
|
else: |
|
|
|
|
|
logger.debug("Save model.") |
|
|
|
|
|
|
|
|
|
|
|
# 3. 保存 optimizers 的状态; |
|
|
# 3. 保存 optimizers 的状态; |
|
|
optimizers_state_dict = {} |
|
|
|
|
|
for i in range(len(self.optimizers)): |
|
|
|
|
|
optimizer: Optimizer = self.optimizers[i] |
|
|
|
|
|
optimizer_state = optimizer.state_dict() |
|
|
|
|
|
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") |
|
|
|
|
|
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
states["optimizers_state_dict"] = self.get_optimizer_state() |
|
|
logger.debug("Save optimizer state dict.") |
|
|
logger.debug("Save optimizer state dict.") |
|
|
states["optimizers_state_dict"] = optimizers_state_dict |
|
|
|
|
|
|
|
|
|
|
|
# 4.保存fp16的状态 |
|
|
# 4.保存fp16的状态 |
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
@@ -297,34 +280,45 @@ class PaddleDriver(Driver): |
|
|
|
|
|
|
|
|
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) |
|
|
paddle.save(states, str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) |
|
|
|
|
|
|
|
|
|
|
|
def get_optimizer_state(self): |
|
|
|
|
|
optimizers_state_dict = {} |
|
|
|
|
|
for i in range(len(self.optimizers)): |
|
|
|
|
|
optimizer: Optimizer = self.optimizers[i] |
|
|
|
|
|
optimizer_state = optimizer.state_dict() |
|
|
|
|
|
optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") |
|
|
|
|
|
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; |
|
|
|
|
|
|
|
|
|
|
|
return optimizers_state_dict |
|
|
|
|
|
|
|
|
|
|
|
def load_optimizer_state(self, states): |
|
|
|
|
|
assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ |
|
|
|
|
|
f"checkpoint it is:{len(states)}" |
|
|
|
|
|
for i in range(len(self.optimizers)): |
|
|
|
|
|
optimizer: Optimizer = self.optimizers[i] |
|
|
|
|
|
optimizer.set_state_dict(states[f"optimizer{i}"]) |
|
|
|
|
|
logger.debug("Load optimizer state dict.") |
|
|
|
|
|
|
|
|
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: |
|
|
def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: |
|
|
|
|
|
|
|
|
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) |
|
|
states = paddle.load(str(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))) |
|
|
|
|
|
|
|
|
# 1. 加载 optimizers 的状态; |
|
|
# 1. 加载 optimizers 的状态; |
|
|
optimizers_state_dict = states.pop("optimizers_state_dict") |
|
|
optimizers_state_dict = states.pop("optimizers_state_dict") |
|
|
for i in range(len(self.optimizers)): |
|
|
|
|
|
optimizer: Optimizer = self.optimizers[i] |
|
|
|
|
|
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) |
|
|
|
|
|
logger.debug("Load optimizer state dict.") |
|
|
|
|
|
|
|
|
self.load_optimizer_state(optimizers_state_dict) |
|
|
|
|
|
|
|
|
# 2. 加载模型状态; |
|
|
# 2. 加载模型状态; |
|
|
if should_load_model: |
|
|
if should_load_model: |
|
|
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) |
|
|
self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict) |
|
|
if only_state_dict: |
|
|
|
|
|
logger.debug("Load model state dict...") |
|
|
|
|
|
else: |
|
|
|
|
|
logger.debug("Load model...") |
|
|
|
|
|
|
|
|
|
|
|
# 3. 加载fp16的状态; |
|
|
# 3. 加载fp16的状态; |
|
|
grad_scaler_state_dict = states.pop("grad_scaler_state_dict", None) |
|
|
|
|
|
if self.fp16: |
|
|
|
|
|
if grad_scaler_state_dict: |
|
|
|
|
|
|
|
|
if "grad_scaler_state_dict" in states: |
|
|
|
|
|
grad_scaler_state_dict = states.pop("grad_scaler_state_dict") |
|
|
|
|
|
if not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
self.grad_scaler.load_state_dict(grad_scaler_state_dict) |
|
|
self.grad_scaler.load_state_dict(grad_scaler_state_dict) |
|
|
logger.debug("Load grad_scaler state dict...") |
|
|
logger.debug("Load grad_scaler state dict...") |
|
|
else: |
|
|
|
|
|
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " |
|
|
|
|
|
f"the training process may be unstable.") |
|
|
|
|
|
|
|
|
elif not isinstance(self.grad_scaler, DummyGradScaler): |
|
|
|
|
|
logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " |
|
|
|
|
|
f"the training process may be unstable.") |
|
|
|
|
|
|
|
|
# 4. 恢复 sampler 的状态; |
|
|
# 4. 恢复 sampler 的状态; |
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
|
dataloader_args = self.get_dataloader_args(dataloader) |
|
@@ -344,7 +338,7 @@ class PaddleDriver(Driver): |
|
|
batch_size=dataloader_args.batch_size, |
|
|
batch_size=dataloader_args.batch_size, |
|
|
drop_last=dataloader_args.drop_last |
|
|
drop_last=dataloader_args.drop_last |
|
|
) |
|
|
) |
|
|
sampler.load_state_dict(states["sampler_states"]) |
|
|
|
|
|
|
|
|
sampler.load_state_dict(states.pop("sampler_states")) |
|
|
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) |
|
|
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) |
|
|
|
|
|
|
|
|
# 5. 修改 trainer_state.batch_idx_in_epoch |
|
|
# 5. 修改 trainer_state.batch_idx_in_epoch |
|
|