From b3e0ebd7fc56b119ce4116c41fd7660071165940 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 30 Apr 2022 17:10:03 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96=E4=BA=86Collator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 3 +- .../callbacks/load_best_model_callback.py | 33 ++++----- fastNLP/core/collators/new_collator.py | 34 +++++---- fastNLP/core/collators/padders/utils.py | 2 + fastNLP/core/collators/utils.py | 51 ++++++++------ fastNLP/core/dataloaders/fdataloader.py | 7 -- .../core/dataloaders/torch_dataloader/fdl.py | 2 +- .../collators/padders/test_numpy_padder.py | 2 +- tests/core/collators/test_new_collator.py | 70 ++++++++++++++++++- tests/core/collators/test_utils.py | 16 ++--- 10 files changed, 148 insertions(+), 72 deletions(-) delete mode 100644 fastNLP/core/dataloaders/fdataloader.py diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 982df7da..7f0c290d 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -126,7 +126,8 @@ class Callback: :param trainer: `fastNLP.Trainer` :param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 - :param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 + :param list[int] indices: 当前的 batch 是 dataset 中的哪些数据。仅在 DataLoader 支持得到当前 batch index 的时候有值, + 其它时候为 None 。 """ pass diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 5addd2e2..32534d2a 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -94,20 +94,21 @@ class LoadBestModelCallback(HasMonitorCallback): else: self.buffer.seek(0) trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) - - self._delete_after_after(trainer) - - def _delete_after_after(self, trainer): - trainer.driver.barrier() if self.delete_after_after: - if self.real_save_folder: - logger.info(f"Deleting {self.real_save_folder}...") - shutil.rmtree(self.real_save_folder, ignore_errors=True) - try: - # 如果是 emtpy 的,就会被删除掉 - os.rmdir(self.save_folder) - except: - pass - elif hasattr(self, 'buffer'): - self.buffer.close() - del self.buffer \ No newline at end of file + trainer.driver.barrier() + self._delete_folder() + trainer.driver.barrier() + + def _delete_folder(self): + if self.real_save_folder: + logger.info(f"Deleting {self.real_save_folder}...") + shutil.rmtree(self.real_save_folder, ignore_errors=True) + try: + # 如果是 emtpy 的,就会被删除掉 + os.rmdir(self.save_folder) + logger.debug(f"Since {self.save_folder} is an empty folder, it has been removed.") + except: + pass + elif hasattr(self, 'buffer'): + self.buffer.close() + del self.buffer \ No newline at end of file diff --git a/fastNLP/core/collators/new_collator.py b/fastNLP/core/collators/new_collator.py index 869a60a7..9123a293 100644 --- a/fastNLP/core/collators/new_collator.py +++ b/fastNLP/core/collators/new_collator.py @@ -6,7 +6,7 @@ from .padders.get_padder import get_padder import re from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ - pack_batch_sequence, NESTED_DICT_SEPARATOR + pack_batch_sequence sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] @@ -16,10 +16,11 @@ class Collator: def __init__(self, backend='torch'): """ 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 - 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。 + 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 + 时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 - :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], - 若为 None ,则不进行 padding 。 + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 + 若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 """ self.unpack_batch_func = None self.pack_batch_func = None @@ -54,22 +55,25 @@ class Collator: else: self.batch_data_type = 's' logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " - f"is {self.batch_data_type}") + f"is `{self.batch_data_type}`.") if self.batch_data_type == 's': - self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整 - self.pack_batch_func = lambda x:x['_single'] + self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 + self.pack_batch_func = lambda x: x['_single'] elif self.batch_data_type == 'l': self.unpack_batch_func = unpack_batch_sequence self.pack_batch_func = pack_batch_sequence elif self.batch_data_type == 'd': - if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value} + if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} self.unpack_batch_func = unpack_batch_nested_mapping self.pack_batch_func = pack_batch_nested_mapping else: self.unpack_batch_func = unpack_batch_mapping self.pack_batch_func = lambda x:x - unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。 + if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 + unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) + else: + unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 pad_batch = {} if len(self.padders)==0: # 第一次运行,准备 padder @@ -96,13 +100,13 @@ class Collator: return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 - def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None, + def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, pad_fn:Callable=None) -> "Collator": """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; + field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 @@ -126,11 +130,11 @@ class Collator: f"index, but other field is set as dict mode." elif self.batch_data_type == 'l': assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ - f"field name is {field_name}" + f"field name is {field_name}." if field_name == '_single': self.batch_data_type = 's' - elif sequence_idx_str.match(field_name): + elif isinstance(field_name, str) and sequence_idx_str.match(field_name): self.batch_data_type = 'l' else: self.batch_data_type = 'd' @@ -165,8 +169,8 @@ class Collator: collator.set_ignore('field1', 'field2') :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 + __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 :return: 返回 Collator 自身 """ for field_name in field_names: diff --git a/fastNLP/core/collators/padders/utils.py b/fastNLP/core/collators/padders/utils.py index f6240219..d2d3a8e0 100644 --- a/fastNLP/core/collators/padders/utils.py +++ b/fastNLP/core/collators/padders/utils.py @@ -149,6 +149,7 @@ def is_number(dtype): if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \ and not is_numpy_number_dtype(dtype): return True + return False except: return False @@ -161,6 +162,7 @@ if __name__ == '__main__': # print(type(b[0])) # print(b) # import torch + print(is_number(type('a'))) print(is_number_or_numpy_number(type(3))) # True print(is_number_or_numpy_number(type(3.1))) # True print(is_number_or_numpy_number(type('3'))) # False diff --git a/fastNLP/core/collators/utils.py b/fastNLP/core/collators/utils.py index 9a397c66..1a82aa23 100644 --- a/fastNLP/core/collators/utils.py +++ b/fastNLP/core/collators/utils.py @@ -2,54 +2,58 @@ from collections import defaultdict from functools import reduce from typing import Sequence, Mapping, Dict -NESTED_DICT_SEPARATOR = '@@' - -def unpack_batch_mapping(batch:Sequence[Mapping])->Dict: +def unpack_batch_mapping(batch:Sequence[Mapping], ignore_fields:set)->Dict: """ 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} :param batch: + :param ignore_fields: :return: """ dict_batch = defaultdict(list) for sample in batch: for key, value in sample.items(): + if key in ignore_fields: + continue dict_batch[key].append(value) return dict_batch -def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict: +def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: """ 将 nested 的 dict 中的内容展开到一个 flat dict 中 :param batch: - :param _parent: 内部使用 + :param ignore_fields: 需要忽略的 field 。 + :param stop_deep_fields: 不需要继续往下衍射的 :return: """ dict_batch = defaultdict(list) - if _parent != '': - _parent += NESTED_DICT_SEPARATOR for sample in batch: for key, value in sample.items(): - if isinstance(value, Mapping): - _dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key) + if key in ignore_fields: + continue + if isinstance(value, Mapping) and key not in stop_deep_fields: + _dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent=(key,)) for key, value in _dict_batch.items(): dict_batch[key].append(value) else: - dict_batch[_parent + key].append(value) + dict_batch[key].append(value) return dict_batch -def _unpack_batch_nested_mapping(value, _parent)->Dict: +def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict: _dict = {} - _parent += NESTED_DICT_SEPARATOR for k, v in value.items(): - if isinstance(v, Mapping): - __dict = _unpack_batch_nested_mapping(v, _parent=_parent + k) + _k = _parent + (k,) + if _k in ignore_fields: + continue + if isinstance(v, Mapping) and _k not in stop_deep_fields: + __dict = _unpack_batch_nested_mapping(v, ignore_fields, stop_deep_fields, _parent=_k) _dict.update(__dict) else: - _dict[_parent + k] = v + _dict[_k] = v return _dict @@ -63,10 +67,11 @@ def pack_batch_nested_mapping(batch:Mapping) -> Dict: dicts = [] for key, value in batch.items(): - keys = key.split(NESTED_DICT_SEPARATOR) - d = {keys[-1]: value} - for key in keys[:-1:][::-1]: - d = {key: d} + if not isinstance(key, tuple): + key = [key] + d = {key[-1]: value} + for k in key[:-1:][::-1]: + d = {k: d} dicts.append(d) return reduce(_merge_dict, dicts) @@ -85,17 +90,21 @@ def _merge_dict(a, b, path=None): return a -def unpack_batch_sequence(batch:Sequence[Sequence])->Dict: +def unpack_batch_sequence(batch:Sequence[Sequence], ignore_fields)->Dict: """ 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} :param batch: + :param ignore_fields: 需要忽略的field :return: """ dict_batch = defaultdict(list) for sample in batch: for i, content in enumerate(sample): - dict_batch[f'_{i}'].append(content) + field_name = f'_{i}' + if field_name in ignore_fields: + continue + dict_batch[field_name].append(content) return dict_batch diff --git a/fastNLP/core/dataloaders/fdataloader.py b/fastNLP/core/dataloaders/fdataloader.py deleted file mode 100644 index 742f3909..00000000 --- a/fastNLP/core/dataloaders/fdataloader.py +++ /dev/null @@ -1,7 +0,0 @@ -__all__ = [ - 'FDataLoader' -] - - -class FDataLoader: - pass diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index cf8e2c31..02721aaf 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -17,7 +17,7 @@ if _NEED_IMPORT_TORCH: from torch.utils.data import DataLoader, Sampler from torch.utils.data._utils.collate import default_collate else: - from ..fdataloader import FDataLoader as DataLoader + from fastNLP.core.utils.dummy_class import DummyClass as DataLoader class _FDataSet: diff --git a/tests/core/collators/padders/test_numpy_padder.py b/tests/core/collators/padders/test_numpy_padder.py index 42665857..6cc9d668 100644 --- a/tests/core/collators/padders/test_numpy_padder.py +++ b/tests/core/collators/padders/test_numpy_padder.py @@ -10,7 +10,7 @@ class TestNumpyNumberPadder: def test_run(self): padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) a = [1, 2, 3] - assert isinstance(a, np.ndarray) + assert isinstance(padder(a), np.ndarray) assert (padder(a) == np.array(a)).sum() == 3 diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py index 5fc82c91..7c27b3a9 100644 --- a/tests/core/collators/test_new_collator.py +++ b/tests/core/collators/test_new_collator.py @@ -158,7 +158,7 @@ class TestCollator: # 测试 ignore collator = Collator(backend='raw') - collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} findDictDiff(raw_pad_batch, collator(dict_batch)) @@ -171,7 +171,7 @@ class TestCollator: # 测试设置 pad 值 collator = Collator(backend='raw') collator.set_pad('nest_lst_int', pad_val=100) - collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} findDictDiff(raw_pad_batch, collator(dict_batch)) @@ -217,6 +217,72 @@ class TestCollator: collator.set_pad('_single') findListDiff(list_batch, collator(list_batch)) + def test_nest_ignore(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} + } + ] + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': {'int':[1, 1]}}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=None) + collator.set_ignore('str', 'int', 'lst_int') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [{'int':1}, {'int':1}]}} + pad_batch = collator(dict_batch) + findDictDiff(raw_pad_batch, pad_batch) + + collator = Collator(backend='raw') + collator.set_pad(('nested_dict', 'c'), pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int') + collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) + pad_batch = collator(dict_batch) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], + 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, + 'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], + 'c': [1, 1]}} + findDictDiff(raw_pad_batch, pad_batch) diff --git a/tests/core/collators/test_utils.py b/tests/core/collators/test_utils.py index d56dacc6..74c54a36 100644 --- a/tests/core/collators/test_utils.py +++ b/tests/core/collators/test_utils.py @@ -4,25 +4,25 @@ from fastNLP.core.collators.utils import * def test_unpack_batch_mapping(): batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] - assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]} + assert unpack_batch_mapping(batch, {})=={'a': [[1, 2], [3]], 'b': [1, 2]} def test_unpack_batch_nested_mapping(): batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] - assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]} + assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c','c'): [1, 2]} batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] - assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]} + assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2]} batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] - assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], - 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} + assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], + ('c','c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} def test_pack_batch_nested_mapping(): - batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], - 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} + batch = {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], + ('c', 'c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} new_batch = pack_batch_nested_mapping(batch) assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], 'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} @@ -30,7 +30,7 @@ def test_pack_batch_nested_mapping(): def test_unpack_batch_sequence(): batch = [[1, 2, 3], [2, 4, 6]] - new_batch = unpack_batch_sequence(batch) + new_batch = unpack_batch_sequence(batch, {}) assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]}