Browse Source

优化了Collator

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
b3e0ebd7fc
10 changed files with 148 additions and 72 deletions
  1. +2
    -1
      fastNLP/core/callbacks/callback.py
  2. +17
    -16
      fastNLP/core/callbacks/load_best_model_callback.py
  3. +19
    -15
      fastNLP/core/collators/new_collator.py
  4. +2
    -0
      fastNLP/core/collators/padders/utils.py
  5. +30
    -21
      fastNLP/core/collators/utils.py
  6. +0
    -7
      fastNLP/core/dataloaders/fdataloader.py
  7. +1
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  8. +1
    -1
      tests/core/collators/padders/test_numpy_padder.py
  9. +68
    -2
      tests/core/collators/test_new_collator.py
  10. +8
    -8
      tests/core/collators/test_utils.py

+ 2
- 1
fastNLP/core/callbacks/callback.py View File

@@ -126,7 +126,8 @@ class Callback:

:param trainer: `fastNLP.Trainer`
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据。仅在 DataLoader 支持得到当前 batch index 的时候有值,
其它时候为 None 。
"""
pass



+ 17
- 16
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -94,20 +94,21 @@ class LoadBestModelCallback(HasMonitorCallback):
else:
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)

self._delete_after_after(trainer)

def _delete_after_after(self, trainer):
trainer.driver.barrier()
if self.delete_after_after:
if self.real_save_folder:
logger.info(f"Deleting {self.real_save_folder}...")
shutil.rmtree(self.real_save_folder, ignore_errors=True)
try:
# 如果是 emtpy 的,就会被删除掉
os.rmdir(self.save_folder)
except:
pass
elif hasattr(self, 'buffer'):
self.buffer.close()
del self.buffer
trainer.driver.barrier()
self._delete_folder()
trainer.driver.barrier()

def _delete_folder(self):
if self.real_save_folder:
logger.info(f"Deleting {self.real_save_folder}...")
shutil.rmtree(self.real_save_folder, ignore_errors=True)
try:
# 如果是 emtpy 的,就会被删除掉
os.rmdir(self.save_folder)
logger.debug(f"Since {self.save_folder} is an empty folder, it has been removed.")
except:
pass
elif hasattr(self, 'buffer'):
self.buffer.close()
del self.buffer

+ 19
- 15
fastNLP/core/collators/new_collator.py View File

@@ -6,7 +6,7 @@ from .padders.get_padder import get_padder
import re

from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \
pack_batch_sequence, NESTED_DICT_SEPARATOR
pack_batch_sequence

sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None]
@@ -16,10 +16,11 @@ class Collator:
def __init__(self, backend='torch'):
"""
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]
若为 None ,则不进行 padding 。
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]
若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。
"""
self.unpack_batch_func = None
self.pack_batch_func = None
@@ -54,22 +55,25 @@ class Collator:
else:
self.batch_data_type = 's'
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
f"is {self.batch_data_type}")
f"is `{self.batch_data_type}`.")
if self.batch_data_type == 's':
self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整
self.pack_batch_func = lambda x:x['_single']
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整
self.pack_batch_func = lambda x: x['_single']
elif self.batch_data_type == 'l':
self.unpack_batch_func = unpack_batch_sequence
self.pack_batch_func = pack_batch_sequence
elif self.batch_data_type == 'd':
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value}
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value}
self.unpack_batch_func = unpack_batch_nested_mapping
self.pack_batch_func = pack_batch_nested_mapping
else:
self.unpack_batch_func = unpack_batch_mapping
self.pack_batch_func = lambda x:x

unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。
if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys()))
else:
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。

pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder
@@ -96,13 +100,13 @@ class Collator:

return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型

def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None,
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "Collator":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b');
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
@@ -126,11 +130,11 @@ class Collator:
f"index, but other field is set as dict mode."
elif self.batch_data_type == 'l':
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
f"field name is {field_name}"
f"field name is {field_name}."

if field_name == '_single':
self.batch_data_type = 's'
elif sequence_idx_str.match(field_name):
elif isinstance(field_name, str) and sequence_idx_str.match(field_name):
self.batch_data_type = 'l'
else:
self.batch_data_type = 'd'
@@ -165,8 +169,8 @@ class Collator:
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
for field_name in field_names:


+ 2
- 0
fastNLP/core/collators/padders/utils.py View File

@@ -149,6 +149,7 @@ def is_number(dtype):
if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \
and not is_numpy_number_dtype(dtype):
return True
return False
except:
return False

@@ -161,6 +162,7 @@ if __name__ == '__main__':
# print(type(b[0]))
# print(b)
# import torch
print(is_number(type('a')))
print(is_number_or_numpy_number(type(3))) # True
print(is_number_or_numpy_number(type(3.1))) # True
print(is_number_or_numpy_number(type('3'))) # False


+ 30
- 21
fastNLP/core/collators/utils.py View File

@@ -2,54 +2,58 @@ from collections import defaultdict
from functools import reduce
from typing import Sequence, Mapping, Dict

NESTED_DICT_SEPARATOR = '@@'


def unpack_batch_mapping(batch:Sequence[Mapping])->Dict:
def unpack_batch_mapping(batch:Sequence[Mapping], ignore_fields:set)->Dict:
"""
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]}

