Browse Source

增加BucketedBatchSampler的测试

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
12d3e08568
6 changed files with 599 additions and 247 deletions
  1. +2
    -2
      fastNLP/core/drivers/torch_driver/single_device.py
  2. +1
    -1
      fastNLP/core/samplers/__init__.py
  3. +152
    -80
      fastNLP/core/samplers/reproducible_batch_sampler.py
  4. +5
    -18
      fastNLP/core/samplers/reproducible_sampler.py
  5. +439
    -0
      tests/core/samplers/test_reproducible_batch_sampler.py
  6. +0
    -146
      tests/core/samplers/test_reproducible_sampler.py

+ 2
- 2
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -130,8 +130,8 @@ class TorchSingleDriver(TorchDriver):
else:
return self._test_step(batch)

def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator],
reproducible: bool = False, sampler_or_batch_sampler=None):
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None,
reproducible: bool = False):
if isinstance(dist, ReproducibleBatchSampler):
return replace_batch_sampler(dataloader, dist)
elif isinstance(dist, ReproducibleIterator):


+ 1
- 1
fastNLP/core/samplers/__init__.py View File

@@ -17,5 +17,5 @@ __all__ = [
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler
from .reproducible_batch_sampler import ReproducibleBatchSampler
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler


+ 152
- 80
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -1,20 +1,48 @@
__all__ = [
'BucketedBatchSampler',
"ReproducibleBatchSampler"
]

import math
from array import array
from copy import deepcopy
from itertools import chain
from typing import Dict, Union, List
from itertools import chain

import numpy as np

from fastNLP.core.dataset import DataSet
from fastNLP.core.samplers import ReproducibleIterator
from fastNLP.core.log import logger
from abc import abstractmethod


class ReproducibleBatchIterator:
@abstractmethod
def set_distributed(self, num_replicas, rank, pad=True):
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.")

@abstractmethod
def __len__(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `__len__` method.")

@abstractmethod
def __iter__(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `__iter__` method.")

@abstractmethod
def state_dict(self):
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.")

@abstractmethod
def load_state_dict(self, states):
raise NotImplementedError("Each specific batch_sampler should implement its own `load_state_dict` method.")

@abstractmethod
def set_epoch(self, epoch):
pass


class ReproducibleBatchSampler:
class ReproducibleBatchSampler(ReproducibleBatchIterator):
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿;
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs):
"""
@@ -94,7 +122,7 @@ class ReproducibleBatchSampler:
self.data_idx = states["data_idx"]
self.need_reinitialize = False

def set_distributed(self):
def set_distributed(self, num_replicas, rank, pad=True):
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.")

def set_epoch(self, epoch):
@@ -110,7 +138,7 @@ class ReproducibleBatchSampler:
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size


class BucketedBatchSampler(ReproducibleIterator):
class BucketedBatchSampler(ReproducibleBatchIterator):
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10,
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs):
"""
@@ -129,20 +157,20 @@ class BucketedBatchSampler(ReproducibleIterator):
:param kwargs: fastNLP 保留使用
"""
super().__init__()
if not isinstance(dataset, DataSet):
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if not isinstance(length[0], int):
length = list(map(len, length))
else:
assert isinstance(length, List) and len(length)==len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \
"the length parameter can only be List[int]"

if drop_last:
assert len(dataset)>=batch_size, "The number of samplers must be larger than batch_size when `drop_last=True`."
assert len(length) == len(dataset), "The length of `data` and `length` should be equal."

self.dataset = dataset
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的


self.batch_size = batch_size
self.num_batch_per_bucket = num_batch_per_bucket
@@ -161,6 +189,10 @@ class BucketedBatchSampler(ReproducibleIterator):
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict()
self.during_iter = kwargs.get("during_iter", False)

# 以下变量为内部使用恢复状态的变量。
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size)
self.old_num_batch_per_bucket = kwargs.get('old_num_batch_per_bucket', self.num_batch_per_bucket)

def set_distributed(self, num_replicas, rank, pad=True):
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \
"during an unfinished iteration."
@@ -217,92 +249,123 @@ class BucketedBatchSampler(ReproducibleIterator):
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了
self.num_consumed_samples = 0
self.during_iter = True
indices = self.generate_indices()

if self.pad:
# add extra samples to make it evenly divisible
padding_size = self.total_size - len(indices)
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[:self.total_size]

assert len(indices) == self.total_size

# subsample
indices = indices[self.num_consumed_samples:]
indices = indices[self.rank:len(indices):self.num_replicas]
assert len(indices) == self.num_left_samples

# 根据内部的长度进行排序
sub_length = self.length[indices] # 取出这个 rank 中的长度
sorted_indices = np.argsort(sub_length)[::-1] # 按长度从高到低排序的
sorted_indices = deepcopy(self.sorted_indices).tolist() # 按长度从高到低排序的

if self.shuffle:
# 实际的 bucket 大小
bucket_size = min(len(sorted_indices), self.batch_size * self.num_batch_per_bucket)
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
num_buckets = (len(sorted_indices) + bucket_size - 1)//bucket_size
batches = []
batch_indices = []
for i in range(num_buckets):
bucket = sorted_indices[i*bucket_size:(i+1)*bucket_size]
rng.shuffle(bucket) # bucket 内部 shuffle 一下
_indices = np.full(fill_value=self.batch_size, dtype=int,
shape=(len(bucket)//self.batch_size)).cumsum()
_batches = np.split(bucket, _indices)
batch_indices.extend(list(range(len(batches), len(batches)+len(_batches))))
batches.extend(_batches)
last_batches = []
if len(batches)>=1 and len(batches[-1])<self.batch_size:
last_batches = batches[-1].tolist()
batch_indices = batch_indices[:-1]
batches = batches[:-1]
if self.drop_last and len(last_batches)<self.batch_size:
last_batches = []
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。
batches = np.array(batches)[batch_indices]
indices = list(chain(*batches)) + last_batches
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的
_batches = []
for _i in range(self.old_num_replicas):
_sorted_indices = sorted_indices[_i:len(sorted_indices):self.old_num_replicas]
__batches = self.bucketerize(_sorted_indices, self.old_batch_size, self.old_num_batch_per_bucket,
seed=self.seed+self.epoch)
_batches.append(__batches)
batches = list(chain(*[_ for _ in zip(*_batches)]))
sorted_indices = list(chain(*batches))
sorted_indices = sorted_indices[self.num_consumed_samples:]
# 再进行排序
sub_length = self.length[sorted_indices]
sorted_indices = np.array(sorted_indices)[np.argsort(sub_length)[::-1]] # 按长度从高到低排序的
# 取出这个 rank ,
sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas]
batches = self.bucketerize(sorted_indices, self.batch_size, self.num_batch_per_bucket,
seed=self.seed+self.epoch)
batches = list(map(list, batches))
else:
indices = sorted_indices
if len(indices)<self.batch_size and self.drop_last:
indices = []

for index in range(indices):
self.num_consumed_samples += self.num_replicas
yield index
sorted_indices = sorted_indices[self.num_consumed_samples:]
sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas]
_num_batches = len(sorted_indices) // self.batch_size
if _num_batches == 0:
batches = [sorted_indices]
else:
batches = list(map(list, np.array_split(sorted_indices[:_num_batches*self.batch_size], _num_batches)))
if len(sorted_indices)%self.batch_size!=0:
batches.append(sorted_indices[_num_batches*self.batch_size:])

need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank:
if len(batches) > 0:
if len(batches[-1])<self.batch_size:
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。
else:
batches.append([batches[-1][0]])
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank:
if len(batches):
batches[-1].pop(-1)
if len(batches[-1])==0:
batches.pop(-1)

assert len(list(chain(*batches))) == self.num_left_samples

if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size:
batches = batches[:-1]

for batch in batches:
self.num_consumed_samples += self.num_replicas * len(batch)
yield list(map(int, batch))
self.during_iter = False
self.num_consumed_samples = 0
self.old_batch_size = self.batch_size
self.old_num_batch_per_bucket = self.num_batch_per_bucket
self.old_num_replicas = self.num_replicas
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了
self.epoch -= 1

def generate_indices(self) -> List[int]:
def bucketerize(self, sorted_indices, batch_size, num_batch_per_bucket, seed):
"""
生成随机序列,用于保证在所有卡的总和加起来是原来的数据量。
将 indices 分桶

:return:
:param sorted_indices: List[int]
:param batch_size: int
:param num_batch_per_bucket: int
:param seed: int
:return: List[List[int]]
"""
if self.shuffle:
indices = list(range(len(self.dataset)))
seed = self.seed + self.epoch
rng = np.random.default_rng(abs(seed))
rng.shuffle(indices)
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
self.epoch -= 1
else:
indices = list(range(len(self.dataset)))
return indices
# 实际的 bucket 大小
bucket_size = min(len(sorted_indices), batch_size * num_batch_per_bucket)
rng = np.random.default_rng(abs(seed))
num_buckets = (len(sorted_indices) + bucket_size - 1) // bucket_size
batches = []
batch_indices = []
for i in range(num_buckets):
bucket = sorted_indices[i * bucket_size:(i + 1) * bucket_size]
rng.shuffle(bucket) # bucket 内部 shuffle 一下
_num_batches = len(bucket) // batch_size
if _num_batches == 0:
_batches = [bucket]
else:
_batches = np.array_split(bucket[:_num_batches*batch_size], _num_batches)
if len(bucket) % batch_size != 0:
_batches.append(bucket[_num_batches*batch_size:])
batch_indices.extend(list(range(len(batches), len(batches) + len(_batches))))
batches.extend(_batches)
last_batches = []
# 最后一个batch 统一不参与shuffle,因为有的rank最后一个 batch 可能不足一个batch_size (不足的时候
# 一定要放在末尾,所以就干脆所有的rank都不对最后一个batch进行shuffle)。
if len(batches) >= 1:
last_batches = [list(batches[-1])]
batch_indices = list(batch_indices[:-1])
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。
batches = (np.array(batches)[batch_indices]).tolist()
if last_batches:
batches = batches + last_batches
return batches

def state_dict(self) -> Dict:
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ")
states = {
'seed': self.seed,
'epoch': self.epoch,
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'shuffle': self.shuffle
'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas
}
return states

@@ -322,4 +385,13 @@ class BucketedBatchSampler(ReproducibleIterator):
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
self.shuffle = states["shuffle"]
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
f"we use shuffle={states['shuffle']}")
self.shuffle = states["shuffle"]
self.old_batch_size = states['batch_size']
self.old_num_batch_per_bucket = states['num_batch_per_bucket']
self.old_num_replicas = states['num_replicas']

def set_epoch(self, epoch):
self.epoch = epoch

+ 5
- 18
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -2,14 +2,14 @@ from typing import Dict, List
import math
import numpy as np

from fastNLP.core.log import logger

__all__ = [
'ReproducibleIterator',
'RandomSampler',
're_instantiate_sampler'
]

from fastNLP.core.samplers import ReproducibleBatchSampler


def re_instantiate_sampler(sampler):
all_attributes = vars(sampler)
@@ -164,6 +164,9 @@ class RandomSampler(ReproducibleIterator):
self.num_consumed_samples = states['num_consumed_samples']
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0
self.num_consumed_samples = 0
if self.shuffle != states['shuffle']:
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, "
f"we use shuffle={states['shuffle']}")
self.shuffle = states["shuffle"]

def set_epoch(self, epoch: int) -> None:
@@ -212,24 +215,8 @@ class RandomSampler(ReproducibleIterator):
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas))


# todo
# class SortedSampler(ReproducibleIterator):
# def __init__(self, dataset, key):
# pass
#
#
# class BucketedSampler(ReproducibleIterator):
# def __init__(self, dataset, key):
# pass

if __name__ == "__main__":

sampler = RandomSampler(1)

print(vars(sampler))

batch_sampler = ReproducibleBatchSampler(list(range(3)), 1, True)
print(vars(batch_sampler))





+ 439
- 0
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -0,0 +1,439 @@
from array import array

import numpy as np
import pytest
from itertools import chain

from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset


class TestReproducibleBatchSampler:
# TODO 拆分测试,在这里只测试一个东西
def test_torch_dataloader_1(self):
import torch
from torch.utils.data import DataLoader
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size,
"sampler_type": "ReproducibleBatchSampler"}

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 改变 batch_size;
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
state = _get_re_batchsampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.append(next(iter_dataloader))
except StopIteration:
break

def test_3(self):
import torch
from torch.utils.data import DataLoader
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size)

for idx, data in enumerate(dataloader):
if idx > 3:
break

iterator = iter(dataloader)
for each in iterator:
pass


class DatasetWithVaryLength:
def __init__(self, num_of_data=100):
self.data = np.arange(num_of_data)

def __getitem__(self, item):
return self.data[item]

def __len__(self):
return len(self.data)


class TestBucketedBatchSampler:
@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71])
def test_single_num_batch(self, shuffle, drop_last, num):
# 数量不够不报错
for num in [2, 7, 14, 15, 70, 71]:
dataset = DatasetWithVaryLength(num_of_data=num)
before_batch_size = 7
re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
num_batch_per_bucket=10, drop_last=drop_last,
shuffle=shuffle)
count = len(list(iter(re_batchsampler)))
if drop_last:
assert count==num//before_batch_size, num
else:
assert count==(num+before_batch_size-1)//before_batch_size, num

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
def test_single(self, shuffle, drop_last):

before_batch_size = 7
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4

dataset = DatasetWithVaryLength(num_of_data=1000)
re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last,
shuffle=shuffle)
re_batchsampler.set_epoch(0)
forward_steps = 10
iterator = iter(re_batchsampler)
already_generate_indices = set()
for _ in range(forward_steps):
batch = next(iterator)
assert max(batch) - min(batch) <= before_batch_size * num_batch_per_bucket
already_generate_indices.update(batch)

# 1. 保存状态
state = re_batchsampler.state_dict()

# 2. 断点重训,继续训练
re_batchsampler2 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size,
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last,
shuffle=shuffle)
re_batchsampler2.load_state_dict(state)
re_batchsampler2.set_epoch(0)
new_already_generate_indices = set()
mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-before_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i])
for batch in re_batchsampler2:
assert max(batch) - min(batch) <= max_diff
for b in batch:
assert b not in already_generate_indices
new_already_generate_indices.update(batch)
if drop_last is False:
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset)

# 改变 batch_size;
after_batch_size = 3
re_batchsampler3 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last,
shuffle=shuffle)
re_batchsampler3.load_state_dict(state)
re_batchsampler3.set_epoch(0)
count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-after_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+after_batch_size * num_batch_per_bucket]-indices[i])

for batch in re_batchsampler3:
assert max(batch) - min(batch) <= max_diff
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

# 再 save ,不允许再上个epoch没结束继续sample
after_batch_size = 5
with pytest.raises(RuntimeError):
state = re_batchsampler3.state_dict()

for batch in re_batchsampler3: # consume all, 这样才能save
pass

already_generate_indices = set()
count = 0
for batch in re_batchsampler3: # 重新开始
assert max(batch) - min(batch) <= max_diff
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)
count += 1
if count > 5:
break

state = re_batchsampler3.state_dict()
# 这里的 drop_last 为 False,需要最终是所有 sample
re_batchsampler4 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size,
num_batch_per_bucket=num_batch_per_bucket, drop_last=False,
shuffle=shuffle)
re_batchsampler4.load_state_dict(state)
re_batchsampler4.set_epoch(0)

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_generate_indices)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices) - after_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i + after_batch_size * num_batch_per_bucket] - indices[i])

for batch in re_batchsampler4:
assert max(batch) - min(batch) <= max_diff
for b in batch:
assert b not in already_generate_indices
already_generate_indices.update(batch)

assert len(already_generate_indices) == len(dataset)

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
def test_multi(self, shuffle, drop_last, pad):
# def test_multi(self, shuffle=True, drop_last=False, pad=False):

# no shuffle
num_replica = 2
dataset = DatasetWithVaryLength(num_of_data=1000)
batch_size = 5
num_batch_per_bucket = 10
lengths = []
rank0_already_seen_indexes = None
max_diff = num_batch_per_bucket * batch_size * num_replica
for rank in range(num_replica):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
num_batch_per_bucket = num_batch_per_bucket,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
already_seen_indexes = set()
repeat_count = 0
for batch in sampler:
assert max_diff>=max(batch)-min(batch)
for b in batch:
repeat_count += int(b in already_seen_indexes)
if rank0_already_seen_indexes: # 不能交叉出现
assert b not in rank0_already_seen_indexes
already_seen_indexes.update(batch)
if rank0_already_seen_indexes is None:
rank0_already_seen_indexes = already_seen_indexes
if pad: # 应该允许重复一次
assert repeat_count<=1
else:
assert repeat_count==0

assert len(set(lengths))==1, lengths # 每个进程的batch数量一致

# 多进程的保存
already_seen_indexes = set()
for rank in range(num_replica):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size,
num_batch_per_bucket = num_batch_per_bucket,
shuffle = shuffle, drop_last=drop_last)
sampler.set_epoch(0)
sampler.set_distributed(num_replica, rank=rank, pad=pad)
lengths.append(len(sampler))
count = 0
for batch in sampler:
assert max_diff>=max(batch)-min(batch)
already_seen_indexes.update(batch)
if count>5:
break
count += 1
state = sampler.state_dict()

# 切换成单机
new_batch_size = 6
num_batch_per_bucket = 3
new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
num_batch_per_bucket=num_batch_per_bucket,
shuffle=shuffle, drop_last=drop_last)
new_sampler.load_state_dict(state)
repeat_count = 0
new_already_seen_indexes = set(list(already_seen_indexes))

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices)-new_batch_size * num_batch_per_bucket):
max_diff = max(max_diff, indices[i+new_batch_size * num_batch_per_bucket]-indices[i])

for batch in new_sampler:
assert max_diff>=max(batch)-min(batch)
for b in batch:
repeat_count += int(b in new_already_seen_indexes)
new_already_seen_indexes.update(batch)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0
if drop_last is False: # 如果没有drop应该相等
assert len(new_already_seen_indexes)==len(dataset)

# 测试替换卡的数量。
num_replica = 3
new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size,
num_batch_per_bucket=num_batch_per_bucket,
shuffle=shuffle, drop_last=drop_last)
new_sampler.set_epoch(0)
new_sampler.load_state_dict(state)
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad)
repeat_count = 0

mask = np.ones(len(dataset), dtype=bool)
mask[list(already_seen_indexes)] = 0
indices = np.arange(len(dataset))[mask]
max_diff = -1
for i in range(len(indices) - new_batch_size * num_batch_per_bucket*num_replica):
max_diff = max(max_diff, indices[i + new_batch_size * num_batch_per_bucket*num_replica] - indices[i])

for batch in new_sampler:
assert max_diff>=max(batch)-min(batch)
for b in batch:
repeat_count += int(b in already_seen_indexes)
if pad: # 应该允许重复一次
assert repeat_count <= 1
else:
assert repeat_count == 0

@pytest.mark.parametrize('shuffle', [True, False])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('pad', [True, False])
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
@pytest.mark.parametrize('num_replica', [2, 3])
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica):
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2):
# TODO 两个 rank 上的长度是要在同一个bucket的
dataset = DatasetWithVaryLength(num_of_data=num_samples)
batch_size = 6
if num_replica*batch_size > num_samples:
return
num_batch_per_bucket = 10
samplers = []
lengths = []
for i in range(num_replica):
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size,
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last)
sampler.set_distributed(num_replica, rank=i, pad=pad)
sampler.set_epoch(0)
samplers.append(sampler)
lengths.append(len(list(iter(sampler))))
assert len(set(lengths))==1
bucket_diff = batch_size * num_batch_per_bucket * num_replica

for bs in zip(*samplers):
diff = max(chain(*bs)) - min(chain(*bs))
assert diff <= bucket_diff

+ 0
- 146
tests/core/samplers/test_reproducible_sampler.py View File

@@ -7,7 +7,6 @@ from functools import partial
from array import array

from fastNLP.core.samplers.reproducible_sampler import RandomSampler
from fastNLP.core.samplers import ReproducibleBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset

@@ -362,148 +361,3 @@ class TestRandomSampler(unittest.TestCase):



class TestReproducibleBatchSampler:
def test_torch_dataloader_1(self):
import torch
from torch.utils.data import DataLoader
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size,
"sampler_type": "ReproducibleBatchSampler"}

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 改变 batch_size;
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler)
state = _get_re_batchsampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.append(next(iter_dataloader))
except StopIteration:
break

def test_3(self):
import torch
from torch.utils.data import DataLoader, RandomSampler, BatchSampler
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size)

for idx, data in enumerate(dataloader):
if idx > 3:
break

iterator = iter(dataloader)
for each in iterator:
pass

Loading…
Cancel
Save