diff --git a/fastNLP/core/collators/new_collator.py b/fastNLP/core/collators/new_collator.py new file mode 100644 index 00000000..869a60a7 --- /dev/null +++ b/fastNLP/core/collators/new_collator.py @@ -0,0 +1,181 @@ +from typing import List, Union, Dict, Callable, Sequence, Mapping + +from fastNLP.core.log import logger +from .padders.get_padder import get_padder + +import re + +from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ + pack_batch_sequence, NESTED_DICT_SEPARATOR + +sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 +SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] + + +class Collator: + def __init__(self, backend='torch'): + """ + 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 + 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。 + + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], + 若为 None ,则不进行 padding 。 + """ + self.unpack_batch_func = None + self.pack_batch_func = None + self.ignore_fields = set() + self.padders = {} + self.input_fields = {} + self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 + self.set_backend(backend) + + def __call__(self, batch)->Union[List, Dict]: + """ + batch可能存在三种可能性 + List[Dict], List[List], List[Sample] + + 第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 + 第二步:使用每个 field 各自的 padder 进行 pad 。 + 第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 + + 第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 + list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample + 的类别。 + 第一次调用会根据当前 field 决定对应的 Padder 。 + + """ + if self.unpack_batch_func is None: + # 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 + if self.batch_data_type is None: + if isinstance(batch[0], Mapping): + self.batch_data_type = 'd' + elif isinstance(batch[0], Sequence): # 这里存在误判的风险 + self.batch_data_type = 'l' + else: + self.batch_data_type = 's' + logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " + f"is {self.batch_data_type}") + if self.batch_data_type == 's': + self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整 + self.pack_batch_func = lambda x:x['_single'] + elif self.batch_data_type == 'l': + self.unpack_batch_func = unpack_batch_sequence + self.pack_batch_func = pack_batch_sequence + elif self.batch_data_type == 'd': + if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value} + self.unpack_batch_func = unpack_batch_nested_mapping + self.pack_batch_func = pack_batch_nested_mapping + else: + self.unpack_batch_func = unpack_batch_mapping + self.pack_batch_func = lambda x:x + + unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。 + + pad_batch = {} + if len(self.padders)==0: # 第一次运行,准备 padder + for key in unpack_batch.keys(): + if key not in self.input_fields and key not in self.ignore_fields: + self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} + + for field_name, setting in self.input_fields.items(): + pad_fn = setting.get('pad_fn', None) + if callable(pad_fn): + padder = pad_fn + else: + batch_field = unpack_batch.get(field_name) + padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], + dtype=setting['dtype'], backend=setting['backend'], + field_name=field_name) + self.padders[field_name] = padder + if self.batch_data_type == 'l': + self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 + + for key, padder in self.padders.items(): + batch = unpack_batch.get(key) + pad_batch[key] = padder(batch) + + return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 + + def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> "Collator": + """ + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, + paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 + """ + self.padders.clear() # 重新生成 + + if self.batch_data_type is not None: + if self.batch_data_type == 's': + logger.debug("Set as single field mode.") + self.input_fields.clear() + elif self.batch_data_type == 'd': + assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ + f"index, but other field is set as dict mode." + elif self.batch_data_type == 'l': + assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ + f"field name is {field_name}" + + if field_name == '_single': + self.batch_data_type = 's' + elif sequence_idx_str.match(field_name): + self.batch_data_type = 'l' + else: + self.batch_data_type = 'd' + + if field_name in self.ignore_fields: + logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") + if backend is None: + backend = self.backend + else: + assert backend in SUPPORTED_BACKENDS + + self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} + + return self + + def set_backend(self, backend:str): + """ + 设置可以 pad 的 field 默认 pad 为什么类型的 tensor + + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], + 若为 None ,则不进行 padding 。 + :return: + """ + assert backend in SUPPORTED_BACKENDS + self.padders.clear() + self.backend = backend + + def set_ignore(self, *field_names) -> "Collator": + """ + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: 返回 Collator 自身 + """ + for field_name in field_names: + if field_name in self.input_fields: + self.input_fields.pop(field_name) + logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") + self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 + self.ignore_fields.add(field_name) + + return self + + diff --git a/fastNLP/core/collators/padders/__init__.py b/fastNLP/core/collators/padders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/core/collators/padders/exceptions.py b/fastNLP/core/collators/padders/exceptions.py new file mode 100644 index 00000000..8b08683d --- /dev/null +++ b/fastNLP/core/collators/padders/exceptions.py @@ -0,0 +1,44 @@ +__all__ = [ + 'InconsistencyError', + 'EleDtypeUnsupportedError', + 'EleDtypeDtypeConversionError', + 'DtypeUnsupportedError', + "DtypeError" +] + + +class InconsistencyError(BaseException): + """ + 当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。 + + """ + def __init__(self, msg, *args): + super(InconsistencyError, self).__init__(msg, *args) + + +class DtypeError(BaseException): + def __init__(self, msg, *args): + super(DtypeError, self).__init__(msg, *args) + self.msg = msg + + +class EleDtypeUnsupportedError(DtypeError): + """ + 当 batch 中的 element 的类别本身无法 pad 的时候报错。 + 例如要求 str 类型的数据进行 padding 。 + + """ + + +class EleDtypeDtypeConversionError(DtypeError): + """ + 当 batch 中的 element 的类别无法转换为 dtype 类型时报错。 + + """ + + +class DtypeUnsupportedError(DtypeError): + """ + 当当前 backend 不支持这种类型的 dtype 时报错。 + + """ \ No newline at end of file diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py new file mode 100644 index 00000000..051a0ffc --- /dev/null +++ b/fastNLP/core/collators/padders/get_padder.py @@ -0,0 +1,193 @@ + +from typing import Dict + + + +from typing import Sequence, Any, Union, Dict +from abc import ABC + +from fastNLP.core.log import logger + + +from .padder import Padder, NullPadder +from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder +from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder +from .raw_padder import RawNumberPadder, RawSequencePadder +from .exceptions import * + + +def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: + """ + 根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。 + + :param batch_field: 将某 field 的内容组合成一个 batch 传入。 + :param pad_val: + :param backend: + :param dtype: + :param field_name: 方便报错的。 + :return: + """ + logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field)) + if pad_val is None: + logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") + return NullPadder() + if backend is None: + logger.debug(f"The backend for field:{field_name} is None, not padding this field.") + return NullPadder() + + # 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。 + must_pad = False + if pad_val != 0 or dtype is not None: + must_pad = True + + catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。 + + # 根据 catalog 来判定当前是否可以进行 pad 。 + # 首先检查是否所有的 key 是一样长的,表明深度是一致的 + depths = set(map(len, catalog.keys())) + num_depth = len(depths) + if num_depth != 1: + msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \ + f"information please set logger's level to DEBUG." + if must_pad: + raise InconsistencyError(msg) + logger.debug(msg) + return NullPadder() + + # 再检查所有的元素 shape 是否一致? + shape_lens = set([len(v[0]) for v in catalog.values()]) + num_shape = len(shape_lens) + if num_shape != 1: + msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \ + f"information please set logger's level to DEBUG." + if must_pad: + raise InconsistencyError(msg) + logger.debug(msg) + return NullPadder() + + # 再检查所有的元素 type 是否一致 + ele_dtypes = set([v[1] for v in catalog.values()]) + num_eletypes = len(ele_dtypes) + if num_eletypes != 1: + msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ + f"information please set logger's level to DEBUG." + if must_pad: + raise InconsistencyError(msg) + logger.debug(msg) + return NullPadder() + + depth = depths.pop() + shape_len = shape_lens.pop() + ele_dtype = ele_dtypes.pop() + + # 需要由 padder 自己决定是否能够 pad 。 + try: + if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] + if backend == 'raw': + return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'numpy': + return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'torch': + return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + + if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 + if backend == 'raw': + return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'numpy': + return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'torch': + return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + + if depth == 1 and shape_len != 0: + if backend == 'numpy': + return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'torch': + return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + + if shape_len != 0 and depth>1: + msg = "Does not support pad tensor under nested list. If you need this, please report." + if must_pad: + raise RuntimeError(msg) + logger.debug(msg) + return NullPadder() + + except DtypeError as e: + msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ + "information please set logger's level to DEBUG." + if must_pad: + raise type(e)(msg=msg) + logger.debug(msg) + return NullPadder() + + except BaseException as e: + raise e + + return NullPadder() + + +class HasShapeDtype(ABC): + """ + 检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。 + + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is HasShapeDtype: + if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'): + return True + return False + return NotImplemented + + +def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: + """ + 获取对象的中 element 的基本信息,用于判断是否可以 padding。 + + :param content: + :param tuple parent: + :param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。 + 例如: [1, 2, 3] -> {(0,): ((), ), (1,): ((), ), (2,): ((), )} + 例如: [1, [2, 3], 4] -> {(0,): ((), ), (1, 0): ((), ), (1, 1): ((), ), (2,): ((), )} + 例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), ), (0, 1): ((), ), (1, 0): ((), ), (2, 0): ((), ), (2, 1): ((), )} + 例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)] + -> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)} + + :return: + """ + if catalog is None: + catalog = {} + + if parent is None: + parent = () + + if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray + shape = content.shape + dtype = content.dtype + catalog[parent] = (shape, dtype) + elif isinstance(content, (tuple, list)): + for i, c in enumerate(content): + _get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog) + else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 + catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 + return catalog + + + + +""" +from numbers import Number + +issubclass(type(3), Number) # True +issubclass(type(3.1), Number) # True +issubclass(type('3'), Number) # False +issubclass(type(True), Number) # True +issubclass(type(np.zeros(3)[0]), Number) # True +isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True +isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True +isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 +is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) +""" + + + diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py new file mode 100644 index 00000000..0298fd86 --- /dev/null +++ b/fastNLP/core/collators/padders/numpy_padder.py @@ -0,0 +1,72 @@ +__all__ = [ + 'NumpyNumberPadder', + 'NumpySequencePadder', +] + +from numbers import Number +from abc import ABC +from typing import Any, Union +import numpy as np + +from .padder import Padder +from .utils import get_padded_numpy_array, is_number_or_numpy_number +from .exceptions import * + + +def _get_dtype(ele_dtype, dtype, class_name): + if not is_number_or_numpy_number(ele_dtype): + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"or numpy numbers but get `{ele_dtype}`.") + + if dtype is None: + dtype = ele_dtype + else: + if not is_number_or_numpy_number(dtype): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " + f"or numpy numbers but get `{dtype}`.") + dtype = dtype + return dtype + + +class NumpyNumberPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + return np.array(batch_field, dtype=dtype) + + +class NumpySequencePadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) + + +class NumpyTensorPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + """ + pad 类似于 [np.array([3, 4], np.array([1])] 的 field + + :param ele_dtype: + :param pad_val: + :param dtype: + """ + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + shapes = [field.shape for field in batch_field] + max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] + array = np.full(max_shape, fill_value=pad_val, dtype=dtype) + for i, field in enumerate(batch_field): + slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) + array[slices] = field + return array + diff --git a/fastNLP/core/collators/padders/padder.py b/fastNLP/core/collators/padders/padder.py new file mode 100644 index 00000000..486574af --- /dev/null +++ b/fastNLP/core/collators/padders/padder.py @@ -0,0 +1,21 @@ + +class Padder: + def __init__(self, pad_val, dtype): + self.pad_val = pad_val + self.dtype = dtype + + def __call__(self, batch_field): + return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + raise NotImplementedError() + + +class NullPadder(Padder): + def __init__(self, ele_dtype=None, pad_val=None, dtype=None): + super().__init__(pad_val=pad_val, dtype=dtype) + + def __call__(self, batch_field): + # 直接返回,不调用 pad() 方法加快速度。 + return batch_field diff --git a/fastNLP/core/collators/padders/raw_padder.py b/fastNLP/core/collators/padders/raw_padder.py new file mode 100644 index 00000000..66393b40 --- /dev/null +++ b/fastNLP/core/collators/padders/raw_padder.py @@ -0,0 +1,48 @@ + + +from .padder import Padder +from .utils import get_padded_nest_list, is_number, get_padded_numpy_array +from .exceptions import * + + +def _get_dtype(ele_dtype, dtype, class_name): + if is_number(ele_dtype): + if dtype is None: + dtype = ele_dtype + elif not is_number(dtype): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " + f"get `{dtype}`.") + else: + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"but get `{ele_dtype}`.") + return dtype + + +class RawNumberPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + def __call__(self, batch_field): + return batch_field + + @staticmethod + def pad(batch_field, pad_val, dtype): + raise NotImplementedError() + + +class RawSequencePadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + """ + + :param batch_field: + :param pad_val: + :param dtype: 该参数无意义。 + :return: + """ + return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py new file mode 100644 index 00000000..a6768435 --- /dev/null +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -0,0 +1,157 @@ + +from inspect import isclass +import numpy as np + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + numpy_to_torch_dtype_dict = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 + np.complex64: torch.complex64, + np.complex128: torch.complex128 + } + number_to_torch_dtype_dict = { + float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64 + int: torch.int64, + bool: torch.bool + } + +from .padder import Padder +from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class +from .exceptions import * + + +def is_torch_tensor(dtype): + if not isclass(dtype) and isinstance(dtype, torch.dtype): + return True + return False + + +def _get_dtype(ele_dtype, dtype, class_name): + if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") + + if dtype is not None: + if not (is_torch_tensor(dtype) or is_number(dtype)): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " + f"or torch.dtype but get `{dtype}`.") + dtype = number_to_torch_dtype_dict.get(dtype, dtype) + else: + if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): + ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) + dtype = ele_dtype + elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 + dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) + elif is_numpy_generic_class(ele_dtype): + dtype = numpy_to_torch_dtype_dict.get(ele_dtype) + + return dtype + + +class TorchNumberPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + return torch.tensor(batch_field, dtype=dtype) + + +class TorchSequencePadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val) + return tensor + + +class TorchTensorPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + """ + 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 + + :param ele_dtype: + :param pad_val: + :param dtype: + """ + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + shapes = [field.shape for field in batch_field] + max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] + if isinstance(dtype, np.dtype): + print(dtype) + tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) + for i, field in enumerate(batch_field): + slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) + if isinstance(field, np.ndarray): + field = torch.from_numpy(field) + tensor[slices] = field + return tensor + + +def fill_tensor(batch_field, padded_batch, dtype): + """ + 将 batch_field 中的值填入到 tensor 中。 + + :param batch_field: 需要填充进入 array 中的内容 + :param padded_batch: 待填充的 tensor + :param dtype: 数据的类别 + + :return: + """ + if padded_batch.ndim == 2: + for i, content_i in enumerate(batch_field): + padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype) + elif padded_batch.ndim == 3: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) + elif padded_batch.ndim == 4: + try: # 应该是图像,所以直接应该就 ok 了。 + padded_batch = np.array(batch_field) + except: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + for k, content_iii in enumerate(content_ii): + padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype) + elif padded_batch.ndim == 1: + padded_batch[:] = torch.tensor(batch_field, dtype=dtype) + else: + raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " + "report.") + return padded_batch + + +def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0): + """ + 例如: + [[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]]) + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param dtype: 目标类别是什么 + :param pad_val: pad 的 value + :return: + """ + shapes = get_shape(batch_field) + tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val) + tensor = fill_tensor(batch_field, tensor, dtype=dtype) + return tensor diff --git a/fastNLP/core/collators/padders/torch_utils.py b/fastNLP/core/collators/padders/torch_utils.py new file mode 100644 index 00000000..a47bea0e --- /dev/null +++ b/fastNLP/core/collators/padders/torch_utils.py @@ -0,0 +1,20 @@ + + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + + +def is_torch_tensor_dtype(dtype) -> bool: + """ + 返回当前 dtype 是否是 torch 的 dtype 类型 + + + :param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 + :return: + """ + try: + return isinstance(dtype, torch.dtype) + except: + return False diff --git a/fastNLP/core/collators/padders/utils.py b/fastNLP/core/collators/padders/utils.py new file mode 100644 index 00000000..f6240219 --- /dev/null +++ b/fastNLP/core/collators/padders/utils.py @@ -0,0 +1,173 @@ + +from typing import Sequence, List +from numbers import Number +import re +from inspect import isclass + +import numpy as np +np_str_obj_array_pattern = re.compile(r'[SaUO]') + + +def get_shape(batch_field:List, shape=None): + """ + 给定 field 返回这个 field pad 完成之后的 shape 。 + 例如: [[1, 2, 3], [3]] -> [2, 3] + [[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3] + + :param batch_field: list,第 0 维一般为 batch 维度。 + :param shape: 无需传入。 + :return: + """ + if shape is None: + shape = [] + if isinstance(batch_field, Sequence): + num_ele = len(batch_field) + _shape = shape + [num_ele] + try: + shapes = [] + if isinstance(batch_field[0], Sequence): + for _field in batch_field: + shapes.append(get_shape(_field, _shape)) + max_shape = [max(_) for _ in zip(*shapes)] + return max_shape + except IndexError: # 空的shape + pass + return _shape # 说明是一个空的 sequence + else: + return shape + + +def fill_array(batch_field:List, padded_batch:np.ndarray): + """ + 将 batch_field 中的值填入到 array 中。 + + :param batch_field: 需要填充进入 array 中的内容 + :param padded_batch: 待填充的 np.ndarray + :return: + """ + if padded_batch.ndim == 2: + for i, content_i in enumerate(batch_field): + padded_batch[i, :len(content_i)] = content_i + elif padded_batch.ndim == 3: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + padded_batch[i, j, :len(content_ii)] = content_ii + elif padded_batch.ndim == 4: + try: # 应该是图像,所以直接应该就 ok 了。 + padded_batch = np.array(batch_field) + except: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + for k, content_iii in enumerate(content_ii): + padded_batch[i, j, k, :len(content_iii)] = content_iii + elif padded_batch.ndim == 1: + padded_batch[:] = batch_field + else: + raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " + "report.") + return padded_batch + + +def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: + """ + 例如: + [[1,2], [3]] -> np.array([[1, 2], [3, 0]]) + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param dtype: 目标类别是什么 + :param pad_val: pad 的 value + :return: + """ + shapes = get_shape(batch_field) + array = np.full(shapes, dtype=dtype, fill_value=pad_val) + array = fill_array(batch_field, array) + return array + + +def get_padded_nest_list(batch_field: List, pad_val=0) -> List: + """ + 例如: + [[1,2], [3]] -> [[1, 2], [3, 0]] + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param pad_val: pad 的 value + :return: + """ + + array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist() + return array + + +def is_number_or_numpy_number(dtype): + """ + 判断 dtype 是否是数字类型,或者 numpy 的数字类型。 + is_number_or_numpy_number(type(3)) # True + is_number_or_numpy_number(type(3.1)) # True + is_number_or_numpy_number(type('3')) # False + is_number_or_numpy_number(type(True)) # True + is_number_or_numpy_number(type(np.zeros(3)[0])) # True + is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True + is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True + is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False + is_number_or_numpy_number(np.array([1, [2]]).dtype) # False + + :param dtype: + :return: + """ + if is_number(dtype): + return True + else: + if isclass(dtype): + return is_numpy_generic_class(dtype) + elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: + return True + return False + + +def is_numpy_number_dtype(dtype): + if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: + return True + return False + + +def is_numpy_generic_class(dtype): + """ + 形如 np.int64,或者 np.zeros(1).dtype.type 的值 + + :param dtype: + :return: + """ + if isclass(dtype) and issubclass(dtype, np.generic): + return True + return False + + +def is_number(dtype): + try: + if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \ + and not is_numpy_number_dtype(dtype): + return True + except: + return False + + + +if __name__ == '__main__': + # a = [[[1]], [1, 2, 3], [3]] + # a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + # b = get_padded_nest_list(a) + # print(type(b[0])) + # print(b) + # import torch + print(is_number_or_numpy_number(type(3))) # True + print(is_number_or_numpy_number(type(3.1))) # True + print(is_number_or_numpy_number(type('3'))) # False + print(is_number_or_numpy_number(type(True))) # True + print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True + print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True + print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True + print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False + print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False + diff --git a/fastNLP/core/collators/utils.py b/fastNLP/core/collators/utils.py new file mode 100644 index 00000000..9a397c66 --- /dev/null +++ b/fastNLP/core/collators/utils.py @@ -0,0 +1,103 @@ +from collections import defaultdict +from functools import reduce +from typing import Sequence, Mapping, Dict + +NESTED_DICT_SEPARATOR = '@@' + + +def unpack_batch_mapping(batch:Sequence[Mapping])->Dict: + """ + 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} + + :param batch: + :return: + """ + dict_batch = defaultdict(list) + for sample in batch: + for key, value in sample.items(): + dict_batch[key].append(value) + return dict_batch + + +def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict: + """ + 将 nested 的 dict 中的内容展开到一个 flat dict 中 + + :param batch: + :param _parent: 内部使用 + :return: + """ + dict_batch = defaultdict(list) + if _parent != '': + _parent += NESTED_DICT_SEPARATOR + for sample in batch: + for key, value in sample.items(): + if isinstance(value, Mapping): + _dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key) + for key, value in _dict_batch.items(): + dict_batch[key].append(value) + else: + dict_batch[_parent + key].append(value) + return dict_batch + + +def _unpack_batch_nested_mapping(value, _parent)->Dict: + _dict = {} + _parent += NESTED_DICT_SEPARATOR + for k, v in value.items(): + if isinstance(v, Mapping): + __dict = _unpack_batch_nested_mapping(v, _parent=_parent + k) + _dict.update(__dict) + else: + _dict[_parent + k] = v + return _dict + + +def pack_batch_nested_mapping(batch:Mapping) -> Dict: + """ + 需要恢复出 nested 的 dict 原来的样式 + + :param batch: + :return: + """ + dicts = [] + + for key, value in batch.items(): + keys = key.split(NESTED_DICT_SEPARATOR) + d = {keys[-1]: value} + for key in keys[:-1:][::-1]: + d = {key: d} + dicts.append(d) + return reduce(_merge_dict, dicts) + + +def _merge_dict(a, b, path=None): + "merges b into a" + if path is None: path = [] + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + _merge_dict(a[key], b[key], path + [str(key)]) + else: + raise Exception('Conflict at %s' % '.'.join(path + [str(key)])) + else: + a[key] = b[key] + return a + + +def unpack_batch_sequence(batch:Sequence[Sequence])->Dict: + """ + 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} + + :param batch: + :return: + """ + dict_batch = defaultdict(list) + for sample in batch: + for i, content in enumerate(sample): + dict_batch[f'_{i}'].append(content) + return dict_batch + + +def pack_batch_sequence(batch:Mapping)->Sequence: + return list(batch.values()) \ No newline at end of file diff --git a/tests/core/collators/__init__.py b/tests/core/collators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/collators/padders/__init__.py b/tests/core/collators/padders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/collators/padders/test_get_padder.py b/tests/core/collators/padders/test_get_padder.py new file mode 100644 index 00000000..38fd4733 --- /dev/null +++ b/tests/core/collators/padders/test_get_padder.py @@ -0,0 +1,139 @@ +import pytest +import numpy as np + +from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \ + _get_element_shape_dtype +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + + +def test_get_element_shape_dtype(): + catalog = _get_element_shape_dtype([[1], [2, 3], [3], 2]) + catalog = _get_element_shape_dtype([['1'], [2, 3]]) + catalog = _get_element_shape_dtype([['1'], [2, 3]]) + catalog = _get_element_shape_dtype([['1'], ['2', '3']]) + catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))]) + + +@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) +def test_get_padder_run(backend): + if not _NEED_IMPORT_TORCH and backend == 'torch': + pytest.skip("No torch") + if not _NEED_IMPORT_PADDLE and backend == 'paddle': + pytest.skip("No paddle") + if not _NEED_IMPORT_PADDLE and backend == 'jittor': + pytest.skip("No jittor") + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + + if backend is not None: + # 不能 pad + batch_field = [[1], [2, 3], [3], 2] + with pytest.raises(InconsistencyError): + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') + + # 不能 pad + batch_field = [['2'], ['2'], ['2', '2']] + with pytest.raises(DtypeError) as exec_info: + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') + + batch_field = [np.zeros(3), np.zeros((3, 1))] + with pytest.raises(InconsistencyError) as exec_info: + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') # no pad + + batch_field = [np.zeros((3, 1)), np.zeros((4, 1))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + + +def test_raw_padder(): + backend = 'raw' + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert pad_batch == batch_field + + batch_field = [[1], [2, 2], [3, 3, 3]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert np.shape(pad_batch) == (3, 3) + + batch_field = [[[1]], [[2, 2], [2]], [[3], [3], [3]]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert np.shape(pad_batch) == (3, 3, 2) + + +def test_numpy_padder(): + backend = 'numpy' + target_type = np.ndarray + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert (pad_batch == np.array(batch_field)).sum()==len(batch_field) + + batch_field = [[1], [2, 2], [3, 3, 3]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert np.shape(pad_batch) == (3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==3 + + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,3))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert np.shape(pad_batch) == (3, 3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==9 + + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert np.shape(pad_batch) == (3, 3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 + + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,))] + with pytest.raises(InconsistencyError): + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + + +def test_torch_padder(): + if not _NEED_IMPORT_TORCH: + pytest.skip("No torch.") + import torch + backend = 'torch' + target_type = torch.Tensor + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert (pad_batch == torch.LongTensor(batch_field)).sum()==len(batch_field) + + batch_field = [[1], [2, 2], [3, 3, 3]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3) + assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==3 + + batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,3))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==9 + + batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12 + + batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,))] + with pytest.raises(InconsistencyError): + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + diff --git a/tests/core/collators/padders/test_numpy_padder.py b/tests/core/collators/padders/test_numpy_padder.py new file mode 100644 index 00000000..42665857 --- /dev/null +++ b/tests/core/collators/padders/test_numpy_padder.py @@ -0,0 +1,81 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.numpy_padder import NumpyTensorPadder, NumpySequencePadder, NumpyNumberPadder +from fastNLP.core.collators.padders.exceptions import DtypeError +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + + +class TestNumpyNumberPadder: + def test_run(self): + padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [1, 2, 3] + assert isinstance(a, np.ndarray) + assert (padder(a) == np.array(a)).sum() == 3 + + +class TestNumpySequencePadder: + def test_run(self): + padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (2, 3) + b = np.array([[1, 2, 3], [3, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + if _NEED_IMPORT_TORCH: + import torch + with pytest.raises(DtypeError): + padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + + +class TestNumpyTensorPadder: + def test_run(self): + padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) + a = [np.zeros(3), np.zeros(2), np.zeros(0)] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (3, 3) + b = np.array([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 1))] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (3, 3, 2) + b = np.array([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (3, 3, 2) + b = np.array([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + def test_dtype_check(self): + padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + if _NEED_IMPORT_TORCH: + import torch + with pytest.raises(DtypeError): + padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) + + + diff --git a/tests/core/collators/padders/test_raw_padder.py b/tests/core/collators/padders/test_raw_padder.py new file mode 100644 index 00000000..41a9de64 --- /dev/null +++ b/tests/core/collators/padders/test_raw_padder.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.raw_padder import RawNumberPadder, RawSequencePadder +from fastNLP.core.collators.padders.exceptions import DtypeError + + +class TestRawNumberPadder: + def test_run(self): + padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [1, 2, 3] + assert padder(a) == a + + +class TestRawSequencePadder: + def test_run(self): + padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = np.shape(a) + assert shape == (2, 3) + b = np.array([[1, 2, 3], [3, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + with pytest.raises(DtypeError): + padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) \ No newline at end of file diff --git a/tests/core/collators/padders/test_torch_padder.py b/tests/core/collators/padders/test_torch_padder.py new file mode 100644 index 00000000..85240b3c --- /dev/null +++ b/tests/core/collators/padders/test_torch_padder.py @@ -0,0 +1,105 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.torch_padder import TorchTensorPadder, TorchSequencePadder, TorchNumberPadder +from fastNLP.core.collators.padders.exceptions import DtypeError +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + + +class TestTorchNumberPadder: + def test_run(self): + padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [1, 2, 3] + t_a = padder(a) + assert isinstance(t_a, torch.Tensor) + assert (t_a == torch.LongTensor(a)).sum() == 3 + + +class TestTorchSequencePadder: + def test_run(self): + padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (2, 3) + b = torch.LongTensor([[1, 2, 3], [3, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) + a = padder([[1], [2, 322]]) + assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 + padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) + + + +class TestTorchTensorPadder: + def test_run(self): + padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) + a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3) + b = torch.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 2))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, 0], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 1))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) + a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.FloatTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + def test_dtype_check(self): + padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) + + + diff --git a/tests/core/collators/padders/test_utils.py b/tests/core/collators/padders/test_utils.py new file mode 100644 index 00000000..4cc70400 --- /dev/null +++ b/tests/core/collators/padders/test_utils.py @@ -0,0 +1,90 @@ +import pytest +import numpy as np + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.collators.padders.utils import get_shape, get_padded_numpy_array, \ + get_padded_nest_list, is_number_or_numpy_number, is_numpy_number_dtype, is_number + + +def test_get_shape(): + a = [[1, 2, 3], [3]] + assert get_shape(a) == [2, 3] + + a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + assert get_shape(a) == [2, 3, 3] + + a = [[[1], [2], [3, 4]], [[]]] + assert get_shape(a) == [2, 3, 2] + + +def test_get_padded_numpy_array(): + a = [[1, 2, 3], [3]] + a = get_padded_numpy_array(a, dtype=int, pad_val=-1) + assert a.shape == (2, 3) + + a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + a = get_padded_numpy_array(a, dtype=int, pad_val=-1) + assert a.shape == (2, 3, 3) + + a = [[[1], [2], [3, 4]], [[]]] + a = get_padded_numpy_array(a, dtype=int, pad_val=-1) + assert a.shape == (2, 3, 2) + + +def test_get_padded_nest_list(): + a = [[1, 2, 3], [3]] + a = get_padded_nest_list(a, pad_val=-1) + assert np.shape(a) == (2, 3) + + a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + a = get_padded_nest_list(a, pad_val=-1) + assert np.shape(a) == (2, 3, 3) + + a = [[[1], [2], [3, 4]], [[]]] + a = get_padded_nest_list(a, pad_val=-1) + assert np.shape(a) == (2, 3, 2) + + +def test_is_number_or_numpy_number(): + assert is_number_or_numpy_number(type(3)) is True + assert is_number_or_numpy_number(type(3.1)) is True + assert is_number_or_numpy_number(type(True)) is True + assert is_number_or_numpy_number(type('3')) is False + assert is_number_or_numpy_number(np.zeros(3).dtype) is True + assert is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) is True + assert is_number_or_numpy_number(np.zeros(3, dtype=object).dtype) is False + + if _NEED_IMPORT_TORCH: + import torch + dtype = torch.ones(3).dtype + assert is_number_or_numpy_number(dtype) is False + + +def test_is_number(): + assert is_number(type(3)) is True + assert is_number(type(3.1)) is True + assert is_number(type(True)) is True + assert is_number(type('3')) is False + assert is_number(np.zeros(3).dtype) is False + assert is_number(np.zeros(3, dtype=int).dtype) is False + assert is_number(np.zeros(3, dtype=object).dtype) is False + + if _NEED_IMPORT_TORCH: + import torch + dtype = torch.ones(3).dtype + assert is_number(dtype) is False + + +def test_is_numpy_number(): + assert is_numpy_number_dtype(type(3)) is False + assert is_numpy_number_dtype(type(3.1)) is False + assert is_numpy_number_dtype(type(True)) is False + assert is_numpy_number_dtype(type('3')) is False + assert is_numpy_number_dtype(np.zeros(3).dtype) is True + assert is_numpy_number_dtype(np.zeros(3, dtype=int).dtype) is True + assert is_numpy_number_dtype(np.zeros(3, dtype=object).dtype) is False + + if _NEED_IMPORT_TORCH: + import torch + dtype = torch.ones(3).dtype + assert is_numpy_number_dtype(dtype) is False \ No newline at end of file diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py new file mode 100644 index 00000000..5fc82c91 --- /dev/null +++ b/tests/core/collators/test_new_collator.py @@ -0,0 +1,225 @@ + +import numpy as np +import pytest + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + +from fastNLP.core.collators.new_collator import Collator + + +def _assert_equal(d1, d2): + try: + if 'torch' in str(type(d1)): + if 'float64' in str(d2.dtype): + print(d2.dtype) + assert (d1 == d2).all().item() + else: + assert all(d1 == d2) + except TypeError: + assert d1 == d2 + except ValueError: + assert (d1 == d2).all() + + +def findDictDiff(d1, d2, path=""): + for k in d1: + if k in d2: + if isinstance(d1[k], dict): + findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) + else: + _assert_equal(d1[k], d2[k]) + else: + raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) + + +def findListDiff(d1, d2): + assert len(d1)==len(d2) + for _d1, _d2 in zip(d1, d2): + if isinstance(_d1, list): + findListDiff(_d1, _d2) + else: + _assert_equal(_d1, _d2) + + +class TestCollator: + def test_run(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + collator = Collator(backend='raw') + assert raw_pad_batch == collator(dict_batch) + collator = Collator(backend='raw') + raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='numpy') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), + 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), + 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), + 'b': np.array([[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='numpy') + numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), + np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), + np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(numpy_pad_lst, collator(list_batch)) + + if _NEED_IMPORT_TORCH: + import torch + collator = Collator(backend='torch') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), + 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': torch.FloatTensor([1.1, 2.1]), + 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), + 'numpy': torch.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), + 'b': torch.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='torch') + torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), + torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), + torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(torch_pad_lst, collator(list_batch)) + + def test_pad(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 测试 set_pad + collator = Collator(backend='raw') + collator.set_pad('str', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + # 测试设置 pad 值 + collator = Collator(backend='raw') + collator.set_pad('nest_lst_int', pad_val=100) + collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 设置 backend 和 type + collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + + # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + # [{'1'}, {'2'}]] + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_4', pad_val=None) + raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='raw') + collator.set_pad('_0', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_2', backend='numpy') + collator.set_pad('_4', backend='numpy', pad_val=100) + raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + # _single + collator = Collator() + collator.set_pad('_single') + findListDiff(list_batch, collator(list_batch)) + + + + + + + diff --git a/tests/core/collators/test_utils.py b/tests/core/collators/test_utils.py new file mode 100644 index 00000000..d56dacc6 --- /dev/null +++ b/tests/core/collators/test_utils.py @@ -0,0 +1,37 @@ + +from fastNLP.core.collators.utils import * + + +def test_unpack_batch_mapping(): + batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] + assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]} + + +def test_unpack_batch_nested_mapping(): + batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] + assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]} + + batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] + assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]} + + batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, + {'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] + assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], + 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} + + +def test_pack_batch_nested_mapping(): + batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], + 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} + new_batch = pack_batch_nested_mapping(batch) + assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], + 'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} + + +def test_unpack_batch_sequence(): + batch = [[1, 2, 3], [2, 4, 6]] + new_batch = unpack_batch_sequence(batch) + assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]} + + +