diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 710bf6be..d9d66970 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -49,6 +49,9 @@ class Driver(ABC): 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; + 注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; + 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; + :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 可以可以加载。 :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 7fe0bcee..9e5e16fd 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -448,31 +448,26 @@ class TorchDDPDriver(TorchDriver): def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, reproducible: bool = False): + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; + # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; if isinstance(dist, ReproducibleBatchSampler): - dist = re_instantiate_sampler(dist) - dist.set_distributed( - num_replicas=self.world_size, - rank=self.global_rank, - pad=True - ) return replace_batch_sampler(dataloader, dist) - if isinstance(dist, ReproducibleIterator): - # 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; - dist = re_instantiate_sampler(dist) - dist.set_distributed( - num_replicas=self.world_size, - rank=self.global_rank, - pad=True - ) return replace_sampler(dataloader, dist) + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; # trainer, evaluator if dist is None: if reproducible: raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " "control.") else: + if isinstance(dist, ReproducibleBatchSampler): + dist = re_instantiate_sampler(dist) + return replace_batch_sampler(dataloader, dist) + if isinstance(dist, ReproducibleIterator): + dist = re_instantiate_sampler(dist) + return replace_sampler(dataloader, dist) return dataloader # trainer elif dist == "dist": @@ -506,7 +501,6 @@ class TorchDDPDriver(TorchDriver): pad=True ) return replace_sampler(dataloader, sampler) - # evaluator elif dist == "unrepeatdist": # todo @yh,补充 unrepeatdist 相关内容; diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 3375d557..14a135ee 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -132,25 +132,29 @@ class TorchSingleDriver(TorchDriver): def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, reproducible: bool = False): + + # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; if isinstance(dist, ReproducibleBatchSampler): - dist = re_instantiate_sampler(dist) return replace_batch_sampler(dataloader, dist) elif isinstance(dist, ReproducibleIterator): - dist = re_instantiate_sampler(dist) return replace_sampler(dataloader, dist) + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, ReproducibleBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleIterator): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) + if reproducible: - args = self.get_dataloader_args(dataloader) - if isinstance(args.sampler, ReproducibleIterator): - sampler = re_instantiate_sampler(args.sampler) - return replace_sampler(dataloader, sampler) - else: - batch_sampler = ReproducibleBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last - ) - return replace_batch_sampler(dataloader, batch_sampler) + batch_sampler = ReproducibleBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 369e4432..ce1bff14 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator class TorchDriver(Driver): @@ -244,47 +244,21 @@ class TorchDriver(Driver): logger.debug("Load model.") # 3. 恢复 sampler 的状态; - """ - 使用场景: - - 现在sampler/batch_sampler的替换情况: - 1. 单卡多卡; - 2. 是否断点重训; - - 3. 用户通过 dist 传入; - 4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler; - - 应当确定的规则: - batchsampler 优先级高于 sampler; - - 单卡: - 不是断点重训: - 用户自己 - - - 用户不自己在外面直接替换 sampler 或者 batchsampler - 1. 单卡: - - """ dataloader_args = self.get_dataloader_args(dataloader) - - # todo 先捋一下; - # batch_sampler = dataloader_args.batch_sampler - # if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)): - - sampler = dataloader_args.sampler - if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): - # 说明这里需要使用 ReproduceSampler 来弄一下了 - if self.is_distributed(): - raise RuntimeError( - "It is not allowed to use single device checkpoint retraining before but ddp now.") + if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + sampler = dataloader_args.batch_sampler + elif isinstance(dataloader_args.sampler, ReproducibleIterator): + sampler = dataloader_args.sampler + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " + "`ReproducibleBatchSampler` or `ReproducibleIterator`.") + else: sampler = ReproducibleBatchSampler( - batch_sampler=sampler, + batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) sampler.load_state_dict(states['sampler_states']) - states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch