From 0fd17152b8a88e137b09d2ab3069d70f76744639 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 5 May 2022 22:24:13 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9Collator=E7=9A=84=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=EF=BC=8C=E4=BC=98=E5=8C=96=E4=BA=86nested=E5=9C=BA?= =?UTF-8?q?=E6=99=AF=E4=B8=8B=E7=9A=84=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collators/collator.py | 531 ++++-------------- fastNLP/core/collators/packer_unpacker.py | 77 ++- .../core/collators/padders/torch_padder.py | 5 +- fastNLP/core/log/logger.py | 8 +- fastNLP/core/log/print.py | 2 +- tests/core/collators/test_collator.py | 83 ++- 6 files changed, 224 insertions(+), 482 deletions(-) diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 9ea08d95..70848fd0 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -12,8 +12,8 @@ from .padders.get_padder import get_padder import re -from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ - pack_batch_sequence +from .packer_unpacker import SequencePackerUnpacker, SinglePackerUnpacker, MappingPackerUnpacker, \ + NestedMappingPackerUnpacker sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', 'auto', None] @@ -126,46 +126,36 @@ class Collator: logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " f"is `{self.batch_data_type}`.") if self.batch_data_type == 's': - self.unpack_batch_func = lambda batch, ignore_fields: {'_single': batch} # 不需要做任何调整 - self.pack_batch_func = lambda x: x['_single'] + self.packer_unpacker = SinglePackerUnpacker() # 不需要做任何调整 elif self.batch_data_type == 'l': - self.unpack_batch_func = unpack_batch_sequence - self.pack_batch_func = pack_batch_sequence + self.packer_unpacker = SequencePackerUnpacker() elif self.batch_data_type == 'd': if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{('a', 'b'): value} - self.unpack_batch_func = unpack_batch_nested_mapping - self.pack_batch_func = pack_batch_nested_mapping + self.packer_unpacker = NestedMappingPackerUnpacker() else: - self.unpack_batch_func = unpack_batch_mapping - self.pack_batch_func = lambda x:x + self.packer_unpacker = MappingPackerUnpacker() - if self.unpack_batch_func is unpack_batch_nested_mapping: # 比较特殊,需要防止继续往下延伸 - unpack_batch: Dict = self.unpack_batch_func(batch, self.ignore_fields, set(self.input_fields.keys())) - else: - unpack_batch:Dict = self.unpack_batch_func(batch, self.ignore_fields) # 将各自 field 组成 batch 形式。 + # 将 batch 中各个 field 组成自己的 batch;同时忽略处于 ignore_fields 中的数据。 + unpack_batch = self.packer_unpacker.unpack_batch(batch, self.ignore_fields, self.input_fields) pad_batch = {} if len(self.padders)==0: # 第一次运行,准备 padder if self.backend == 'auto': # 如果 backend 为 auto ,则尝试通过调用栈等自动获取 backend 。 self.backend = _get_backend() - for key in unpack_batch.keys(): - if key not in self.input_fields and key not in self.ignore_fields: - self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} - elif key in self.input_fields and self.input_fields[key]['backend'] == 'auto': - self.input_fields[key]['backend'] = self.backend - - for field_name, setting in self.input_fields.items(): - pad_fn = setting.get('pad_fn', None) + for field_name, batch_field in unpack_batch.items(): + setting = self.input_fields.get(field_name, {'backend': self.backend, 'pad_val': 0 , + 'dtype': None, 'pad_fn': None}) + pad_fn = setting['pad_fn'] if callable(pad_fn): padder = pad_fn else: backend = self.backend if setting['backend'] == 'auto' else setting['backend'] - batch_field = unpack_batch.get(field_name) padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], dtype=setting['dtype'], backend=backend, field_name=field_name) self.padders[field_name] = padder + if self.batch_data_type == 'l': self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 @@ -173,7 +163,7 @@ class Collator: batch = unpack_batch.get(key) pad_batch[key] = padder(batch) - return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 + return self.packer_unpacker.pack_batch(pad_batch) # 根据情况恢复成与输入一致的类型 def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend='auto', pad_fn:Callable=None) -> "Collator": @@ -195,16 +185,17 @@ class Collator: 形式,输出将被直接作为结果输出。 :return: 返回 Collator 自身 """ - self.padders.clear() # 重新生成 + self._renew() - if self.batch_data_type is not None: - if self.batch_data_type == 's': - logger.debug("Set as single field mode.") - self.input_fields.clear() - elif self.batch_data_type == 'd': + if self.batch_data_type == 's': + logger.debug("Set as single field mode.") + self.input_fields.clear() + elif self.batch_data_type == 'd': + if isinstance(field_name, str): assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ f"index, but other field is set as dict mode." - elif self.batch_data_type == 'l': + elif self.batch_data_type == 'l': + if isinstance(field_name, str): assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ f"field name is {field_name}." @@ -215,8 +206,40 @@ class Collator: else: self.batch_data_type = 'd' - if field_name in self.ignore_fields: - logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") + # 检测是否已经设置了,主要需要考虑它的父亲节点的情况 + ignore_fields = [(field, field) if isinstance(field, tuple) else ((field,), field) + for field in self.ignore_fields] + input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) + for field in self.input_fields.keys()] + if isinstance(field_name, tuple): + _field_name = field_name + else: + _field_name = (field_name,) + for field, o_field in ignore_fields: + d = _compare_tuple(field, _field_name) + if d is None: + continue + if d == 0: + logger.rank_zero_warning(f"Field:`{field_name}` has been set as ignored before. It will not be " + f"ignored afterwards.") + self.ignore_fields.remove(o_field) + if d > 0: + raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set " + f"as ignore field.") + if d < 0: + raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set " + f"as ignore field.") + for field, o_field in input_field_names: + d = _compare_tuple(field, _field_name) + if d is None: + continue + if d > 0: + raise KeyError(f"Cannot set `{field_name}` as input, since its children `{o_field}` has been set " + f"pad.") + if d < 0: + raise KeyError(f"Cannot set `{field_name}` as input, since its parent `{o_field}` has been set " + f"pad.") + if backend is None: backend = self.backend else: @@ -235,7 +258,7 @@ class Collator: :return: """ assert backend in SUPPORTED_BACKENDS - self.padders.clear() + self._renew() self.backend = backend def set_ignore(self, *field_names) -> "Collator": @@ -249,400 +272,56 @@ class Collator: __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 :return: 返回 Collator 自身 """ - for field_name in field_names: - if field_name in self.input_fields: - self.input_fields.pop(field_name) - logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") - self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 - self.ignore_fields.add(field_name) + self._renew() + input_field_names = [(field, field) if isinstance(field, tuple) else ((field,), field) + for field in self.input_fields.keys()] + + # 需要考虑父节点之类的情况 + for field in field_names: + if not isinstance(field, tuple): + _field = (field,) + else: + _field = field + for _field_name, o_field_name in input_field_names: + d = _compare_tuple(_field, _field_name) + if d is None: + continue + if d == 0: + self.input_fields.pop(o_field_name) + logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.") + if d < 0: + self.input_fields.pop(o_field_name) + logger.rank_zero_warning(f"Field:{o_field_name} has been set as pad before. It will be ignored afterwards.") + if d > 0: + raise KeyError(f"Cannot ignore {field} since its parent key {o_field_name} has been set as pad.") + self.ignore_fields.add(field) return self + def _renew(self): + self.packer_unpacker = None + self.padders.clear() - - - - - -# -# from abc import ABCMeta, abstractmethod -# from typing import Any, Dict, List, Callable, Union, Tuple -# from numbers import Number -# import warnings -# -# import numpy as np -# -# from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH -# -# if _NEED_IMPORT_PADDLE: -# import paddle -# -# if _NEED_IMPORT_TORCH: -# import torch -# -# -# class ApplyResultException(Exception): -# def __init__(self, msg, index=None): -# super().__init__(msg) -# self.msg = msg -# self.index = index # 标示在哪个数据遭遇到问题了 -# -# -# class SetInputOrTargetException(Exception): -# def __init__(self, msg, index=None, field_name=None): -# super().__init__(msg) -# self.msg = msg -# self.index = index # 标示在哪个数据遭遇到问题了 -# self.field_name = field_name # 标示当前 field 的名称 -# -# -# def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]: -# r""" -# 识别cell的类别与dimension的数量 -# -# numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html -# :param cell: -# :param dim: -# :return: -# """ -# if isinstance(cell, (str, Number, np.bool_)): -# if hasattr(cell, 'dtype'): -# return cell.dtype.type, dim -# return type(cell), dim -# -# elif isinstance(cell, list): -# dim += 1 -# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] -# types = set([i for i, j in res]) -# dims = set([j for i, j in res]) -# if len(types) > 1: -# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) -# elif len(types) == 0: -# raise SetInputOrTargetException("Empty value encountered.") -# if len(dims) > 1: -# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) -# return types.pop(), dims.pop() -# -# elif isinstance(cell, torch.Tensor): -# return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 -# -# elif isinstance(cell, paddle.Tensor): -# return cell.dtype, cell.dim() + dim -# -# elif isinstance(cell, np.ndarray): -# if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 -# return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 -# # 否则需要继续往下 iterate -# dim += 1 -# res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] -# types = set([i for i, j in res]) -# dims = set([j for i, j in res]) -# if len(types) > 1: -# raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) -# elif len(types) == 0: -# raise SetInputOrTargetException("Empty value encountered.") -# if len(dims) > 1: -# raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) -# return types.pop(), dims.pop() -# -# else: # 包含 tuple, set, dict 以及其它的类型 -# raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") -# -# -# def _get_ds_type_dim(ds: dict): -# # 获取数据集第一行的 field 内部函数的类型和维度 -# field_dtype, field_dim = {}, {} -# for field_name, field_content in ds.items(): -# type_0, dim_0 = _get_ele_type_and_dim(field_content) -# field_dtype[field_name], field_dim[field_name] = type_0, dim_0 -# return field_dtype, field_dim -# -# -# class Collator(metaclass=ABCMeta): -# r""" -# 辅助DataLoader管理collate_fn的类 -# -# """ -# -# def __init__(self): -# super(Collator, self).__init__() -# self.collate_fn = [] -# -# @abstractmethod -# def __call__(self, ins_lst: List) -> Any: -# raise NotImplementedError -# -# @abstractmethod -# def set_pad_val(self, *field_names: str, value=0): -# raise NotImplementedError -# -# -# class _MultiCollator: -# """ -# 管理所有collator的容器, -# 遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 -# """ -# -# def __init__(self, collate_fns: Union[Callable, List[Callable], None]): -# -# if collate_fns is None: -# collate_fns = [] -# -# if isinstance(collate_fns, Callable): -# collate_fns = [collate_fns] -# -# self._collators: list = collate_fns -# -# def __call__(self, ins_lst) -> Dict: -# out, list_out = {}, [] -# for idx, _collate_fn in enumerate(self._collators): -# res = _collate_fn(ins_lst) -# if isinstance(res, Dict): -# out.update(res) -# else: -# list_out.append(res) -# # else: -# # raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") -# if len(out) > 0 and len(list_out) > 0: -# raise ValueError("the return of collate_fns is not the same, must be dict or list") -# if len(list_out) == 1: -# list_out = list_out[-1] -# # print(list_out) -# return out if len(out) > 0 else list_out -# -# def get_collators(self): -# return self._collators -# -# def add_collator(self, collator: Callable): -# self._collators.append(collator) -# -# def set_as_numpy(self, as_numpy: bool): -# """ -# 存在AutoCollator时,as_numpy控制其返回值的类型 -# -# :param as_numpy: -# :return: -# """ -# for collator in self._collators: -# if isinstance(collator, AutoCollator): -# collator.set_as_numpy(as_numpy) -# return self -# -# def set_pad_val(self, *field_names, val=0): -# """ -# 存在AutoCollator时,设置field_name的padding值 -# -# :param field_names: 数据集的field名 -# :param val: padding的值 -# :return: -# """ -# flag = True -# for collator in self._collators: -# if isinstance(collator, AutoCollator): -# collator.set_pad_val(*field_names, val=val) -# flag = False -# if flag: -# warnings.warn("AutoCollator is remove, set_padding is unavailable!!") -# return self -# -# def set_input(self, *field_names): -# """ -# 设置AutoCollator需要的field_names,未被设置默认过滤掉 -# -# :param field_names: -# :return: -# """ -# flag = True -# for collator in self._collators: -# if isinstance(collator, AutoCollator): -# collator.set_input(*field_names) -# flag = False -# if flag: -# warnings.warn("AutoCollator is removed, set_input is unavailable!!") -# return self -# -# -# class AutoCollator(Collator): -# -# def __init__(self, as_numpy: bool): -# super(AutoCollator, self).__init__() -# self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 -# self.need_inputs = set() # 需要的 field name -# self.field_dtypes = None # 每列数据单元的 dtype 类型 -# self.field_dims = None # 每列数据单元维度 -# self.as_numpy = as_numpy -# -# def __call__(self, ins_lst: List[Dict]) -> dict: -# if len(self.need_inputs) == 0: -# raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) -# # TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的 -# -# # 第一种情况,设置了 set_input 的值 -# # 第二种情况, 根据数据的类型的判断是否 padding -# if self.field_dtypes is None and self.field_dims is None: -# field_dtypes, field_dims = {}, {} -# for key, value in ins_lst[0].items(): -# if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None: -# field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value) -# self.field_dtypes = field_dtypes -# self.field_dims = field_dims -# -# pack_ins_lst, pad_ins_lst = {field_name: [] -# for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} -# # 将 list 列表内数据按列名打包 -# for per_ins in ins_lst: -# for field_name, _field_content in per_ins.items(): -# if field_name in self.need_inputs: -# pack_ins_lst[field_name].append(_field_content) -# -# pad_field_kv = {field_name: 0 for field_name in self.need_inputs} -# pad_field_kv.update(self.pad_field_value) -# self.pad_field_value = pad_field_kv -# -# if len(self.pad_field_value.keys()) > 0: -# # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 -# non_pad_field_names = [] -# for k, v in self.pad_field_value.items(): -# if v is None: -# non_pad_field_names.append(k) -# -# # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) -# for field_name in non_pad_field_names: -# field_array = pack_ins_lst.pop(field_name) -# pad_ins_lst[field_name] = np.array(field_array) -# -# for field_name, field_array in pack_ins_lst.items(): -# content = pad_content(field_array, field_name, self.field_dtypes[field_name], -# self.field_dims[field_name], -# self.pad_field_value[field_name], -# as_numpy=self.as_numpy) -# pad_ins_lst[field_name] = content -# -# # else: -# # # 取出每列的数据,根据类型判断是否能 pad -# # for field_name, field_array in pack_ins_lst.items(): -# # pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], -# # self.field_dims[field_name], -# # pad_val=0, as_numpy=self.as_numpy) -# # pad_ins_lst[field_name] = pad_field_array -# -# return pad_ins_lst -# -# def set_pad_val(self, *field_names, val=0): -# for field_name in field_names: -# self.pad_field_value[field_name] = val -# -# def set_as_numpy(self, as_numpy: bool): -# self.as_numpy = as_numpy -# -# def set_input(self, *field_names): -# for field_name in field_names: -# self.need_inputs.add(field_name) -# -# -# def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): -# -# if field_type: -# # 不处理, 返回 np.array 类型 -# if field_dim > 3: -# return np.array(content) -# # 元素类型为数值类型 np.int64, np.float64, int, float 等 -# if isinstance(field_type, type) and \ -# (issubclass(field_type, np.number) or issubclass(field_type, Number)): -# if field_dim == 0: -# array = np.array(content, dtype=field_type) -# elif field_dim == 1: -# max_len = max(map(len, content)) -# array = np.full((len(content), max_len), pad_val, dtype=field_type) -# for i, content_i in enumerate(content): -# array[i, :len(content_i)] = content_i -# elif field_dim == 2: -# max_len = max(map(len, content)) -# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for -# content_i in content]) -# array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) -# for i, content_i in enumerate(content): -# for j, content_ii in enumerate(content_i): -# array[i, j, :len(content_ii)] = content_ii -# else: -# shape = np.shape(content) -# if len(shape) == 4: # 说明各 dimension 是相同的大小 -# array = np.array(content, dtype=field_type) -# else: -# raise RuntimeError( -# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") -# if as_numpy is False: -# array = torch.tensor(array) -# return array -# # 元素类型为数值类型 torch.float 等 -# elif str(field_type).startswith('torch'): -# if field_dim == 0: -# tensor = torch.tensor(content).to(field_type) -# elif field_dim == 1: -# max_len = max(map(len, content)) -# tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) -# for i, content_i in enumerate(content): -# tensor[i, :len(content_i)] = content_i.clone().detach() -# elif field_dim == 2: -# max_len = max(map(len, content)) -# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for -# content_i in content]) -# tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, -# dtype=field_type) -# for i, content_i in enumerate(content): -# for j, content_ii in enumerate(content_i): -# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() -# else: -# shapes = set([np.shape(content_i) for content_i in content]) -# if len(shapes) > 1: -# raise RuntimeError( -# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") -# shape = shapes.pop() -# if len(shape) == 3: -# tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, -# dtype=field_type) -# for i, content_i in enumerate(content): -# tensor[i] = content_i.clone().detach().to(field_type) -# else: -# raise RuntimeError( -# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") -# return tensor -# # TODO 增加jittor/paddle? -# elif str(field_type).startswith('paddle'): -# if field_dim == 0: -# tensor = paddle.Tensor(content).to(field_type) -# elif field_dim == 1: -# max_len = max(map(len, content)) -# tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) -# for i, content_i in enumerate(content): -# tensor[i, :len(content_i)] = content_i.clone().detach() -# elif field_dim == 2: -# max_len = max(map(len, content)) -# max_word_len = max([max([len(content_ii) for content_ii in content_i]) for -# content_i in content]) -# tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, -# dtype=field_type) -# for i, content_i in enumerate(content): -# for j, content_ii in enumerate(content_i): -# tensor[i, j, :len(content_ii)] = content_ii.clone().detach() -# else: -# shapes = set([np.shape(content_i) for content_i in content]) -# if len(shapes) > 1: -# raise RuntimeError( -# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") -# shape = shapes.pop() -# if len(shape) == 3: -# tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, -# dtype=field_type) -# for i, content_i in enumerate(content): -# tensor[i] = content_i.clone().detach().to(field_type) -# else: -# raise RuntimeError( -# f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") -# return tensor -# -# else: -# return np.array(content) # 不进行任何操作 -# else: -# return np.array(content) +def _compare_tuple(t1, t2): + """ + 检测 t1 和 t2 的关系。 + 例如 (1, ) 和 (1, ) 关系为 0,表示两者完全没有差异 + 例如 (1, ) 和 (2, ) 关系为 None,表示完全不同 + 例如 (1, 2, 3) 和 (1, ) 关系为 2,表示前者比后者长 2 位 + 但 例如 (1, 2, 3) 和 (2, ) 关系为 None,因为它们从前往后的key 不一样 + 例如 (1, 2, 3) 和 (1, 3) 关系为 None,因为它们从前往后的key 不一样 + + 例如 (1, ) 和 (1, 2, 3) 关系为 -2,表示后者比前者长 2 位 + 但 例如 (2, ) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样 + 例如 (1, 3) 和 (1, 2, 3) 关系为 None,因为它们从前往后的key 不一样 + :param t1: + :param t2: + :return: None 没有关系; 0 两者完全一样; >0 t1比t2长,<0 t2比t1长 + """ + if t1 == t2: + return 0 + for _t1, _t2 in zip(t1, t2): # 会按照最短的计算 + if _t1 != _t2: + return None + return len(t1) - len(t2) diff --git a/fastNLP/core/collators/packer_unpacker.py b/fastNLP/core/collators/packer_unpacker.py index f71b4113..033cfca5 100644 --- a/fastNLP/core/collators/packer_unpacker.py +++ b/fastNLP/core/collators/packer_unpacker.py @@ -3,7 +3,7 @@ from functools import reduce from typing import Sequence, Mapping, Dict -class MappingPackerUnPacker: +class MappingPackerUnpacker: @staticmethod def unpack_batch(batch:Sequence[Mapping], ignore_fields:set, input_fields:Dict)->Dict: """ @@ -53,8 +53,9 @@ class NestedMappingPackerUnpacker: @staticmethod def pack_batch(batch): + if len(batch) == 0: + return [] dicts = [] - for key, value in batch.items(): if not isinstance(key, tuple): key = [key] @@ -65,30 +66,38 @@ class NestedMappingPackerUnpacker: return reduce(_merge_dict, dicts) -class +class SequencePackerUnpacker: + @staticmethod + def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields)->Dict: + """ + 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} + :param batch: + :param ignore_fields: 需要忽略的field + :return: + """ + dict_batch = defaultdict(list) + for sample in batch: + for i, content in enumerate(sample): + field_name = f'_{i}' + if field_name in ignore_fields: + continue + dict_batch[field_name].append(content) + return dict_batch -def unpack_batch_nested_mapping(batch:Sequence[Mapping], ignore_fields:set, stop_deep_fields:set)->Dict: - """ - 将 nested 的 dict 中的内容展开到一个 flat dict 中 + @staticmethod + def pack_batch(batch): + return list(batch.values()) - :param batch: - :param ignore_fields: 需要忽略的 field 。 - :param stop_deep_fields: 不需要继续往下衍射的 - :return: - """ - dict_batch = defaultdict(list) - for sample in batch: - for key, value in sample.items(): - if key in ignore_fields: - continue - if isinstance(value, Mapping) and key not in stop_deep_fields: - _dict_batch = _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent=(key,)) - for key, value in _dict_batch.items(): - dict_batch[key].append(value) - else: - dict_batch[key].append(value) - return dict_batch + +class SinglePackerUnpacker: + @staticmethod + def unpack_batch(batch:Sequence[Sequence], ignore_fields, input_fields): + return {'_single': batch} + + @staticmethod + def pack_batch(batch): + return batch['_single'] def _unpack_batch_nested_mapping(value, ignore_fields, stop_deep_fields, _parent)->Dict: @@ -136,25 +145,3 @@ def _merge_dict(a, b, path=None): else: a[key] = b[key] return a - - -def unpack_batch_sequence(batch:Sequence[Sequence], ignore_fields)->Dict: - """ - 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} - - :param batch: - :param ignore_fields: 需要忽略的field - :return: - """ - dict_batch = defaultdict(list) - for sample in batch: - for i, content in enumerate(sample): - field_name = f'_{i}' - if field_name in ignore_fields: - continue - dict_batch[field_name].append(content) - return dict_batch - - -def pack_batch_sequence(batch:Mapping)->Sequence: - return list(batch.values()) \ No newline at end of file diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index f1940380..d6d07dcd 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -112,16 +112,19 @@ class TorchTensorPadder(Padder): @staticmethod def pad(batch_field, pad_val, dtype): + device = None try: if not isinstance(batch_field[0], torch.Tensor): batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] + else: + device = batch_field[0].device except AttributeError: raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") shapes = [field.shape for field in batch_field] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] - tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) + tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype, device=device) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) tensor[slices] = field diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index bbc1e8e1..179755e2 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -134,11 +134,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): :return: """ if os.environ.get(FASTNLP_GLOBAL_RANK, '0') == '0': - if msg not in self._warning_msgs: - if self.isEnabledFor(WARNING): - # kwargs = self._add_rank_info(kwargs) - self._log(WARNING, msg, args, **kwargs) - self._warning_msgs.add(msg) + if self.isEnabledFor(WARNING): + # kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) def warn(self, msg, *args, **kwargs): if self.isEnabledFor(WARNING): diff --git a/fastNLP/core/log/print.py b/fastNLP/core/log/print.py index 30797b89..1ebebcd3 100644 --- a/fastNLP/core/log/print.py +++ b/fastNLP/core/log/print.py @@ -21,5 +21,5 @@ def print(*args, sep=' ', end='\n', file=None, flush=False): :param flush: 该参数无意义。 :return: """ - line = sep.join(args) + line = sep.join(map(str, args)) logger.info(line) \ No newline at end of file diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index 65101321..ae219793 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -5,6 +5,7 @@ import pytest from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR from fastNLP.core.collators.collator import Collator +from ...helpers.utils import Capturing def _assert_equal(d1, d2): @@ -42,7 +43,6 @@ def findListDiff(d1, d2): class TestCollator: - @pytest.mark.torch def test_run(self): dict_batch = [{ @@ -286,8 +286,83 @@ class TestCollator: 'c': [1, 1]}} findDictDiff(raw_pad_batch, pad_batch) + def test_raise(self, capsys): + from fastNLP.core.log import logger + logger.set_stdout('raw') + # 对于 nested 的情况 + collator = Collator(backend='numpy') + data = [[1, 2], [2, 3]] + collator.set_pad('_0') + collator.set_pad('_0') + print(collator(data)) + with Capturing() as out: + collator.set_ignore('_0') + assert '_0' in out[0] + + data = [{1: {2: 2, 3: 3}}] + collator = Collator() + collator.set_pad((1, 2)) + collator.set_pad((1, 3)) + with Capturing() as out: + collator.set_ignore(1) + assert '(1, 2)' in out[0] and '(1, 3)' in out[0] + assert len(collator(data))==0 + collator = Collator() + collator.set_ignore((1, 2)) + with pytest.raises(KeyError): + collator.set_pad(1) - - - + collator = Collator() + collator.set_ignore(1) + with pytest.raises(KeyError): + collator.set_pad((1, 2)) + + +@pytest.mark.torch +def test_torch_dl(): + from fastNLP import TorchDataLoader + from fastNLP import DataSet + import numpy as np + import torch + + ds = DataSet({ + 'x': [1, 2], 'y': [[1,2], [3]], 'z':[np.ones((1, 2)), np.ones((2, 3))], + 'i': [{'j': [1, 2]}, {'j': [3]}], 'j': ['a', 'b'] + }) + + dl = TorchDataLoader(ds, batch_size=2) + batch = next(iter(dl)) + assert 'x' in batch and 'y' in batch and 'z' in batch and 'i' in batch and 'j' in batch + assert isinstance(batch['z'], torch.Tensor) + assert isinstance(batch['j'], list) + assert isinstance(batch['i']['j'], torch.Tensor) + + dl.set_ignore('x') + batch = next(iter(dl)) + assert 'x' not in batch and 'y' in batch and 'z' in batch + + dl.set_pad('y', pad_val=None) + batch = next(iter(dl)) + assert 'x' not in batch and 'y' in batch and 'z' in batch + assert isinstance(batch['y'], list) + assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad + + dl.set_pad(('i', 'j'), pad_val=None) + batch = next(iter(dl)) + assert 'x' not in batch and 'y' in batch and 'z' in batch + assert isinstance(batch['y'], list) + assert len(batch['y'][0])!=len(batch['y'][1]) # 没有 pad + assert isinstance(batch['i']['j'], list) + assert len(batch['i']['j'][0])!=len(batch['i']['j'][1]) # 没有 pad + + with pytest.raises(KeyError): + dl.set_pad('i', pad_val=None) + + +def test_compare_tuple(): + from fastNLP.core.collators.collator import _compare_tuple + for t1, t2, t in zip([(1,), (1, 2, 3), (1,), (1, 2)], + [(1, 2, 3), (1,), (2,), (1, 3)], + [-2, 2, None, None]): + assert _compare_tuple(t1, t2) == t