Browse Source

deepspeed checkpoint相关函数(ï微未测试)

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
2d2bf421fd
1 changed files with 108 additions and 12 deletions
  1. +108
    -12
      fastNLP/core/drivers/torch_driver/deepspeed.py

+ 108
- 12
fastNLP/core/drivers/torch_driver/deepspeed.py View File

@@ -1,11 +1,16 @@
import os
from pathlib import Path

from typing import Optional, Union, Callable, Dict, Tuple, Sequence, List
from typing import Union, Dict, List
from .torch_driver import TorchDriver
from .ddp import TorchDDPDriver
from .utils import _create_default_config, _DDPWrappingModel
from fastNLP.core.utils import nullcontext
from fastNLP.core.log import logger
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK
from fastNLP.envs import(
FASTNLP_DISTRIBUTED_CHECK,
FASTNLP_CHECKPOINT_FILENAME
)
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED

if _NEED_IMPORT_TORCH:
@@ -79,6 +84,15 @@ class DeepSpeedDriver(TorchDDPDriver):
self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的;
self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹;
self.strategy = strategy
self.accumulation_steps = kwargs.get("accumulation_steps", 1)
# 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数
train_dl = kwargs.get("train_dataloader", None)
if train_dl is not None:
self.train_micro_batch_size = self.get_dataloader_args(train_dl)
else:
logger.warn("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`"
"to 1 for deepspeed configuration.")
self.train_micro_batch_size = 1

self._ds_kwargs = kwargs.get("deepspeed_kwargs", {})

@@ -93,8 +107,8 @@ class DeepSpeedDriver(TorchDDPDriver):
raise ValueError("Multi optimizers is not supported for DeepSpeedDriver right now.")
if self._has_setup:
return
self.setup_config()
self._has_setup = True
self.setup_config()
# 如果用户需要使用多机模式,那么一定进入到这里;
if self.is_pull_by_torch_run:
# dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用;
@@ -152,12 +166,14 @@ class DeepSpeedDriver(TorchDDPDriver):
# 设置 deepspeed
if not isinstance(self.model, deepspeed.DeepSpeedEngine):
model=_DDPWrappingModel(self.model)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
self.model, ds_optimizer, _, _ = deepspeed.initialize(
model=_DDPWrappingModel(self.model),
model=model,
optimizer=self.optimizers[0],
config=self.config
model_parameters=model_parameters,
config=self.config,
)
# TODO 是否有必要
self._optimizers = [ds_optimizer]

if self.config.get("activation_checkpointing"):
@@ -174,6 +190,13 @@ class DeepSpeedDriver(TorchDDPDriver):

def setup_config(self):

self.config = self._ds_kwargs.get("config")
if self.config is not None:
# TODO 究竟哪些参数按照config,哪些按照trainer参数
logger.warn("Notice that you have defined a configuration for deepspeed and parameters like"
"`optimizers`, `strategy` and `fp16` may not take effects.")
return

if self.strategy == "deepspeed":
self.config = _create_default_config(stage=2)
elif self.strategy == "deepspeed_stage_1":
@@ -202,13 +225,11 @@ class DeepSpeedDriver(TorchDDPDriver):
else:
raise ValueError(f"Unknown deepspeed strategy {self.strategy}.")

self.config.setdefault("train_micro_batch_size_per_gpu", 1)
# 设置成 max_int 防止 deepspeed 的输出干扰 fastnlp 的输出
self.config.setdefault("steps_per_print", 2147483647)
self.config["gradient_accumulation_steps"] = self.accumulation_steps
self.config.setdefault("train_micro_batch_size_per_gpu", self.train_micro_batch_size)

# TODO 梯度裁剪的设置,这里需要用到trainer
# 从kwargs 获取
# 精度设置
# _format_precision_config
if self.fp16:
if "fp16" not in self.config:
# FP16 is a DeepSpeed standalone AMP implementation
@@ -238,6 +259,81 @@ class DeepSpeedDriver(TorchDDPDriver):

def unwrap_model(self):
r"""
:return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹
:return: 返回原本的模型;
"""
return self.model.module.model

def get_model_no_sync_context(self):
r"""
:return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文
"""
# 注意此时的 model 是 "DistributedDataParallel" 对象;
return nullcontext

def save_model(self, filepath: Union[str, Path], only_state_dict: bool = False, **kwargs):
"""
保存当前 driver 的模型到 folder 下。

:param filepath: 保存到哪个文件夹;
:param only_state_dict: 是否只保存权重;
:return:
"""
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
if self.zero_stage_3:
logger.rank_zero_warning(
"When saving the DeepSpeed Stage 3 checkpoint, "
"each worker will save a shard of the checkpoint within a directory. "
# TODO check一下
# "If a single file is required after training, "
# "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#"
# "deepspeed-zero-stage-3-single-file for instructions."
)
if not only_state_dict:
logger.rank_zero_warning("Only saving state dict is not allowed for `DeepSpeedDriver`. We will save its "
"checkpoint for you instead.")
self.model.save_checkpoint(filepath, **kwargs)

def load_model(self, filepath: Union[Path, str], only_state_dict: bool = False, **kwargs):
"""
从 folder 中加载权重并赋值到当前 driver 的模型上。

:param filepath: 加载权重或模型的路径
:param load_state_dict: 保存的内容是否只是权重。
:param kwargs:
:return:
"""
if not only_state_dict:
logger.warn("Only loading state dict is not allowed for `DeepSpeedDriver`. We will load its "
"checkpoint for you instead.")
self.model.load_checkpoint(filepath, **kwargs)

def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs):
# deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器
# 1. 保存 sampler 的状态
sampler_state_dict = self.get_sampler_state_dict()

# 2. 保存模型的状态;
if not should_save_model:
logger.rank_zero_warning("Saving checkpoint without model is not allowed for `DeepSpeedDriver`, "
"so we will still save the model for you.")

self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME),
client_state=sampler_state_dict)

def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict:
# 1. 加载模型状态;
if not should_load_model:
logger.rank_zero_warning("Loading checkpoint without model is not allowed for `DeepSpeedDriver`, "
"so we will still load the model for you.")
load_path, states = self.model.load_checkpoint(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME))
if load_path is None:
raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}")

# 2.恢复 sampler 的状态
states = self.load_sampler_state_dict(states)

return states

@property
def stage_3(self) -> bool:
return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3

Loading…
Cancel
Save