From 2d2bf421fdc6e97fb14346e21511c3e91e1933d5 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sat, 18 Jun 2022 22:28:57 +0800 Subject: [PATCH] =?UTF-8?q?deepspeed=20checkpoint=E7=9B=B8=E5=85=B3?= =?UTF-8?q?=E5=87=BD=E6=95=B0=EF=BC=88=C3=AF=E5=BE=AE=E6=9C=AA=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=EF=BC=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/drivers/torch_driver/deepspeed.py | 120 ++++++++++++++++-- 1 file changed, 108 insertions(+), 12 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/deepspeed.py b/fastNLP/core/drivers/torch_driver/deepspeed.py index 298945ed..bb4df495 100644 --- a/fastNLP/core/drivers/torch_driver/deepspeed.py +++ b/fastNLP/core/drivers/torch_driver/deepspeed.py @@ -1,11 +1,16 @@ import os +from pathlib import Path -from typing import Optional, Union, Callable, Dict, Tuple, Sequence, List +from typing import Union, Dict, List from .torch_driver import TorchDriver from .ddp import TorchDDPDriver from .utils import _create_default_config, _DDPWrappingModel +from fastNLP.core.utils import nullcontext from fastNLP.core.log import logger -from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK +from fastNLP.envs import( + FASTNLP_DISTRIBUTED_CHECK, + FASTNLP_CHECKPOINT_FILENAME +) from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_DEEPSPEED if _NEED_IMPORT_TORCH: @@ -79,6 +84,15 @@ class DeepSpeedDriver(TorchDDPDriver): self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; self.strategy = strategy + self.accumulation_steps = kwargs.get("accumulation_steps", 1) + # 获取 batch_size 以设置 train_micro_batch_size_per_gpu 参数 + train_dl = kwargs.get("train_dataloader", None) + if train_dl is not None: + self.train_micro_batch_size = self.get_dataloader_args(train_dl) + else: + logger.warn("No `train_dataloader` found, and we will set `train_micro_batch_size_per_gpu`" + "to 1 for deepspeed configuration.") + self.train_micro_batch_size = 1 self._ds_kwargs = kwargs.get("deepspeed_kwargs", {}) @@ -93,8 +107,8 @@ class DeepSpeedDriver(TorchDDPDriver): raise ValueError("Multi optimizers is not supported for DeepSpeedDriver right now.") if self._has_setup: return - self.setup_config() self._has_setup = True + self.setup_config() # 如果用户需要使用多机模式,那么一定进入到这里; if self.is_pull_by_torch_run: # dist.get_world_size() 只能在 dist.init_process_group 初始化之后进行调用; @@ -152,12 +166,14 @@ class DeepSpeedDriver(TorchDDPDriver): # 设置 deepspeed if not isinstance(self.model, deepspeed.DeepSpeedEngine): + model=_DDPWrappingModel(self.model) + model_parameters = filter(lambda p: p.requires_grad, model.parameters()) self.model, ds_optimizer, _, _ = deepspeed.initialize( - model=_DDPWrappingModel(self.model), + model=model, optimizer=self.optimizers[0], - config=self.config + model_parameters=model_parameters, + config=self.config, ) - # TODO 是否有必要 self._optimizers = [ds_optimizer] if self.config.get("activation_checkpointing"): @@ -174,6 +190,13 @@ class DeepSpeedDriver(TorchDDPDriver): def setup_config(self): + self.config = self._ds_kwargs.get("config") + if self.config is not None: + # TODO 究竟哪些参数按照config,哪些按照trainer参数 + logger.warn("Notice that you have defined a configuration for deepspeed and parameters like" + "`optimizers`, `strategy` and `fp16` may not take effects.") + return + if self.strategy == "deepspeed": self.config = _create_default_config(stage=2) elif self.strategy == "deepspeed_stage_1": @@ -202,13 +225,11 @@ class DeepSpeedDriver(TorchDDPDriver): else: raise ValueError(f"Unknown deepspeed strategy {self.strategy}.") - self.config.setdefault("train_micro_batch_size_per_gpu", 1) + # 设置成 max_int 防止 deepspeed 的输出干扰 fastnlp 的输出 self.config.setdefault("steps_per_print", 2147483647) + self.config["gradient_accumulation_steps"] = self.accumulation_steps + self.config.setdefault("train_micro_batch_size_per_gpu", self.train_micro_batch_size) - # TODO 梯度裁剪的设置,这里需要用到trainer - # 从kwargs 获取 - # 精度设置 - # _format_precision_config if self.fp16: if "fp16" not in self.config: # FP16 is a DeepSpeed standalone AMP implementation @@ -238,6 +259,81 @@ class DeepSpeedDriver(TorchDDPDriver): def unwrap_model(self): r""" - :return: 返回原本的模型,例如没有被 ``DataParallel`` 包裹; + :return: 返回原本的模型; """ return self.model.module.model + + def get_model_no_sync_context(self): + r""" + :return: 返回一个 ``context`` 上下文环境,用于关闭各个进程之间的同步;在 ``deepspeed`` 中,返回一个空的上下文 + """ + # 注意此时的 model 是 "DistributedDataParallel" 对象; + return nullcontext + + def save_model(self, filepath: Union[str, Path], only_state_dict: bool = False, **kwargs): + """ + 保存当前 driver 的模型到 folder 下。 + + :param filepath: 保存到哪个文件夹; + :param only_state_dict: 是否只保存权重; + :return: + """ + # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 + if self.zero_stage_3: + logger.rank_zero_warning( + "When saving the DeepSpeed Stage 3 checkpoint, " + "each worker will save a shard of the checkpoint within a directory. " + # TODO check一下 + # "If a single file is required after training, " + # "see https://pytorch-lightning.readthedocs.io/en/latest/advanced/advanced_gpu.html#" + # "deepspeed-zero-stage-3-single-file for instructions." + ) + if not only_state_dict: + logger.rank_zero_warning("Only saving state dict is not allowed for `DeepSpeedDriver`. We will save its " + "checkpoint for you instead.") + self.model.save_checkpoint(filepath, **kwargs) + + def load_model(self, filepath: Union[Path, str], only_state_dict: bool = False, **kwargs): + """ + 从 folder 中加载权重并赋值到当前 driver 的模型上。 + + :param filepath: 加载权重或模型的路径 + :param load_state_dict: 保存的内容是否只是权重。 + :param kwargs: + :return: + """ + if not only_state_dict: + logger.warn("Only loading state dict is not allowed for `DeepSpeedDriver`. We will load its " + "checkpoint for you instead.") + self.model.load_checkpoint(filepath, **kwargs) + + def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + # deepspeed engine 要求在每个 rank 都调用 save_checkpoint,故去掉了 rank_zero_call 装饰器 + # 1. 保存 sampler 的状态 + sampler_state_dict = self.get_sampler_state_dict() + + # 2. 保存模型的状态; + if not should_save_model: + logger.rank_zero_warning("Saving checkpoint without model is not allowed for `DeepSpeedDriver`, " + "so we will still save the model for you.") + + self.model.save_checkpoint(Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME), + client_state=sampler_state_dict) + + def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + # 1. 加载模型状态; + if not should_load_model: + logger.rank_zero_warning("Loading checkpoint without model is not allowed for `DeepSpeedDriver`, " + "so we will still load the model for you.") + load_path, states = self.model.load_checkpoint(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) + if load_path is None: + raise RuntimeError(f"Failed to load checkpoint from path: {str(folder)}") + + # 2.恢复 sampler 的状态 + states = self.load_sampler_state_dict(states) + + return states + + @property + def stage_3(self) -> bool: + return self.config.get("zero_optimization") and self.config.get("zero_optimization").get("stage") == 3 \ No newline at end of file