|
|
@@ -23,11 +23,12 @@ from fastNLP.core.drivers.torch_driver.utils import ( |
|
|
|
ForwardState, |
|
|
|
_MODE_PARAMETER, |
|
|
|
reset_seed, |
|
|
|
replace_sampler |
|
|
|
replace_sampler, |
|
|
|
replace_batch_sampler |
|
|
|
) |
|
|
|
from fastNLP.core.drivers.utils import distributed_open_proc |
|
|
|
from fastNLP.core.utils import auto_param_call, check_user_specific_params |
|
|
|
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler |
|
|
|
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler, ReproducibleBatchSampler |
|
|
|
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED |
|
|
|
from fastNLP.core.log import logger |
|
|
|
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object |
|
|
@@ -445,11 +446,25 @@ class TorchDDPDriver(TorchDriver): |
|
|
|
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) |
|
|
|
return self._test_step(batch) |
|
|
|
|
|
|
|
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]], |
|
|
|
reproducible: bool = False, sampler_or_batch_sampler=None): |
|
|
|
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None, |
|
|
|
reproducible: bool = False): |
|
|
|
if isinstance(dist, ReproducibleBatchSampler): |
|
|
|
dist = re_instantiate_sampler(dist) |
|
|
|
dist.set_distributed( |
|
|
|
num_replicas=self.world_size, |
|
|
|
rank=self.global_rank, |
|
|
|
pad=True |
|
|
|
) |
|
|
|
return replace_batch_sampler(dataloader, dist) |
|
|
|
|
|
|
|
if isinstance(dist, ReproducibleIterator): |
|
|
|
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数; |
|
|
|
dist = re_instantiate_sampler(dist) |
|
|
|
dist.set_distributed( |
|
|
|
num_replicas=self.world_size, |
|
|
|
rank=self.global_rank, |
|
|
|
pad=True |
|
|
|
) |
|
|
|
return replace_sampler(dataloader, dist) |
|
|
|
|
|
|
|
# trainer, evaluator |
|
|
@@ -463,7 +478,15 @@ class TorchDDPDriver(TorchDriver): |
|
|
|
elif dist == "dist": |
|
|
|
args = self.get_dataloader_args(dataloader) |
|
|
|
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; |
|
|
|
if isinstance(args.sampler, ReproducibleIterator): |
|
|
|
if isinstance(args.batch_sampler, ReproducibleBatchSampler): |
|
|
|
batch_sampler = re_instantiate_sampler(args.batch_sampler) |
|
|
|
batch_sampler.set_distributed( |
|
|
|
num_replicas=self.world_size, |
|
|
|
rank=self.global_rank, |
|
|
|
pad=True |
|
|
|
) |
|
|
|
return replace_batch_sampler(dataloader, batch_sampler) |
|
|
|
elif isinstance(args.sampler, ReproducibleIterator): |
|
|
|
sampler = re_instantiate_sampler(args.sampler) |
|
|
|
sampler.set_distributed( |
|
|
|
num_replicas=self.world_size, |
|
|
@@ -477,7 +500,6 @@ class TorchDDPDriver(TorchDriver): |
|
|
|
shuffle=args.shuffle, |
|
|
|
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0)) |
|
|
|
) |
|
|
|
# todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行 |
|
|
|
sampler.set_distributed( |
|
|
|
num_replicas=self.world_size, |
|
|
|
rank=self.global_rank, |
|
|
@@ -487,7 +509,10 @@ class TorchDDPDriver(TorchDriver): |
|
|
|
|
|
|
|
# evaluator |
|
|
|
elif dist == "unrepeatdist": |
|
|
|
# todo @yh,补充 unrepeatdist 相关内容; |
|
|
|
args = self.get_dataloader_args(dataloader) |
|
|
|
|
|
|
|
# todo 判断 batch_sampler; |
|
|
|
sampler = UnrepeatedDistributedSampler( |
|
|
|
dataset=args.dataset, |
|
|
|
shuffle=args.shuffle, |
|
|
|