:param batch:
:param ignore_fields:
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
if key in ignore_fields:
continue
dict_batch[key].append(value)
return dict_batch


def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict:
def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict:
"""
将 nested 的 dict 中的内容展开到一个 flat dict 中

:param batch:
:param _parent: 内部使用
:param ignore_fields: 需要忽略的 field 。
:param stop_deep_fields: 不需要继续往下衍射的
:return:
"""
dict_batch = defaultdict(list)
if _parent != '':
_parent += NESTED_DICT_SEPARATOR
for sample in batch:
for key, value in sample.items():
if isinstance(value, Mapping):
_dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key)
if key in ignore_fields:
continue
if isinstance(value, Mapping) and key not in stop_deep_fields:
_dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent=(key,))
for key, value in _dict_batch.items():
dict_batch[key].append(value)
else:
dict_batch[_parent + key].append(value)
dict_batch[key].append(value)
return dict_batch


def _unpack_batch_nested_mapping(value, _parent)->Dict:
def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict:
_dict = {}
_parent += NESTED_DICT_SEPARATOR
for k, v in value.items():
if isinstance(v, Mapping):
__dict = _unpack_batch_nested_mapping(v, _parent=_parent + k)
_k = _parent + (k,)
if _k in ignore_fields:
continue
if isinstance(v, Mapping) and _k not in stop_deep_fields:
__dict = _unpack_batch_nested_mapping(v, ignore_fields, stop_deep_fields, _parent=_k)
_dict.update(__dict)
else:
_dict[_parent + k] = v
_dict[_k] = v
return _dict


@@ -63,10 +67,11 @@ def pack_batch_nested_mapping(batch:Mapping) -> Dict:
dicts = []

for key, value in batch.items():
keys = key.split(NESTED_DICT_SEPARATOR)
d = {keys[-1]: value}
for key in keys[:-1:][::-1]:
d = {key: d}
if not isinstance(key, tuple):
key = [key]
d = {key[-1]: value}
for k in key[:-1:][::-1]:
d = {k: d}
dicts.append(d)
return reduce(_merge_dict, dicts)

@@ -85,17 +90,21 @@ def _merge_dict(a, b, path=None):
return a


def unpack_batch_sequence(batch:Sequence[Sequence])->Dict:
def unpack_batch_sequence(batch:Sequence[Sequence], ignore_fields)->Dict:
"""
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]}

:param batch:
:param ignore_fields: 需要忽略的field
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for i, content in enumerate(sample):
dict_batch[f'_{i}'].append(content)
field_name = f'_{i}'
if field_name in ignore_fields:
continue
dict_batch[field_name].append(content)
return dict_batch




+ 0
- 7
fastNLP/core/dataloaders/fdataloader.py View File

@@ -1,7 +0,0 @@
__all__ = [
'FDataLoader'
]


class FDataLoader:
pass

+ 1
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -17,7 +17,7 @@ if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader, Sampler
from torch.utils.data._utils.collate import default_collate
else:
from ..fdataloader import FDataLoader as DataLoader
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader


class _FDataSet:


+ 1
- 1
tests/core/collators/padders/test_numpy_padder.py View File

@@ -10,7 +10,7 @@ class TestNumpyNumberPadder:
def test_run(self):
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
assert isinstance(a, np.ndarray)
assert isinstance(padder(a), np.ndarray)
assert (padder(a) == np.array(a)).sum() == 3




+ 68
- 2
tests/core/collators/test_new_collator.py View File

@@ -158,7 +158,7 @@ class TestCollator:

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

@@ -171,7 +171,7 @@ class TestCollator:
# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))
@@ -217,6 +217,72 @@ class TestCollator:
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))

def test_nest_ignore(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}}
}
]
# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int'))
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': {'int':[1, 1]}}}
findDictDiff(raw_pad_batch, collator(dict_batch))

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=None)
collator.set_ignore('str', 'int', 'lst_int')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [{'int':1}, {'int':1}]}}
pad_batch = collator(dict_batch)
findDictDiff(raw_pad_batch, pad_batch)

collator = Collator(backend='raw')
collator.set_pad(('nested_dict', 'c'), pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int')
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x])
pad_batch = collator(dict_batch)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False],
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']},
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]],
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)





+ 8
- 8
tests/core/collators/test_utils.py View File

@@ -4,25 +4,25 @@ from fastNLP.core.collators.utils import *

def test_unpack_batch_mapping():
batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}]
assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]}
assert unpack_batch_mapping(batch, {})=={'a': [[1, 2], [3]], 'b': [1, 2]}


def test_unpack_batch_nested_mapping():
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]}
assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c','c'): [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]}
assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}},
{'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}
assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2],
('c','c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]}


def test_pack_batch_nested_mapping():
batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}
batch = {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2],
('c', 'c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]}
new_batch = pack_batch_nested_mapping(batch)
assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2],
'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}}
@@ -30,7 +30,7 @@ def test_pack_batch_nested_mapping():

def test_unpack_batch_sequence():
batch = [[1, 2, 3], [2, 4, 6]]
new_batch = unpack_batch_sequence(batch)
new_batch = unpack_batch_sequence(batch, {})
assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]}




Loading…
Cancel
Save