From 8ca17fc9ed2cfd21a41180d91bd7350366ad52e8 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Mon, 11 Apr 2022 15:45:06 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=E6=96=AD=E7=82=B9?= =?UTF-8?q?=E9=87=8D=E6=96=B0sampler=E4=B8=AD=E7=9A=84=E9=83=A8=E5=88=86?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/torch_driver/ddp.py | 37 ++++++++++++++++--- .../drivers/torch_driver/single_device.py | 2 + .../core/drivers/torch_driver/torch_driver.py | 26 +++++++++++++ fastNLP/core/samplers/reproducible_sampler.py | 1 - 4 files changed, 59 insertions(+), 7 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 9b3325d8..3be40279 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -23,11 +23,12 @@ from fastNLP.core.drivers.torch_driver.utils import ( ForwardState, _MODE_PARAMETER, reset_seed, - replace_sampler + replace_sampler, + replace_batch_sampler ) from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.utils import auto_param_call, check_user_specific_params -from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler +from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler, ReproducibleBatchSampler from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED from fastNLP.core.log import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object @@ -445,11 +446,25 @@ class TorchDDPDriver(TorchDriver): # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], - reproducible: bool = False, sampler_or_batch_sampler=None): + def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, + reproducible: bool = False): + if isinstance(dist, ReproducibleBatchSampler): + dist = re_instantiate_sampler(dist) + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleIterator): # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; dist = re_instantiate_sampler(dist) + dist.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) return replace_sampler(dataloader, dist) # trainer, evaluator @@ -463,7 +478,15 @@ class TorchDDPDriver(TorchDriver): elif dist == "dist": args = self.get_dataloader_args(dataloader) # 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; - if isinstance(args.sampler, ReproducibleIterator): + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + batch_sampler.set_distributed( + num_replicas=self.world_size, + rank=self.global_rank, + pad=True + ) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleIterator): sampler = re_instantiate_sampler(args.sampler) sampler.set_distributed( num_replicas=self.world_size, @@ -477,7 +500,6 @@ class TorchDDPDriver(TorchDriver): shuffle=args.shuffle, seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) ) - # todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行 sampler.set_distributed( num_replicas=self.world_size, rank=self.global_rank, @@ -487,7 +509,10 @@ class TorchDDPDriver(TorchDriver): # evaluator elif dist == "unrepeatdist": + # todo @yh,补充 unrepeatdist 相关内容; args = self.get_dataloader_args(dataloader) + + # todo 判断 batch_sampler; sampler = UnrepeatedDistributedSampler( dataset=args.dataset, shuffle=args.shuffle, diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 952712be..3375d557 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -133,8 +133,10 @@ class TorchSingleDriver(TorchDriver): def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, reproducible: bool = False): if isinstance(dist, ReproducibleBatchSampler): + dist = re_instantiate_sampler(dist) return replace_batch_sampler(dataloader, dist) elif isinstance(dist, ReproducibleIterator): + dist = re_instantiate_sampler(dist) return replace_sampler(dataloader, dist) if reproducible: diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 96d11761..66c93d4d 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -246,8 +246,34 @@ class TorchDriver(Driver): logger.debug("Load model.") # 3. 恢复 sampler 的状态; + """ + 使用场景: + + 现在sampler/batch_sampler的替换情况: + 1. 单卡多卡; + 2. 是否断点重训; + + 3. 用户通过 dist 传入; + 4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler; + + 应当确定的规则: + batchsampler 优先级高于 sampler; + + 单卡: + 不是断点重训: + 用户自己 + + + 用户不自己在外面直接替换 sampler 或者 batchsampler + 1. 单卡: + + """ dataloader_args = self.get_dataloader_args(dataloader) + # todo 先捋一下; + # batch_sampler = dataloader_args.batch_sampler + # if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)): + sampler = dataloader_args.sampler if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): # 说明这里需要使用 ReproduceSampler 来弄一下了 diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 0a4ac7bf..6d2c8246 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -16,7 +16,6 @@ def re_instantiate_sampler(sampler): return type(sampler)(**all_attributes) - class ReproducibleIterator: """ 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler