Browse Source

修改Collator的实现,优化了nested场景下的报错

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
0fd17152b8
6 changed files with 224 additions and 482 deletions
  1. +105
    -426
      fastNLP/core/collators/collator.py
  2. +32
    -45
      fastNLP/core/collators/packer_unpacker.py
  3. +4
    -1
      fastNLP/core/collators/padders/torch_padder.py
  4. +3
    -5
      fastNLP/core/log/logger.py
  5. +1
    -1
      fastNLP/core/log/print.py
  6. +79
    -4
      tests/core/collators/test_collator.py

+ 105
- 426
fastNLP/core/collators/collator.py View File

@@ -12,8 +12,8 @@ from .padders.get_padder import get_padder

import re

from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \
pack_batch_sequence
from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \
NestedMappingPackerUnpacker

sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None]
@@ -126,46 +126,36 @@ class Collator:
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
f"is `{self.batch_data_type}`.")
if self.batch_data_type == 's':
self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整
self.pack_batch_func = lambda x: x['_single']
self.packer_unpacker = SinglePackerUnpacker() # 不需要做任何调整
elif self.batch_data_type == 'l':
self.unpack_batch_func = unpack_batch_sequence
self.pack_batch_func = pack_batch_sequence
self.packer_unpacker = SequencePackerUnpacker()
elif self.batch_data_type == 'd':
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value}
self.unpack_batch_func = unpack_batch_nested_mapping
self.pack_batch_func = pack_batch_nested_mapping
self.packer_unpacker = NestedMappingPackerUnpacker()
else:
self.unpack_batch_func = unpack_batch_mapping
self.pack_batch_func = lambda x:x
self.packer_unpacker = MappingPackerUnpacker()

if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸
unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys()))
else:
unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。
# 将 batch 中各个 field 组成自己的 batch;同时忽略处于 ignore_fields 中的数据。
unpack_batch = self.packer_unpacker.unpack_batch(batch, self.ignore_fields, self.input_fields)

pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder
if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。
self.backend = _get_backend()

for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}
elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto':
self.input_fields[key]['backend'] = self.backend

for field_name, setting in self.input_fields.items():
pad_fn = setting.get('pad_fn', None)
for field_name, batch_field in unpack_batch.items():
setting = self.input_fields.get(field_name, {'backend': self.backend, 'pad_val': 0 ,
'dtype': None, 'pad_fn': None})
pad_fn = setting['pad_fn']
if callable(pad_fn):
padder = pad_fn
else:
backend = self.backend if setting['backend'] == 'auto' else setting['backend']
batch_field = unpack_batch.get(field_name)
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
dtype=setting['dtype'], backend=backend,
field_name=field_name)
self.padders[field_name] = padder

if self.batch_data_type == 'l':
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序

@@ -173,7 +163,7 @@ class Collator:
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)

return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型
return self.packer_unpacker.pack_batch(pad_batch) # 根据情况恢复成与输入一致的类型

def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto',
pad_fn:Callable=None) -> "Collator":
@@ -195,16 +185,17 @@ class Collator:
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
self.padders.clear() # 重新生成
self._renew()

if self.batch_data_type is not None:
if self.batch_data_type == 's':
logger.debug("Set as single field mode.")
self.input_fields.clear()
elif self.batch_data_type == 'd':
if self.batch_data_type == 's':
logger.debug("Set as single field mode.")
self.input_fields.clear()
elif self.batch_data_type == 'd':
if isinstance(field_name, str):
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \
f"index, but other field is set as dict mode."
elif self.batch_data_type == 'l':
elif self.batch_data_type == 'l':
if isinstance(field_name, str):
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
f"field name is {field_name}."

@@ -215,8 +206,40 @@ class Collator:
else:
self.batch_data_type = 'd'

if field_name in self.ignore_fields:
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.")
# 检测是否已经设置了,主要需要考虑它的父亲节点的情况
ignore_fields = [(field, field) if isinstance(field, tuple) else ((field,), field)
for field in self.ignore_fields]
input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field)
for field in self.input_fields.keys()]
if isinstance(field_name, tuple):
_field_name = field_name
else:
_field_name = (field_name,)
for field, o_field in ignore_fields:
d = _compare_tuple(field, _field_name)
if d is None:
continue
if d == 0:
logger.rank_zero_warning(f"Field:`{field_name}` has been set as ignored before. It will not be "
f"ignored afterwards.")
self.ignore_fields.remove(o_field)
if d > 0:
raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set "
f"as ignore field.")
if d < 0:
raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set "
f"as ignore field.")
for field, o_field in input_field_names:
d = _compare_tuple(field, _field_name)
if d is None:
continue
if d > 0:
raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set "
f"pad.")
if d < 0:
raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set "
f"pad.")

if backend is None:
backend = self.backend
else:
@@ -235,7 +258,7 @@ class Collator:
:return:
"""
assert backend in SUPPORTED_BACKENDS
self.padders.clear()
self._renew()
self.backend = backend

def set_ignore(self, *field_names) -> "Collator":
@@ -249,400 +272,56 @@ class Collator:
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
for field_name in field_names:
if field_name in self.input_fields:
self.input_fields.pop(field_name)
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.")
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。
self.ignore_fields.add(field_name)
self._renew()
input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field)
for field in self.input_fields.keys()]

# 需要考虑父节点之类的情况
for field in field_names:
if not isinstance(field, tuple):
_field = (field,)
else:
_field = field
for _field_name, o_field_name in input_field_names:
d = _compare_tuple(_field, _field_name)
if d is None:
continue
if d == 0:
self.input_fields.pop(o_field_name)
logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.")
if d < 0:
self.input_fields.pop(o_field_name)
logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.")
if d > 0:
raise KeyError(f"Cannot ignore {field} since its parent key {o_field_name} has been set as pad.")
self.ignore_fields.add(field)

return self

def _renew(self):
self.packer_unpacker = None
self.padders.clear()







#
# from abc import ABCMeta, abstractmethod
# from typing import Any, Dict, List, Callable, Union, Tuple
# from numbers import Number
# import warnings
#
# import numpy as np
#
# from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
#
# if _NEED_IMPORT_PADDLE:
# import paddle
#
# if _NEED_IMPORT_TORCH:
# import torch
#
#
# class ApplyResultException(Exception):
# def __init__(self, msg, index=None):
# super().__init__(msg)
# self.msg = msg
# self.index = index # 标示在哪个数据遭遇到问题了
#
#
# class SetInputOrTargetException(Exception):
# def __init__(self, msg, index=None, field_name=None):
# super().__init__(msg)
# self.msg = msg
# self.index = index # 标示在哪个数据遭遇到问题了
# self.field_name = field_name # 标示当前 field 的名称
#
#
# def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]:
# r"""
# 识别cell的类别与dimension的数量
#
# numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html
# :param cell:
# :param dim:
# :return:
# """
# if isinstance(cell, (str, Number, np.bool_)):
# if hasattr(cell, 'dtype'):
# return cell.dtype.type, dim
# return type(cell), dim
#
# elif isinstance(cell, list):
# dim += 1
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
# types = set([i for i, j in res])
# dims = set([j for i, j in res])
# if len(types) > 1:
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
# elif len(types) == 0:
# raise SetInputOrTargetException("Empty value encountered.")
# if len(dims) > 1:
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
# return types.pop(), dims.pop()
#
# elif isinstance(cell, torch.Tensor):
# return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0
#
# elif isinstance(cell, paddle.Tensor):
# return cell.dtype, cell.dim() + dim
#
# elif isinstance(cell, np.ndarray):
# if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了
# return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等
# # 否则需要继续往下 iterate
# dim += 1
# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
# types = set([i for i, j in res])
# dims = set([j for i, j in res])
# if len(types) > 1:
# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
# elif len(types) == 0:
# raise SetInputOrTargetException("Empty value encountered.")
# if len(dims) > 1:
# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
# return types.pop(), dims.pop()
#
# else: # 包含 tuple, set, dict 以及其它的类型
# raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")
#
#
# def _get_ds_type_dim(ds: dict):
# # 获取数据集第一行的 field 内部函数的类型和维度
# field_dtype, field_dim = {}, {}
# for field_name, field_content in ds.items():
# type_0, dim_0 = _get_ele_type_and_dim(field_content)
# field_dtype[field_name], field_dim[field_name] = type_0, dim_0
# return field_dtype, field_dim
#
#
# class Collator(metaclass=ABCMeta):
# r"""
# 辅助DataLoader管理collate_fn的类
#
# """
#
# def __init__(self):
# super(Collator, self).__init__()
# self.collate_fn = []
#
# @abstractmethod
# def __call__(self, ins_lst: List) -> Any:
# raise NotImplementedError
#
# @abstractmethod
# def set_pad_val(self, *field_names: str, value=0):
# raise NotImplementedError
#
#
# class _MultiCollator:
# """
# 管理所有collator的容器,
# 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。
# """
#
# def __init__(self, collate_fns: Union[Callable, List[Callable], None]):
#
# if collate_fns is None:
# collate_fns = []
#
# if isinstance(collate_fns, Callable):
# collate_fns = [collate_fns]
#
# self._collators: list = collate_fns
#
# def __call__(self, ins_lst) -> Dict:
# out, list_out = {}, []
# for idx, _collate_fn in enumerate(self._collators):
# res = _collate_fn(ins_lst)
# if isinstance(res, Dict):
# out.update(res)
# else:
# list_out.append(res)
# # else:
# # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict")
# if len(out) > 0 and len(list_out) > 0:
# raise ValueError("the return of collate_fns is not the same, must be dict or list")
# if len(list_out) == 1:
# list_out = list_out[-1]
# # print(list_out)
# return out if len(out) > 0 else list_out
#
# def get_collators(self):
# return self._collators
#
# def add_collator(self, collator: Callable):
# self._collators.append(collator)
#
# def set_as_numpy(self, as_numpy: bool):
# """
# 存在AutoCollator时,as_numpy控制其返回值的类型
#
# :param as_numpy:
# :return:
# """
# for collator in self._collators:
# if isinstance(collator, AutoCollator):
# collator.set_as_numpy(as_numpy)
# return self
#
# def set_pad_val(self, *field_names, val=0):
# """
# 存在AutoCollator时,设置field_name的padding值
#
# :param field_names: 数据集的field名
# :param val: padding的值
# :return:
# """
# flag = True
# for collator in self._collators:
# if isinstance(collator, AutoCollator):
# collator.set_pad_val(*field_names, val=val)
# flag = False
# if flag:
# warnings.warn("AutoCollator is remove, set_padding is unavailable!!")
# return self
#
# def set_input(self, *field_names):
# """
# 设置AutoCollator需要的field_names,未被设置默认过滤掉
#
# :param field_names:
# :return:
# """
# flag = True
# for collator in self._collators:
# if isinstance(collator, AutoCollator):
# collator.set_input(*field_names)
# flag = False
# if flag:
# warnings.warn("AutoCollator is removed, set_input is unavailable!!")
# return self
#
#
# class AutoCollator(Collator):
#
# def __init__(self, as_numpy: bool):
# super(AutoCollator, self).__init__()
# self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0
# self.need_inputs = set() # 需要的 field name
# self.field_dtypes = None # 每列数据单元的 dtype 类型
# self.field_dims = None # 每列数据单元维度
# self.as_numpy = as_numpy
#
# def __call__(self, ins_lst: List[Dict]) -> dict:
# if len(self.need_inputs) == 0:
# raise ValueError({"set_inputs is None, you should use set_inputs method first!!"})
# # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的
#
# # 第一种情况,设置了 set_input 的值
# # 第二种情况, 根据数据的类型的判断是否 padding
# if self.field_dtypes is None and self.field_dims is None:
# field_dtypes, field_dims = {}, {}
# for key, value in ins_lst[0].items():
# if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None:
# field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value)
# self.field_dtypes = field_dtypes
# self.field_dims = field_dims
#
# pack_ins_lst, pad_ins_lst = {field_name: []
# for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {}
# # 将 list 列表内数据按列名打包
# for per_ins in ins_lst:
# for field_name, _field_content in per_ins.items():
# if field_name in self.need_inputs:
# pack_ins_lst[field_name].append(_field_content)
#
# pad_field_kv = {field_name: 0 for field_name in self.need_inputs}
# pad_field_kv.update(self.pad_field_value)
# self.pad_field_value = pad_field_kv
#
# if len(self.pad_field_value.keys()) > 0:
# # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略
# non_pad_field_names = []
# for k, v in self.pad_field_value.items():
# if v is None:
# non_pad_field_names.append(k)
#
# # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields))
# for field_name in non_pad_field_names:
# field_array = pack_ins_lst.pop(field_name)
# pad_ins_lst[field_name] = np.array(field_array)
#
# for field_name, field_array in pack_ins_lst.items():
# content = pad_content(field_array, field_name, self.field_dtypes[field_name],
# self.field_dims[field_name],
# self.pad_field_value[field_name],
# as_numpy=self.as_numpy)
# pad_ins_lst[field_name] = content
#
# # else:
# # # 取出每列的数据,根据类型判断是否能 pad
# # for field_name, field_array in pack_ins_lst.items():
# # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name],
# # self.field_dims[field_name],
# # pad_val=0, as_numpy=self.as_numpy)
# # pad_ins_lst[field_name] = pad_field_array
#
# return pad_ins_lst
#
# def set_pad_val(self, *field_names, val=0):
# for field_name in field_names:
# self.pad_field_value[field_name] = val
#
# def set_as_numpy(self, as_numpy: bool):
# self.as_numpy = as_numpy
#
# def set_input(self, *field_names):
# for field_name in field_names:
# self.need_inputs.add(field_name)
#
#
# def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool):
#
# if field_type:
# # 不处理, 返回 np.array 类型
# if field_dim > 3:
# return np.array(content)
# # 元素类型为数值类型 np.int64, np.float64, int, float 等
# if isinstance(field_type, type) and \
# (issubclass(field_type, np.number) or issubclass(field_type, Number)):
# if field_dim == 0:
# array = np.array(content, dtype=field_type)
# elif field_dim == 1:
# max_len = max(map(len, content))
# array = np.full((len(content), max_len), pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# array[i, :len(content_i)] = content_i
# elif field_dim == 2:
# max_len = max(map(len, content))
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
# content_i in content])
# array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# for j, content_ii in enumerate(content_i):
# array[i, j, :len(content_ii)] = content_ii
# else:
# shape = np.shape(content)
# if len(shape) == 4: # 说明各 dimension 是相同的大小
# array = np.array(content, dtype=field_type)
# else:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# if as_numpy is False:
# array = torch.tensor(array)
# return array
# # 元素类型为数值类型 torch.float 等
# elif str(field_type).startswith('torch'):
# if field_dim == 0:
# tensor = torch.tensor(content).to(field_type)
# elif field_dim == 1:
# max_len = max(map(len, content))
# tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i, :len(content_i)] = content_i.clone().detach()
# elif field_dim == 2:
# max_len = max(map(len, content))
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
# content_i in content])
# tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# for j, content_ii in enumerate(content_i):
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
# else:
# shapes = set([np.shape(content_i) for content_i in content])
# if len(shapes) > 1:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# shape = shapes.pop()
# if len(shape) == 3:
# tensor = torch.full([len(content)] + list(shape), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i] = content_i.clone().detach().to(field_type)
# else:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# return tensor
# # TODO 增加jittor/paddle?
# elif str(field_type).startswith('paddle'):
# if field_dim == 0:
# tensor = paddle.Tensor(content).to(field_type)
# elif field_dim == 1:
# max_len = max(map(len, content))
# tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i, :len(content_i)] = content_i.clone().detach()
# elif field_dim == 2:
# max_len = max(map(len, content))
# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
# content_i in content])
# tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# for j, content_ii in enumerate(content_i):
# tensor[i, j, :len(content_ii)] = content_ii.clone().detach()
# else:
# shapes = set([np.shape(content_i) for content_i in content])
# if len(shapes) > 1:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# shape = shapes.pop()
# if len(shape) == 3:
# tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val,
# dtype=field_type)
# for i, content_i in enumerate(content):
# tensor[i] = content_i.clone().detach().to(field_type)
# else:
# raise RuntimeError(
# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
# return tensor
#
# else:
# return np.array(content) # 不进行任何操作
# else:
# return np.array(content)
def _compare_tuple(t1, t2):
"""
检测 t1 和 t2 的关系。
例如 (1, ) 和 (1, ) 关系为 0,表示两者完全没有差异
例如 (1, ) 和 (2, ) 关系为 None,表示完全不同
例如 (1, 2, 3) 和 (1, ) 关系为 2,表示前者比后者长 2 位
但 例如 (1, 2, 3) 和 (2, ) 关系为 None,因为它们从前往后的key 不一样
例如 (1, 2, 3) 和 (1, 3) 关系为 None,因为它们从前往后的key 不一样

例如 (1, ) 和 (1, 2, 3) 关系为 -2,表示后者比前者长 2 位
但 例如 (2, ) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样
例如 (1, 3) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样
:param t1:
:param t2:
:return: None 没有关系; 0 两者完全一样; >0 t1比t2长,<0 t2比t1长
"""
if t1 == t2:
return 0
for _t1, _t2 in zip(t1, t2): # 会按照最短的计算
if _t1 != _t2:
return None
return len(t1) - len(t2)

