@@ -130,8 +130,8 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator], | |||||
reproducible: bool = False, sampler_or_batch_sampler=None): | |||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, | |||||
reproducible: bool = False): | |||||
if isinstance(dist, ReproducibleBatchSampler): | if isinstance(dist, ReproducibleBatchSampler): | ||||
return replace_batch_sampler(dataloader, dist) | return replace_batch_sampler(dataloader, dist) | ||||
elif isinstance(dist, ReproducibleIterator): | elif isinstance(dist, ReproducibleIterator): | ||||
@@ -17,5 +17,5 @@ __all__ = [ | |||||
from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler | from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler | ||||
from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler | from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler | ||||
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler | ||||
from .reproducible_batch_sampler import ReproducibleBatchSampler | |||||
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler | |||||
@@ -1,20 +1,48 @@ | |||||
__all__ = [ | |||||
'BucketedBatchSampler', | |||||
"ReproducibleBatchSampler" | |||||
] | |||||
import math | import math | ||||
from array import array | from array import array | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from itertools import chain | |||||
from typing import Dict, Union, List | from typing import Dict, Union, List | ||||
from itertools import chain | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.samplers import ReproducibleIterator | |||||
from fastNLP.core.log import logger | |||||
from abc import abstractmethod | |||||
class ReproducibleBatchIterator: | |||||
@abstractmethod | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") | |||||
@abstractmethod | |||||
def __len__(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `__len__` method.") | |||||
@abstractmethod | |||||
def __iter__(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `__iter__` method.") | |||||
@abstractmethod | |||||
def state_dict(self): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `state_dict` method.") | |||||
@abstractmethod | |||||
def load_state_dict(self, states): | |||||
raise NotImplementedError("Each specific batch_sampler should implement its own `load_state_dict` method.") | |||||
@abstractmethod | |||||
def set_epoch(self, epoch): | |||||
pass | |||||
class ReproducibleBatchSampler: | |||||
class ReproducibleBatchSampler(ReproducibleBatchIterator): | |||||
# 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; | ||||
def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): | ||||
""" | """ | ||||
@@ -94,7 +122,7 @@ class ReproducibleBatchSampler: | |||||
self.data_idx = states["data_idx"] | self.data_idx = states["data_idx"] | ||||
self.need_reinitialize = False | self.need_reinitialize = False | ||||
def set_distributed(self): | |||||
def set_distributed(self, num_replicas, rank, pad=True): | |||||
raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | raise RuntimeError(f"ReproduceBatchSampler does not support to change to distributed training.") | ||||
def set_epoch(self, epoch): | def set_epoch(self, epoch): | ||||
@@ -110,7 +138,7 @@ class ReproducibleBatchSampler: | |||||
(len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size | ||||
class BucketedBatchSampler(ReproducibleIterator): | |||||
class BucketedBatchSampler(ReproducibleBatchIterator): | |||||
def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, | ||||
shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): | ||||
""" | """ | ||||
@@ -129,20 +157,20 @@ class BucketedBatchSampler(ReproducibleIterator): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
if not isinstance(dataset, DataSet): | |||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | length = dataset.get_field(length) | ||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
assert isinstance(length, List) and len(length)==len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||||
"the length parameter can only be List[int]" | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||||
assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ | |||||
"the length parameter can only be List[int]" | |||||
if drop_last: | |||||
assert len(dataset)>=batch_size, "The number of samplers must be larger than batch_size when `drop_last=True`." | |||||
assert len(length) == len(dataset), "The length of `data` and `length` should be equal." | |||||
self.dataset = dataset | self.dataset = dataset | ||||
self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 | ||||
self.sorted_indices = np.argsort(self.length)[::-1] # 按长度从高到低排序的 | |||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.num_batch_per_bucket = num_batch_per_bucket | self.num_batch_per_bucket = num_batch_per_bucket | ||||
@@ -161,6 +189,10 @@ class BucketedBatchSampler(ReproducibleIterator): | |||||
# 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | # 是否处于iteration之间,为True不允许调用 set_distributed()和load_state_dict() | ||||
self.during_iter = kwargs.get("during_iter", False) | self.during_iter = kwargs.get("during_iter", False) | ||||
# 以下变量为内部使用恢复状态的变量。 | |||||
self.old_batch_size = kwargs.get('old_batch_size', self.batch_size) | |||||
self.old_num_batch_per_bucket = kwargs.get('old_num_batch_per_bucket', self.num_batch_per_bucket) | |||||
def set_distributed(self, num_replicas, rank, pad=True): | def set_distributed(self, num_replicas, rank, pad=True): | ||||
assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | assert self.during_iter is False, "Cannot set the sampler to be distributed when it is " \ | ||||
"during an unfinished iteration." | "during an unfinished iteration." | ||||
@@ -217,92 +249,123 @@ class BucketedBatchSampler(ReproducibleIterator): | |||||
if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
self.during_iter = True | self.during_iter = True | ||||
indices = self.generate_indices() | |||||
if self.pad: | |||||
# add extra samples to make it evenly divisible | |||||
padding_size = self.total_size - len(indices) | |||||
if padding_size <= len(indices): | |||||
indices += indices[:padding_size] | |||||
else: | |||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] | |||||
else: | |||||
# remove tail of data to make it evenly divisible. | |||||
indices = indices[:self.total_size] | |||||
assert len(indices) == self.total_size | |||||
# subsample | |||||
indices = indices[self.num_consumed_samples:] | |||||
indices = indices[self.rank:len(indices):self.num_replicas] | |||||
assert len(indices) == self.num_left_samples | |||||
# 根据内部的长度进行排序 | |||||
sub_length = self.length[indices] # 取出这个 rank 中的长度 | |||||
sorted_indices = np.argsort(sub_length)[::-1] # 按长度从高到低排序的 | |||||
sorted_indices = deepcopy(self.sorted_indices).tolist() # 按长度从高到低排序的 | |||||
if self.shuffle: | if self.shuffle: | ||||
# 实际的 bucket 大小 | |||||
bucket_size = min(len(sorted_indices), self.batch_size * self.num_batch_per_bucket) | |||||
seed = self.seed + self.epoch | |||||
rng = np.random.default_rng(abs(seed)) | |||||
num_buckets = (len(sorted_indices) + bucket_size - 1)//bucket_size | |||||
batches = [] | |||||
batch_indices = [] | |||||
for i in range(num_buckets): | |||||
bucket = sorted_indices[i*bucket_size:(i+1)*bucket_size] | |||||
rng.shuffle(bucket) # bucket 内部 shuffle 一下 | |||||
_indices = np.full(fill_value=self.batch_size, dtype=int, | |||||
shape=(len(bucket)//self.batch_size)).cumsum() | |||||
_batches = np.split(bucket, _indices) | |||||
batch_indices.extend(list(range(len(batches), len(batches)+len(_batches)))) | |||||
batches.extend(_batches) | |||||
last_batches = [] | |||||
if len(batches)>=1 and len(batches[-1])<self.batch_size: | |||||
last_batches = batches[-1].tolist() | |||||
batch_indices = batch_indices[:-1] | |||||
batches = batches[:-1] | |||||
if self.drop_last and len(last_batches)<self.batch_size: | |||||
last_batches = [] | |||||
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | |||||
batches = np.array(batches)[batch_indices] | |||||
indices = list(chain(*batches)) + last_batches | |||||
if self.num_consumed_samples > 0: # 需要先按照原来的排序,删掉多余的 | |||||
_batches = [] | |||||
for _i in range(self.old_num_replicas): | |||||
_sorted_indices = sorted_indices[_i:len(sorted_indices):self.old_num_replicas] | |||||
__batches = self.bucketerize(_sorted_indices, self.old_batch_size, self.old_num_batch_per_bucket, | |||||
seed=self.seed+self.epoch) | |||||
_batches.append(__batches) | |||||
batches = list(chain(*[_ for _ in zip(*_batches)])) | |||||
sorted_indices = list(chain(*batches)) | |||||
sorted_indices = sorted_indices[self.num_consumed_samples:] | |||||
# 再进行排序 | |||||
sub_length = self.length[sorted_indices] | |||||
sorted_indices = np.array(sorted_indices)[np.argsort(sub_length)[::-1]] # 按长度从高到低排序的 | |||||
# 取出这个 rank , | |||||
sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas] | |||||
batches = self.bucketerize(sorted_indices, self.batch_size, self.num_batch_per_bucket, | |||||
seed=self.seed+self.epoch) | |||||
batches = list(map(list, batches)) | |||||
else: | else: | ||||
indices = sorted_indices | |||||
if len(indices)<self.batch_size and self.drop_last: | |||||
indices = [] | |||||
for index in range(indices): | |||||
self.num_consumed_samples += self.num_replicas | |||||
yield index | |||||
sorted_indices = sorted_indices[self.num_consumed_samples:] | |||||
sorted_indices = sorted_indices[self.rank:len(sorted_indices):self.num_replicas] | |||||
_num_batches = len(sorted_indices) // self.batch_size | |||||
if _num_batches == 0: | |||||
batches = [sorted_indices] | |||||
else: | |||||
batches = list(map(list, np.array_split(sorted_indices[:_num_batches*self.batch_size], _num_batches))) | |||||
if len(sorted_indices)%self.batch_size!=0: | |||||
batches.append(sorted_indices[_num_batches*self.batch_size:]) | |||||
need_pad_num = (len(self.dataset)-self.num_consumed_samples) % self.num_replicas | |||||
if self.pad and need_pad_num !=0 and need_pad_num<=self.rank: | |||||
if len(batches) > 0: | |||||
if len(batches[-1])<self.batch_size: | |||||
batches[-1].append(batches[-1][0]) # 这里可以保证这个bucket的长度没被破坏。 | |||||
else: | |||||
batches.append([batches[-1][0]]) | |||||
elif self.pad is False and need_pad_num !=0 and need_pad_num>self.rank: | |||||
if len(batches): | |||||
batches[-1].pop(-1) | |||||
if len(batches[-1])==0: | |||||
batches.pop(-1) | |||||
assert len(list(chain(*batches))) == self.num_left_samples | |||||
if self.drop_last and len(batches) >= 1 and len(batches[-1]) < self.batch_size: | |||||
batches = batches[:-1] | |||||
for batch in batches: | |||||
self.num_consumed_samples += self.num_replicas * len(batch) | |||||
yield list(map(int, batch)) | |||||
self.during_iter = False | self.during_iter = False | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
self.old_batch_size = self.batch_size | |||||
self.old_num_batch_per_bucket = self.num_batch_per_bucket | |||||
self.old_num_replicas = self.num_replicas | |||||
if self.epoch < 0: # 防止用户没有修改epoch,导致每个epoch都一样了 | |||||
self.epoch -= 1 | |||||
def generate_indices(self) -> List[int]: | |||||
def bucketerize(self, sorted_indices, batch_size, num_batch_per_bucket, seed): | |||||
""" | """ | ||||
生成随机序列,用于保证在所有卡的总和加起来是原来的数据量。 | |||||
将 indices 分桶 | |||||
:return: | |||||
:param sorted_indices: List[int] | |||||
:param batch_size: int | |||||
:param num_batch_per_bucket: int | |||||
:param seed: int | |||||
:return: List[List[int]] | |||||
""" | """ | ||||
if self.shuffle: | |||||
indices = list(range(len(self.dataset))) | |||||
seed = self.seed + self.epoch | |||||
rng = np.random.default_rng(abs(seed)) | |||||
rng.shuffle(indices) | |||||
if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 | |||||
self.epoch -= 1 | |||||
else: | |||||
indices = list(range(len(self.dataset))) | |||||
return indices | |||||
# 实际的 bucket 大小 | |||||
bucket_size = min(len(sorted_indices), batch_size * num_batch_per_bucket) | |||||
rng = np.random.default_rng(abs(seed)) | |||||
num_buckets = (len(sorted_indices) + bucket_size - 1) // bucket_size | |||||
batches = [] | |||||
batch_indices = [] | |||||
for i in range(num_buckets): | |||||
bucket = sorted_indices[i * bucket_size:(i + 1) * bucket_size] | |||||
rng.shuffle(bucket) # bucket 内部 shuffle 一下 | |||||
_num_batches = len(bucket) // batch_size | |||||
if _num_batches == 0: | |||||
_batches = [bucket] | |||||
else: | |||||
_batches = np.array_split(bucket[:_num_batches*batch_size], _num_batches) | |||||
if len(bucket) % batch_size != 0: | |||||
_batches.append(bucket[_num_batches*batch_size:]) | |||||
batch_indices.extend(list(range(len(batches), len(batches) + len(_batches)))) | |||||
batches.extend(_batches) | |||||
last_batches = [] | |||||
# 最后一个batch 统一不参与shuffle,因为有的rank最后一个 batch 可能不足一个batch_size (不足的时候 | |||||
# 一定要放在末尾,所以就干脆所有的rank都不对最后一个batch进行shuffle)。 | |||||
if len(batches) >= 1: | |||||
last_batches = [list(batches[-1])] | |||||
batch_indices = list(batch_indices[:-1]) | |||||
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 | |||||
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 | |||||
batches = (np.array(batches)[batch_indices]).tolist() | |||||
if last_batches: | |||||
batches = batches + last_batches | |||||
return batches | |||||
def state_dict(self) -> Dict: | def state_dict(self) -> Dict: | ||||
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: | |||||
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" | |||||
" consumed. ") | |||||
states = { | states = { | ||||
'seed': self.seed, | 'seed': self.seed, | ||||
'epoch': self.epoch, | 'epoch': self.epoch, | ||||
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; | ||||
'sampler_type': self.__class__.__name__, | 'sampler_type': self.__class__.__name__, | ||||
'length': len(self.dataset), | 'length': len(self.dataset), | ||||
'shuffle': self.shuffle | |||||
'shuffle': self.shuffle, | |||||
'batch_size': self.batch_size, | |||||
'num_batch_per_bucket': self.num_batch_per_bucket, | |||||
'num_replicas': self.num_replicas | |||||
} | } | ||||
return states | return states | ||||
@@ -322,4 +385,13 @@ class BucketedBatchSampler(ReproducibleIterator): | |||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
self.shuffle = states["shuffle"] | |||||
if self.shuffle != states['shuffle']: | |||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||||
f"we use shuffle={states['shuffle']}") | |||||
self.shuffle = states["shuffle"] | |||||
self.old_batch_size = states['batch_size'] | |||||
self.old_num_batch_per_bucket = states['num_batch_per_bucket'] | |||||
self.old_num_replicas = states['num_replicas'] | |||||
def set_epoch(self, epoch): | |||||
self.epoch = epoch |
@@ -2,14 +2,14 @@ from typing import Dict, List | |||||
import math | import math | ||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.log import logger | |||||
__all__ = [ | __all__ = [ | ||||
'ReproducibleIterator', | 'ReproducibleIterator', | ||||
'RandomSampler', | 'RandomSampler', | ||||
're_instantiate_sampler' | 're_instantiate_sampler' | ||||
] | ] | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
def re_instantiate_sampler(sampler): | def re_instantiate_sampler(sampler): | ||||
all_attributes = vars(sampler) | all_attributes = vars(sampler) | ||||
@@ -164,6 +164,9 @@ class RandomSampler(ReproducibleIterator): | |||||
self.num_consumed_samples = states['num_consumed_samples'] | self.num_consumed_samples = states['num_consumed_samples'] | ||||
if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | if self.num_consumed_samples>=length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 | ||||
self.num_consumed_samples = 0 | self.num_consumed_samples = 0 | ||||
if self.shuffle != states['shuffle']: | |||||
logger.info(f"The shuffle from the checkpoint is {states['shuffle']}, while set as {self.shuffle}, " | |||||
f"we use shuffle={states['shuffle']}") | |||||
self.shuffle = states["shuffle"] | self.shuffle = states["shuffle"] | ||||
def set_epoch(self, epoch: int) -> None: | def set_epoch(self, epoch: int) -> None: | ||||
@@ -212,24 +215,8 @@ class RandomSampler(ReproducibleIterator): | |||||
self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) | ||||
# todo | |||||
# class SortedSampler(ReproducibleIterator): | |||||
# def __init__(self, dataset, key): | |||||
# pass | |||||
# | |||||
# | |||||
# class BucketedSampler(ReproducibleIterator): | |||||
# def __init__(self, dataset, key): | |||||
# pass | |||||
if __name__ == "__main__": | |||||
sampler = RandomSampler(1) | |||||
print(vars(sampler)) | |||||
batch_sampler = ReproducibleBatchSampler(list(range(3)), 1, True) | |||||
print(vars(batch_sampler)) | |||||
@@ -0,0 +1,439 @@ | |||||
from array import array | |||||
import numpy as np | |||||
import pytest | |||||
from itertools import chain | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
class TestReproducibleBatchSampler: | |||||
# TODO 拆分测试,在这里只测试一个东西 | |||||
def test_torch_dataloader_1(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
"sampler_type": "ReproducibleBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.append(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
def test_3(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
for idx, data in enumerate(dataloader): | |||||
if idx > 3: | |||||
break | |||||
iterator = iter(dataloader) | |||||
for each in iterator: | |||||
pass | |||||
class DatasetWithVaryLength: | |||||
def __init__(self, num_of_data=100): | |||||
self.data = np.arange(num_of_data) | |||||
def __getitem__(self, item): | |||||
return self.data[item] | |||||
def __len__(self): | |||||
return len(self.data) | |||||
class TestBucketedBatchSampler: | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('num', [2, 7, 14, 15, 70, 71]) | |||||
def test_single_num_batch(self, shuffle, drop_last, num): | |||||
# 数量不够不报错 | |||||
for num in [2, 7, 14, 15, 70, 71]: | |||||
dataset = DatasetWithVaryLength(num_of_data=num) | |||||
before_batch_size = 7 | |||||
re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
num_batch_per_bucket=10, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
count = len(list(iter(re_batchsampler))) | |||||
if drop_last: | |||||
assert count==num//before_batch_size, num | |||||
else: | |||||
assert count==(num+before_batch_size-1)//before_batch_size, num | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
def test_single(self, shuffle, drop_last): | |||||
before_batch_size = 7 | |||||
num_batch_per_bucket = 4 # 那么任意 batch 内的长度差值不应该超过4 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
re_batchsampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler.set_epoch(0) | |||||
forward_steps = 10 | |||||
iterator = iter(re_batchsampler) | |||||
already_generate_indices = set() | |||||
for _ in range(forward_steps): | |||||
batch = next(iterator) | |||||
assert max(batch) - min(batch) <= before_batch_size * num_batch_per_bucket | |||||
already_generate_indices.update(batch) | |||||
# 1. 保存状态 | |||||
state = re_batchsampler.state_dict() | |||||
# 2. 断点重训,继续训练 | |||||
re_batchsampler2 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=before_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler2.load_state_dict(state) | |||||
re_batchsampler2.set_epoch(0) | |||||
new_already_generate_indices = set() | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-before_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+before_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in re_batchsampler2: | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
new_already_generate_indices.update(batch) | |||||
if drop_last is False: | |||||
assert len(new_already_generate_indices.union(already_generate_indices))==len(dataset) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
re_batchsampler3 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=drop_last, | |||||
shuffle=shuffle) | |||||
re_batchsampler3.load_state_dict(state) | |||||
re_batchsampler3.set_epoch(0) | |||||
count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-after_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+after_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in re_batchsampler3: | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
# 再 save ,不允许再上个epoch没结束继续sample | |||||
after_batch_size = 5 | |||||
with pytest.raises(RuntimeError): | |||||
state = re_batchsampler3.state_dict() | |||||
for batch in re_batchsampler3: # consume all, 这样才能save | |||||
pass | |||||
already_generate_indices = set() | |||||
count = 0 | |||||
for batch in re_batchsampler3: # 重新开始 | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
count += 1 | |||||
if count > 5: | |||||
break | |||||
state = re_batchsampler3.state_dict() | |||||
# 这里的 drop_last 为 False,需要最终是所有 sample | |||||
re_batchsampler4 = BucketedBatchSampler(dataset, length=dataset.data, batch_size=after_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, drop_last=False, | |||||
shuffle=shuffle) | |||||
re_batchsampler4.load_state_dict(state) | |||||
re_batchsampler4.set_epoch(0) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_generate_indices)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices) - after_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i + after_batch_size * num_batch_per_bucket] - indices[i]) | |||||
for batch in re_batchsampler4: | |||||
assert max(batch) - min(batch) <= max_diff | |||||
for b in batch: | |||||
assert b not in already_generate_indices | |||||
already_generate_indices.update(batch) | |||||
assert len(already_generate_indices) == len(dataset) | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
def test_multi(self, shuffle, drop_last, pad): | |||||
# def test_multi(self, shuffle=True, drop_last=False, pad=False): | |||||
# no shuffle | |||||
num_replica = 2 | |||||
dataset = DatasetWithVaryLength(num_of_data=1000) | |||||
batch_size = 5 | |||||
num_batch_per_bucket = 10 | |||||
lengths = [] | |||||
rank0_already_seen_indexes = None | |||||
max_diff = num_batch_per_bucket * batch_size * num_replica | |||||
for rank in range(num_replica): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
num_batch_per_bucket = num_batch_per_bucket, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
already_seen_indexes = set() | |||||
repeat_count = 0 | |||||
for batch in sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if rank0_already_seen_indexes: # 不能交叉出现 | |||||
assert b not in rank0_already_seen_indexes | |||||
already_seen_indexes.update(batch) | |||||
if rank0_already_seen_indexes is None: | |||||
rank0_already_seen_indexes = already_seen_indexes | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count<=1 | |||||
else: | |||||
assert repeat_count==0 | |||||
assert len(set(lengths))==1, lengths # 每个进程的batch数量一致 | |||||
# 多进程的保存 | |||||
already_seen_indexes = set() | |||||
for rank in range(num_replica): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size = batch_size, | |||||
num_batch_per_bucket = num_batch_per_bucket, | |||||
shuffle = shuffle, drop_last=drop_last) | |||||
sampler.set_epoch(0) | |||||
sampler.set_distributed(num_replica, rank=rank, pad=pad) | |||||
lengths.append(len(sampler)) | |||||
count = 0 | |||||
for batch in sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
already_seen_indexes.update(batch) | |||||
if count>5: | |||||
break | |||||
count += 1 | |||||
state = sampler.state_dict() | |||||
# 切换成单机 | |||||
new_batch_size = 6 | |||||
num_batch_per_bucket = 3 | |||||
new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.load_state_dict(state) | |||||
repeat_count = 0 | |||||
new_already_seen_indexes = set(list(already_seen_indexes)) | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices)-new_batch_size * num_batch_per_bucket): | |||||
max_diff = max(max_diff, indices[i+new_batch_size * num_batch_per_bucket]-indices[i]) | |||||
for batch in new_sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
for b in batch: | |||||
repeat_count += int(b in new_already_seen_indexes) | |||||
new_already_seen_indexes.update(batch) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
if drop_last is False: # 如果没有drop应该相等 | |||||
assert len(new_already_seen_indexes)==len(dataset) | |||||
# 测试替换卡的数量。 | |||||
num_replica = 3 | |||||
new_sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=new_batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, | |||||
shuffle=shuffle, drop_last=drop_last) | |||||
new_sampler.set_epoch(0) | |||||
new_sampler.load_state_dict(state) | |||||
new_sampler.set_distributed(num_replicas=num_replica, rank=1, pad=pad) | |||||
repeat_count = 0 | |||||
mask = np.ones(len(dataset), dtype=bool) | |||||
mask[list(already_seen_indexes)] = 0 | |||||
indices = np.arange(len(dataset))[mask] | |||||
max_diff = -1 | |||||
for i in range(len(indices) - new_batch_size * num_batch_per_bucket*num_replica): | |||||
max_diff = max(max_diff, indices[i + new_batch_size * num_batch_per_bucket*num_replica] - indices[i]) | |||||
for batch in new_sampler: | |||||
assert max_diff>=max(batch)-min(batch) | |||||
for b in batch: | |||||
repeat_count += int(b in already_seen_indexes) | |||||
if pad: # 应该允许重复一次 | |||||
assert repeat_count <= 1 | |||||
else: | |||||
assert repeat_count == 0 | |||||
@pytest.mark.parametrize('shuffle', [True, False]) | |||||
@pytest.mark.parametrize('drop_last', [True, False]) | |||||
@pytest.mark.parametrize('pad', [True, False]) | |||||
@pytest.mark.parametrize('num_samples', [13, 100, 623, 1000]) | |||||
@pytest.mark.parametrize('num_replica', [2, 3]) | |||||
def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): | |||||
# def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): | |||||
# TODO 两个 rank 上的长度是要在同一个bucket的 | |||||
dataset = DatasetWithVaryLength(num_of_data=num_samples) | |||||
batch_size = 6 | |||||
if num_replica*batch_size > num_samples: | |||||
return | |||||
num_batch_per_bucket = 10 | |||||
samplers = [] | |||||
lengths = [] | |||||
for i in range(num_replica): | |||||
sampler = BucketedBatchSampler(dataset, length=dataset.data, batch_size=batch_size, | |||||
num_batch_per_bucket=num_batch_per_bucket, shuffle=shuffle, drop_last=drop_last) | |||||
sampler.set_distributed(num_replica, rank=i, pad=pad) | |||||
sampler.set_epoch(0) | |||||
samplers.append(sampler) | |||||
lengths.append(len(list(iter(sampler)))) | |||||
assert len(set(lengths))==1 | |||||
bucket_diff = batch_size * num_batch_per_bucket * num_replica | |||||
for bs in zip(*samplers): | |||||
diff = max(chain(*bs)) - min(chain(*bs)) | |||||
assert diff <= bucket_diff |
@@ -7,7 +7,6 @@ from functools import partial | |||||
from array import array | from array import array | ||||
from fastNLP.core.samplers.reproducible_sampler import RandomSampler | from fastNLP.core.samplers.reproducible_sampler import RandomSampler | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
@@ -362,148 +361,3 @@ class TestRandomSampler(unittest.TestCase): | |||||
class TestReproducibleBatchSampler: | |||||
def test_torch_dataloader_1(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, | |||||
"sampler_type": "ReproducibleBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.append(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
def test_3(self): | |||||
import torch | |||||
from torch.utils.data import DataLoader, RandomSampler, BatchSampler | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
for idx, data in enumerate(dataloader): | |||||
if idx > 3: | |||||
break | |||||
iterator = iter(dataloader) | |||||
for each in iterator: | |||||
pass |