From 2366bc320bcc6d5bbd2b28e703572a3b8b71480d Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 11 Apr 2022 15:04:39 +0000 Subject: [PATCH] =?UTF-8?q?=E8=B7=9F=E8=BF=9B=E6=96=AD=E7=82=B9=E9=87=8D?= =?UTF-8?q?=E8=AE=AD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 59 ++++++++++++---- .../drivers/paddle_driver/paddle_driver.py | 55 +++++++-------- .../drivers/paddle_driver/single_device.py | 69 +++++++------------ 3 files changed, 94 insertions(+), 89 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 2a1d5228..86198959 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -10,7 +10,8 @@ from .utils import ( _MODE_PARAMETER, get_device_from_visible, reset_seed, - replace_sampler + replace_sampler, + replace_batch_sampler, ) from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -23,10 +24,12 @@ from fastNLP.core.utils import ( from fastNLP.core.samplers import ( RandomBatchSampler, ReproducibleSampler, - ReproducibleIterator, + ReproducibleBatchSampler, RandomSampler, - UnrepeatedDistributedSampler, + UnrepeatedSampler, + UnrepeatedSequentialSampler, re_instantiate_sampler, + conversion_between_reproducible_and_unrepeated_sampler, ) from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger @@ -261,7 +264,6 @@ class PaddleFleetDriver(PaddleDriver): 当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 根据 paddle 设置的环境变量来获得各种属性 """ - print("set_from_env") self.world_size = dist.get_world_size() self.global_rank = dist.get_rank() @@ -325,23 +327,50 @@ class PaddleFleetDriver(PaddleDriver): # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." - if isinstance(dist, ReproducibleIterator): - dist = re_instantiate_sampler(dist) + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleSampler 说明是在断点重训时 driver.load 函数调用; + # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; + if isinstance(dist, ReproducibleBatchSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleSampler): + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_sampler(dataloader, dist) + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; # trainer, evaluator - # 自己初始化了分布式,什么都不做 if dist is None: if reproducible: - raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " + raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: + if isinstance(dist, ReproducibleBatchSampler): + dist = re_instantiate_sampler(dist) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleSampler): + dist = re_instantiate_sampler(dist) + return replace_sampler(dataloader, dist) return dataloader # trainer elif dist == "dist": args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(args.sampler, ReproducibleIterator): + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + batch_sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, @@ -364,10 +393,14 @@ class PaddleFleetDriver(PaddleDriver): # evaluator elif dist == "unrepeatdist": args = self.get_dataloader_args(dataloader) - sampler = UnrepeatedDistributedSampler( - dataset=args.dataset, - shuffle=args.shuffle, - ) + if isinstance(args.sampler, ReproducibleSampler): + sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) + elif not isinstance(args.sampler, UnrepeatedSampler): + sampler = UnrepeatedSequentialSampler( + dataset=args.dataset + ) + else: + sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 69f9ed44..95e6215e 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -14,7 +14,7 @@ from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler if _NEED_IMPORT_PADDLE: import paddle @@ -178,11 +178,13 @@ class PaddleDriver(Driver): :param kwargs: :return: """ + debug = kwargs.get("debug", False) model = self.unwrap_model() - if only_state_dict: states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} paddle.save(states, filepath) + if debug: + logger.debug("Save model state dict.") else: # paddle 在保存整个模型时需要传入额外参数 input_spec = kwargs.get("input_spec", None) @@ -196,6 +198,8 @@ class PaddleDriver(Driver): self.move_model_to_device(model, self.model_device) else: paddle.jit.save(model, filepath, input_spec) + if debug: + logger.debug("Save model.") def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" @@ -207,11 +211,16 @@ class PaddleDriver(Driver): :param kwargs: :return: """ + debug = kwargs.get("debug", False) model = self.unwrap_model() if only_state_dict: model.load_dict(paddle.load(filepath)) + if debug: + logger.debug("Load model state dict.") else: model.load_dict(paddle.jit.load(filepath).state_dict()) + if debug: + logger.debug("Load model.") @rank_zero_call def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): @@ -252,17 +261,7 @@ class PaddleDriver(Driver): # 2. 保存模型的状态; if should_save_model: - model = self.unwrap_model() - if only_state_dict: - model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} - paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) - logger.debug("Save model state dict") - else: - input_spec = kwargs.get("input_spec", None) - if input_spec is None: - raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") - paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) - logger.debug("Save model") + self.save_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True, **kwargs) # 3. 保存 optimizers 的状态; optimizers_state_dict = {} @@ -272,7 +271,7 @@ class PaddleDriver(Driver): optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; - logger.debug("Save optimizer state dict") + logger.debug("Save optimizer state dict.") states["optimizers_state_dict"] = optimizers_state_dict paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) @@ -289,30 +288,23 @@ class PaddleDriver(Driver): # 2. 加载模型状态; if should_load_model: - model = self.unwrap_model() - if only_state_dict: - res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) - model.load_dict(res) - logger.debug("Load model state dict.") - else: - model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) - logger.debug("Load model.") + self.load_model(folder.joinpath(FASTNLP_MODEL_FILENAME), only_state_dict, debug=True) # 3. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) - sampler = dataloader_args.sampler - if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): - # 说明这里需要使用 ReproduceSampler 来弄一下了 - if self.is_distributed(): - raise RuntimeError( - "It is not allowed to use single device checkpoint retraining before but ddp now.") + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif isinstance(dataloader_args.sampler, ReproducibleSampler): + sampler = dataloader_args.sampler + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") + else: sampler = ReproducibleBatchSampler( - batch_sampler=sampler, - batch_size=dataloader_args.batch_sampler.batch_size, + batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, + batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) sampler.load_state_dict(states['sampler_states']) - states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch @@ -420,6 +412,7 @@ class PaddleDriver(Driver): res.dataset = dataloader.dataset if dataloader.batch_sampler is not None: + # 不过在 paddle 中,我们限定了 batch_sampler 不能为 None res.batch_sampler = dataloader.batch_sampler if hasattr(dataloader.batch_sampler, "batch_size"): res.batch_size = getattr(dataloader.batch_sampler, "batch_size") diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 83c3112a..dd5a340a 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -11,7 +11,7 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -137,55 +137,34 @@ class PaddleSingleDriver(PaddleDriver): """ return paddle_move_data_to_device(batch, "gpu:0") -<<<<<<< HEAD - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], -======= - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler], ->>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 - reproducible: bool = False, sampler_or_batch_sampler=None): - # 暂时不支持IteratorDataset + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, + reproducible: bool = False): + + # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ - "FastNLP does not support `IteratorDataset` now." + "FastNLP does not support `IteratorDataset` now." + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; if isinstance(dist, ReproducibleBatchSampler): -<<<<<<< HEAD return replace_batch_sampler(dataloader, dist) - elif isinstance(dist, ReproducibleIterator): - return replace_sampler(dataloader, dist) - - if reproducible: - args = self.get_dataloader_args(dataloader) - if isinstance(args.sampler, ReproducibleIterator): - sampler = re_instantiate_sampler(args.sampler) - return replace_sampler(dataloader, sampler) - elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): - batch_sampler = re_instantiate_sampler(dataloader.batch_sampler) - return replace_batch_sampler(dataloader, batch_sampler) - else: - batch_sampler = ReproducibleBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last -======= - dataloader.batch_sampler = dist - return dataloader - if isinstance(dist, ReproducibleSampler): - dataloader.batch_sampler.sampler = dist - return dataloader + elif isinstance(dist, ReproducibleSampler): + return replace_sampler(dataloader, dist) + + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) if reproducible: - if isinstance(dataloader.batch_sampler.sampler, ReproducibleSampler): - return dataloader - elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): - return dataloader - else: - # TODO - batch_sampler = ReproducibleBatchSampler( - batch_sampler=dataloader.batch_sampler, - batch_size=dataloader.batch_sampler.batch_size, - drop_last=dataloader.drop_last ->>>>>>> 388e426d78e8985a2f34dc83dfffe881274239a1 - ) - return replace_batch_sampler(dataloader, batch_sampler) + batch_sampler = ReproducibleBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader