@@ -49,6 +49,9 @@ class Driver(ABC): | |||||
不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 | ||||
数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; | ||||
否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; | ||||
注意当 dist 为 ReproducibleIterator, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; | |||||
当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; | |||||
:param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 | ||||
可以可以加载。 | 可以可以加载。 | ||||
:return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, | ||||
@@ -448,31 +448,26 @@ class TorchDDPDriver(TorchDriver): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, | def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, | ||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dist = re_instantiate_sampler(dist) | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
if isinstance(dist, ReproducibleIterator): | if isinstance(dist, ReproducibleIterator): | ||||
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; | |||||
dist = re_instantiate_sampler(dist) | |||||
dist.set_distributed( | |||||
num_replicas=self.world_size, | |||||
rank=self.global_rank, | |||||
pad=True | |||||
) | |||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
# trainer, evaluator | # trainer, evaluator | ||||
if dist is None: | if dist is None: | ||||
if reproducible: | if reproducible: | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " | ||||
"control.") | "control.") | ||||
else: | else: | ||||
if isinstance(dist, ReproducibleBatchSampler): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_batch_sampler(dataloader, dist) | |||||
if isinstance(dist, ReproducibleIterator): | |||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | |||||
return dataloader | return dataloader | ||||
# trainer | # trainer | ||||
elif dist == "dist": | elif dist == "dist": | ||||
@@ -506,7 +501,6 @@ class TorchDDPDriver(TorchDriver): | |||||
pad=True | pad=True | ||||
) | ) | ||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
# evaluator | # evaluator | ||||
elif dist == "unrepeatdist": | elif dist == "unrepeatdist": | ||||
# todo @yh,补充 unrepeatdist 相关内容; | # todo @yh,补充 unrepeatdist 相关内容; | ||||
@@ -132,25 +132,29 @@ class TorchSingleDriver(TorchDriver): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | ||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
dist = re_instantiate_sampler(dist) | |||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleIterator): | elif isinstance(dist, ReproducibleIterator): | ||||
dist = re_instantiate_sampler(dist) | |||||
return replace_sampler(dataloader, dist) | return replace_sampler(dataloader, dist) | ||||
# 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; | |||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.batch_sampler, ReproducibleBatchSampler): | |||||
batch_sampler = re_instantiate_sampler(args.batch_sampler) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
elif isinstance(args.sampler, ReproducibleIterator): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
if reproducible: | if reproducible: | ||||
args = self.get_dataloader_args(dataloader) | |||||
if isinstance(args.sampler, ReproducibleIterator): | |||||
sampler = re_instantiate_sampler(args.sampler) | |||||
return replace_sampler(dataloader, sampler) | |||||
else: | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -244,47 +244,21 @@ class TorchDriver(Driver): | |||||
logger.debug("Load model.") | logger.debug("Load model.") | ||||
# 3. 恢复 sampler 的状态; | # 3. 恢复 sampler 的状态; | ||||
""" | |||||
使用场景: | |||||
现在sampler/batch_sampler的替换情况: | |||||
1. 单卡多卡; | |||||
2. 是否断点重训; | |||||
3. 用户通过 dist 传入; | |||||
4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler; | |||||
应当确定的规则: | |||||
batchsampler 优先级高于 sampler; | |||||
单卡: | |||||
不是断点重训: | |||||
用户自己 | |||||
用户不自己在外面直接替换 sampler 或者 batchsampler | |||||
1. 单卡: | |||||
""" | |||||
dataloader_args = self.get_dataloader_args(dataloader) | dataloader_args = self.get_dataloader_args(dataloader) | ||||
# todo 先捋一下; | |||||
# batch_sampler = dataloader_args.batch_sampler | |||||
# if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)): | |||||
sampler = dataloader_args.sampler | |||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||||
if self.is_distributed(): | |||||
raise RuntimeError( | |||||
"It is not allowed to use single device checkpoint retraining before but ddp now.") | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif isinstance(dataloader_args.sampler, ReproducibleIterator): | |||||
sampler = dataloader_args.sampler | |||||
elif self.is_distributed(): | |||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " | |||||
"`ReproducibleBatchSampler` or `ReproducibleIterator`.") | |||||
else: | |||||
sampler = ReproducibleBatchSampler( | sampler = ReproducibleBatchSampler( | ||||
batch_sampler=sampler, | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
) | ) | ||||
sampler.load_state_dict(states['sampler_states']) | sampler.load_state_dict(states['sampler_states']) | ||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | ||||
# 4. 修改 trainer_state.batch_idx_in_epoch | # 4. 修改 trainer_state.batch_idx_in_epoch | ||||