diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 2b56624a..ba1e7e08 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -1,81 +1,293 @@ + +import numpy as np import pytest -from fastNLP.core.collators import AutoCollator -from fastNLP.core.collators.collator import _MultiCollator -from fastNLP.core.dataset import DataSet +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + +from fastNLP.core.collators.collator import Collator + + +def _assert_equal(d1, d2): + try: + if 'torch' in str(type(d1)): + if 'float64' in str(d2.dtype): + print(d2.dtype) + assert (d1 == d2).all().item() + else: + assert all(d1 == d2) + except TypeError: + assert d1 == d2 + except ValueError: + assert (d1 == d2).all() + + +def findDictDiff(d1, d2, path=""): + for k in d1: + if k in d2: + if isinstance(d1[k], dict): + findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) + else: + _assert_equal(d1[k], d2[k]) + else: + raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) + + +def findListDiff(d1, d2): + assert len(d1)==len(d2) + for _d1, _d2 in zip(d1, d2): + if isinstance(_d1, list): + findListDiff(_d1, _d2) + else: + _assert_equal(_d1, _d2) class TestCollator: - @pytest.mark.parametrize('as_numpy', [True, False]) - def test_auto_collator(self, as_numpy): - """ - 测试auto_collator的auto_pad功能 - - :param as_numpy: - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=as_numpy) - collator.set_input('x', 'y') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - results = [] - for bucket in bucket_data: - res = collator(bucket) - assert res['x'].shape == (40, 5) - assert res['y'].shape == (40,) - results.append(res) - - def test_auto_collator_v1(self): - """ - 测试auto_collator的set_pad_val和set_pad_val功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - collator.set_input('x') - collator.set_pad_val('x', val=-1) - collator.set_as_numpy(True) - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = collator(bucket) - print(res) - - def test_multicollator(self): - """ - 测试multicollator功能 - - :return: - """ - dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, - 'y': [0, 1, 1, 0] * 100}) - collator = AutoCollator(as_numpy=False) - multi_collator = _MultiCollator(collator) - multi_collator.set_as_numpy(as_numpy=True) - multi_collator.set_pad_val('x', val=-1) - multi_collator.set_input('x') - bucket_data = [] - data = [] - for i in range(len(dataset)): - data.append(dataset[i]) - if len(data) == 40: - bucket_data.append(data) - data = [] - for bucket in bucket_data: - res = multi_collator(bucket) - print(res) + @pytest.mark.torch + def test_run(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + collator = Collator(backend='raw') + assert raw_pad_batch == collator(dict_batch) + collator = Collator(backend='raw') + raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='numpy') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), + 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), + 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), + 'b': np.array([[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='numpy') + numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), + np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), + np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(numpy_pad_lst, collator(list_batch)) + + if _NEED_IMPORT_TORCH: + import torch + collator = Collator(backend='torch') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), + 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': torch.FloatTensor([1.1, 2.1]), + 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), + 'numpy': torch.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), + 'b': torch.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='torch') + torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), + torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), + torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(torch_pad_lst, collator(list_batch)) + + def test_pad(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 测试 set_pad + collator = Collator(backend='raw') + collator.set_pad('str', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + # 测试设置 pad 值 + collator = Collator(backend='raw') + collator.set_pad('nest_lst_int', pad_val=100) + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 设置 backend 和 type + collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + + # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + # [{'1'}, {'2'}]] + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_4', pad_val=None) + raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='raw') + collator.set_pad('_0', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_2', backend='numpy') + collator.set_pad('_4', backend='numpy', pad_val=100) + raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + # _single + collator = Collator() + collator.set_pad('_single') + findListDiff(list_batch, collator(list_batch)) + + def test_nest_ignore(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} + } + ] + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': {'int':[1, 1]}}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=None) + collator.set_ignore('str', 'int', 'lst_int') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [{'int':1}, {'int':1}]}} + pad_batch = collator(dict_batch) + findDictDiff(raw_pad_batch, pad_batch) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int') + collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) + pad_batch = collator(dict_batch) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [1, 1]}} + findDictDiff(raw_pad_batch, pad_batch) + + + + + + diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py deleted file mode 100644 index ba1e7e08..00000000 --- a/tests/core/collators/test_new_collator.py +++ /dev/null @@ -1,293 +0,0 @@ - -import numpy as np -import pytest - -from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR - -from fastNLP.core.collators.collator import Collator - - -def _assert_equal(d1, d2): - try: - if 'torch' in str(type(d1)): - if 'float64' in str(d2.dtype): - print(d2.dtype) - assert (d1 == d2).all().item() - else: - assert all(d1 == d2) - except TypeError: - assert d1 == d2 - except ValueError: - assert (d1 == d2).all() - - -def findDictDiff(d1, d2, path=""): - for k in d1: - if k in d2: - if isinstance(d1[k], dict): - findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) - else: - _assert_equal(d1[k], d2[k]) - else: - raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) - - -def findListDiff(d1, d2): - assert len(d1)==len(d2) - for _d1, _d2 in zip(d1, d2): - if isinstance(_d1, list): - findListDiff(_d1, _d2) - else: - _assert_equal(_d1, _d2) - - -class TestCollator: - - @pytest.mark.torch - def test_run(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - collator = Collator(backend='raw') - assert raw_pad_batch == collator(dict_batch) - collator = Collator(backend='raw') - raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='numpy') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), - 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), - 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), - 'b': np.array([[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='numpy') - numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), - np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), - np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(numpy_pad_lst, collator(list_batch)) - - if _NEED_IMPORT_TORCH: - import torch - collator = Collator(backend='torch') - numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), - 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), - 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - 'float': torch.FloatTensor([1.1, 2.1]), - 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), - 'numpy': torch.FloatTensor([[1], [0]]), - 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), - 'b': torch.LongTensor( - [[1, 2], [1, 2]])}} - - findDictDiff(numpy_pad_batch, collator(dict_batch)) - collator = Collator(backend='torch') - torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), - torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), - torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), - torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(torch_pad_lst, collator(list_batch)) - - def test_pad(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'a': 1, 'b':[1, 2]} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'a': 2, 'b': [1, 2]} - } - ] - - raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} - - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 测试 set_pad - collator = Collator(backend='raw') - collator.set_pad('str', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - # 测试设置 pad 值 - collator = Collator(backend='raw') - collator.set_pad('nest_lst_int', pad_val=100) - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - # 设置 backend 和 type - collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], - 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - - # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - # [{'1'}, {'2'}]] - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_4', pad_val=None) - raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - collator = Collator(backend='raw') - collator.set_pad('_0', pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], - ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] - collator = Collator(backend='raw') - collator.set_ignore('_0', '_3', '_1') - collator.set_pad('_2', backend='numpy') - collator.set_pad('_4', backend='numpy', pad_val=100) - raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], - [{'1'}, {'2'}]] - findListDiff(raw_pad_lst, collator(list_batch)) - - # _single - collator = Collator() - collator.set_pad('_single') - findListDiff(list_batch, collator(list_batch)) - - def test_nest_ignore(self): - dict_batch = [{ - 'str': '1', - 'lst_str': ['1'], - 'int': 1, - 'lst_int': [1], - 'nest_lst_int': [[1]], - 'float': 1.1, - 'lst_float': [1.1], - 'bool': True, - 'numpy': np.ones(1), - 'dict': {'1': '1'}, - 'set': {'1'}, - 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} - }, - { - 'str': '2', - 'lst_str': ['2', '2'], - 'int': 2, - 'lst_int': [1, 2], - 'nest_lst_int': [[1], [1, 2]], - 'float': 2.1, - 'lst_float': [2.1], - 'bool': False, - 'numpy': np.zeros(1), - 'dict': {'1': '2'}, - 'set': {'2'}, - 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} - } - ] - # 测试 ignore - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': {'int':[1, 1]}}} - findDictDiff(raw_pad_batch, collator(dict_batch)) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=None) - collator.set_ignore('str', 'int', 'lst_int') - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [{'int':1}, {'int':1}]}} - pad_batch = collator(dict_batch) - findDictDiff(raw_pad_batch, pad_batch) - - collator = Collator(backend='raw') - collator.set_pad(('nested_dict', 'c'), pad_val=1) - with pytest.raises(BaseException): - collator(dict_batch) - - collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int') - collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) - pad_batch = collator(dict_batch) - raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], - 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, - 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], - 'c': [1, 1]}} - findDictDiff(raw_pad_batch, pad_batch) - - - - - -