Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
MorningForest 2 years ago
parent
commit
36149a57b0
10 changed files with 251 additions and 102 deletions
  1. +2
    -2
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +32
    -7
      fastNLP/core/drivers/torch_driver/ddp.py
  3. +2
    -0
      fastNLP/core/drivers/torch_driver/single_device.py
  4. +26
    -0
      fastNLP/core/drivers/torch_driver/torch_driver.py
  5. +10
    -5
      fastNLP/core/samplers/__init__.py
  6. +0
    -1
      fastNLP/core/samplers/mix_sampler.py
  7. +0
    -1
      fastNLP/core/samplers/reproducible_sampler.py
  8. +1
    -86
      fastNLP/core/samplers/sampler.py
  9. +114
    -0
      fastNLP/core/samplers/unrepeated_sampler.py
  10. +64
    -0
      tests/core/samplers/test_unrepeated_sampler.py

+ 2
- 2
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -19,7 +19,7 @@ from fastNLP.core.utils import (
paddle_move_data_to_device,
is_in_paddle_dist,
)
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger

@@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver):
return dataloader
# evaluator
elif dist == "unrepeatdist":
sampler = UnrepeatedDistributedSampler(
sampler = UnrepeatedSampler(
dataset=dataloader.dataset,
shuffle=shuffle,
seed=int(os.environ.get("FASTNLP_SEED", 0))


+ 32
- 7
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -23,11 +23,12 @@ from fastNLP.core.drivers.torch_driver.utils import (
ForwardState,
_MODE_PARAMETER,
reset_seed,
replace_sampler
replace_sampler,
replace_batch_sampler
)
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
@@ -445,11 +446,25 @@ class TorchDDPDriver(TorchDriver):
# return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST})
return self._test_step(batch)

def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator]],
reproducible: bool = False, sampler_or_batch_sampler=None):
def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleIterator, ReproducibleBatchSampler]]=None,
reproducible: bool = False):
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, dist)

if isinstance(dist, ReproducibleIterator):
# 注意这里不需要调用 dist_sampler.set_distributed;因为如果用户使用的是 TorchDDPDriver,那么其在 Trainer 初始化的时候就已经调用了该函数;
dist = re_instantiate_sampler(dist)
dist.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_sampler(dataloader, dist)

# trainer, evaluator
@@ -463,7 +478,15 @@ class TorchDDPDriver(TorchDriver):
elif dist == "dist":
args = self.get_dataloader_args(dataloader)
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为;
if isinstance(args.sampler, ReproducibleIterator):
if isinstance(args.batch_sampler, ReproducibleBatchSampler):
batch_sampler = re_instantiate_sampler(args.batch_sampler)
batch_sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
pad=True
)
return replace_batch_sampler(dataloader, batch_sampler)
elif isinstance(args.sampler, ReproducibleIterator):
sampler = re_instantiate_sampler(args.sampler)
sampler.set_distributed(
num_replicas=self.world_size,
@@ -477,7 +500,6 @@ class TorchDDPDriver(TorchDriver):
shuffle=args.shuffle,
seed=int(os.environ.get(FASTNLP_GLOBAL_SEED, 0))
)
# todo 这个你写个todo吧,有两个角度;第一个是dataloader即使检测到sampler是我们reproducible,也不能直接set_distributeds; 第二个如果是单卡的,也需要替换sampler乃至切换sampler的状态,方式之前多卡,现在切换成单卡运行
sampler.set_distributed(
num_replicas=self.world_size,
rank=self.global_rank,
@@ -487,8 +509,11 @@ class TorchDDPDriver(TorchDriver):

# evaluator
elif dist == "unrepeatdist":
# todo @yh,补充 unrepeatdist 相关内容;
args = self.get_dataloader_args(dataloader)
sampler = UnrepeatedDistributedSampler(

# todo 判断 batch_sampler;
sampler = UnrepeatedSampler(
dataset=args.dataset,
shuffle=args.shuffle,
)


+ 2
- 0
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -133,8 +133,10 @@ class TorchSingleDriver(TorchDriver):
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
reproducible: bool = False):
if isinstance(dist, ReproducibleBatchSampler):
dist = re_instantiate_sampler(dist)
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):
dist = re_instantiate_sampler(dist)
return replace_sampler(dataloader, dist)

