diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 958cf5b4..21b3f059 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -8,7 +8,6 @@ import math from copy import deepcopy from typing import Dict, Union, List from itertools import chain -import os import numpy as np diff --git a/tests/core/controllers/utils/test_utils.py b/tests/core/controllers/utils/test_utils.py index 0cf7a252..860d84d5 100644 --- a/tests/core/controllers/utils/test_utils.py +++ b/tests/core/controllers/utils/test_utils.py @@ -1,7 +1,7 @@ from functools import reduce from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; -from tests.helpers.datasets.normal_data import NormalIterator +from tests.helpers.datasets.normal_data import NormalSampler class Test_WrapDataLoader: @@ -9,7 +9,7 @@ class Test_WrapDataLoader: def test_normal_generator(self): all_sanity_batches = [4, 20, 100] for sanity_batches in all_sanity_batches: - data = NormalIterator(num_of_data=1000) + data = NormalSampler(num_of_data=1000) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) dataloader = iter(wrapper(dataloader=data)) mark = 0 diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index cac595ba..c4dd8c50 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -1,161 +1,131 @@ -from array import array - import numpy as np import pytest from itertools import chain from copy import deepcopy +from array import array +from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler -from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler -from tests.helpers.datasets.torch_data import TorchNormalDataset - -# -# class TestReproducibleBatchSampler: -# # TODO 拆分测试,在这里只测试一个东西 -# def test_torch_dataloader_1(self): -# import torch -# from torch.utils.data import DataLoader -# # no shuffle -# before_batch_size = 7 -# dataset = TorchNormalDataset(num_of_data=100) -# dataloader = DataLoader(dataset, batch_size=before_batch_size) -# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# forward_steps = 3 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# next(iter_dataloader) -# -# # 1. 保存状态 -# _get_re_batchsampler = dataloader.batch_sampler -# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) -# state = _get_re_batchsampler.state_dict() -# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, -# "sampler_type": "ReproduceBatchSampler"} -# -# # 2. 断点重训,重新生成一个 dataloader; -# # 不改变 batch_size; -# dataloader = DataLoader(dataset, batch_size=before_batch_size) -# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# re_batchsampler.load_state_dict(state) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# real_res = [] -# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) -# forward_steps = 2 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# real_res.append(next(iter_dataloader)) -# -# for i in range(forward_steps): -# assert all(real_res[i] == supposed_res[i]) -# -# # 改变 batch_size; -# after_batch_size = 3 -# dataloader = DataLoader(dataset, batch_size=after_batch_size) -# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# re_batchsampler.load_state_dict(state) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# real_res = [] -# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) -# forward_steps = 2 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# real_res.append(next(iter_dataloader)) -# -# for i in range(forward_steps): -# assert all(real_res[i] == supposed_res[i]) -# -# # 断点重训的第二轮是否是一个完整的 dataloader; -# # 先把断点重训所在的那一个 epoch 跑完; -# begin_idx = 27 -# while True: -# try: -# data = next(iter_dataloader) -# _batch_size = len(data) -# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) -# begin_idx += _batch_size -# except StopIteration: -# break -# -# # 开始新的一轮; -# begin_idx = 0 -# iter_dataloader = iter(dataloader) -# while True: -# try: -# data = next(iter_dataloader) -# _batch_size = len(data) -# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) -# begin_idx += _batch_size -# except StopIteration: -# break -# -# def test_torch_dataloader_2(self): -# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; -# from torch.utils.data import DataLoader -# # no shuffle -# before_batch_size = 7 -# dataset = TorchNormalDataset(num_of_data=100) -# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; -# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) -# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# # 将一轮的所有数据保存下来,看是否恢复的是正确的; -# all_supposed_data = [] -# forward_steps = 3 -# iter_dataloader = iter(dataloader) -# for _ in range(forward_steps): -# all_supposed_data.extend(next(iter_dataloader).tolist()) -# -# # 1. 保存状态 -# _get_re_batchsampler = dataloader.batch_sampler -# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) -# state = _get_re_batchsampler.state_dict() -# -# # 2. 断点重训,重新生成一个 dataloader; -# # 不改变 batch_size; -# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) -# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) -# re_batchsampler.load_state_dict(state) -# dataloader = replace_batch_sampler(dataloader, re_batchsampler) -# -# # 先把这一轮的数据过完; -# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] -# while True: -# try: -# all_supposed_data.extend(next(iter_dataloader).tolist()) -# except StopIteration: -# break -# assert all_supposed_data == list(pre_index_list) -# -# # 重新开启新的一轮; -# for _ in range(3): -# iter_dataloader = iter(dataloader) -# res = [] -# while True: -# try: -# res.append(next(iter_dataloader)) -# except StopIteration: -# break -# -# def test_3(self): -# import torch -# from torch.utils.data import DataLoader -# before_batch_size = 7 -# dataset = TorchNormalDataset(num_of_data=100) -# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; -# dataloader = DataLoader(dataset, batch_size=before_batch_size) -# -# for idx, data in enumerate(dataloader): -# if idx > 3: -# break -# -# iterator = iter(dataloader) -# for each in iterator: -# pass + + +class TestReproducibleBatchSampler: + def test_1(self): + sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; + + reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) + + forward_steps = 3 + iterator = iter(reproduce_batch_sampler) + i = 0 + while i < forward_steps: + next(iterator) + i += 1 + + # 保存状态; + state = reproduce_batch_sampler.state_dict() + + assert state == {"index_list": array("I", list(range(100))), + "num_consumed_samples": forward_steps * 4, + "sampler_type": "ReproduceBatchSampler"} + + # 重新生成一个 batchsampler 然后加载状态; + sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; + reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) + reproduce_batch_sampler.load_state_dict(state) + + real_res = [] + supposed_res = (list(range(12, 16)), list(range(16, 20))) + forward_steps = 2 + iter_dataloader = iter(reproduce_batch_sampler) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert supposed_res[i] == real_res[i] + + # 改变 batchsize; + sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; + reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False) + reproduce_batch_sampler.load_state_dict(state) + + real_res = [] + supposed_res = (list(range(12, 19)), list(range(19, 26))) + forward_steps = 2 + iter_dataloader = iter(reproduce_batch_sampler) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert supposed_res[i] == real_res[i] + + # 断点重训的第二轮是否是一个完整的 dataloader; + # 先把断点重训所在的那一个 epoch 跑完; + begin_idx = 26 + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert data == list(range(begin_idx, begin_idx + _batch_size)) + begin_idx += _batch_size + except StopIteration: + break + + # 开始新的一轮; + begin_idx = 0 + iter_dataloader = iter(reproduce_batch_sampler) + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert data == list(range(begin_idx, begin_idx + _batch_size)) + begin_idx += _batch_size + except StopIteration: + break + + def test_2(self): + + # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; + before_batch_size = 7 + sampler = NormalSampler(num_of_data=100) + # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; + reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) + + # 将一轮的所有数据保存下来,看是否恢复的是正确的; + all_supposed_data = [] + forward_steps = 3 + iter_dataloader = iter(reproduce_batch_sampler) + for _ in range(forward_steps): + all_supposed_data.extend(next(iter_dataloader)) + + # 1. 保存状态 + state = reproduce_batch_sampler.state_dict() + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + sampler = NormalSampler(num_of_data=100, shuffle=True) + reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) + reproduce_batch_sampler.load_state_dict(state) + + # 先把这一轮的数据过完; + pre_index_list = reproduce_batch_sampler.state_dict()["index_list"] + iter_dataloader = iter(reproduce_batch_sampler) + while True: + try: + all_supposed_data.extend(next(iter_dataloader)) + except StopIteration: + break + assert all_supposed_data == list(pre_index_list) + + # 重新开启新的一轮; + for _ in range(3): + iter_dataloader = iter(reproduce_batch_sampler) + res = [] + while True: + try: + res.extend(next(iter_dataloader)) + except StopIteration: + break + assert res != all_supposed_data class DatasetWithVaryLength: diff --git a/tests/core/samplers/test_reproducible_batch_sampler_torch.py b/tests/core/samplers/test_reproducible_batch_sampler_torch.py new file mode 100644 index 00000000..af180f56 --- /dev/null +++ b/tests/core/samplers/test_reproducible_batch_sampler_torch.py @@ -0,0 +1,141 @@ +from array import array +import torch +from torch.utils.data import DataLoader + +import pytest + +from fastNLP.core.samplers import ReproduceBatchSampler +from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler +from tests.helpers.datasets.torch_data import TorchNormalDataset + + +@pytest.mark.torch +class TestReproducibleBatchSamplerTorch: + def test_torch_dataloader_1(self): + # no shuffle + before_batch_size = 7 + dataset = TorchNormalDataset(num_of_data=100) + dataloader = DataLoader(dataset, batch_size=before_batch_size) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + forward_steps = 3 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + next(iter_dataloader) + + # 1. 保存状态 + _get_re_batchsampler = dataloader.batch_sampler + assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) + state = _get_re_batchsampler.state_dict() + assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, + "sampler_type": "ReproduceBatchSampler"} + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + dataloader = DataLoader(dataset, batch_size=before_batch_size) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + real_res = [] + supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) + forward_steps = 2 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert all(real_res[i] == supposed_res[i]) + + # 改变 batch_size; + after_batch_size = 3 + dataloader = DataLoader(dataset, batch_size=after_batch_size) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + real_res = [] + supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) + forward_steps = 2 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + real_res.append(next(iter_dataloader)) + + for i in range(forward_steps): + assert all(real_res[i] == supposed_res[i]) + + # 断点重训的第二轮是否是一个完整的 dataloader; + # 先把断点重训所在的那一个 epoch 跑完; + begin_idx = 27 + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) + begin_idx += _batch_size + except StopIteration: + break + + # 开始新的一轮; + begin_idx = 0 + iter_dataloader = iter(dataloader) + while True: + try: + data = next(iter_dataloader) + _batch_size = len(data) + assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) + begin_idx += _batch_size + except StopIteration: + break + + def test_torch_dataloader_2(self): + # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; + from torch.utils.data import DataLoader + before_batch_size = 7 + dataset = TorchNormalDataset(num_of_data=100) + # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; + dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + # 将一轮的所有数据保存下来,看是否恢复的是正确的; + all_supposed_data = [] + forward_steps = 3 + iter_dataloader = iter(dataloader) + for _ in range(forward_steps): + all_supposed_data.extend(next(iter_dataloader).tolist()) + + # 1. 保存状态 + _get_re_batchsampler = dataloader.batch_sampler + assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) + state = _get_re_batchsampler.state_dict() + + # 2. 断点重训,重新生成一个 dataloader; + # 不改变 batch_size; + dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) + re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler.load_state_dict(state) + dataloader = replace_batch_sampler(dataloader, re_batchsampler) + + iter_dataloader = iter(dataloader) + # 先把这一轮的数据过完; + pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] + while True: + try: + all_supposed_data.extend(next(iter_dataloader).tolist()) + except StopIteration: + break + assert all_supposed_data == list(pre_index_list) + + # 重新开启新的一轮; + for _ in range(3): + iter_dataloader = iter(dataloader) + res = [] + while True: + try: + res.extend(next(iter_dataloader).tolist()) + except StopIteration: + break + assert res != all_supposed_data + diff --git a/tests/helpers/datasets/normal_data.py b/tests/helpers/datasets/normal_data.py index 714ec676..b4e3ffca 100644 --- a/tests/helpers/datasets/normal_data.py +++ b/tests/helpers/datasets/normal_data.py @@ -1,13 +1,25 @@ import numpy as np +import random -class NormalIterator: - def __init__(self, num_of_data=1000): +class NormalSampler: + def __init__(self, num_of_data=1000, shuffle=False): self._num_of_data = num_of_data self._data = list(range(num_of_data)) + if shuffle: + random.shuffle(self._data) + self.shuffle = shuffle self._index = 0 + self.need_reinitialize = False def __iter__(self): + if self.need_reinitialize: + self._index = 0 + if self.shuffle: + random.shuffle(self._data) + else: + self.need_reinitialize = True + return self def __next__(self): @@ -15,12 +27,45 @@ class NormalIterator: raise StopIteration _data = self._data[self._index] self._index += 1 - return self._data + return _data def __len__(self): return self._num_of_data +class NormalBatchSampler: + def __init__(self, sampler, batch_size: int, drop_last: bool) -> None: + # Since collections.abc.Iterable does not check for `__getitem__`, which + # is one way for an object to be an iterable, we don't do an `isinstance` + # check here. + if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ + batch_size <= 0: + raise ValueError("batch_size should be a positive integer value, " + "but got batch_size={}".format(batch_size)) + if not isinstance(drop_last, bool): + raise ValueError("drop_last should be a boolean value, but got " + "drop_last={}".format(drop_last)) + self.sampler = sampler + self.batch_size = batch_size + self.drop_last = drop_last + + def __iter__(self): + batch = [] + for idx in self.sampler: + batch.append(idx) + if len(batch) == self.batch_size: + yield batch + batch = [] + if len(batch) > 0 and not self.drop_last: + yield batch + + def __len__(self) -> int: + if self.drop_last: + return len(self.sampler) // self.batch_size + else: + return (len(self.sampler) + self.batch_size - 1) // self.batch_size + + class RandomDataset: def __init__(self, num_data=10): self.data = np.random.rand(num_data) @@ -29,4 +74,7 @@ class RandomDataset: return len(self.data) def __getitem__(self, item): - return self.data[item] \ No newline at end of file + return self.data[item] + + +