@@ -126,7 +126,8 @@ class Callback: | |||
:param trainer: `fastNLP.Trainer` | |||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 | |||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据。仅在 DataLoader 支持得到当前 batch index 的时候有值, | |||
其它时候为 None 。 | |||
""" | |||
pass | |||
@@ -94,20 +94,21 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
else: | |||
self.buffer.seek(0) | |||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
self._delete_after_after(trainer) | |||
def _delete_after_after(self, trainer): | |||
trainer.driver.barrier() | |||
if self.delete_after_after: | |||
if self.real_save_folder: | |||
logger.info(f"Deleting {self.real_save_folder}...") | |||
shutil.rmtree(self.real_save_folder, ignore_errors=True) | |||
try: | |||
# 如果是 emtpy 的,就会被删除掉 | |||
os.rmdir(self.save_folder) | |||
except: | |||
pass | |||
elif hasattr(self, 'buffer'): | |||
self.buffer.close() | |||
del self.buffer | |||
trainer.driver.barrier() | |||
self._delete_folder() | |||
trainer.driver.barrier() | |||
def _delete_folder(self): | |||
if self.real_save_folder: | |||
logger.info(f"Deleting {self.real_save_folder}...") | |||
shutil.rmtree(self.real_save_folder, ignore_errors=True) | |||
try: | |||
# 如果是 emtpy 的,就会被删除掉 | |||
os.rmdir(self.save_folder) | |||
logger.debug(f"Since {self.save_folder} is an empty folder, it has been removed.") | |||
except: | |||
pass | |||
elif hasattr(self, 'buffer'): | |||
self.buffer.close() | |||
del self.buffer |
@@ -6,7 +6,7 @@ from .padders.get_padder import get_padder | |||
import re | |||
from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||
pack_batch_sequence, NESTED_DICT_SEPARATOR | |||
pack_batch_sequence | |||
sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] | |||
@@ -16,10 +16,11 @@ class Collator: | |||
def __init__(self, backend='torch'): | |||
""" | |||
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。 | |||
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。Collator 在第一次进行 pad 的 | |||
时候自动根据设置以及数据情况,为每个 field 获取一个 padder ,在之后的每次调用中,都将使用对应的 Padder 给对应的 field 。 | |||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||
若为 None ,则不进行 padding 。 | |||
:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None]。 | |||
若为 None ,则不进行 padding 。该参数对本身就不能进行 pad 的数据没用影响,不能 pad 的数据返回一定是 list 。 | |||
""" | |||
self.unpack_batch_func = None | |||
self.pack_batch_func = None | |||
@@ -54,22 +55,25 @@ class Collator: | |||
else: | |||
self.batch_data_type = 's' | |||
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||
f"is {self.batch_data_type}") | |||
f"is `{self.batch_data_type}`.") | |||
if self.batch_data_type == 's': | |||
self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整 | |||
self.pack_batch_func = lambda x:x['_single'] | |||
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 | |||
self.pack_batch_func = lambda x: x['_single'] | |||
elif self.batch_data_type == 'l': | |||
self.unpack_batch_func = unpack_batch_sequence | |||
self.pack_batch_func = pack_batch_sequence | |||
elif self.batch_data_type == 'd': | |||
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value} | |||
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} | |||
self.unpack_batch_func = unpack_batch_nested_mapping | |||
self.pack_batch_func = pack_batch_nested_mapping | |||
else: | |||
self.unpack_batch_func = unpack_batch_mapping | |||
self.pack_batch_func = lambda x:x | |||
unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。 | |||
if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 | |||
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) | |||
else: | |||
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 | |||
pad_batch = {} | |||
if len(self.padders)==0: # 第一次运行,准备 padder | |||
@@ -96,13 +100,13 @@ class Collator: | |||
return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||
def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
pad_fn:Callable=None) -> "Collator": | |||
""" | |||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
@@ -126,11 +130,11 @@ class Collator: | |||
f"index, but other field is set as dict mode." | |||
elif self.batch_data_type == 'l': | |||
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||
f"field name is {field_name}" | |||
f"field name is {field_name}." | |||
if field_name == '_single': | |||
self.batch_data_type = 's' | |||
elif sequence_idx_str.match(field_name): | |||
elif isinstance(field_name, str) and sequence_idx_str.match(field_name): | |||
self.batch_data_type = 'l' | |||
else: | |||
self.batch_data_type = 'd' | |||
@@ -165,8 +169,8 @@ class Collator: | |||
collator.set_ignore('field1', 'field2') | |||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 | |||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
:return: 返回 Collator 自身 | |||
""" | |||
for field_name in field_names: | |||
@@ -149,6 +149,7 @@ def is_number(dtype): | |||
if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \ | |||
and not is_numpy_number_dtype(dtype): | |||
return True | |||
return False | |||
except: | |||
return False | |||
@@ -161,6 +162,7 @@ if __name__ == '__main__': | |||
# print(type(b[0])) | |||
# print(b) | |||
# import torch | |||
print(is_number(type('a'))) | |||
print(is_number_or_numpy_number(type(3))) # True | |||
print(is_number_or_numpy_number(type(3.1))) # True | |||
print(is_number_or_numpy_number(type('3'))) # False | |||
@@ -2,54 +2,58 @@ from collections import defaultdict | |||
from functools import reduce | |||
from typing import Sequence, Mapping, Dict | |||
NESTED_DICT_SEPARATOR = '@@' | |||
def unpack_batch_mapping(batch:Sequence[Mapping])->Dict: | |||
def unpack_batch_mapping(batch:Sequence[Mapping], ignore_fields:set)->Dict: | |||
""" | |||
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} | |||
:param batch: | |||
:param ignore_fields: | |||
:return: | |||
""" | |||
dict_batch = defaultdict(list) | |||
for sample in batch: | |||
for key, value in sample.items(): | |||
if key in ignore_fields: | |||
continue | |||
dict_batch[key].append(value) | |||
return dict_batch | |||
def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict: | |||
def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: | |||
""" | |||
将 nested 的 dict 中的内容展开到一个 flat dict 中 | |||
:param batch: | |||
:param _parent: 内部使用 | |||
:param ignore_fields: 需要忽略的 field 。 | |||
:param stop_deep_fields: 不需要继续往下衍射的 | |||
:return: | |||
""" | |||
dict_batch = defaultdict(list) | |||
if _parent != '': | |||
_parent += NESTED_DICT_SEPARATOR | |||
for sample in batch: | |||
for key, value in sample.items(): | |||
if isinstance(value, Mapping): | |||
_dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key) | |||
if key in ignore_fields: | |||
continue | |||
if isinstance(value, Mapping) and key not in stop_deep_fields: | |||
_dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent=(key,)) | |||
for key, value in _dict_batch.items(): | |||
dict_batch[key].append(value) | |||
else: | |||
dict_batch[_parent + key].append(value) | |||
dict_batch[key].append(value) | |||
return dict_batch | |||
def _unpack_batch_nested_mapping(value, _parent)->Dict: | |||
def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict: | |||
_dict = {} | |||
_parent += NESTED_DICT_SEPARATOR | |||
for k, v in value.items(): | |||
if isinstance(v, Mapping): | |||
__dict = _unpack_batch_nested_mapping(v, _parent=_parent + k) | |||
_k = _parent + (k,) | |||
if _k in ignore_fields: | |||
continue | |||
if isinstance(v, Mapping) and _k not in stop_deep_fields: | |||
__dict = _unpack_batch_nested_mapping(v, ignore_fields, stop_deep_fields, _parent=_k) | |||
_dict.update(__dict) | |||
else: | |||
_dict[_parent + k] = v | |||
_dict[_k] = v | |||
return _dict | |||
@@ -63,10 +67,11 @@ def pack_batch_nested_mapping(batch:Mapping) -> Dict: | |||
dicts = [] | |||
for key, value in batch.items(): | |||
keys = key.split(NESTED_DICT_SEPARATOR) | |||
d = {keys[-1]: value} | |||
for key in keys[:-1:][::-1]: | |||
d = {key: d} | |||
if not isinstance(key, tuple): | |||
key = [key] | |||
d = {key[-1]: value} | |||
for k in key[:-1:][::-1]: | |||
d = {k: d} | |||
dicts.append(d) | |||
return reduce(_merge_dict, dicts) | |||
@@ -85,17 +90,21 @@ def _merge_dict(a, b, path=None): | |||
return a | |||
def unpack_batch_sequence(batch:Sequence[Sequence])->Dict: | |||
def unpack_batch_sequence(batch:Sequence[Sequence], ignore_fields)->Dict: | |||
""" | |||
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||
:param batch: | |||
:param ignore_fields: 需要忽略的field | |||
:return: | |||
""" | |||
dict_batch = defaultdict(list) | |||
for sample in batch: | |||
for i, content in enumerate(sample): | |||
dict_batch[f'_{i}'].append(content) | |||
field_name = f'_{i}' | |||
if field_name in ignore_fields: | |||
continue | |||
dict_batch[field_name].append(content) | |||
return dict_batch | |||
@@ -1,7 +0,0 @@ | |||
__all__ = [ | |||
'FDataLoader' | |||
] | |||
class FDataLoader: | |||
pass |
@@ -17,7 +17,7 @@ if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
from torch.utils.data._utils.collate import default_collate | |||
else: | |||
from ..fdataloader import FDataLoader as DataLoader | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
class _FDataSet: | |||
@@ -10,7 +10,7 @@ class TestNumpyNumberPadder: | |||
def test_run(self): | |||
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
a = [1, 2, 3] | |||
assert isinstance(a, np.ndarray) | |||
assert isinstance(padder(a), np.ndarray) | |||
assert (padder(a) == np.array(a)).sum() == 3 | |||
@@ -158,7 +158,7 @@ class TestCollator: | |||
# 测试 ignore | |||
collator = Collator(backend='raw') | |||
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'a')) | |||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
@@ -171,7 +171,7 @@ class TestCollator: | |||
# 测试设置 pad 值 | |||
collator = Collator(backend='raw') | |||
collator.set_pad('nest_lst_int', pad_val=100) | |||
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict','a')) | |||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
@@ -217,6 +217,72 @@ class TestCollator: | |||
collator.set_pad('_single') | |||
findListDiff(list_batch, collator(list_batch)) | |||
def test_nest_ignore(self): | |||
dict_batch = [{ | |||
'str': '1', | |||
'lst_str': ['1'], | |||
'int': 1, | |||
'lst_int': [1], | |||
'nest_lst_int': [[1]], | |||
'float': 1.1, | |||
'lst_float': [1.1], | |||
'bool': True, | |||
'numpy': np.ones(1), | |||
'dict': {'1': '1'}, | |||
'set': {'1'}, | |||
'nested_dict': {'int': 1, 'lst_int':[1, 2], 'c': {'int': 1}} | |||
}, | |||
{ | |||
'str': '2', | |||
'lst_str': ['2', '2'], | |||
'int': 2, | |||
'lst_int': [1, 2], | |||
'nest_lst_int': [[1], [1, 2]], | |||
'float': 2.1, | |||
'lst_float': [2.1], | |||
'bool': False, | |||
'numpy': np.zeros(1), | |||
'dict': {'1': '2'}, | |||
'set': {'2'}, | |||
'nested_dict': {'int': 1, 'lst_int': [1, 2], 'c': {'int': 1}} | |||
} | |||
] | |||
# 测试 ignore | |||
collator = Collator(backend='raw') | |||
collator.set_ignore('str', 'int', 'lst_int', ('nested_dict', 'int')) | |||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||
'c': {'int':[1, 1]}}} | |||
findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
collator = Collator(backend='raw') | |||
collator.set_pad(('nested_dict', 'c'), pad_val=None) | |||
collator.set_ignore('str', 'int', 'lst_int') | |||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||
'c': [{'int':1}, {'int':1}]}} | |||
pad_batch = collator(dict_batch) | |||
findDictDiff(raw_pad_batch, pad_batch) | |||
collator = Collator(backend='raw') | |||
collator.set_pad(('nested_dict', 'c'), pad_val=1) | |||
with pytest.raises(BaseException): | |||
collator(dict_batch) | |||
collator = Collator(backend='raw') | |||
collator.set_ignore('str', 'int', 'lst_int') | |||
collator.set_pad(('nested_dict', 'c'), pad_fn=lambda x: [d['int'] for d in x]) | |||
pad_batch = collator(dict_batch) | |||
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], | |||
'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, | |||
'set': [{'1'}, {'2'}], 'nested_dict': {'lst_int': [[1, 2], [1, 2]], | |||
'c': [1, 1]}} | |||
findDictDiff(raw_pad_batch, pad_batch) | |||
@@ -4,25 +4,25 @@ from fastNLP.core.collators.utils import * | |||
def test_unpack_batch_mapping(): | |||
batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] | |||
assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||
assert unpack_batch_mapping(batch, {})=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||
def test_unpack_batch_nested_mapping(): | |||
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] | |||
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]} | |||
assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c','c'): [1, 2]} | |||
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] | |||
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]} | |||
assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2]} | |||
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, | |||
{'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] | |||
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||
assert unpack_batch_nested_mapping(batch, {}, {}) == {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], | |||
('c','c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} | |||
def test_pack_batch_nested_mapping(): | |||
batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||
batch = {'a': [[1, 2], [3]], 'b': [1, 2], ('c', 'c', 'c'): [1, 2], | |||
('c', 'c', 'd'):[[1, 1], [2, 2]], ('c', 'd'): [[1], [2, 2]]} | |||
new_batch = pack_batch_nested_mapping(batch) | |||
assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], | |||
'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} | |||
@@ -30,7 +30,7 @@ def test_pack_batch_nested_mapping(): | |||
def test_unpack_batch_sequence(): | |||
batch = [[1, 2, 3], [2, 4, 6]] | |||
new_batch = unpack_batch_sequence(batch) | |||
new_batch = unpack_batch_sequence(batch, {}) | |||
assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]} | |||