From 9678c559c99b68d572c228bad821031b5389bf31 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Sun, 10 Apr 2022 14:59:45 +0000 Subject: [PATCH] =?UTF-8?q?=E8=B7=9F=E8=BF=9B=E6=96=AD=E7=82=B9=E9=87=8D?= =?UTF-8?q?=E8=AE=AD=E7=9A=84=E8=AE=BE=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 50 ++-- .../drivers/paddle_driver/paddle_driver.py | 231 ++++++++++++++---- .../drivers/paddle_driver/single_device.py | 35 ++- fastNLP/core/drivers/paddle_driver/utils.py | 69 ++++-- 4 files changed, 272 insertions(+), 113 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 0fd74795..3635ae14 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -10,6 +10,7 @@ from .utils import ( _MODE_PARAMETER, get_device_from_visible, reset_seed, + replace_sampler ) from fastNLP.envs.imports import _NEED_IMPORT_PADDLE @@ -19,8 +20,13 @@ from fastNLP.core.utils import ( paddle_move_data_to_device, is_in_paddle_dist, ) -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler -from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES +from fastNLP.core.samplers import ( + ReproducibleIterator, + RandomSampler, + UnrepeatedDistributedSampler, + re_instantiate_sampler, +) +from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -314,23 +320,15 @@ class PaddleFleetDriver(PaddleDriver): def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], reproducible: bool = False, sampler_or_batch_sampler=None): - # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." if isinstance(dist, ReproducibleIterator): - dataloader.batch_sampler.sampler = dist - return dataloader - - # paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 - # 但是其子类 DistributedBatchSampler 却有 shuffle 成员 - # 因此用 type() 进行严格的判断 - if type(dataloader.batch_sampler) == BatchSampler: - shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler) - else: - shuffle = dataloader.batch_sampler.shuffle + dist = re_instantiate_sampler(dist) + return replace_sampler(dataloader, dist) # trainer, evaluator + # 自己初始化了分布式,什么都不做 if dist is None: if reproducible: raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " @@ -339,40 +337,40 @@ class PaddleFleetDriver(PaddleDriver): return dataloader # trainer elif dist == "dist": + args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): - dataloader.batch_sampler.sampler.set_distributed( + if isinstance(args.sampler, ReproducibleIterator): + sampler = re_instantiate_sampler(args.sampler) + sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, pad=True ) - return dataloader + return replace_sampler(dataloader, sampler) else: sampler = RandomSampler( - dataset=dataloader.dataset, - shuffle=shuffle, - seed=int(os.environ.get("FASTNLP_SEED", 0)) + dataset=args.dataset, + shuffle=args.shuffle, + seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) ) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, pad=True ) - dataloader.batch_sampler.sampler = sampler - return dataloader + return replace_sampler(dataloader, sampler) # evaluator elif dist == "unrepeatdist": + args = self.get_dataloader_args(dataloader) sampler = UnrepeatedDistributedSampler( - dataset=dataloader.dataset, - shuffle=shuffle, - seed=int(os.environ.get("FASTNLP_SEED", 0)) + dataset=args.dataset, + shuffle=args.shuffle, ) sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank ) - dataloader.batch_sampler.sampler = sampler - return dataloader + return replace_sampler(dataloader, sampler) else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 84ce6ec2..69f9ed44 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -1,21 +1,31 @@ import os import random -from typing import Union, Optional, Callable, Dict +from typing import Union, Optional, Dict +from pathlib import Path from functools import partial +from dataclasses import dataclass import numpy as np -from .utils import _build_fp16_env +from .utils import _build_fp16_env, optimizer_state_to_device from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.core.drivers.driver import Driver from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device from fastNLP.envs import rank_zero_call -from fastNLP.envs import FASTNLP_SEED_WORKERS +from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger +from fastNLP.core.samplers import ReproducibleBatchSampler if _NEED_IMPORT_PADDLE: import paddle - from paddle.io import DataLoader, IterableDataset + from paddle.io import ( + DataLoader, + IterableDataset, + Dataset, + Sampler, + BatchSampler, + RandomSampler, + ) from paddle.optimizer import Optimizer _reduces = { @@ -69,6 +79,8 @@ class PaddleDriver(Driver): # TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; if isinstance(dataloader.dataset, IterableDataset): raise TypeError("`IterableDataset` is not allowed.") + if dataloader.batch_sampler is None and dataloader.batch_size is None: + raise ValueError(f"At least one of `{dataloader_name}`'s `batch_sampler` and `batch_size` should be set.") else: if not isinstance(dataloader, Dict): raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") @@ -79,6 +91,9 @@ class PaddleDriver(Driver): f"type, not {type(each_dataloader)}.") if isinstance(each_dataloader.dataset, IterableDataset): raise TypeError("`IterableDataset` is not allowed.") + if dataloader.batch_sampler is None and dataloader.batch_size is None: + raise ValueError(f"For each dataloader of parameter `{dataloader_name}`, at least one of " + f"`batch_sampler` and `batch_size` should be set.") @staticmethod def _check_optimizer_legality(optimizers): @@ -153,45 +168,53 @@ class PaddleDriver(Driver): getattr(self.model, mode)() @rank_zero_call - def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): + def save_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" 保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; 如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; :param filepath: 保存文件的文件位置(需要包括文件名); - :param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; - :param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); + :param only_state_dict: 是否只保存模型的 `state_dict`; + :param kwargs: + :return: """ - if model_save_fn is not None: - model_save_fn(filepath) + model = self.unwrap_model() + + if only_state_dict: + states = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} + paddle.save(states, filepath) else: - model = self.unwrap_model() - if only_state_dict: - paddle.save(model.state_dict(), filepath) + # paddle 在保存整个模型时需要传入额外参数 + input_spec = kwargs.get("input_spec", None) + if input_spec is None: + raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") + if self.model_device is not None: + if not self.is_distributed(): + self.move_model_to_device(model, "cpu") + paddle.jit.save(model, filepath, input_spec) + if not self.is_distributed(): + self.move_model_to_device(model, self.model_device) else: - input_spec = kwargs.get("input_spec", None) - if input_spec is None: - raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") paddle.jit.save(model, filepath, input_spec) - @staticmethod - @rank_zero_call - def load_model(filepath: str, load_dict: bool = True): + def load_model(self, filepath: str, only_state_dict: bool = True, **kwargs): r""" 加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; :param filepath: 需要被加载的对象的文件位置(需要包括文件名); :param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, 即保存了整个模型时,这个参数必须也为False - :return: 返回加载指定文件后的结果; + :param kwargs: + :return: """ - if load_dict: - return paddle.load(filepath) + model = self.unwrap_model() + if only_state_dict: + model.load_dict(paddle.load(filepath)) else: - return paddle.jit.load(filepath) + model.load_dict(paddle.jit.load(filepath).state_dict()) @rank_zero_call - def save(self, folder, states: Dict): + def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): r""" 断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; 需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver @@ -203,48 +226,110 @@ class PaddleDriver(Driver): :param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 传入的值保持一致。 + :param dataloader: 正在使用的 dataloader,需要保存里面的状态使得之后可以从当前迭代的位置恢复。 + :param only_state_dict: 是否只保存模型的参数,当 should_save_model 为 False ,该参数无效。 + :param should_save_model: 是否应该保存模型,如果为False,Driver 将不负责 model 的保存。 + :return: """ - # 1. 保存模型的状态; - model = self.unwrap_model() - model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} - # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; - states["model_state_dict"] = model_state_dict + # 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 + # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; + + # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; + # paddle 的 DataLoader 在初始化之后 batch_sampler 可能为 None,也可能为用户设置的 batch_sampler + dataloader_args = self.get_dataloader_args(dataloader) + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif dataloader_args.sampler: + sampler = dataloader_args.sampler + else: + raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") + + if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): + states['sampler_states'] = sampler.state_dict() + else: + raise RuntimeError( + 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') - # 2. 保存 optimizers 的状态; + # 2. 保存模型的状态; + if should_save_model: + model = self.unwrap_model() + if only_state_dict: + model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} + paddle.save(model_state_dict, folder.joinpath(FASTNLP_MODEL_FILENAME)) + logger.debug("Save model state dict") + else: + input_spec = kwargs.get("input_spec", None) + if input_spec is None: + raise ValueError("To save the whole Paddle Layer, parameter `input_spec` is needed.") + paddle.jit.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME), input_spec) + logger.debug("Save model") + + # 3. 保存 optimizers 的状态; optimizers_state_dict = {} for i in range(len(self.optimizers)): optimizer: Optimizer = self.optimizers[i] optimizer_state = optimizer.state_dict() - optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} + optimizer_state["state"] = optimizer_state_to_device(optimizer_state, "cpu") optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; - states["optimizers_state_dict"] = optimizers_state_dict - - paddle.save(states, folder) - def load(self, filepath) -> Dict: - r""" - 断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; - driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 - 因此 save 函数和 load 函数的接受和返回值应该是对应的; - - 该函数需要在所有 rank 上执行。 + logger.debug("Save optimizer state dict") + states["optimizers_state_dict"] = optimizers_state_dict + paddle.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) - :param filepath: 保存断点重训的状态的文件名; - :return: 需要返回 save 函数输入的 states 内容; - """ - states = paddle.load(filepath) + def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: + + states = paddle.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) # 1. 加载 optimizers 的状态; optimizers_state_dict = states["optimizers_state_dict"] for i in range(len(self.optimizers)): - optimizer: paddle.optimizer.Optimizer = self.optimizers[i] + optimizer: Optimizer = self.optimizers[i] optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) + logger.debug("Load optimizer state dict.") # 2. 加载模型状态; - model = self.unwrap_model() - model.load_dict(states["model_state_dict"]) + if should_load_model: + model = self.unwrap_model() + if only_state_dict: + res = paddle.load(folder.joinpath(FASTNLP_MODEL_FILENAME)) + model.load_dict(res) + logger.debug("Load model state dict.") + else: + model.load_dict(paddle.jit.load(folder.joinpath(FASTNLP_MODEL_FILENAME)).state_dict()) + logger.debug("Load model.") + + # 3. 恢复 sampler 的状态; + dataloader_args = self.get_dataloader_args(dataloader) + sampler = dataloader_args.sampler + if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): + # 说明这里需要使用 ReproduceSampler 来弄一下了 + if self.is_distributed(): + raise RuntimeError( + "It is not allowed to use single device checkpoint retraining before but ddp now.") + sampler = ReproducibleBatchSampler( + batch_sampler=sampler, + batch_size=dataloader_args.batch_sampler.batch_size, + drop_last=dataloader_args.drop_last + ) + sampler.load_state_dict(states['sampler_states']) + + states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) + + # 4. 修改 trainer_state.batch_idx_in_epoch + # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; + if not isinstance(sampler, ReproducibleBatchSampler): + if dataloader_args.drop_last: + batch_idx_in_epoch = len( + sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size + else: + batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ + (sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size + # sampler 是 batch_sampler; + else: + batch_idx_in_epoch = sampler.batch_idx_in_epoch + + states["batch_idx_in_epoch"] = batch_idx_in_epoch - self.barrier() return states def get_evaluate_context(self): @@ -313,3 +398,53 @@ class PaddleDriver(Driver): """ if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): dataloader.batch_sampler.set_epoch(cur_epoch_idx) + + @staticmethod + def get_dataloader_args(dataloader: "DataLoader"): + """ + 获取 dataloader 的 shuffle 和 drop_last 属性; + """ + + @dataclass + class Res: + dataset: Optional[Dataset] = None + batch_sampler: Optional[BatchSampler] = None + sampler: Optional[Sampler] = None + batch_size: Optional[int] = None + shuffle: Optional[bool] = None + drop_last: Optional[bool] = None + + res = Res() + + # paddle 的 DataLoader 一定会有 dataset 属性; + res.dataset = dataloader.dataset + + if dataloader.batch_sampler is not None: + res.batch_sampler = dataloader.batch_sampler + if hasattr(dataloader.batch_sampler, "batch_size"): + res.batch_size = getattr(dataloader.batch_sampler, "batch_size") + # 用户使用的是自己的 batch_sampler 并且其没有 "batch_size" 属性; + else: + dataloader_iter = iter(dataloader) + pre_sample = next(dataloader_iter) + res.batch_size = pre_sample.shape[0] + + if hasattr(dataloader.batch_sampler, "sampler"): + res.sampler = dataloader.batch_sampler.sampler + if hasattr(dataloader.batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): + res.shuffle = True + else: + res.shuffle = False + else: + res.sampler = None + res.shuffle = False + + if hasattr(dataloader.batch_sampler, "drop_last"): + res.drop_last = getattr(dataloader.batch_sampler, "drop_last") + # 用户使用的是自己的 batch_sampler 并且其没有 "drop_last" 属性; + else: + res.drop_last = False + + return res diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 97f14bb6..75d80478 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -2,6 +2,7 @@ import os from typing import Optional, Dict, Union from .paddle_driver import PaddleDriver +from .utils import replace_batch_sampler, replace_sampler from fastNLP.envs.imports import _NEED_IMPORT_PADDLE from fastNLP.envs.env import USER_CUDA_VISIBLE_DEVICES from fastNLP.core.utils import ( @@ -10,7 +11,7 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator, re_instantiate_sampler from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -93,11 +94,8 @@ class PaddleSingleDriver(PaddleDriver): self._test_signature_fn = model.forward def setup(self): - user_visible_devices = os.environ[USER_CUDA_VISIBLE_DEVICES] device_id = get_paddle_device_id(self.model_device) - if user_visible_devices is not None and user_visible_devices != "": - # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES - device_id = user_visible_devices.split(",")[device_id] + device_id = os.environ[USER_CUDA_VISIBLE_DEVICES].split(",")[device_id] os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) paddle.device.set_device("gpu:0") self.model.to("gpu:0") @@ -145,26 +143,25 @@ class PaddleSingleDriver(PaddleDriver): assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." if isinstance(dist, ReproducibleBatchSampler): - dataloader.batch_sampler = dist - return dataloader - if isinstance(dist, ReproducibleIterator): - dataloader.batch_sampler.sampler = dist - return dataloader + return replace_batch_sampler(dataloader, dist) + elif isinstance(dist, ReproducibleIterator): + return replace_sampler(dataloader, dist) if reproducible: - if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): - return dataloader + args = self.get_dataloader_args(dataloader) + if isinstance(args.sampler, ReproducibleIterator): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): - return dataloader + batch_sampler = re_instantiate_sampler(dataloader.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) else: - # TODO batch_sampler = ReproducibleBatchSampler( - batch_sampler=dataloader.batch_sampler, - batch_size=dataloader.batch_sampler.batch_size, - drop_last=dataloader.drop_last + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last ) - dataloader.batch_sampler = batch_sampler - return dataloader + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index ebe0f6c5..a8121879 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -9,7 +9,7 @@ from enum import IntEnum from typing import Dict, Optional, Union from fastNLP.envs.imports import _NEED_IMPORT_PADDLE -from fastNLP.core.utils import get_paddle_device_id, auto_param_call +from fastNLP.core.utils import get_paddle_device_id, auto_param_call, paddle_to from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES from fastNLP.core.log import logger @@ -272,11 +272,9 @@ def get_device_from_visible(device: Union[str, int]): else: # 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 user_visible_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) - if user_visible_devices is not None and user_visible_devices != "": - # 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES - idx = user_visible_devices.split(",")[idx] - else: - idx = str(idx) + if user_visible_devices is None: + raise RuntimeError("This situation cannot happen, please report a bug to us.") + idx = user_visible_devices.split(",")[idx] cuda_visible_devices_list = cuda_visible_devices.split(',') assert idx in cuda_visible_devices_list, "Can't find "\ @@ -285,31 +283,44 @@ def get_device_from_visible(device: Union[str, int]): res = cuda_visible_devices_list.index(idx) return res -def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): - # 拿到实例属性; +def replace_batch_sampler(dataloader: "DataLoader", batch_sampler: "BatchSampler"): + """ + 利用 `batch_sampler` 重新构建一个 DataLoader,起到替换 `batch_sampler` 又不影响原 `dataloader` 的作用。 + 考虑了用户自己定制了 DataLoader 的情形。 + """ + # 拿到非下划线开头的实例属性; instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} - # 拿到 dataloader '__init__' 函数的默认函数签名; + # 拿到 dataloader '__init__' 函数的默认函数签名;可以获取参数名和参数的默认值以及类型 init_params = dict(inspect.signature(dataloader.__init__).parameters) # 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 # 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 # 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader - # 中寻找; + # 中寻找;VAR_KEYWORD 代表 **kwargs has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) if has_variadic_kwargs: init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) del init_params["self"] # 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; + # 将同时在实例名和参数名中出现且不是默认值的参数收集起来 non_default_params = {name for name, p in init_params.items() if name in instance_attrs and p.default != instance_attrs[name]} # add `dataset` as it might have been replaced with `*args` non_default_params.add("dataset") + # 收集不是默认值的参数和它的值 reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1}) - + # persistent_workers 在类中的对应成员带有下划线,因此添加进来 + reconstruct_args.update({ + "batch_sampler": batch_sampler, "shuffle": False, "drop_last": False, "batch_size": 1, + "persistent_workers": dataloader._persistent_workers, + }) + + # POSITIONAL_OR_KEYWORD 代表一般的参数 + # 收集初始化函数中出现的、一般形式的、不带默认值且不在 reconstruct_args 中的参数 + # 也即它们没有在初始化函数和实例成员中同时出现 required_args = { p.name for p in init_params.values() @@ -323,12 +334,9 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): required_args = sorted(required_args) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " + f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " f"The missing attributes are {required_args}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " - "manually add the `DistributedBatchSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." ) # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; @@ -340,12 +348,33 @@ def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): missing_kwargs = sorted(missing_kwargs) dataloader_self_name = dataloader.__class__.__name__ raise Exception( - f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " + f"Trying to inject `BatchSampler` into the `{dataloader_self_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " f"The missing arguments are {missing_kwargs}. " - f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " - "manually add the `DistributedBatchSampler` as: " - f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." ) return type(dataloader)(**reconstruct_args) + +def replace_sampler(dataloader, new_sampler): + """ + 使用 `new_sampler` 重新构建一个 BatchSampler,并替换到 `dataloader` 中 + """ + new_batch_sampler = BatchSampler( + dataset=dataloader.batch_sampler.dataset, + sampler=new_sampler, + shuffle=isinstance(dataloader.batch_sampler.sampler, paddle.io.RandomSampler), + batch_size=dataloader.batch_sampler.batch_size, + drop_last=dataloader.batch_sampler.drop_last + ) + return replace_batch_sampler(dataloader, new_batch_sampler) + +def optimizer_state_to_device(state, device): + new_state = {} + for name, param in state.items(): + if isinstance(param, dict): + new_state[name] = optimizer_state_to_device(param, device) + elif isinstance(param, paddle.Tensor): + new_state[name] = paddle_to(param, device).clone() + else: + new_state[name] = param + return new_state