+ 32
- 45
fastNLP/core/collators/packer_unpacker.py View File

@@ -3,7 +3,7 @@ from functools import reduce
from typing import Sequence, Mapping, Dict


class MappingPackerUnPacker:
class MappingPackerUnpacker:
@staticmethod
def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict:
"""
@@ -53,8 +53,9 @@ class NestedMappingPackerUnpacker:

@staticmethod
def pack_batch(batch):
if len(batch) == 0:
return []
dicts = []

for key, value in batch.items():
if not isinstance(key, tuple):
key = [key]
@@ -65,30 +66,38 @@ class NestedMappingPackerUnpacker:
return reduce(_merge_dict, dicts)


class
class SequencePackerUnpacker:
@staticmethod
def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict:
"""
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]}

:param batch:
:param ignore_fields: 需要忽略的field
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for i, content in enumerate(sample):
field_name = f'_{i}'
if field_name in ignore_fields:
continue
dict_batch[field_name].append(content)
return dict_batch

def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict:
"""
将 nested 的 dict 中的内容展开到一个 flat dict 中
@staticmethod
def pack_batch(batch):
return list(batch.values())

:param batch:
:param ignore_fields: 需要忽略的 field 。
:param stop_deep_fields: 不需要继续往下衍射的
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
if key in ignore_fields:
continue
if isinstance(value, Mapping) and key not in stop_deep_fields:
_dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent=(key,))
for key, value in _dict_batch.items():
dict_batch[key].append(value)
else:
dict_batch[key].append(value)
return dict_batch

class SinglePackerUnpacker:
@staticmethod
def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields):
return {'_single': batch}

@staticmethod
def pack_batch(batch):
return batch['_single']


def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict:
@@ -136,25 +145,3 @@ def _merge_dict(a, b, path=None):
else:
a[key] = b[key]
return a


def unpack_batch_sequence(batch:Sequence[Sequence], ignore_fields)->Dict:
"""
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]}

