diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 034292eb..952712be 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -130,8 +130,8 @@ class TorchSingleDriver(TorchDriver): else: return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], - reproducible: bool = False, sampler_or_batch_sampler=None): + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, + reproducible: bool = False): if isinstance(dist, ReproducibleBatchSampler): return replace_batch_sampler(dataloader, dist) elif isinstance(dist, ReproducibleIterator): diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index e5721ebc..68928b66 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -17,5 +17,5 @@ __all__ = [ from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler -from .reproducible_batch_sampler import ReproducibleBatchSampler +from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 3476ba71..3e39aca5 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -1,20 +1,48 @@ +__all__ = [ + 'BucketedBatchSampler', + "ReproducibleBatchSampler" +] + import math from array import array from copy import deepcopy -from itertools import chain from typing import Dict, Union, List +from itertools import chain import numpy as np from fastNLP.core.dataset import DataSet -from fastNLP.core.samplers import ReproducibleIterator +from fastNLP.core.log import logger +from abc import abstractmethod + + +class ReproducibleBatchIterator: + @abstractmethod + def set_distributed(self, num_replicas, rank, pad=True): + raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") + + @abstractmethod + def __len__(self): + raise NotImplementedError("Each specific batch_sampler should implement its own `__len__` method.") + @abstractmethod + def __iter__(self): + raise NotImplementedError("Each specific batch_sampler should implement its own `__iter__` method.") + @abstractmethod + def state_dict(self): + raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") + @abstractmethod + def load_state_dict(self, states): + raise NotImplementedError("Each specific batch_sampler should implement its own `load_state_dict` method.") + @abstractmethod + def set_epoch(self, epoch): + pass -class ReproducibleBatchSampler: +class ReproducibleBatchSampler(ReproducibleBatchIterator): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): """ @@ -94,7 +122,7 @@ class ReproducibleBatchSampler: self.data_idx = states["data_idx"] self.need_reinitialize = False - def set_distributed(self): + def set_distributed(self, num_replicas, rank, pad=True): raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") def set_epoch(self, epoch): @@ -110,7 +138,7 @@ class ReproducibleBatchSampler: (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size -class BucketedBatchSampler(ReproducibleIterator): +class BucketedBatchSampler(ReproducibleBatchIterator): def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): """ @@ -129,20 +157,20 @@ class BucketedBatchSampler(ReproducibleIterator): :param kwargs: fastNLP 保留使用 """ super().__init__() - if not isinstance(dataset, DataSet): + if isinstance(dataset, DataSet): length = dataset.get_field(length) if not isinstance(length[0], int): length = list(map(len, length)) else: - assert isinstance(length, List) and len(length)==len(dataset), "When the dataset is not fastNLP.DataSet, " \ - "the length parameter can only be List[int]" - assert len(length) == len(dataset), "The length of `data` and `length` should be equal." + assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ + "the length parameter can only be List[int]" - if drop_last: - assert len(dataset)>=batch_size, "The number of samplers must be larger than batch_size when `drop_last=True`." + assert len(length) == len(dataset), "The length of `data` and `length` should be equal." self.dataset = dataset self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 + self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 + self.batch_size = batch_size self.num_batch_per_bucket = num_batch_per_bucket @@ -161,6 +189,10 @@ class BucketedBatchSampler(ReproducibleIterator): # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() self.during_iter = kwargs.get("during_iter", False) + # 以下变量为内部使用恢复状态的变量。 + self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) + self.old_num_batch_per_bucket = kwargs.get('old_num_batch_per_bucket', self.num_batch_per_bucket) + def set_distributed(self, num_replicas, rank, pad=True): assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ "during an unfinished iteration." @@ -217,92 +249,123 @@ class BucketedBatchSampler(ReproducibleIterator): if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 self.num_consumed_samples = 0 self.during_iter = True - indices = self.generate_indices() - - if self.pad: - # add extra samples to make it evenly divisible - padding_size = self.total_size - len(indices) - if padding_size <= len(indices): - indices += indices[:padding_size] - else: - indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] - else: - # remove tail of data to make it evenly divisible. - indices = indices[:self.total_size] - assert len(indices) == self.total_size - - # subsample - indices = indices[self.num_consumed_samples:] - indices = indices[self.rank:len(indices):self.num_replicas] - assert len(indices) == self.num_left_samples - - # 根据内部的长度进行排序 - sub_length = self.length[indices] # 取出这个 rank 中的长度 - sorted_indices = np.argsort(sub_length)[::-1] # 按长度从高到低排序的 + sorted_indices = deepcopy(self.sorted_indices).tolist() # 按长度从高到低排序的 if self.shuffle: - # 实际的 bucket 大小 - bucket_size = min(len(sorted_indices), self.batch_size * self.num_batch_per_bucket) - seed = self.seed + self.epoch - rng = np.random.default_rng(abs(seed)) - num_buckets = (len(sorted_indices) + bucket_size - 1)//bucket_size - batches = [] - batch_indices = [] - for i in range(num_buckets): - bucket = sorted_indices[i*bucket_size:(i+1)*bucket_size] - rng.shuffle(bucket) # bucket 内部 shuffle 一下 - _indices = np.full(fill_value=self.batch_size, dtype=int, - shape=(len(bucket)//self.batch_size)).cumsum() - _batches = np.split(bucket, _indices) - batch_indices.extend(list(range(len(batches), len(batches)+len(_batches)))) - batches.extend(_batches) - last_batches = [] - if len(batches)>=1 and len(batches[-1]) 0: # 需要先按照原来的排序,删掉多余的 + _batches = [] + for _i in range(self.old_num_replicas): + _sorted_indices = sorted_indices[_i:len(sorted_indices):self.old_num_replicas] + __batches = self.bucketerize(_sorted_indices, self.old_batch_size, self.old_num_batch_per_bucket, + seed=self.seed+self.epoch) + _batches.append(__batches) + batches = list(chain(*[_ for _ in zip(*_batches)])) + sorted_indices = list(chain(*batches)) + sorted_indices = sorted_indices[self.num_consumed_samples:] + # 再进行排序 + sub_length = self.length[sorted_indices] + sorted_indices = np.array(sorted_indices)[np.argsort(sub_length)[::-1]] # 按长度从高到低排序的 + # 取出这个 rank , + sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas] + batches = self.bucketerize(sorted_indices, self.batch_size, self.num_batch_per_bucket, + seed=self.seed+self.epoch) + batches = list(map(list, batches)) else: - indices = sorted_indices - if len(indices) 0: + if len(batches[-1])self.rank: + if len(batches): + batches[-1].pop(-1) + if len(batches[-1])==0: + batches.pop(-1) + + assert len(list(chain(*batches))) == self.num_left_samples + + if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: + batches = batches[:-1] + + for batch in batches: + self.num_consumed_samples += self.num_replicas * len(batch) + yield list(map(int, batch)) self.during_iter = False self.num_consumed_samples = 0 + self.old_batch_size = self.batch_size + self.old_num_batch_per_bucket = self.num_batch_per_bucket + self.old_num_replicas = self.num_replicas + if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 + self.epoch -= 1 - def generate_indices(self) -> List[int]: + def bucketerize(self, sorted_indices, batch_size, num_batch_per_bucket, seed): """ - 生成随机序列,用于保证在所有卡的总和加起来是原来的数据量。 + 将 indices 分桶 - :return: + :param sorted_indices: List[int] + :param batch_size: int + :param num_batch_per_bucket: int + :param seed: int + :return: List[List[int]] """ - if self.shuffle: - indices = list(range(len(self.dataset))) - seed = self.seed + self.epoch - rng = np.random.default_rng(abs(seed)) - rng.shuffle(indices) - if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 - self.epoch -= 1 - else: - indices = list(range(len(self.dataset))) - return indices + # 实际的 bucket 大小 + bucket_size = min(len(sorted_indices), batch_size * num_batch_per_bucket) + rng = np.random.default_rng(abs(seed)) + num_buckets = (len(sorted_indices) + bucket_size - 1) // bucket_size + batches = [] + batch_indices = [] + for i in range(num_buckets): + bucket = sorted_indices[i * bucket_size:(i + 1) * bucket_size] + rng.shuffle(bucket) # bucket 内部 shuffle 一下 + _num_batches = len(bucket) // batch_size + if _num_batches == 0: + _batches = [bucket] + else: + _batches = np.array_split(bucket[:_num_batches*batch_size], _num_batches) + if len(bucket) % batch_size != 0: + _batches.append(bucket[_num_batches*batch_size:]) + batch_indices.extend(list(range(len(batches), len(batches) + len(_batches)))) + batches.extend(_batches) + last_batches = [] + # 最后一个batch 统一不参与shuffle,因为有的rank最后一个 batch 可能不足一个batch_size (不足的时候 + # 一定要放在末尾,所以就干脆所有的rank都不对最后一个batch进行shuffle)。 + if len(batches) >= 1: + last_batches = [list(batches[-1])] + batch_indices = list(batch_indices[:-1]) + rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 + rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 + batches = (np.array(batches)[batch_indices]).tolist() + if last_batches: + batches = batches + last_batches + return batches def state_dict(self) -> Dict: + if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: + raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" + " consumed. ") states = { 'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; 'sampler_type': self.__class__.__name__, 'length': len(self.dataset), - 'shuffle': self.shuffle + 'shuffle': self.shuffle, + 'batch_size': self.batch_size, + 'num_batch_per_bucket': self.num_batch_per_bucket, + 'num_replicas': self.num_replicas } return states @@ -322,4 +385,13 @@ class BucketedBatchSampler(ReproducibleIterator): self.num_consumed_samples = states['num_consumed_samples'] if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 self.num_consumed_samples = 0 - self.shuffle = states["shuffle"] \ No newline at end of file + if self.shuffle != states['shuffle']: + logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " + f"we use shuffle={states['shuffle']}") + self.shuffle = states["shuffle"] + self.old_batch_size = states['batch_size'] + self.old_num_batch_per_bucket = states['num_batch_per_bucket'] + self.old_num_replicas = states['num_replicas'] + + def set_epoch(self, epoch): + self.epoch = epoch \ No newline at end of file diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 0ae011b2..0a4ac7bf 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -2,14 +2,14 @@ from typing import Dict, List import math import numpy as np +from fastNLP.core.log import logger + __all__ = [ 'ReproducibleIterator', 'RandomSampler', 're_instantiate_sampler' ] -from fastNLP.core.samplers import ReproducibleBatchSampler - def re_instantiate_sampler(sampler): all_attributes = vars(sampler) @@ -164,6 +164,9 @@ class RandomSampler(ReproducibleIterator): self.num_consumed_samples = states['num_consumed_samples'] if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 self.num_consumed_samples = 0 + if self.shuffle != states['shuffle']: + logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " + f"we use shuffle={states['shuffle']}") self.shuffle = states["shuffle"] def set_epoch(self, epoch: int) -> None: @@ -212,24 +215,8 @@ class RandomSampler(ReproducibleIterator): self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) -# todo -# class SortedSampler(ReproducibleIterator): -# def __init__(self, dataset, key): -# pass -# -# -# class BucketedSampler(ReproducibleIterator): -# def __init__(self, dataset, key): -# pass - -if __name__ == "__main__": - - sampler = RandomSampler(1) - print(vars(sampler)) - batch_sampler = ReproducibleBatchSampler(list(range(3)), 1, True) - print(vars(batch_sampler)) diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index e69de29b..42b86dcd 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -0,0 +1,439 @@ +from array import array + +import numpy as np +import pytest +from itertools import chain + +from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler +from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler +from tests.helpers.datasets.torch_data import TorchNormalDataset + + +class TestReproducibleBatchSampler: + # TODO 拆分测试,在这里只测试一个东西 + def test_torch_dataloader_1(self): + import torch + from torch.utils.data import DataLoader + # no shuffle + before_batch_size = 7 + dataset = TorchNormalDataset(num_of_data=100) + dataloader = DataLoader(dataset, batch_size=before_batch_size) + re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + forward_steps = 3 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + next(iter_dataloader) + + # 1. 保存状态 + _get_re_batchsampler = dataloader.batch_sampler + assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + state = _get_re_batchsampler.state_dict() + assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, + "sampler_type": "ReproducibleBatchSampler"} + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + dataloader = DataLoader(dataset, batch_size=before_batch_size) + re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + real_res = [] + supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) + forward_steps = 2 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert all(real_res[i] == supposed_res[i]) + + # 改变 batch_size; + after_batch_size = 3 + dataloader = DataLoader(dataset, batch_size=after_batch_size) + re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + real_res = [] + supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) + forward_steps = 2 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert all(real_res[i] == supposed_res[i]) + + # 断点重训的第二轮是否是一个完整的 dataloader; + # 先把断点重训所在的那一个 epoch 跑完; + begin_idx = 27 + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) + begin_idx += _batch_size + except StopIteration: + break + + # 开始新的一轮; + begin_idx = 0 + iter_dataloader = iter(dataloader) + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) + begin_idx += _batch_size + except StopIteration: + break + + def test_torch_dataloader_2(self): + # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; + from torch.utils.data import DataLoader + # no shuffle + before_batch_size = 7 + dataset = TorchNormalDataset(num_of_data=100) + # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; + dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) + re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + # 将一轮的所有数据保存下来,看是否恢复的是正确的; + all_supposed_data = [] + forward_steps = 3 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + all_supposed_data.extend(next(iter_dataloader).tolist()) + + # 1. 保存状态 + _get_re_batchsampler = dataloader.batch_sampler + assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + state = _get_re_batchsampler.state_dict() + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) + re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + # 先把这一轮的数据过完; + pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] + while True: + try: + all_supposed_data.extend(next(iter_dataloader).tolist()) + except StopIteration: + break + assert all_supposed_data == list(pre_index_list) + + # 重新开启新的一轮; + for _ in range(3): + iter_dataloader = iter(dataloader) + res = [] + while True: + try: + res.append(next(iter_dataloader)) + except StopIteration: + break + + def test_3(self): + import torch + from torch.utils.data import DataLoader + before_batch_size = 7 + dataset = TorchNormalDataset(num_of_data=100) + # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; + dataloader = DataLoader(dataset, batch_size=before_batch_size) + + for idx, data in enumerate(dataloader): + if idx > 3: + break + + iterator = iter(dataloader) + for each in iterator: + pass + + +class DatasetWithVaryLength: + def __init__(self, num_of_data=100): + self.data = np.arange(num_of_data) + + def __getitem__(self, item): + return self.data[item] + + def __len__(self): + return len(self.data) + + +class TestBucketedBatchSampler: + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) + def test_single_num_batch(self, shuffle, drop_last, num): + # 数量不够不报错 + for num in [2, 7, 14, 15, 70, 71]: + dataset = DatasetWithVaryLength(num_of_data=num) + before_batch_size = 7 + re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + num_batch_per_bucket=10, drop_last=drop_last, + shuffle=shuffle) + count = len(list(iter(re_batchsampler))) + if drop_last: + assert count==num//before_batch_size, num + else: + assert count==(num+before_batch_size-1)//before_batch_size, num + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + def test_single(self, shuffle, drop_last): + + before_batch_size = 7 + num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 + + dataset = DatasetWithVaryLength(num_of_data=1000) + re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, + shuffle=shuffle) + re_batchsampler.set_epoch(0) + forward_steps = 10 + iterator = iter(re_batchsampler) + already_generate_indices = set() + for _ in range(forward_steps): + batch = next(iterator) + assert max(batch) - min(batch) <= before_batch_size * num_batch_per_bucket + already_generate_indices.update(batch) + + # 1. 保存状态 + state = re_batchsampler.state_dict() + + # 2. 断点重训,继续训练 + re_batchsampler2 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, + num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, + shuffle=shuffle) + re_batchsampler2.load_state_dict(state) + re_batchsampler2.set_epoch(0) + new_already_generate_indices = set() + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + indices = np.arange(len(dataset))[mask] + max_diff = -1 + for i in range(len(indices)-before_batch_size * num_batch_per_bucket): + max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) + for batch in re_batchsampler2: + assert max(batch) - min(batch) <= max_diff + for b in batch: + assert b not in already_generate_indices + new_already_generate_indices.update(batch) + if drop_last is False: + assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) + + # 改变 batch_size; + after_batch_size = 3 + re_batchsampler3 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, + num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, + shuffle=shuffle) + re_batchsampler3.load_state_dict(state) + re_batchsampler3.set_epoch(0) + count = 0 + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + indices = np.arange(len(dataset))[mask] + max_diff = -1 + for i in range(len(indices)-after_batch_size * num_batch_per_bucket): + max_diff = max(max_diff, indices[i+after_batch_size * num_batch_per_bucket]-indices[i]) + + for batch in re_batchsampler3: + assert max(batch) - min(batch) <= max_diff + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + count += 1 + if count > 5: + break + + # 再 save ,不允许再上个epoch没结束继续sample + after_batch_size = 5 + with pytest.raises(RuntimeError): + state = re_batchsampler3.state_dict() + + for batch in re_batchsampler3: # consume all, 这样才能save + pass + + already_generate_indices = set() + count = 0 + for batch in re_batchsampler3: # 重新开始 + assert max(batch) - min(batch) <= max_diff + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + count += 1 + if count > 5: + break + + state = re_batchsampler3.state_dict() + # 这里的 drop_last 为 False,需要最终是所有 sample + re_batchsampler4 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, + num_batch_per_bucket=num_batch_per_bucket, drop_last=False, + shuffle=shuffle) + re_batchsampler4.load_state_dict(state) + re_batchsampler4.set_epoch(0) + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_generate_indices)] = 0 + indices = np.arange(len(dataset))[mask] + max_diff = -1 + for i in range(len(indices) - after_batch_size * num_batch_per_bucket): + max_diff = max(max_diff, indices[i + after_batch_size * num_batch_per_bucket] - indices[i]) + + for batch in re_batchsampler4: + assert max(batch) - min(batch) <= max_diff + for b in batch: + assert b not in already_generate_indices + already_generate_indices.update(batch) + + assert len(already_generate_indices) == len(dataset) + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + def test_multi(self, shuffle, drop_last, pad): + # def test_multi(self, shuffle=True, drop_last=False, pad=False): + + # no shuffle + num_replica = 2 + dataset = DatasetWithVaryLength(num_of_data=1000) + batch_size = 5 + num_batch_per_bucket = 10 + lengths = [] + rank0_already_seen_indexes = None + max_diff = num_batch_per_bucket * batch_size * num_replica + for rank in range(num_replica): + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, + num_batch_per_bucket = num_batch_per_bucket, + shuffle = shuffle, drop_last=drop_last) + sampler.set_epoch(0) + sampler.set_distributed(num_replica, rank=rank, pad=pad) + lengths.append(len(sampler)) + already_seen_indexes = set() + repeat_count = 0 + for batch in sampler: + assert max_diff>=max(batch)-min(batch) + for b in batch: + repeat_count += int(b in already_seen_indexes) + if rank0_already_seen_indexes: # 不能交叉出现 + assert b not in rank0_already_seen_indexes + already_seen_indexes.update(batch) + if rank0_already_seen_indexes is None: + rank0_already_seen_indexes = already_seen_indexes + if pad: # 应该允许重复一次 + assert repeat_count<=1 + else: + assert repeat_count==0 + + assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 + + # 多进程的保存 + already_seen_indexes = set() + for rank in range(num_replica): + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, + num_batch_per_bucket = num_batch_per_bucket, + shuffle = shuffle, drop_last=drop_last) + sampler.set_epoch(0) + sampler.set_distributed(num_replica, rank=rank, pad=pad) + lengths.append(len(sampler)) + count = 0 + for batch in sampler: + assert max_diff>=max(batch)-min(batch) + already_seen_indexes.update(batch) + if count>5: + break + count += 1 + state = sampler.state_dict() + + # 切换成单机 + new_batch_size = 6 + num_batch_per_bucket = 3 + new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, + num_batch_per_bucket=num_batch_per_bucket, + shuffle=shuffle, drop_last=drop_last) + new_sampler.load_state_dict(state) + repeat_count = 0 + new_already_seen_indexes = set(list(already_seen_indexes)) + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_seen_indexes)] = 0 + indices = np.arange(len(dataset))[mask] + max_diff = -1 + for i in range(len(indices)-new_batch_size * num_batch_per_bucket): + max_diff = max(max_diff, indices[i+new_batch_size * num_batch_per_bucket]-indices[i]) + + for batch in new_sampler: + assert max_diff>=max(batch)-min(batch) + for b in batch: + repeat_count += int(b in new_already_seen_indexes) + new_already_seen_indexes.update(batch) + if pad: # 应该允许重复一次 + assert repeat_count <= 1 + else: + assert repeat_count == 0 + if drop_last is False: # 如果没有drop应该相等 + assert len(new_already_seen_indexes)==len(dataset) + + # 测试替换卡的数量。 + num_replica = 3 + new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, + num_batch_per_bucket=num_batch_per_bucket, + shuffle=shuffle, drop_last=drop_last) + new_sampler.set_epoch(0) + new_sampler.load_state_dict(state) + new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) + repeat_count = 0 + + mask = np.ones(len(dataset), dtype=bool) + mask[list(already_seen_indexes)] = 0 + indices = np.arange(len(dataset))[mask] + max_diff = -1 + for i in range(len(indices) - new_batch_size * num_batch_per_bucket*num_replica): + max_diff = max(max_diff, indices[i + new_batch_size * num_batch_per_bucket*num_replica] - indices[i]) + + for batch in new_sampler: + assert max_diff>=max(batch)-min(batch) + for b in batch: + repeat_count += int(b in already_seen_indexes) + if pad: # 应该允许重复一次 + assert repeat_count <= 1 + else: + assert repeat_count == 0 + + @pytest.mark.parametrize('shuffle', [True, False]) + @pytest.mark.parametrize('drop_last', [True, False]) + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) + @pytest.mark.parametrize('num_replica', [2, 3]) + def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): + # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): + # TODO 两个 rank 上的长度是要在同一个bucket的 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + batch_size = 6 + if num_replica*batch_size > num_samples: + return + num_batch_per_bucket = 10 + samplers = [] + lengths = [] + for i in range(num_replica): + sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, + num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) + sampler.set_distributed(num_replica, rank=i, pad=pad) + sampler.set_epoch(0) + samplers.append(sampler) + lengths.append(len(list(iter(sampler)))) + assert len(set(lengths))==1 + bucket_diff = batch_size * num_batch_per_bucket * num_replica + + for bs in zip(*samplers): + diff = max(chain(*bs)) - min(chain(*bs)) + assert diff <= bucket_diff diff --git a/tests/core/samplers/test_reproducible_sampler.py b/tests/core/samplers/test_reproducible_sampler.py index 29e07a09..0a3697d3 100644 --- a/tests/core/samplers/test_reproducible_sampler.py +++ b/tests/core/samplers/test_reproducible_sampler.py @@ -7,7 +7,6 @@ from functools import partial from array import array from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.samplers import ReproducibleBatchSampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -362,148 +361,3 @@ class TestRandomSampler(unittest.TestCase): -class TestReproducibleBatchSampler: - def test_torch_dataloader_1(self): - import torch - from torch.utils.data import DataLoader - # no shuffle - before_batch_size = 7 - dataset = TorchNormalDataset(num_of_data=100) - dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - forward_steps = 3 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - next(iter_dataloader) - - # 1. 保存状态 - _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) - state = _get_re_batchsampler.state_dict() - assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, - "sampler_type": "ReproducibleBatchSampler"} - - # 2. 断点重训,重新生成一个 dataloader; - # 不改变 batch_size; - dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - re_batchsampler.load_state_dict(state) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - real_res = [] - supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) - forward_steps = 2 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - real_res.append(next(iter_dataloader)) - - for i in range(forward_steps): - assert all(real_res[i] == supposed_res[i]) - - # 改变 batch_size; - after_batch_size = 3 - dataloader = DataLoader(dataset, batch_size=after_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - re_batchsampler.load_state_dict(state) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - real_res = [] - supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) - forward_steps = 2 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - real_res.append(next(iter_dataloader)) - - for i in range(forward_steps): - assert all(real_res[i] == supposed_res[i]) - - # 断点重训的第二轮是否是一个完整的 dataloader; - # 先把断点重训所在的那一个 epoch 跑完; - begin_idx = 27 - while True: - try: - data = next(iter_dataloader) - _batch_size = len(data) - assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) - begin_idx += _batch_size - except StopIteration: - break - - # 开始新的一轮; - begin_idx = 0 - iter_dataloader = iter(dataloader) - while True: - try: - data = next(iter_dataloader) - _batch_size = len(data) - assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) - begin_idx += _batch_size - except StopIteration: - break - - def test_torch_dataloader_2(self): - # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; - from torch.utils.data import DataLoader - # no shuffle - before_batch_size = 7 - dataset = TorchNormalDataset(num_of_data=100) - # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; - dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - # 将一轮的所有数据保存下来,看是否恢复的是正确的; - all_supposed_data = [] - forward_steps = 3 - iter_dataloader = iter(dataloader) - for _ in range(forward_steps): - all_supposed_data.extend(next(iter_dataloader).tolist()) - - # 1. 保存状态 - _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) - state = _get_re_batchsampler.state_dict() - - # 2. 断点重训,重新生成一个 dataloader; - # 不改变 batch_size; - dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) - re_batchsampler.load_state_dict(state) - dataloader = replace_batch_sampler(dataloader, re_batchsampler) - - # 先把这一轮的数据过完; - pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] - while True: - try: - all_supposed_data.extend(next(iter_dataloader).tolist()) - except StopIteration: - break - assert all_supposed_data == list(pre_index_list) - - # 重新开启新的一轮; - for _ in range(3): - iter_dataloader = iter(dataloader) - res = [] - while True: - try: - res.append(next(iter_dataloader)) - except StopIteration: - break - - def test_3(self): - import torch - from torch.utils.data import DataLoader, RandomSampler, BatchSampler - before_batch_size = 7 - dataset = TorchNormalDataset(num_of_data=100) - # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; - dataloader = DataLoader(dataset, batch_size=before_batch_size) - - for idx, data in enumerate(dataloader): - if idx > 3: - break - - iterator = iter(dataloader) - for each in iterator: - pass