@@ -8,7 +8,6 @@ import math | |||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Dict, Union, List | from typing import Dict, Union, List | ||||
from itertools import chain | from itertools import chain | ||||
import os | |||||
import numpy as np | import numpy as np | ||||
@@ -70,7 +70,7 @@ def model_and_optimizers(): | |||||
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_event_trigger( | |||||
def test_trainer_event_trigger_1( | |||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
@@ -104,5 +104,126 @@ def test_trainer_event_trigger( | |||||
assert member.value in output[0] | assert member.value in output[0] | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7]) | |||||
@pytest.mark.torch | |||||
@magic_argv_env_context | |||||
def test_trainer_event_trigger_2( | |||||
model_and_optimizers: TrainerParameters, | |||||
driver, | |||||
device, | |||||
n_epochs=2, | |||||
): | |||||
@Trainer.on(Events.on_after_trainer_initialized) | |||||
def on_after_trainer_initialized(trainer, driver): | |||||
print("on_after_trainer_initialized") | |||||
@Trainer.on(Events.on_sanity_check_begin) | |||||
def on_sanity_check_begin(trainer): | |||||
print("on_sanity_check_begin") | |||||
@Trainer.on(Events.on_sanity_check_end) | |||||
def on_sanity_check_end(trainer, sanity_check_res): | |||||
print("on_sanity_check_end") | |||||
@Trainer.on(Events.on_train_begin) | |||||
def on_train_begin(trainer): | |||||
print("on_train_begin") | |||||
@Trainer.on(Events.on_train_end) | |||||
def on_train_end(trainer): | |||||
print("on_train_end") | |||||
@Trainer.on(Events.on_train_epoch_begin) | |||||
def on_train_epoch_begin(trainer): | |||||
if trainer.cur_epoch_idx >= 1: | |||||
# 触发 on_exception; | |||||
raise Exception | |||||
print("on_train_epoch_begin") | |||||
@Trainer.on(Events.on_train_epoch_end) | |||||
def on_train_epoch_end(trainer): | |||||
print("on_train_epoch_end") | |||||
@Trainer.on(Events.on_fetch_data_begin) | |||||
def on_fetch_data_begin(trainer): | |||||
print("on_fetch_data_begin") | |||||
@Trainer.on(Events.on_fetch_data_end) | |||||
def on_fetch_data_end(trainer): | |||||
print("on_fetch_data_end") | |||||
@Trainer.on(Events.on_train_batch_begin) | |||||
def on_train_batch_begin(trainer, batch, indices=None): | |||||
print("on_train_batch_begin") | |||||
@Trainer.on(Events.on_train_batch_end) | |||||
def on_train_batch_end(trainer): | |||||
print("on_train_batch_end") | |||||
@Trainer.on(Events.on_exception) | |||||
def on_exception(trainer, exception): | |||||
print("on_exception") | |||||
@Trainer.on(Events.on_before_backward) | |||||
def on_before_backward(trainer, outputs): | |||||
print("on_before_backward") | |||||
@Trainer.on(Events.on_after_backward) | |||||
def on_after_backward(trainer): | |||||
print("on_after_backward") | |||||
@Trainer.on(Events.on_before_optimizers_step) | |||||
def on_before_optimizers_step(trainer, optimizers): | |||||
print("on_before_optimizers_step") | |||||
@Trainer.on(Events.on_after_optimizers_step) | |||||
def on_after_optimizers_step(trainer, optimizers): | |||||
print("on_after_optimizers_step") | |||||
@Trainer.on(Events.on_before_zero_grad) | |||||
def on_before_zero_grad(trainer, optimizers): | |||||
print("on_before_zero_grad") | |||||
@Trainer.on(Events.on_after_zero_grad) | |||||
def on_after_zero_grad(trainer, optimizers): | |||||
print("on_after_zero_grad") | |||||
@Trainer.on(Events.on_evaluate_begin) | |||||
def on_evaluate_begin(trainer): | |||||
print("on_evaluate_begin") | |||||
@Trainer.on(Events.on_evaluate_end) | |||||
def on_evaluate_end(trainer, results): | |||||
print("on_evaluate_end") | |||||
with pytest.raises(Exception): | |||||
with Capturing() as output: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
for name, member in Events.__members__.items(): | |||||
assert member.value in output[0] | |||||
@@ -1,7 +1,7 @@ | |||||
from functools import reduce | from functools import reduce | ||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; | from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; | ||||
from tests.helpers.datasets.normal_data import NormalIterator | |||||
from tests.helpers.datasets.normal_data import NormalSampler | |||||
class Test_WrapDataLoader: | class Test_WrapDataLoader: | ||||
@@ -9,7 +9,7 @@ class Test_WrapDataLoader: | |||||
def test_normal_generator(self): | def test_normal_generator(self): | ||||
all_sanity_batches = [4, 20, 100] | all_sanity_batches = [4, 20, 100] | ||||
for sanity_batches in all_sanity_batches: | for sanity_batches in all_sanity_batches: | ||||
data = NormalIterator(num_of_data=1000) | |||||
data = NormalSampler(num_of_data=1000) | |||||
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) | ||||
dataloader = iter(wrapper) | dataloader = iter(wrapper) | ||||
mark = 0 | mark = 0 | ||||
@@ -1,161 +1,131 @@ | |||||
from array import array | |||||
import numpy as np | import numpy as np | ||||
import pytest | import pytest | ||||
from itertools import chain | from itertools import chain | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from array import array | |||||
from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler | |||||
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler | from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
# | |||||
# class TestReproducibleBatchSampler: | |||||
# # TODO 拆分测试,在这里只测试一个东西 | |||||
# def test_torch_dataloader_1(self): | |||||
# import torch | |||||
# from torch.utils.data import DataLoader | |||||
# # no shuffle | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# forward_steps = 3 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# next(iter_dataloader) | |||||
# | |||||
# # 1. 保存状态 | |||||
# _get_re_batchsampler = dataloader.batch_sampler | |||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | |||||
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
# "sampler_type": "ReproduceBatchSampler"} | |||||
# | |||||
# # 2. 断点重训,重新生成一个 dataloader; | |||||
# # 不改变 batch_size; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# real_res = [] | |||||
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
# forward_steps = 2 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# real_res.append(next(iter_dataloader)) | |||||
# | |||||
# for i in range(forward_steps): | |||||
# assert all(real_res[i] == supposed_res[i]) | |||||
# | |||||
# # 改变 batch_size; | |||||
# after_batch_size = 3 | |||||
# dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# real_res = [] | |||||
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
# forward_steps = 2 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# real_res.append(next(iter_dataloader)) | |||||
# | |||||
# for i in range(forward_steps): | |||||
# assert all(real_res[i] == supposed_res[i]) | |||||
# | |||||
# # 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# # 先把断点重训所在的那一个 epoch 跑完; | |||||
# begin_idx = 27 | |||||
# while True: | |||||
# try: | |||||
# data = next(iter_dataloader) | |||||
# _batch_size = len(data) | |||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
# begin_idx += _batch_size | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# # 开始新的一轮; | |||||
# begin_idx = 0 | |||||
# iter_dataloader = iter(dataloader) | |||||
# while True: | |||||
# try: | |||||
# data = next(iter_dataloader) | |||||
# _batch_size = len(data) | |||||
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
# begin_idx += _batch_size | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# def test_torch_dataloader_2(self): | |||||
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
# from torch.utils.data import DataLoader | |||||
# # no shuffle | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# # 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
# all_supposed_data = [] | |||||
# forward_steps = 3 | |||||
# iter_dataloader = iter(dataloader) | |||||
# for _ in range(forward_steps): | |||||
# all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# | |||||
# # 1. 保存状态 | |||||
# _get_re_batchsampler = dataloader.batch_sampler | |||||
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
# state = _get_re_batchsampler.state_dict() | |||||
# | |||||
# # 2. 断点重训,重新生成一个 dataloader; | |||||
# # 不改变 batch_size; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
# re_batchsampler.load_state_dict(state) | |||||
# dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# | |||||
# # 先把这一轮的数据过完; | |||||
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
# while True: | |||||
# try: | |||||
# all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# except StopIteration: | |||||
# break | |||||
# assert all_supposed_data == list(pre_index_list) | |||||
# | |||||
# # 重新开启新的一轮; | |||||
# for _ in range(3): | |||||
# iter_dataloader = iter(dataloader) | |||||
# res = [] | |||||
# while True: | |||||
# try: | |||||
# res.append(next(iter_dataloader)) | |||||
# except StopIteration: | |||||
# break | |||||
# | |||||
# def test_3(self): | |||||
# import torch | |||||
# from torch.utils.data import DataLoader | |||||
# before_batch_size = 7 | |||||
# dataset = TorchNormalDataset(num_of_data=100) | |||||
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
# dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
# | |||||
# for idx, data in enumerate(dataloader): | |||||
# if idx > 3: | |||||
# break | |||||
# | |||||
# iterator = iter(dataloader) | |||||
# for each in iterator: | |||||
# pass | |||||
class TestReproducibleBatchSampler: | |||||
def test_1(self): | |||||
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) | |||||
forward_steps = 3 | |||||
iterator = iter(reproduce_batch_sampler) | |||||
i = 0 | |||||
while i < forward_steps: | |||||
next(iterator) | |||||
i += 1 | |||||
# 保存状态; | |||||
state = reproduce_batch_sampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), | |||||
"num_consumed_samples": forward_steps * 4, | |||||
"sampler_type": "ReproduceBatchSampler"} | |||||
# 重新生成一个 batchsampler 然后加载状态; | |||||
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False) | |||||
reproduce_batch_sampler.load_state_dict(state) | |||||
real_res = [] | |||||
supposed_res = (list(range(12, 16)), list(range(16, 20))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert supposed_res[i] == real_res[i] | |||||
# 改变 batchsize; | |||||
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False) | |||||
reproduce_batch_sampler.load_state_dict(state) | |||||
real_res = [] | |||||
supposed_res = (list(range(12, 19)), list(range(19, 26))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert supposed_res[i] == real_res[i] | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 26 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert data == list(range(begin_idx, begin_idx + _batch_size)) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert data == list(range(begin_idx, begin_idx + _batch_size)) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
before_batch_size = 7 | |||||
sampler = NormalSampler(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader)) | |||||
# 1. 保存状态 | |||||
state = reproduce_batch_sampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
sampler = NormalSampler(num_of_data=100, shuffle=True) | |||||
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False) | |||||
reproduce_batch_sampler.load_state_dict(state) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = reproduce_batch_sampler.state_dict()["index_list"] | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(reproduce_batch_sampler) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.extend(next(iter_dataloader)) | |||||
except StopIteration: | |||||
break | |||||
assert res != all_supposed_data | |||||
class DatasetWithVaryLength: | class DatasetWithVaryLength: | ||||
@@ -0,0 +1,141 @@ | |||||
from array import array | |||||
import torch | |||||
from torch.utils.data import DataLoader | |||||
import pytest | |||||
from fastNLP.core.samplers import ReproduceBatchSampler | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||||
@pytest.mark.torch | |||||
class TestReproducibleBatchSamplerTorch: | |||||
def test_torch_dataloader_1(self): | |||||
# no shuffle | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
next(iter_dataloader) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size, | |||||
"sampler_type": "ReproduceBatchSampler"} | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 改变 batch_size; | |||||
after_batch_size = 3 | |||||
dataloader = DataLoader(dataset, batch_size=after_batch_size) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
real_res = [] | |||||
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27)))) | |||||
forward_steps = 2 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
real_res.append(next(iter_dataloader)) | |||||
for i in range(forward_steps): | |||||
assert all(real_res[i] == supposed_res[i]) | |||||
# 断点重训的第二轮是否是一个完整的 dataloader; | |||||
# 先把断点重训所在的那一个 epoch 跑完; | |||||
begin_idx = 27 | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
# 开始新的一轮; | |||||
begin_idx = 0 | |||||
iter_dataloader = iter(dataloader) | |||||
while True: | |||||
try: | |||||
data = next(iter_dataloader) | |||||
_batch_size = len(data) | |||||
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size)))) | |||||
begin_idx += _batch_size | |||||
except StopIteration: | |||||
break | |||||
def test_torch_dataloader_2(self): | |||||
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的; | |||||
from torch.utils.data import DataLoader | |||||
before_batch_size = 7 | |||||
dataset = TorchNormalDataset(num_of_data=100) | |||||
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
# 将一轮的所有数据保存下来,看是否恢复的是正确的; | |||||
all_supposed_data = [] | |||||
forward_steps = 3 | |||||
iter_dataloader = iter(dataloader) | |||||
for _ in range(forward_steps): | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
# 1. 保存状态 | |||||
_get_re_batchsampler = dataloader.batch_sampler | |||||
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler) | |||||
state = _get_re_batchsampler.state_dict() | |||||
# 2. 断点重训,重新生成一个 dataloader; | |||||
# 不改变 batch_size; | |||||
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) | |||||
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) | |||||
re_batchsampler.load_state_dict(state) | |||||
dataloader = replace_batch_sampler(dataloader, re_batchsampler) | |||||
iter_dataloader = iter(dataloader) | |||||
# 先把这一轮的数据过完; | |||||
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"] | |||||
while True: | |||||
try: | |||||
all_supposed_data.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert all_supposed_data == list(pre_index_list) | |||||
# 重新开启新的一轮; | |||||
for _ in range(3): | |||||
iter_dataloader = iter(dataloader) | |||||
res = [] | |||||
while True: | |||||
try: | |||||
res.extend(next(iter_dataloader).tolist()) | |||||
except StopIteration: | |||||
break | |||||
assert res != all_supposed_data | |||||
@@ -1,13 +1,25 @@ | |||||
import numpy as np | import numpy as np | ||||
import random | |||||
class NormalIterator: | |||||
def __init__(self, num_of_data=1000): | |||||
class NormalSampler: | |||||
def __init__(self, num_of_data=1000, shuffle=False): | |||||
self._num_of_data = num_of_data | self._num_of_data = num_of_data | ||||
self._data = list(range(num_of_data)) | self._data = list(range(num_of_data)) | ||||
if shuffle: | |||||
random.shuffle(self._data) | |||||
self.shuffle = shuffle | |||||
self._index = 0 | self._index = 0 | ||||
self.need_reinitialize = False | |||||
def __iter__(self): | def __iter__(self): | ||||
if self.need_reinitialize: | |||||
self._index = 0 | |||||
if self.shuffle: | |||||
random.shuffle(self._data) | |||||
else: | |||||
self.need_reinitialize = True | |||||
return self | return self | ||||
def __next__(self): | def __next__(self): | ||||
@@ -15,12 +27,45 @@ class NormalIterator: | |||||
raise StopIteration | raise StopIteration | ||||
_data = self._data[self._index] | _data = self._data[self._index] | ||||
self._index += 1 | self._index += 1 | ||||
return self._data | |||||
return _data | |||||
def __len__(self): | def __len__(self): | ||||
return self._num_of_data | return self._num_of_data | ||||
class NormalBatchSampler: | |||||
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None: | |||||
# Since collections.abc.Iterable does not check for `__getitem__`, which | |||||
# is one way for an object to be an iterable, we don't do an `isinstance` | |||||
# check here. | |||||
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \ | |||||
batch_size <= 0: | |||||
raise ValueError("batch_size should be a positive integer value, " | |||||
"but got batch_size={}".format(batch_size)) | |||||
if not isinstance(drop_last, bool): | |||||
raise ValueError("drop_last should be a boolean value, but got " | |||||
"drop_last={}".format(drop_last)) | |||||
self.sampler = sampler | |||||
self.batch_size = batch_size | |||||
self.drop_last = drop_last | |||||
def __iter__(self): | |||||
batch = [] | |||||
for idx in self.sampler: | |||||
batch.append(idx) | |||||
if len(batch) == self.batch_size: | |||||
yield batch | |||||
batch = [] | |||||
if len(batch) > 0 and not self.drop_last: | |||||
yield batch | |||||
def __len__(self) -> int: | |||||
if self.drop_last: | |||||
return len(self.sampler) // self.batch_size | |||||
else: | |||||
return (len(self.sampler) + self.batch_size - 1) // self.batch_size | |||||
class RandomDataset: | class RandomDataset: | ||||
def __init__(self, num_data=10): | def __init__(self, num_data=10): | ||||
self.data = np.random.rand(num_data) | self.data = np.random.rand(num_data) | ||||
@@ -29,4 +74,7 @@ class RandomDataset: | |||||
return len(self.data) | return len(self.data) | ||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return self.data[item] | |||||
return self.data[item] | |||||