:param batch:
:param ignore_fields: 需要忽略的field
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for i, content in enumerate(sample):
field_name = f'_{i}'
if field_name in ignore_fields:
continue
dict_batch[field_name].append(content)
return dict_batch


def pack_batch_sequence(batch:Mapping)->Sequence:
return list(batch.values())

+ 4
- 1
fastNLP/core/collators/padders/torch_padder.py View File

@@ -112,16 +112,19 @@ class TorchTensorPadder(Padder):

@staticmethod
def pad(batch_field, pad_val, dtype):
device = None
try:
if not isinstance(batch_field[0], torch.Tensor):
batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field]
else:
device = batch_field[0].device
except AttributeError:
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")

shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype, device=device)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
tensor[slices] = field


+ 3
- 5
fastNLP/core/log/logger.py View File

@@ -134,11 +134,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
:return:
"""
if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0':
if msg not in self._warning_msgs:
if self.isEnabledFor(WARNING):
# kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg)
if self.isEnabledFor(WARNING):
# kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)

def warn(self, msg, *args, **kwargs):
if self.isEnabledFor(WARNING):


+ 1
- 1
fastNLP/core/log/print.py View File

@@ -21,5 +21,5 @@ def print(*args, sep=' ', end='\n', file=None, flush=False):
:param flush: 该参数无意义。
:return:
"""
line = sep.join(args)
line = sep.join(map(str, args))
logger.info(line)