if reproducible:


+ 26
- 0
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -244,8 +244,34 @@ class TorchDriver(Driver):
logger.debug("Load model.")

# 3. 恢复 sampler 的状态;
"""
使用场景:
现在sampler/batch_sampler的替换情况:
1. 单卡多卡;
2. 是否断点重训;
3. 用户通过 dist 传入;
4. 用户自己直接在外面替换dataloader的sampler或者 batchsampler;
应当确定的规则:
batchsampler 优先级高于 sampler;
单卡:
不是断点重训:
用户自己
用户不自己在外面直接替换 sampler 或者 batchsampler
1. 单卡:
"""
dataloader_args = self.get_dataloader_args(dataloader)

# todo 先捋一下;
# batch_sampler = dataloader_args.batch_sampler
# if not (hasattr(batch_sampler, 'load_state_dict') and callable(batch_sampler.load_state_dict)):

sampler = dataloader_args.sampler
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)):
# 说明这里需要使用 ReproduceSampler 来弄一下了


+ 10
- 5
fastNLP/core/samplers/__init__.py View File

@@ -3,19 +3,24 @@ __all__ = [
'SortedSampler',
'ConstTokenNumSampler',
'ConstantTokenNumSampler',
'UnrepeatedDistributedSampler',
'MixSampler',
'InnerSampler',
'DopedSampler',
'MixSequentialSampler',
'PollingSampler',

'ReproducibleIterator',
'RandomSampler',
're_instantiate_sampler'

're_instantiate_sampler',

'UnrepeatedSampler',
"UnrepeatedSortedSampler"
]

from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler


+ 0
- 1
fastNLP/core/samplers/mix_sampler.py View File

@@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict

