From 0506fc2fcb8e3e31da13ec271f19e85913505c5f Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 5 Jul 2022 23:51:44 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=20TorchFSDPDriver?= =?UTF-8?q?=EF=BC=9B=E4=BF=AE=E6=94=B9=E4=BA=86=20ddp=20=E4=B8=AD=E7=9A=84?= =?UTF-8?q?=E9=83=A8=E5=88=86=E7=BB=86=E8=8A=82=EF=BC=9B=E5=88=A0=E9=99=A4?= =?UTF-8?q?=E4=BA=86=20topksaveer=20=E7=9A=84=20rank=5Fzero=5Fonly=20?= =?UTF-8?q?=E4=BF=AE=E9=A5=B0=E5=99=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/topk_saver.py | 4 +- fastNLP/core/drivers/choose_driver.py | 2 +- fastNLP/core/drivers/torch_driver/ddp.py | 18 +- .../torch_driver/initialize_torch_driver.py | 13 +- .../core/drivers/torch_driver/torch_driver.py | 2 +- .../core/drivers/torch_driver/torch_fsdp.py | 341 ++++++++++++++++ fastNLP/envs/imports.py | 1 + .../test_checkpoint_callback_torch.py | 2 +- .../test_trainer_w_evaluator_torch.py | 14 +- tests/core/drivers/torch_driver/test_fsdp.py | 379 ++++++++++++++++++ 10 files changed, 756 insertions(+), 20 deletions(-) create mode 100644 fastNLP/core/drivers/torch_driver/torch_fsdp.py create mode 100644 tests/core/drivers/torch_driver/test_fsdp.py diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 1ac23b77..b7bfbf19 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -51,7 +51,6 @@ class Saver: self.timestamp_path = self.folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) - @rank_zero_call def save(self, trainer, folder_name): """ 执行保存的函数,将数据保存在:: @@ -66,6 +65,7 @@ class Saver: """ folder = self.timestamp_path.joinpath(folder_name) folder.mkdir(parents=True, exist_ok=True) + save_fn = getattr(trainer, self.save_fn_name) save_fn( folder=folder, @@ -217,7 +217,7 @@ class TopkSaver(ResultsMonitor, Saver): self.topk_queue = TopkQueue(topk) self.save_evaluate_results = save_evaluate_results - @rank_zero_call + # 注意这里我们为了支持 torch_fsdp 去除了 ''@rank_zero_call''; def save_topk(self, trainer, results: Dict) -> Optional[str]: """ 根据 ``results`` 是否满足 topk 的相关设定决定是否保存,如果发生了保存,将返回保存的文件夹。如果返回为 ``None`` ,则说明此次没有满足 diff --git a/fastNLP/core/drivers/choose_driver.py b/fastNLP/core/drivers/choose_driver.py index 0f173b1c..7618a17e 100644 --- a/fastNLP/core/drivers/choose_driver.py +++ b/fastNLP/core/drivers/choose_driver.py @@ -30,7 +30,7 @@ def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, else: raise ValueError(f"Cannot choose driver automatically based on model, please set `driver` specifically.") - if driver in {"torch", "fairscale", "deepspeed"}: + if driver in {"torch", "fairscale", "deepspeed", "torch_fsdp"}: from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver return initialize_torch_driver(driver, device, model, **kwargs) elif driver in {"jittor"}: diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 28670071..26978ae2 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -309,9 +309,9 @@ class TorchDDPDriver(TorchDriver): self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) self.global_rank = 0 - self._ddp_kwargs = self._torch_kwargs.get("ddp_kwargs", {}) - check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) - if len(self.model._buffers) != 0 and self._ddp_kwargs.get("broadcast_buffers", None) is None: + self._fsdp_kwargs = self._torch_kwargs.get("ddp_kwargs", {}) + check_user_specific_params(self._fsdp_kwargs, DistributedDataParallel.__init__, DistributedDataParallel.__name__) + if len(self.model._buffers) != 0 and self._fsdp_kwargs.get("broadcast_buffers", None) is None: logger.info("Notice your model has buffers and you are using `TorchDDPDriver`, but you do not set " "'broadcast_buffers' in your trainer. Cause in most situations, this parameter can be set" " to 'False' to avoid redundant data communication between different processes.") @@ -381,8 +381,6 @@ class TorchDDPDriver(TorchDriver): self.global_rank = dist.get_rank() if not self.outside_ddp: - torch.cuda.set_device(self.model_device) - self.model.to(self.model_device) self.configure_ddp() self.barrier() @@ -400,11 +398,13 @@ class TorchDDPDriver(TorchDriver): self._pids = self.tensor_to_numeric(self._pids) def configure_ddp(self): + torch.cuda.set_device(self.model_device) + self.model.to(self.model_device) if not isinstance(self.model, DistributedDataParallel): self.model = DistributedDataParallel( # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; _DDPWrappingModel(self.model), device_ids=[self.model_device.index], - **self._ddp_kwargs + **self._fsdp_kwargs ) self._has_ddpwrapped = True @@ -505,6 +505,12 @@ class TorchDDPDriver(TorchDriver): raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") return fn, None elif fn in {"train_step", "evaluate_step"}: + + logger.warning("\n\nfucking hei\n\n") + print(model) + print("\n\n") + print(type(model)) + print("\n\n") return model, model.forward else: raise RuntimeError(f"There is no `{fn}` method in your model.") diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index f242b813..5f5af2ad 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -9,6 +9,7 @@ from .single_device import TorchSingleDriver from .ddp import TorchDDPDriver from .fairscale import FairScaleDriver from .deepspeed import DeepSpeedDriver +from .torch_fsdp import TorchFSDPDriver from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_BACKEND_LAUNCH from pkg_resources import parse_version @@ -45,7 +46,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), is_pull_by_torch_run=True, **kwargs) - if driver not in {"torch", "fairscale", "deepspeed"}: + if driver not in {"torch", "fairscale", "deepspeed", "torch_fsdp"}: raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'fairscale'].") _could_use_device_num = torch.cuda.device_count() @@ -95,4 +96,12 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi logger.warning_once("Notice you are using `deepspeed`, but the `device` is only one gpu.") return DeepSpeedDriver(model, [device], **kwargs) else: - return DeepSpeedDriver(model, device, **kwargs) \ No newline at end of file + return DeepSpeedDriver(model, device, **kwargs) + elif driver == "torch_fsdp": + if not isinstance(device, List): + if device.type == 'cpu': + raise ValueError("You are using `torch_fsdp` driver, but your chosen `device` is 'cpu'.") + logger.warning_once("Notice you are using `torch_fsdp`, but the `device` is only one gpu.") + return TorchFSDPDriver(model, [device], **kwargs) + else: + return TorchFSDPDriver(model, device, **kwargs) \ No newline at end of file diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index a748aa32..4c71f155 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -27,7 +27,7 @@ from .utils import optimizer_state_to_device from fastNLP.core.drivers.driver import Driver from fastNLP.core.drivers.torch_driver.utils import _build_fp16_env, DummyGradScaler from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device -from fastNLP.envs import rank_zero_call +from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler diff --git a/fastNLP/core/drivers/torch_driver/torch_fsdp.py b/fastNLP/core/drivers/torch_driver/torch_fsdp.py new file mode 100644 index 00000000..e6011603 --- /dev/null +++ b/fastNLP/core/drivers/torch_driver/torch_fsdp.py @@ -0,0 +1,341 @@ + + + +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12 + +if _TORCH_GREATER_EQUAL_1_12: + from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType + +import os +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel +from typing import Optional, Union, List, Dict, Mapping +from pathlib import Path + +from .ddp import TorchDDPDriver +from fastNLP.core.drivers.torch_driver.utils import ( + _DDPWrappingModel, +) + +from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, \ + FASTNLP_GLOBAL_RANK, rank_zero_call +from fastNLP.core.drivers.torch_driver.utils import DummyGradScaler +from fastNLP.core.log import logger +from fastNLP.core.utils import check_user_specific_params +from .utils import optimizer_state_to_device + + +""" +参考文档: +1. https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/ +2. https://pytorch.org/docs/stable/fsdp.html?highlight=fsdp +3. https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html +4. https://engineering.fb.com/2021/07/15/open-source/fsdp/ +""" + +class TorchFSDPDriver(TorchDDPDriver): + r""" + 实现对于 pytorch 自己实现的 fully sharded data parallel;请阅读该文档了解更多: + https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict; + + ..note:: + + ``TorchFSDPDriver`` 大部分行为与 ``TorchDDPDriver`` 相同,如果您不了解 ``TorchDDPDriver``, + 您可以先阅读 :class:`~fastNLP.core.drivers.TorchDDPDriver`; + + ..warning:: + + ``TorchFSDPDriver`` 现在还不支持断点重训功能,但是支持保存模型和加载模型; + + """ + + def __init__( + self, + model, + parallel_device: Optional[Union[List["torch.device"], "torch.device"]], + is_pull_by_torch_run: bool = False, + fp16: bool = False, + torch_kwargs: Dict = None, + **kwargs + ): + + # 在加入很多东西后,需要注意这里调用 super 函数的位置; + super(TorchDDPDriver, self).__init__(model, fp16=fp16, torch_kwargs=torch_kwargs, **kwargs) + + if isinstance(model, torch.nn.DataParallel): + raise ValueError(f"Parameter `model` can not be `DataParallel` in `TorchDDPDriver`, it should be " + f"`torch.nn.Module` or `torch.nn.parallel.DistributedDataParallel` type.") + + # 如果用户自己在外面初始化 DDP,那么其一定是通过 python -m torch.distributed.launch 拉起的; + self.is_pull_by_torch_run = is_pull_by_torch_run + self.parallel_device = parallel_device + if not is_pull_by_torch_run and parallel_device is None: + raise ValueError( + "Parameter `parallel_device` can not be None when using `TorchDDPDriver`. This error is caused " + "when your value of parameter `device` is `None` in your `Trainer` instance.") + + # 注意我们在 initialize_torch_driver 中的逻辑就是如果是 is_pull_by_torch_run,那么我们就直接把 parallel_device 置为当前进程的gpu; + if is_pull_by_torch_run: + self.model_device = parallel_device + else: + # 我们的 model_device 一定是 torch.device,而不是一个 list; + self.model_device = parallel_device[self.local_rank] + + # 如果用户自己在外面初始化了 FSDP; + self.outside_ddp = False + if dist.is_initialized() and FASTNLP_DISTRIBUTED_CHECK not in os.environ and \ + "fastnlp_torch_launch_not_ddp" not in os.environ: + # 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; + if not isinstance(model, FullyShardedDataParallel): + raise RuntimeError( + "It is not allowed to input a normal model instead of `FullyShardedDataParallel` when" + "you initialize the ddp process out of our control.") + if isinstance(model, DistributedDataParallel): + logger.warning("You are using `TorchFSDPDriver`, but you have initialized your model as " + "`DistributedDataParallel`, which will make the `FullyShardedDataParallel` not work " + "as expected. You could just delete `DistributedDataParallel` wrap operation.") + + self.outside_ddp = True + # 用户只有将模型上传到对应机器上后才能用 DistributedDataParallel 包裹,因此如果用户在外面初始化了 DDP,那么在 TorchDDPDriver 中 + # 我们就直接将 model_device 置为 None; + self.model_device = None + + # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; + self._data_device = kwargs.get("data_device", None) + if isinstance(self._data_device, int): + if self._data_device < 0: + raise ValueError("Parameter `data_device` can not be smaller than 0.") + _could_use_device_num = torch.cuda.device_count() + if self._data_device >= _could_use_device_num: + raise ValueError("The gpu device that parameter `device` specifies is not existed.") + self._data_device = torch.device(f"cuda:{self._data_device}") + elif isinstance(self._data_device, str): + self._data_device = torch.device(self._data_device) + elif self._data_device is not None and not isinstance(self._data_device, torch.device): + raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") + + self._master_port = None + # world_size 表示的就是全局的显卡的数量; + self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) + self.global_rank = 0 + + self._fsdp_kwargs = self._torch_kwargs.get("fsdp_kwargs", {}) + self._save_on_rank0 = self._fsdp_kwargs.get("save_on_rank0", False) + if "save_on_rank0" in self._fsdp_kwargs: + self._fsdp_kwargs.pop("save_on_rank0") + self._load_on_rank0 = self._fsdp_kwargs.get("load_on_rank0", False) + if "load_on_rank0" in self._fsdp_kwargs: + self._fsdp_kwargs.pop("load_on_rank0") + + if self._save_on_rank0 != self._load_on_rank0: + logger.warning(f"Notice the behavior between ``save`` and ``load`` is not matched, you choose " + f"{'save on rank0' if self._save_on_rank0 else 'save on each rank'}, but " + f"{'load on rank0' if self._save_on_rank0 else 'load on each rank'}!") + + check_user_specific_params(self._fsdp_kwargs, FullyShardedDataParallel.__init__, FullyShardedDataParallel.__name__) + if "cpu_offload" in self._fsdp_kwargs and kwargs["accumulation_steps"] != 1: + logger.warning("It is not supported ``accumulation_steps`` when using ``cpu_offload`` in " + "``FullyShardedDataParallel``.") + + self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") + assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." + if self.output_from_new_proc not in {"all", "ignore", "only_error"}: + os.makedirs(name=self.output_from_new_proc, exist_ok=True) + self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) + + self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; + self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; + + def configure_ddp(self): + torch.cuda.set_device(self.model_device) + if not isinstance(self.model, FullyShardedDataParallel): + self.model = FullyShardedDataParallel( + # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; + _DDPWrappingModel(self.model), device_id=self.model_device.index, + **self._fsdp_kwargs + ) + + # 必须先使用 FullyShardedDataParallel 包裹模型后再使用 optimizer 包裹模型的参数,因此这里需要将 optimizer 重新初始化一遍; + for i in range(len(self.optimizers)): + self.optimizers[i] = type(self.optimizers[i])(self.model.parameters(), **self.optimizers[i].defaults) + + self._has_ddpwrapped = True + + def unwrap_model(self): + """ + 注意该函数因为需要在特定的时候进行调用,例如 ddp 在 get_model_call_fn 的时候,因此不能够删除; + 如果您使用该函数来获取原模型的结构信息,是可以的; + 但是如果您想要通过该函数来获取原模型实际的参数,是不可以的,因为在 FullyShardedDataParallel 中模型被切分成了多个部分,而对于每个 gpu 上 + 的模型只是整体模型的一部分。 + """ + _module = self.model.module.module + if isinstance(_module, _DDPWrappingModel): + return _module.model + else: + return _module + + def save_model(self, filepath: Union[str, Path], only_state_dict: bool = True, **kwargs): + filepath = Path(filepath) + prefix = filepath.parent + filename = filepath.name + _filename = filename.split('.') + filename, suffix = _filename[0], '.'.join(_filename[1:]) + if only_state_dict: + if self._save_on_rank0: + full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT, full_state_dict_config): + state_dict = self.model.state_dict() + rank_zero_call(torch.save)(state_dict, filepath) + else: + # 添加 'rank0/1' 字段来区分全部聚集到 rank0 保存的方式; + _filename = filename.split('_') + filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1] + filepath = prefix.joinpath(filename + "." + suffix) + with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + state_dict = self.model.state_dict() + torch.save(state_dict, filepath) + else: + raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") + + def load_model(self, filepath: Union[Path, str], only_state_dict: bool = True, **kwargs): + if only_state_dict is False: + raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") + filepath = Path(filepath) + prefix = filepath.parent + filename = filepath.name + _filename = filename.split('.') + filename, suffix = _filename[0], '.'.join(_filename[1:]) + + if not self._load_on_rank0: + _filename = filename.split('_') + filename = _filename[0] + f"_rank{int(os.environ.get(FASTNLP_GLOBAL_RANK, 0))}_" + _filename[1] + filepath = prefix.joinpath(filename + "." + suffix) + states = torch.load(filepath) + else: + states = torch.load(filepath, map_location="cpu") + + if isinstance(states, dict) and only_state_dict is False: + logger.rank_zero_warning(f"It seems like that {filepath} only contains state, you may need to use " + f"`only_state_dict=True`") + elif not isinstance(states, dict) and only_state_dict is True: + logger.rank_zero_warning(f"It seems like that {filepath} is not state, you may need to use " + f"`only_state_dict=False`") + if not isinstance(states, Mapping): + states = states.state_dict() + + if self._load_on_rank0: + with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.FULL_STATE_DICT): + self.model.load_state_dict(states) + else: + with FullyShardedDataParallel.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + self.model.load_state_dict(states) + + def save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + raise RuntimeError("``TorchFSDPDriver`` does not support ``save_checkpoint`` function for now, there is some " + "technical issues that needs to solve. You can implement your own breakpoint retraining " + "by rewriting this function. The important thing is how to save and load the optimizers' state dict, " + "you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.") + + def load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + raise RuntimeError("``TorchFSDPDriver`` does not support ``load_checkpoint`` function for now, there is some " + "technical issues that needs to solve. You can implement your own breakpoint retraining " + "by rewriting this function. The important thing is how to save and load the optimizers' state dict, " + "you can see ``https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.full_optim_state_dict``.") + + # todo 这些加了 __ 的函数是目前还不支持; + # 这是因为 1.12 的 pytorch fsdp 的关于如何保存和加载 optimizer state dict 的接口有点过于反人类,无法在 fastNLP 的框架中进行调和 + # 使用; + def __get_optimizer_state(self): + optimizers_state_dict = {} + for i in range(len(self.optimizers)): + # 注意这里其余 rank 拿到的是一个空字典,因此在真正保存的时候需要保证只有 rank0 在工作; + optimizer_state = FullyShardedDataParallel.full_optim_state_dict(self.model, self.optimizers[i]) + if self._save_on_rank0: + with FullyShardedDataParallel.summon_full_params(self.model): + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: + unwrapped_model = self.model.module.module + optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict( + optimizer_state, OptimStateKeyType.PARAM_ID, unwrapped_model) + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: + optimizer_state["state"] = optimizer_state_to_device(optimizer_state["state"], torch.device("cpu")) + optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; + return optimizers_state_dict + + # 这里单独拿出来是因为对于 fsdp 来说,每一个进程都需要运行此函数,因此不能包裹 rank_zero_call; + def __save_checkpoint(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): + if not only_state_dict: + raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") + + # 1. sampler 的状态; + num_consumed_batches = states.pop('num_consumed_batches') + states['sampler_states'] = self.get_sampler_state(dataloader, num_consumed_batches) + + # 2. 保存模型的状态; + if should_save_model: + if not os.path.exists(folder): + os.mkdir(folder) + model_path = folder.joinpath(FASTNLP_MODEL_FILENAME) + self.save_model(model_path, only_state_dict=True) + + # 3. 保存 optimizers 的状态; + states["optimizers_state_dict"] = self.get_optimizer_state() + logger.debug("Save optimizer state dict.") + + # 4. 保存fp16的状态 + if not isinstance(self.grad_scaler, DummyGradScaler): + grad_scaler_state_dict = self.grad_scaler.state_dict() + states['grad_scaler_state_dict'] = grad_scaler_state_dict + + # 确保只有 rank0 才会执行实际的保存操作; + rank_zero_call(torch.save)(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) + + def __load_optimizer_state(self, states): + assert len(states) == len(self.optimizers), f"The number of optimizers is:{len(self.optimizers)}, while in " \ + f"checkpoint it is:{len(states)}" + + with FullyShardedDataParallel.summon_full_params(self.model): + unwrapped_model = self.model.module.module + + for i in range(len(self.optimizers)): + optimizer_state = states[f'optimizer{i}'] + if self._load_on_rank0: + optimizer_state = FullyShardedDataParallel.rekey_optim_state_dict(optimizer_state, OptimStateKeyType.PARAM_NAME, unwrapped_model) + optimizer_state = FullyShardedDataParallel.shard_full_optim_state_dict(optimizer_state, unwrapped_model) + optimizer: torch.optim.Optimizer = type(self.optimizers[i])(unwrapped_model.parameters(), **self.optimizers[i].defaults) + optimizer.load_state_dict(optimizer_state) + self.optimizers[i] = optimizer + + logger.debug("Load optimizer state dict.") + + def __load_checkpoint(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + if not only_state_dict: + raise RuntimeError("When using `TorchFSDPDriver`, only `only_state_dict=True` is allowed.") + + states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) + + # 1. 加载 optimizers 的状态; + optimizers_state_dict = states.pop("optimizers_state_dict") + self.load_optimizer_state(optimizers_state_dict) + + # 2. 加载模型状态; + if should_load_model: + self.load_model(filepath=folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict=only_state_dict) + + # 3. 加载 fp16 的状态 + if "grad_scaler_state_dict" in states: + grad_scaler_state_dict = states.pop("grad_scaler_state_dict") + if not isinstance(self.grad_scaler, DummyGradScaler): + self.grad_scaler.load_state_dict(grad_scaler_state_dict) + logger.debug("Load grad_scaler state dict...") + elif not isinstance(self.grad_scaler, DummyGradScaler): + logger.rank_zero_warning(f"Checkpoint {folder} is not trained with fp16=True, while resume to a fp16=True training, " + f"the training process may be unstable.") + + # 4. 恢复 sampler 的状态; + sampler_states = states.pop('sampler_states') + states_ret = self.load_sampler_state(dataloader, sampler_states) + states.update(states_ret) + + return states + diff --git a/fastNLP/envs/imports.py b/fastNLP/envs/imports.py index 08afc6a5..2a8e5317 100644 --- a/fastNLP/envs/imports.py +++ b/fastNLP/envs/imports.py @@ -26,3 +26,4 @@ _NEED_IMPORT_DEEPSPEED = _module_available("deepspeed") and 'torch' in need_impo _NEED_IMPORT_ONEFLOW = _module_available("oneflow") and 'oneflow' in need_import _TORCH_GREATER_EQUAL_1_8 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.8.0") +_TORCH_GREATER_EQUAL_1_12 = _NEED_IMPORT_TORCH and _compare_version("torch", operator.ge, "1.12.0") \ No newline at end of file diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index d227a162..00b73b51 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -75,7 +75,7 @@ def model_and_optimizers(request): @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) @magic_argv_env_context(timeout=100) def test_model_checkpoint_callback_1( model_and_optimizers: TrainerParameters, diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 78eff36c..4d31e5f8 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -103,8 +103,8 @@ def model_and_optimizers(request): # 测试一下普通的情况; @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), - ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), + ("torch", [4, 5])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) @pytest.mark.parametrize("evaluate_every", [-3, -1, 2]) @magic_argv_env_context def test_trainer_torch_with_evaluator( @@ -139,7 +139,7 @@ def test_trainer_torch_with_evaluator( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1)]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [4, 5]), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) @pytest.mark.parametrize("fp16", [True, False]) @pytest.mark.parametrize("accumulation_steps", [1, 3]) @magic_argv_env_context @@ -250,7 +250,7 @@ def test_trainer_on( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 0)]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", 'cpu'), ("torch", 4)]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_specific_params_1( model_and_optimizers: TrainerParameters, @@ -291,7 +291,7 @@ def test_trainer_specific_params_1( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) @magic_argv_env_context def test_trainer_specific_params_2( model_and_optimizers: TrainerParameters, @@ -331,7 +331,7 @@ def test_trainer_specific_params_2( assert trainer.driver.wo_auto_param_call is True assert trainer.driver.output_from_new_proc == "all" - _ddp_kwargs = trainer.driver._ddp_kwargs + _ddp_kwargs = trainer.driver._fsdp_kwargs assert _ddp_kwargs.get("broadcast_buffers") is True assert _ddp_kwargs.get("find_unused_parameters") is True @@ -340,7 +340,7 @@ def test_trainer_specific_params_2( @pytest.mark.torch -@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1) +@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", [0, 1]),("torch", 1) @pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)]) @magic_argv_env_context def test_trainer_w_evaluator_overfit_torch( diff --git a/tests/core/drivers/torch_driver/test_fsdp.py b/tests/core/drivers/torch_driver/test_fsdp.py new file mode 100644 index 00000000..9ba890ca --- /dev/null +++ b/tests/core/drivers/torch_driver/test_fsdp.py @@ -0,0 +1,379 @@ +import os +from dataclasses import dataclass +from typing import Any +from pathlib import Path +import re + +import pytest +from fastNLP.core.controllers.trainer import Trainer +from torchmetrics import Accuracy +from fastNLP.core.callbacks import CheckpointCallback +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset +from tests.helpers.callbacks.helper_callbacks import RecordLossCallback +from tests.helpers.utils import magic_argv_env_context +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.envs import FASTNLP_LAUNCH_TIME, rank_zero_rm +if _NEED_IMPORT_TORCH: + import torch.distributed as dist + from torch.optim import SGD + from torch.utils.data import DataLoader + + +@dataclass +class ArgMaxDatasetConfig: + num_labels: int = 10 + feature_dimension: int = 10 + data_num: int = 50 + seed: int = 0 + + batch_size: int = 2 + shuffle: bool = True + + +@dataclass +class TrainerParameters: + model: Any = None + optimizers: Any = None + train_dataloader: Any = None + evaluate_dataloaders: Any = None + input_mapping: Any = None + output_mapping: Any = None + metrics: Any = None + + +@pytest.fixture(scope="module", params=[0], autouse=True) +def model_and_optimizers(request): + trainer_params = TrainerParameters() + + trainer_params.model = TorchNormalModel_Classification_1( + num_labels=ArgMaxDatasetConfig.num_labels, + feature_dimension=ArgMaxDatasetConfig.feature_dimension + ) + trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) + dataset = TorchArgMaxDataset( + feature_dimension=ArgMaxDatasetConfig.feature_dimension, + data_num=ArgMaxDatasetConfig.data_num, + seed=ArgMaxDatasetConfig.seed + ) + _dataloader = DataLoader( + dataset=dataset, + batch_size=ArgMaxDatasetConfig.batch_size, + shuffle=True + ) + trainer_params.train_dataloader = _dataloader + trainer_params.evaluate_dataloaders = _dataloader + trainer_params.metrics = {"acc": Accuracy()} + + return trainer_params + +@pytest.mark.torch +@magic_argv_env_context +def test_trainer_torch_without_evaluator( + model_and_optimizers: TrainerParameters, + n_epochs=3, +): + callbacks = [RecordLossCallback(loss_threshold=0.5)] + trainer = Trainer( + model=model_and_optimizers.model, + driver="torch_fsdp", + device=[4, 5], + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=3, + callbacks=callbacks, + output_from_new_proc="all" + + ) + + trainer.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + +@pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch_fsdp", [4, 5])]) +@magic_argv_env_context(timeout=100) +def test_model_checkpoint_callback_1( + model_and_optimizers: TrainerParameters, + driver, + device +): + for version in [0]: + # 需要在每一个循环开始重新初始化 model,是因为 fsdp 会将当前卡上的 model 删除,从而导致这个引用实际上引用到的是一个空模型; + model_and_optimizers.model = TorchNormalModel_Classification_1( + num_labels=ArgMaxDatasetConfig.num_labels, + feature_dimension=ArgMaxDatasetConfig.feature_dimension + ) + try: + path = Path.cwd().joinpath(f"test_model_checkpoint") + path.mkdir(exist_ok=True, parents=True) + + if version == 0: + callbacks = [ + CheckpointCallback(folder=path, every_n_epochs=1, every_n_batches=123, last=False, on_exceptions=None, topk=0, + monitor=None, only_state_dict=True, save_object='model') + ] + elif version == 1: + callbacks = [ + CheckpointCallback(folder=path, every_n_epochs=3, every_n_batches=None, last=True, on_exceptions=None, topk=2, + monitor="acc", only_state_dict=True, save_object='model') + ] + + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=10, + callbacks=callbacks, + output_from_new_proc="all", + # torch_kwargs={"fsdp_kwargs": {'save_on_rank0': True}} + ) + + trainer.run() + print("Finish train") + all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} + # 检查生成保存模型文件的数量是不是正确的; + if version == 0: + + if not isinstance(device, list): + assert "model-epoch_10" in all_saved_model_paths + assert "model-epoch_4-batch_123" in all_saved_model_paths + + epoch_save_path = all_saved_model_paths["model-epoch_10"] + step_save_path = all_saved_model_paths["model-epoch_4-batch_123"] + + assert len(all_saved_model_paths) == 12 + # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; + else: + assert "model-epoch_6" in all_saved_model_paths + assert "model-epoch_9-batch_123" in all_saved_model_paths + + epoch_save_path = all_saved_model_paths["model-epoch_6"] + step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] + + assert len(all_saved_model_paths) == 11 + all_state_dicts = [epoch_save_path]#, step_save_path] + + elif version == 1: + + pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") + + if not isinstance(device, list): + assert "model-epoch_9" in all_saved_model_paths + assert "model-last" in all_saved_model_paths + aLL_topk_folders = [] + for each_folder_name in all_saved_model_paths: + each_folder_name = pattern.findall(each_folder_name) + if len(each_folder_name) != 0: + aLL_topk_folders.append(each_folder_name[0]) + assert len(aLL_topk_folders) == 2 + + epoch_save_path = all_saved_model_paths["model-epoch_9"] + last_save_path = all_saved_model_paths["model-last"] + topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] + + assert len(all_saved_model_paths) == 6 + # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; + else: + assert "model-epoch_9" in all_saved_model_paths + assert "model-last" in all_saved_model_paths + + aLL_topk_folders = [] + for each_folder_name in all_saved_model_paths: + each_folder_name = pattern.findall(each_folder_name) + if len(each_folder_name) != 0: + aLL_topk_folders.append(each_folder_name[0]) + assert len(aLL_topk_folders) == 2 + + epoch_save_path = all_saved_model_paths["model-epoch_9"] + last_save_path = all_saved_model_paths["model-last"] + topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] + + assert len(all_saved_model_paths) == 6 + + all_state_dicts = [epoch_save_path, last_save_path, topk_save_path] + + for folder in all_state_dicts: + model_and_optimizers.model = TorchNormalModel_Classification_1( + num_labels=ArgMaxDatasetConfig.num_labels, + feature_dimension=ArgMaxDatasetConfig.feature_dimension + ) + + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=20, + output_from_new_proc="all", + + ) + trainer.load_model(folder, only_state_dict=True) + + trainer.run() + trainer.driver.barrier() + finally: + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() + + + + + +@pytest.mark.skip("现在 fsdp 还不支持断点重训;") +@pytest.mark.torch +@pytest.mark.parametrize("driver,device", [("torch_fsdp", [6, 7])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) +@magic_argv_env_context(timeout=100) +def test_trainer_checkpoint_callback_1( + model_and_optimizers: TrainerParameters, + driver, + device +): + for version in [0, 1]: + model_and_optimizers.model = TorchNormalModel_Classification_1( + num_labels=ArgMaxDatasetConfig.num_labels, + feature_dimension=ArgMaxDatasetConfig.feature_dimension + ) + try: + path = Path.cwd().joinpath(f"test_model_checkpoint") + path.mkdir(exist_ok=True, parents=True) + + if version == 0: + callbacks = [ + CheckpointCallback(folder=path, every_n_epochs=7, every_n_batches=123, last=False, on_exceptions=None, topk=0, + monitor=None, only_state_dict=True, save_object='trainer') + ] + elif version == 1: + callbacks = [ + CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=True, on_exceptions=None, + topk=2, monitor="acc", only_state_dict=True, save_object='trainer') + ] + + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=10, + callbacks=callbacks, + output_from_new_proc="all" + ) + + trainer.run() + + all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} + # 检查生成保存模型文件的数量是不是正确的; + if version == 0: + + if not isinstance(device, list): + assert "trainer-epoch_7" in all_saved_model_paths + assert "trainer-epoch_4-batch_123" in all_saved_model_paths + + epoch_save_path = all_saved_model_paths["trainer-epoch_7"] + step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"] + + assert len(all_saved_model_paths) == 3 + # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; + else: + assert "trainer-epoch_7" in all_saved_model_paths + assert "trainer-epoch_9-batch_123" in all_saved_model_paths + + epoch_save_path = all_saved_model_paths["trainer-epoch_7"] + step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"] + + assert len(all_saved_model_paths) == 2 + all_state_dicts = [epoch_save_path, step_save_path] + + elif version == 1: + + pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") + + # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} + if not isinstance(device, list): + assert "trainer-last" in all_saved_model_paths + aLL_topk_folders = [] + for each_folder_name in all_saved_model_paths: + each_folder_name = pattern.findall(each_folder_name) + if len(each_folder_name) != 0: + aLL_topk_folders.append(each_folder_name[0]) + assert len(aLL_topk_folders) == 2 + + last_save_path = all_saved_model_paths["trainer-last"] + topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] + + assert len(all_saved_model_paths) == 3 + # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; + else: + assert "trainer-last" in all_saved_model_paths + + aLL_topk_folders = [] + for each_folder_name in all_saved_model_paths: + each_folder_name = pattern.findall(each_folder_name) + if len(each_folder_name) != 0: + aLL_topk_folders.append(each_folder_name[0]) + assert len(aLL_topk_folders) == 2 + + last_save_path = all_saved_model_paths["trainer-last"] + topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] + + assert len(all_saved_model_paths) == 3 + + all_state_dicts = [last_save_path, topk_save_path] + + for folder in all_state_dicts: + model_and_optimizers.model = TorchNormalModel_Classification_1( + num_labels=ArgMaxDatasetConfig.num_labels, + feature_dimension=ArgMaxDatasetConfig.feature_dimension + ) + + trainer = Trainer( + model=model_and_optimizers.model, + driver=driver, + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + + n_epochs=13, + output_from_new_proc="all" + ) + trainer.load_checkpoint(folder, only_state_dict=True) + + trainer.run() + trainer.driver.barrier() + + finally: + rank_zero_rm(path) + + if dist.is_initialized(): + dist.destroy_process_group() \ No newline at end of file