Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
36ed894a77
6 changed files with 439 additions and 160 deletions
  1. +0
    -1
      fastNLP/core/samplers/reproducible_batch_sampler.py
  2. +122
    -1
      tests/core/controllers/test_trainer_event_trigger.py
  3. +2
    -2
      tests/core/controllers/utils/test_utils.py
  4. +122
    -152
      tests/core/samplers/test_reproducible_batch_sampler.py
  5. +141
    -0
      tests/core/samplers/test_reproducible_batch_sampler_torch.py
  6. +52
    -4
      tests/helpers/datasets/normal_data.py

+ 0
- 1
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -8,7 +8,6 @@ import math
from copy import deepcopy
from typing import Dict, Union, List
from itertools import chain
import os

import numpy as np



+ 122
- 1
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -70,7 +70,7 @@ def model_and_optimizers():
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger(
def test_trainer_event_trigger_1(
model_and_optimizers: TrainerParameters,
driver,
device,
@@ -104,5 +104,126 @@ def test_trainer_event_trigger(
assert member.value in output[0]


@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_2(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):

@Trainer.on(Events.on_after_trainer_initialized)
def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized")

@Trainer.on(Events.on_sanity_check_begin)
def on_sanity_check_begin(trainer):
print("on_sanity_check_begin")

@Trainer.on(Events.on_sanity_check_end)
def on_sanity_check_end(trainer, sanity_check_res):
print("on_sanity_check_end")

@Trainer.on(Events.on_train_begin)
def on_train_begin(trainer):
print("on_train_begin")

@Trainer.on(Events.on_train_end)
def on_train_end(trainer):
print("on_train_end")

@Trainer.on(Events.on_train_epoch_begin)
def on_train_epoch_begin(trainer):
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")

@Trainer.on(Events.on_train_epoch_end)
def on_train_epoch_end(trainer):
print("on_train_epoch_end")

@Trainer.on(Events.on_fetch_data_begin)
def on_fetch_data_begin(trainer):
print("on_fetch_data_begin")

@Trainer.on(Events.on_fetch_data_end)
def on_fetch_data_end(trainer):
print("on_fetch_data_end")

@Trainer.on(Events.on_train_batch_begin)
def on_train_batch_begin(trainer, batch, indices=None):
print("on_train_batch_begin")

@Trainer.on(Events.on_train_batch_end)
def on_train_batch_end(trainer):
print("on_train_batch_end")

@Trainer.on(Events.on_exception)
def on_exception(trainer, exception):
print("on_exception")

@Trainer.on(Events.on_before_backward)
def on_before_backward(trainer, outputs):
print("on_before_backward")

@Trainer.on(Events.on_after_backward)
def on_after_backward(trainer):
print("on_after_backward")

@Trainer.on(Events.on_before_optimizers_step)
def on_before_optimizers_step(trainer, optimizers):
print("on_before_optimizers_step")

@Trainer.on(Events.on_after_optimizers_step)
def on_after_optimizers_step(trainer, optimizers):
print("on_after_optimizers_step")

@Trainer.on(Events.on_before_zero_grad)
def on_before_zero_grad(trainer, optimizers):
print("on_before_zero_grad")

@Trainer.on(Events.on_after_zero_grad)
def on_after_zero_grad(trainer, optimizers):
print("on_after_zero_grad")

@Trainer.on(Events.on_evaluate_begin)
def on_evaluate_begin(trainer):
print("on_evaluate_begin")

@Trainer.on(Events.on_evaluate_end)
def on_evaluate_end(trainer, results):
print("on_evaluate_end")

with pytest.raises(Exception):
with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()

for name, member in Events.__members__.items():
assert member.value in output[0]








+ 2
- 2
tests/core/controllers/utils/test_utils.py View File

@@ -1,7 +1,7 @@
from functools import reduce

from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
from tests.helpers.datasets.normal_data import NormalIterator
from tests.helpers.datasets.normal_data import NormalSampler


class Test_WrapDataLoader:
@@ -9,7 +9,7 @@ class Test_WrapDataLoader:
def test_normal_generator(self):
all_sanity_batches = [4, 20, 100]
for sanity_batches in all_sanity_batches:
data = NormalIterator(num_of_data=1000)
data = NormalSampler(num_of_data=1000)
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
dataloader = iter(wrapper)
mark = 0


+ 122
- 152
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -1,161 +1,131 @@
from array import array

import numpy as np
import pytest
from itertools import chain
from copy import deepcopy
from array import array

from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset

#
# class TestReproducibleBatchSampler:
# # TODO 拆分测试,在这里只测试一个东西
# def test_torch_dataloader_1(self):
# import torch
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# next(iter_dataloader)
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "ReproduceBatchSampler"}
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 改变 batch_size;
# after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 断点重训的第二轮是否是一个完整的 dataloader;
# # 先把断点重训所在的那一个 epoch 跑完;
# begin_idx = 27
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# # 开始新的一轮;
# begin_idx = 0
# iter_dataloader = iter(dataloader)
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# def test_torch_dataloader_2(self):
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
# all_supposed_data = []
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# all_supposed_data.extend(next(iter_dataloader).tolist())
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 先把这一轮的数据过完;
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
# while True:
# try:
# all_supposed_data.extend(next(iter_dataloader).tolist())
# except StopIteration:
# break
# assert all_supposed_data == list(pre_index_list)
#
# # 重新开启新的一轮;
# for _ in range(3):
# iter_dataloader = iter(dataloader)
# res = []
# while True:
# try:
# res.append(next(iter_dataloader))
# except StopIteration:
# break
#
# def test_3(self):
# import torch
# from torch.utils.data import DataLoader
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
#
# for idx, data in enumerate(dataloader):
# if idx > 3:
# break
#
# iterator = iter(dataloader)
# for each in iterator:
# pass


class TestReproducibleBatchSampler:
def test_1(self):
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;

reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)

forward_steps = 3
iterator = iter(reproduce_batch_sampler)
i = 0
while i < forward_steps:
next(iterator)
i += 1

# 保存状态;
state = reproduce_batch_sampler.state_dict()

assert state == {"index_list": array("I", list(range(100))),
"num_consumed_samples": forward_steps * 4,
"sampler_type": "ReproduceBatchSampler"}

# 重新生成一个 batchsampler 然后加载状态;
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

real_res = []
supposed_res = (list(range(12, 16)), list(range(16, 20)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert supposed_res[i] == real_res[i]

# 改变 batchsize;
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

real_res = []
supposed_res = (list(range(12, 19)), list(range(19, 26)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert supposed_res[i] == real_res[i]

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 26
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break

def test_2(self):

# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
before_batch_size = 7
sampler = NormalSampler(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader))

# 1. 保存状态
state = reproduce_batch_sampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
sampler = NormalSampler(num_of_data=100, shuffle=True)
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

# 先把这一轮的数据过完;
pre_index_list = reproduce_batch_sampler.state_dict()["index_list"]
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
all_supposed_data.extend(next(iter_dataloader))
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(reproduce_batch_sampler)
res = []
while True:
try:
res.extend(next(iter_dataloader))
except StopIteration:
break
assert res != all_supposed_data


class DatasetWithVaryLength:


+ 141
- 0
tests/core/samplers/test_reproducible_batch_sampler_torch.py View File

@@ -0,0 +1,141 @@
from array import array
import torch
from torch.utils.data import DataLoader

import pytest

from fastNLP.core.samplers import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset


@pytest.mark.torch
class TestReproducibleBatchSamplerTorch:
def test_torch_dataloader_1(self):
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "ReproduceBatchSampler"}

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 改变 batch_size;
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

iter_dataloader = iter(dataloader)
# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert res != all_supposed_data


+ 52
- 4
tests/helpers/datasets/normal_data.py View File

@@ -1,13 +1,25 @@
import numpy as np
import random


class NormalIterator:
def __init__(self, num_of_data=1000):
class NormalSampler:
def __init__(self, num_of_data=1000, shuffle=False):
self._num_of_data = num_of_data
self._data = list(range(num_of_data))
if shuffle:
random.shuffle(self._data)
self.shuffle = shuffle
self._index = 0
self.need_reinitialize = False

def __iter__(self):
if self.need_reinitialize:
self._index = 0
if self.shuffle:
random.shuffle(self._data)
else:
self.need_reinitialize = True

return self

def __next__(self):
@@ -15,12 +27,45 @@ class NormalIterator:
raise StopIteration
_data = self._data[self._index]
self._index += 1
return self._data
return _data

def __len__(self):
return self._num_of_data


class NormalBatchSampler:
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size


class RandomDataset:
def __init__(self, num_data=10):
self.data = np.random.rand(num_data)
@@ -29,4 +74,7 @@ class RandomDataset:
return len(self.data)

def __getitem__(self, item):
return self.data[item]
return self.data[item]




Loading…
Cancel
Save