__all__ = [
'MixSampler',
'InnerSampler',
'DopedSampler',
'MixSequentialSampler',
'PollingSampler'


+ 0
- 1
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -16,7 +16,6 @@ def re_instantiate_sampler(sampler):
return type(sampler)(**all_attributes)



class ReproducibleIterator:
"""
注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler


+ 1
- 86
fastNLP/core/samplers/sampler.py View File

@@ -7,7 +7,6 @@ __all__ = [
"SortedSampler",
'ConstTokenNumSampler',
"ConstantTokenNumSampler",
"UnrepeatedDistributedSampler",
]

from itertools import chain
@@ -18,7 +17,7 @@ import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
from torch.utils.data import SequentialSampler, Sampler, RandomSampler
from torch.utils.data import Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Sampler

@@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets):
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
bucket_data[bucket_id].append(idx)
return bucket_data


class UnrepeatedDistributedSampler:
def __init__(self, dataset, shuffle: bool = False, seed: int = 0):
"""
考虑在多卡evaluate的场景下,不能重复sample。

:param dataset:
:param shuffle:
:param seed:
"""
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed

# 多卡的相关的参数
self.num_replicas = 1
self.rank = 0
self.epoch = -1

def __len__(self):
"""
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
:return:
"""
num_common = len(self.dataset)//self.num_replicas
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
return self.num_samples

def __iter__(self):
r"""
当前使用num_consumed_samples做法会在交替使用的时候遇到问题;
Example:
>>> sampler = RandomSampler()
>>> iter1 = iter(sampler)
>>> iter2 = iter(sampler)
>>> next(iter1)
>>> next(iter2) # 当前num_consumed_samples的数量会发生变化
"""

indices = self.generate_indices()

# subsample
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == len(self)

for index in indices:
yield index

def generate_indices(self) -> List[int]:
"""
生成随机序列

:return:
"""
if self.shuffle:
indices = list(range(len(self.dataset)))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(len(self.dataset)))
return indices

def set_epoch(self, epoch: int) -> None:
self.epoch = epoch

def set_distributed(self, num_replicas, rank):
"""
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;

:param num_replicas:
:param rank:
:return:
"""
assert num_replicas>0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0<=rank<num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank

return self

+ 114
- 0
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -0,0 +1,114 @@
__all__ = [
'UnrepeatedSortedSampler',
'UnrepeatedSampler'
]

from typing import List, Union
from fastNLP.core.dataset import DataSet

import numpy as np


class UnrepeatedSampler:
def __init__(self, dataset, shuffle: bool = False, seed: int = 0, **kwargs):
"""
考虑在多卡evaluate的场景下,不能重复sample。

:param dataset:
:param shuffle:
:param seed:
"""
self.dataset = dataset
self.shuffle = shuffle
self.seed = seed

# 多卡的相关的参数
self.num_replicas = kwargs.get('num_replicas', 1)
self.rank = kwargs.get('rank', 0)
self.epoch = kwargs.get('epoch', -1)

def __len__(self):
"""
返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
:return:
"""
num_common = len(self.dataset)//self.num_replicas
self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
return self.num_samples

def __iter__(self):
indices = self.generate_indices()

# subsample
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == len(self)

for index in indices:
yield index

def generate_indices(self) -> List[int]:
"""
生成随机序列

:return:
"""
if self.shuffle:
indices = list(range(len(self.dataset)))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(len(self.dataset)))
return indices

def set_epoch(self, epoch: int) -> None:
self.epoch = epoch

def set_distributed(self, num_replicas, rank):
"""
该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;

:param num_replicas:
:param rank:
:return:
"""
assert num_replicas>0 and isinstance(num_replicas, int)
assert isinstance(rank, int) and 0<=rank<num_replicas
# 注意初始化该函数时,所有的状态都应当默认是一个 epoch 刚开始训练的状态;
self.num_replicas = num_replicas
self.rank = rank

return self


class UnrepeatedSortedSampler(UnrepeatedSampler):
def __init__(self, dataset, length:Union[str, List], seed: int = 0):
"""
将 dataset 中的数据根据 length 从长到短进行迭代,并且保证在多卡场景下数据不重复。本 sampler 可能导致各个机器上的
batch 数量不完全一致。

:param dataset: 实现了 __len__ 方法的数据容器。
:param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的
DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。
:param shuffle: 如果为 True,将不进行 shuffle,实际上数据会以从长到短的方式输出。
:param seed: 设置的随机数种子
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, shuffle=False, seed=seed)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if not isinstance(length[0], int):
length = list(map(len, length))
else:
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"

assert len(length) == len(dataset), "The length of `data` and `length` should be equal."

self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的

def generate_indices(self) -> List[int]:
return self.sorted_indices

+ 64
- 0
tests/core/samplers/test_unrepeated_sampler.py View File

@@ -0,0 +1,64 @@
from itertools import chain

import pytest

from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler


class DatasetWithVaryLength:
def __init__(self, num_of_data=100):
self.data = list(range(num_of_data))

def __getitem__(self, item):
return self.data[item]

def __len__(self):
return len(self.data)


class TestUnrepeatedSampler:
@pytest.mark.parametrize('shuffle', [True, False])
def test_single(self, shuffle):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSampler(data, shuffle)
indexes = set(sampler)
assert indexes==set(range(num_of_data))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replica, num_of_data, shuffle):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle)
sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler)

indexes = set(chain(*samplers))
assert indexes==set(range(num_of_data))


class TestUnrepeatedSortedSampler:
@pytest.mark.parametrize('shuffle', [True, False])
def test_single(self, shuffle):
num_of_data = 100
data = DatasetWithVaryLength(num_of_data)
sampler = UnrepeatedSortedSampler(data, length=data.data)
indexes = list(sampler)
assert indexes==list(range(num_of_data-1, -1, -1))

@pytest.mark.parametrize('num_replica', [2, 3])
@pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
@pytest.mark.parametrize('shuffle', [False, True])
def test_multi(self, num_replica, num_of_data, shuffle):
data = DatasetWithVaryLength(num_of_data=num_of_data)
samplers = []
for i in range(num_replica):
sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
sampler.set_distributed(num_replica, rank=i)
samplers.append(sampler)

indexes = set(chain(*samplers))
assert indexes==set(range(num_of_data))

Loading…
Cancel
Save