+ 79
- 4
tests/core/collators/test_collator.py View File

@@ -5,6 +5,7 @@ import pytest
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.collator import Collator
from ...helpers.utils import Capturing


def _assert_equal(d1, d2):
@@ -42,7 +43,6 @@ def findListDiff(d1, d2):


class TestCollator:

@pytest.mark.torch
def test_run(self):
dict_batch = [{
@@ -286,8 +286,83 @@ class TestCollator:
'c': [1, 1]}}
findDictDiff(raw_pad_batch, pad_batch)

def test_raise(self, capsys):
from fastNLP.core.log import logger
logger.set_stdout('raw')
# 对于 nested 的情况
collator = Collator(backend='numpy')
data = [[1, 2], [2, 3]]
collator.set_pad('_0')
collator.set_pad('_0')
print(collator(data))
with Capturing() as out:
collator.set_ignore('_0')
assert '_0' in out[0]

data = [{1: {2: 2, 3: 3}}]
collator = Collator()
collator.set_pad((1, 2))
collator.set_pad((1, 3))
with Capturing() as out:
collator.set_ignore(1)
assert '(1, 2)' in out[0] and '(1, 3)' in out[0]
assert len(collator(data))==0

collator = Collator()
collator.set_ignore((1, 2))
with pytest.raises(KeyError):
collator.set_pad(1)




