|
|
@@ -1,81 +1,293 @@ |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
|
from fastNLP.core.collators import AutoCollator |
|
|
|
from fastNLP.core.collators.collator import _MultiCollator |
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR |
|
|
|
|
|
|
|
from fastNLP.core.collators.collator import Collator |
|
|
|
|
|
|
|
|
|
|
|
def _assert_equal(d1, d2): |
|
|
|
try: |
|
|
|
if 'torch' in str(type(d1)): |
|
|
|
if 'float64' in str(d2.dtype): |
|
|
|
print(d2.dtype) |
|
|
|
assert (d1 == d2).all().item() |
|
|
|
else: |
|
|
|
assert all(d1 == d2) |
|
|
|
except TypeError: |
|
|
|
assert d1 == d2 |
|
|
|
except ValueError: |
|
|
|
assert (d1 == d2).all() |
|
|
|
|
|
|
|
|
|
|
|
def findDictDiff(d1, d2, path=""): |
|
|
|
for k in d1: |
|
|
|
if k in d2: |
|
|
|
if isinstance(d1[k], dict): |
|
|
|
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) |
|
|
|
else: |
|
|
|
_assert_equal(d1[k], d2[k]) |
|
|
|
else: |
|
|
|
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) |
|
|
|
|
|
|
|
|
|
|
|
def findListDiff(d1, d2): |
|
|
|
assert len(d1)==len(d2) |
|
|
|
for _d1, _d2 in zip(d1, d2): |
|
|
|
if isinstance(_d1, list): |
|
|
|
findListDiff(_d1, _d2) |
|
|
|
else: |
|
|
|
_assert_equal(_d1, _d2) |
|
|
|
|
|
|
|
|
|
|
|
class TestCollator: |
|
|
|
|
|
|
|
@pytest.mark.parametrize('as_numpy', [True, False]) |
|
|
|
def test_auto_collator(self, as_numpy): |
|
|
|
""" |
|
|
|
测试auto_collator的auto_pad功能 |
|
|
|
|
|
|
|
:param as_numpy: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, |
|
|
|
'y': [0, 1, 1, 0] * 100}) |
|
|
|
collator = AutoCollator(as_numpy=as_numpy) |
|
|
|
collator.set_input('x', 'y') |
|
|
|
bucket_data = [] |
|
|
|
data = [] |
|
|
|
for i in range(len(dataset)): |
|
|
|
data.append(dataset[i]) |
|
|
|
if len(data) == 40: |
|
|
|
bucket_data.append(data) |
|
|
|
data = [] |
|
|
|
results = [] |
|
|
|
for bucket in bucket_data: |
|
|
|
res = collator(bucket) |
|
|
|
assert res['x'].shape == (40, 5) |
|
|
|
assert res['y'].shape == (40,) |
|
|
|
results.append(res) |
|
|
|
|
|
|
|
def test_auto_collator_v1(self): |
|
|
|
""" |
|
|
|
测试auto_collator的set_pad_val和set_pad_val功能 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, |
|
|
|
'y': [0, 1, 1, 0] * 100}) |
|
|
|
collator = AutoCollator(as_numpy=False) |
|
|
|
collator.set_input('x') |
|
|
|
collator.set_pad_val('x', val=-1) |
|
|
|
collator.set_as_numpy(True) |
|
|
|
bucket_data = [] |
|
|
|
data = [] |
|
|
|
for i in range(len(dataset)): |
|
|
|
data.append(dataset[i]) |
|
|
|
if len(data) == 40: |
|
|
|
bucket_data.append(data) |
|
|
|
data = [] |
|
|
|
for bucket in bucket_data: |
|
|
|
res = collator(bucket) |
|
|
|
print(res) |
|
|
|
|
|
|
|
def test_multicollator(self): |
|
|
|
""" |
|
|
|
测试multicollator功能 |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, |
|
|
|
'y': [0, 1, 1, 0] * 100}) |
|
|
|
collator = AutoCollator(as_numpy=False) |
|
|
|
multi_collator = _MultiCollator(collator) |
|
|
|
multi_collator.set_as_numpy(as_numpy=True) |
|
|
|
multi_collator.set_pad_val('x', val=-1) |
|
|
|
multi_collator.set_input('x') |
|
|
|
bucket_data = [] |
|
|
|
data = [] |
|
|
|
for i in range(len(dataset)): |
|
|
|
data.append(dataset[i]) |
|
|
|
if len(data) == 40: |
|
|
|
bucket_data.append(data) |
|
|
|
data = [] |
|
|
|
for bucket in bucket_data: |
|
|
|
res = multi_collator(bucket) |
|
|
|
print(res) |
|
|
|
@pytest.mark.torch |
|
|
|
def test_run(self): |
|
|
|
dict_batch = [{ |
|
|
|
'str': '1', |
|
|
|
'lst_str': ['1'], |
|
|
|
'int': 1, |
|
|
|
'lst_int': [1], |
|
|
|
'nest_lst_int': [[1]], |
|
|
|
'float': 1.1, |
|
|
|
'lst_float': [1.1], |
|
|
|
'bool': True, |
|
|
|
'numpy': np.ones(1), |
|
|
|
'dict': {'1': '1'}, |
|
|
|
'set': {'1'}, |
|
|
|
'nested_dict': {'a': 1, 'b':[1, 2]} |
|
|
|
}, |
|
|
|
{ |
|
|
|
'str': '2', |
|
|
|
'lst_str': ['2', '2'], |
|
|
|
'int': 2, |
|
|
|
'lst_int': [1, 2], |
|
|
|
'nest_lst_int': [[1], [1, 2]], |
|
|
|
'float': 2.1, |
|
|
|
'lst_float': [2.1], |
|
|
|
'bool': False, |
|
|
|
'numpy': np.zeros(1), |
|
|
|
'dict': {'1': '2'}, |
|
|
|
'set': {'2'}, |
|
|
|
'nested_dict': {'a': 2, 'b': [1, 2]} |
|
|
|
} |
|
|
|
] |
|
|
|
|
|
|
|
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], |
|
|
|
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] |
|
|
|
|
|
|
|
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} |
|
|
|
collator = Collator(backend='raw') |
|
|
|
assert raw_pad_batch == collator(dict_batch) |
|
|
|
collator = Collator(backend='raw') |
|
|
|
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], |
|
|
|
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], |
|
|
|
[{'1'}, {'2'}]] |
|
|
|
findListDiff(raw_pad_lst, collator(list_batch)) |
|
|
|
|
|
|
|
collator = Collator(backend='numpy') |
|
|
|
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), |
|
|
|
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), |
|
|
|
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), |
|
|
|
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), |
|
|
|
'b': np.array([[1, 2], [1, 2]])}} |
|
|
|
|
|
|
|
findDictDiff(numpy_pad_batch, collator(dict_batch)) |
|
|
|
collator = Collator(backend='numpy') |
|
|
|
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), |
|
|
|
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), |
|
|
|
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), |
|
|
|
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], |
|
|
|
[{'1'}, {'2'}]] |
|
|
|
findListDiff(numpy_pad_lst, collator(list_batch)) |
|
|
|
|
|
|
|
if _NEED_IMPORT_TORCH: |
|
|
|
import torch |
|
|
|
collator = Collator(backend='torch') |
|
|
|
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), |
|
|
|
'lst_int': torch.LongTensor([[1, 0], [1, 2]]), |
|
|
|
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), |
|
|
|
'float': torch.FloatTensor([1.1, 2.1]), |
|
|
|
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), |
|
|
|
'numpy': torch.FloatTensor([[1], [0]]), |
|
|
|
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), |
|
|
|
'b': torch.LongTensor( |
|
|
|
[[1, 2], [1, 2]])}} |
|
|
|
|
|
|
|
findDictDiff(numpy_pad_batch, collator(dict_batch)) |
|
|
|
collator = Collator(backend='torch') |
|
|
|
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), |
|
|
|
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), |
|
|
|
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), |
|
|
|
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], |
|
|
|
[{'1'}, {'2'}]] |
|
|
|
findListDiff(torch_pad_lst, collator(list_batch)) |
|
|
|
|
|
|
|
def test_pad(self): |
|
|
|
dict_batch = [{ |
|
|
|
'str': '1', |
|
|
|
'lst_str': ['1'], |
|
|
|
'int': 1, |
|
|
|
'lst_int': [1], |
|
|
|
'nest_lst_int': [[1]], |
|
|
|
'float': 1.1, |
|
|
|
'lst_float': [1.1], |
|
|
|
'bool': True, |
|
|
|
'numpy': np.ones(1), |
|
|
|
'dict': {'1': '1'}, |
|
|
|
'set': {'1'}, |
|
|
|
'nested_dict': {'a': 1, 'b':[1, 2]} |
|
|
|
}, |
|
|
|
{ |
|
|
|
'str': '2', |
|
|
|
'lst_str': ['2', '2'], |
|
|
|
'int': 2, |
|
|
|
'lst_int': [1, 2], |
|
|
|
'nest_lst_int': [[1], [1, 2]], |
|
|
|
'float': 2.1, |
|
|
|
'lst_float': [2.1], |
|
|
|
'bool': False, |
|
|
|
'numpy': np.zeros(1), |
|
|
|
'dict': {'1': '2'}, |
|
|
|
'set': {'2'}, |
|
|
|
'nested_dict': {'a': 2, 'b': [1, 2]} |
|
|
|
} |
|
|
|
] |
|
|
|
|
|
|
|
raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} |
|
|
|
|
|
|
|
# 测试 ignore |
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) |
|
|
|
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} |
|
|
|
findDictDiff(raw_pad_batch, collator(dict_batch)) |
|
|
|
|
|
|
|
# 测试 set_pad |
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_pad('str', pad_val=1) |
|
|
|
with pytest.raises(BaseException): |
|
|
|
collator(dict_batch) |
|
|
|
|
|
|
|
# 测试设置 pad 值 |
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_pad('nest_lst_int', pad_val=100) |
|
|
|
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) |
|
|
|
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], |
|
|
|
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} |
|
|
|
findDictDiff(raw_pad_batch, collator(dict_batch)) |
|
|
|
|
|
|
|
# 设置 backend 和 type |
|
|
|
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) |
|
|
|
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], |
|
|
|
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} |
|
|
|
findDictDiff(raw_pad_batch, collator(dict_batch)) |
|
|
|
|
|
|
|
|
|
|
|
# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], |
|
|
|
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], |
|
|
|
# [{'1'}, {'2'}]] |
|
|
|
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], |
|
|
|
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] |
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_ignore('_0', '_3', '_1') |
|
|
|
collator.set_pad('_4', pad_val=None) |
|
|
|
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], |
|
|
|
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], |
|
|
|
[{'1'}, {'2'}]] |
|
|
|
findListDiff(raw_pad_lst, collator(list_batch)) |
|
|
|
|
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_pad('_0', pad_val=1) |
|
|
|
with pytest.raises(BaseException): |
|
|
|
collator(dict_batch) |
|
|
|
|
|
|
|
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], |
|
|
|
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] |
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_ignore('_0', '_3', '_1') |
|
|
|
collator.set_pad('_2', backend='numpy') |
|
|
|
collator.set_pad('_4', backend='numpy', pad_val=100) |
|
|
|
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), |
|
|
|
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], |
|
|
|
[{'1'}, {'2'}]] |
|
|
|
findListDiff(raw_pad_lst, collator(list_batch)) |
|
|
|
|
|
|
|
# _single |
|
|
|
collator = Collator() |
|
|
|
collator.set_pad('_single') |
|
|
|
findListDiff(list_batch, collator(list_batch)) |
|
|
|
|
|
|
|
def test_nest_ignore(self): |
|
|
|
dict_batch = [{ |
|
|
|
'str': '1', |
|
|
|
'lst_str': ['1'], |
|
|
|
'int': 1, |
|
|
|
'lst_int': [1], |
|
|
|
'nest_lst_int': [[1]], |
|
|
|
'float': 1.1, |
|
|
|
'lst_float': [1.1], |
|
|
|
'bool': True, |
|
|
|
'numpy': np.ones(1), |
|
|
|
'dict': {'1': '1'}, |
|
|
|
'set': {'1'}, |
|
|
|
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} |
|
|
|
}, |
|
|
|
{ |
|
|
|
'str': '2', |
|
|
|
'lst_str': ['2', '2'], |
|
|
|
'int': 2, |
|
|
|
'lst_int': [1, 2], |
|
|
|
'nest_lst_int': [[1], [1, 2]], |
|
|
|
'float': 2.1, |
|
|
|
'lst_float': [2.1], |
|
|
|
'bool': False, |
|
|
|
'numpy': np.zeros(1), |
|
|
|
'dict': {'1': '2'}, |
|
|
|
'set': {'2'}, |
|
|
|
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} |
|
|
|
} |
|
|
|
] |
|
|
|
# 测试 ignore |
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) |
|
|
|
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], |
|
|
|
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], |
|
|
|
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, |
|
|
|
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], |
|
|
|
'c': {'int':[1, 1]}}} |
|
|
|
findDictDiff(raw_pad_batch, collator(dict_batch)) |
|
|
|
|
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_pad(('nested_dict', 'c'), pad_val=None) |
|
|
|
collator.set_ignore('str', 'int', 'lst_int') |
|
|
|
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], |
|
|
|
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], |
|
|
|
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, |
|
|
|
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], |
|
|
|
'c': [{'int':1}, {'int':1}]}} |
|
|
|
pad_batch = collator(dict_batch) |
|
|
|
findDictDiff(raw_pad_batch, pad_batch) |
|
|
|
|
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_pad(('nested_dict', 'c'), pad_val=1) |
|
|
|
with pytest.raises(BaseException): |
|
|
|
collator(dict_batch) |
|
|
|
|
|
|
|
collator = Collator(backend='raw') |
|
|
|
collator.set_ignore('str', 'int', 'lst_int') |
|
|
|
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) |
|
|
|
pad_batch = collator(dict_batch) |
|
|
|
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], |
|
|
|
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], |
|
|
|
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, |
|
|
|
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], |
|
|
|
'c': [1, 1]}} |
|
|
|
findDictDiff(raw_pad_batch, pad_batch) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|