Browse Source

修改了断点重新sampler中的部分逻辑

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
8ca17fc9ed
4 changed files with 59 additions and 7 deletions
  1. +31
    -6
      fastNLP/core/drivers/torch_driver/ddp.py
  2. +2
    -0
      fastNLP/core/drivers/torch_driver/single_device.py
  3. +26
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py
  4. +0
    -1
      fastNLP/core/samplers/reproducible_sampler.py

+ 31
- 6
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -23,11 +23,12 @@ from fastNLP.core.drivers.torch_driver.utils import (
ForwardState,
_MODE_PARAMETER,
reset_seed,
replace_sampler
replace_sampler,
replace_batch_sampler
)
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler, ReproducibleBatchSampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
@@ -445,11 +446,25 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None,
reproducible: bool = False):
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, dist)

if isinstance(dist, ReproducibleIterator):
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
dist = re_instantiate_sampler(dist)
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, dist)

# trainer, evaluator
@@ -463,7 +478,15 @@ class TorchDDPDriver(TorchDriver):
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(args.sampler, ReproducibleIterator):
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
batch_sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleIterator):
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed(
num_replicas=self.world_size,
@@ -477,7 +500,6 @@ class TorchDDPDriver(TorchDriver):
shuffle=args.shuffle,
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
)
# todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
@@ -487,7 +509,10 @@ class TorchDDPDriver(TorchDriver):

# evaluator
elif dist == "unrepeatdist":
# todo @yh,补充 unrepeatdist 相关内容;
args = self.get_dataloader_args(dataloader)

# todo 判断 batch_sampler;
sampler = UnrepeatedDistributedSampler(
dataset=args.dataset,
shuffle=args.shuffle,


+ 2
- 0
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -133,8 +133,10 @@ class TorchSingleDriver(TorchDriver):
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
reproducible: bool = False):
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)

if reproducible:


+ 26
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -246,8 +246,34 @@ class TorchDriver(Driver):
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
"""
使用场景:
现在sampler/batch_sampler的替换情况:
1. 单卡多卡;
2. 是否断点重训;
3. 用户通过 dist 传入;
4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler;
应当确定的规则:
batchsampler 优先级高于 sampler;
单卡:
不是断点重训:
用户自己
用户不自己在外面直接替换 sampler 或者 batchsampler
1. 单卡:
"""
dataloader_args = self.get_dataloader_args(dataloader)

# todo 先捋一下;
# batch_sampler = dataloader_args.batch_sampler
# if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)):

sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了


+ 0
- 1
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -16,7 +16,6 @@ def re_instantiate_sampler(sampler):
return type(sampler)(**all_attributes)



class ReproducibleIterator:
"""
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler


Loading…
Cancel
Save