collator = Collator()
collator.set_ignore(1)
with pytest.raises(KeyError):
collator.set_pad((1, 2))


@pytest.mark.torch
def test_torch_dl():
from fastNLP import TorchDataLoader
from fastNLP import DataSet
import numpy as np
import torch

ds = DataSet({
'x': [1, 2], 'y': [[1,2], [3]], 'z':[np.ones((1, 2)), np.ones((2, 3))],
'i': [{'j': [1, 2]}, {'j': [3]}], 'j': ['a', 'b']
})

dl = TorchDataLoader(ds, batch_size=2)
batch = next(iter(dl))
assert 'x' in batch and 'y' in batch and 'z' in batch and 'i' in batch and 'j' in batch
assert isinstance(batch['z'], torch.Tensor)
assert isinstance(batch['j'], list)
assert isinstance(batch['i']['j'], torch.Tensor)

dl.set_ignore('x')
batch = next(iter(dl))
assert 'x' not in batch and 'y' in batch and 'z' in batch

dl.set_pad('y', pad_val=None)
batch = next(iter(dl))
assert 'x' not in batch and 'y' in batch and 'z' in batch
assert isinstance(batch['y'], list)
assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad

dl.set_pad(('i', 'j'), pad_val=None)
batch = next(iter(dl))
assert 'x' not in batch and 'y' in batch and 'z' in batch
assert isinstance(batch['y'], list)
assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad
assert isinstance(batch['i']['j'], list)
assert len(batch['i']['j'][0])!=len(batch['i']['j'][1]) # 没有 pad

with pytest.raises(KeyError):
dl.set_pad('i', pad_val=None)


def test_compare_tuple():
from fastNLP.core.collators.collator import _compare_tuple
for t1, t2, t in zip([(1,), (1, 2, 3), (1,), (1, 2)],
[(1, 2, 3), (1,), (2,), (1, 3)],
[-2, 2, None, None]):
assert _compare_tuple(t1, t2) == t

Loading…
Cancel
Save