| @@ -0,0 +1,181 @@ | |||
| from typing import List, Union, Dict, Callable, Sequence, Mapping | |||
| from fastNLP.core.log import logger | |||
| from .padders.get_padder import get_padder | |||
| import re | |||
| from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ | |||
| pack_batch_sequence, NESTED_DICT_SEPARATOR | |||
| sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 | |||
| SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] | |||
| class Collator: | |||
| def __init__(self, backend='torch'): | |||
| """ | |||
| 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 | |||
| 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。 | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||
| 若为 None ,则不进行 padding 。 | |||
| """ | |||
| self.unpack_batch_func = None | |||
| self.pack_batch_func = None | |||
| self.ignore_fields = set() | |||
| self.padders = {} | |||
| self.input_fields = {} | |||
| self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 | |||
| self.set_backend(backend) | |||
| def __call__(self, batch)->Union[List, Dict]: | |||
| """ | |||
| batch可能存在三种可能性 | |||
| List[Dict], List[List], List[Sample] | |||
| 第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 | |||
| 第二步:使用每个 field 各自的 padder 进行 pad 。 | |||
| 第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 | |||
| 第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 | |||
| list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample | |||
| 的类别。 | |||
| 第一次调用会根据当前 field 决定对应的 Padder 。 | |||
| """ | |||
| if self.unpack_batch_func is None: | |||
| # 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 | |||
| if self.batch_data_type is None: | |||
| if isinstance(batch[0], Mapping): | |||
| self.batch_data_type = 'd' | |||
| elif isinstance(batch[0], Sequence): # 这里存在误判的风险 | |||
| self.batch_data_type = 'l' | |||
| else: | |||
| self.batch_data_type = 's' | |||
| logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " | |||
| f"is {self.batch_data_type}") | |||
| if self.batch_data_type == 's': | |||
| self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整 | |||
| self.pack_batch_func = lambda x:x['_single'] | |||
| elif self.batch_data_type == 'l': | |||
| self.unpack_batch_func = unpack_batch_sequence | |||
| self.pack_batch_func = pack_batch_sequence | |||
| elif self.batch_data_type == 'd': | |||
| if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value} | |||
| self.unpack_batch_func = unpack_batch_nested_mapping | |||
| self.pack_batch_func = pack_batch_nested_mapping | |||
| else: | |||
| self.unpack_batch_func = unpack_batch_mapping | |||
| self.pack_batch_func = lambda x:x | |||
| unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。 | |||
| pad_batch = {} | |||
| if len(self.padders)==0: # 第一次运行,准备 padder | |||
| for key in unpack_batch.keys(): | |||
| if key not in self.input_fields and key not in self.ignore_fields: | |||
| self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} | |||
| for field_name, setting in self.input_fields.items(): | |||
| pad_fn = setting.get('pad_fn', None) | |||
| if callable(pad_fn): | |||
| padder = pad_fn | |||
| else: | |||
| batch_field = unpack_batch.get(field_name) | |||
| padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], | |||
| dtype=setting['dtype'], backend=setting['backend'], | |||
| field_name=field_name) | |||
| self.padders[field_name] = padder | |||
| if self.batch_data_type == 'l': | |||
| self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 | |||
| for key, padder in self.padders.items(): | |||
| batch = unpack_batch.get(key) | |||
| pad_batch[key] = padder(batch) | |||
| return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 | |||
| def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||
| pad_fn:Callable=None) -> "Collator": | |||
| """ | |||
| 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | |||
| :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||
| 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 | |||
| 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 | |||
| :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 | |||
| field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 | |||
| :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 | |||
| :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, | |||
| paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 | |||
| :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 | |||
| batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch | |||
| 形式,输出将被直接作为结果输出。 | |||
| :return: 返回 Collator 自身 | |||
| """ | |||
| self.padders.clear() # 重新生成 | |||
| if self.batch_data_type is not None: | |||
| if self.batch_data_type == 's': | |||
| logger.debug("Set as single field mode.") | |||
| self.input_fields.clear() | |||
| elif self.batch_data_type == 'd': | |||
| assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ | |||
| f"index, but other field is set as dict mode." | |||
| elif self.batch_data_type == 'l': | |||
| assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ | |||
| f"field name is {field_name}" | |||
| if field_name == '_single': | |||
| self.batch_data_type = 's' | |||
| elif sequence_idx_str.match(field_name): | |||
| self.batch_data_type = 'l' | |||
| else: | |||
| self.batch_data_type = 'd' | |||
| if field_name in self.ignore_fields: | |||
| logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") | |||
| if backend is None: | |||
| backend = self.backend | |||
| else: | |||
| assert backend in SUPPORTED_BACKENDS | |||
| self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} | |||
| return self | |||
| def set_backend(self, backend:str): | |||
| """ | |||
| 设置可以 pad 的 field 默认 pad 为什么类型的 tensor | |||
| :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], | |||
| 若为 None ,则不进行 padding 。 | |||
| :return: | |||
| """ | |||
| assert backend in SUPPORTED_BACKENDS | |||
| self.padders.clear() | |||
| self.backend = backend | |||
| def set_ignore(self, *field_names) -> "Collator": | |||
| """ | |||
| 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | |||
| Ex:: | |||
| collator.set_ignore('field1', 'field2') | |||
| :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | |||
| field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; | |||
| 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | |||
| :return: 返回 Collator 自身 | |||
| """ | |||
| for field_name in field_names: | |||
| if field_name in self.input_fields: | |||
| self.input_fields.pop(field_name) | |||
| logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") | |||
| self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 | |||
| self.ignore_fields.add(field_name) | |||
| return self | |||
| @@ -0,0 +1,44 @@ | |||
| __all__ = [ | |||
| 'InconsistencyError', | |||
| 'EleDtypeUnsupportedError', | |||
| 'EleDtypeDtypeConversionError', | |||
| 'DtypeUnsupportedError', | |||
| "DtypeError" | |||
| ] | |||
| class InconsistencyError(BaseException): | |||
| """ | |||
| 当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。 | |||
| """ | |||
| def __init__(self, msg, *args): | |||
| super(InconsistencyError, self).__init__(msg, *args) | |||
| class DtypeError(BaseException): | |||
| def __init__(self, msg, *args): | |||
| super(DtypeError, self).__init__(msg, *args) | |||
| self.msg = msg | |||
| class EleDtypeUnsupportedError(DtypeError): | |||
| """ | |||
| 当 batch 中的 element 的类别本身无法 pad 的时候报错。 | |||
| 例如要求 str 类型的数据进行 padding 。 | |||
| """ | |||
| class EleDtypeDtypeConversionError(DtypeError): | |||
| """ | |||
| 当 batch 中的 element 的类别无法转换为 dtype 类型时报错。 | |||
| """ | |||
| class DtypeUnsupportedError(DtypeError): | |||
| """ | |||
| 当当前 backend 不支持这种类型的 dtype 时报错。 | |||
| """ | |||
| @@ -0,0 +1,193 @@ | |||
| from typing import Dict | |||
| from typing import Sequence, Any, Union, Dict | |||
| from abc import ABC | |||
| from fastNLP.core.log import logger | |||
| from .padder import Padder, NullPadder | |||
| from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder | |||
| from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | |||
| from .raw_padder import RawNumberPadder, RawSequencePadder | |||
| from .exceptions import * | |||
| def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: | |||
| """ | |||
| 根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。 | |||
| :param batch_field: 将某 field 的内容组合成一个 batch 传入。 | |||
| :param pad_val: | |||
| :param backend: | |||
| :param dtype: | |||
| :param field_name: 方便报错的。 | |||
| :return: | |||
| """ | |||
| logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field)) | |||
| if pad_val is None: | |||
| logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | |||
| return NullPadder() | |||
| if backend is None: | |||
| logger.debug(f"The backend for field:{field_name} is None, not padding this field.") | |||
| return NullPadder() | |||
| # 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。 | |||
| must_pad = False | |||
| if pad_val != 0 or dtype is not None: | |||
| must_pad = True | |||
| catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。 | |||
| # 根据 catalog 来判定当前是否可以进行 pad 。 | |||
| # 首先检查是否所有的 key 是一样长的,表明深度是一致的 | |||
| depths = set(map(len, catalog.keys())) | |||
| num_depth = len(depths) | |||
| if num_depth != 1: | |||
| msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \ | |||
| f"information please set logger's level to DEBUG." | |||
| if must_pad: | |||
| raise InconsistencyError(msg) | |||
| logger.debug(msg) | |||
| return NullPadder() | |||
| # 再检查所有的元素 shape 是否一致? | |||
| shape_lens = set([len(v[0]) for v in catalog.values()]) | |||
| num_shape = len(shape_lens) | |||
| if num_shape != 1: | |||
| msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \ | |||
| f"information please set logger's level to DEBUG." | |||
| if must_pad: | |||
| raise InconsistencyError(msg) | |||
| logger.debug(msg) | |||
| return NullPadder() | |||
| # 再检查所有的元素 type 是否一致 | |||
| ele_dtypes = set([v[1] for v in catalog.values()]) | |||
| num_eletypes = len(ele_dtypes) | |||
| if num_eletypes != 1: | |||
| msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ | |||
| f"information please set logger's level to DEBUG." | |||
| if must_pad: | |||
| raise InconsistencyError(msg) | |||
| logger.debug(msg) | |||
| return NullPadder() | |||
| depth = depths.pop() | |||
| shape_len = shape_lens.pop() | |||
| ele_dtype = ele_dtypes.pop() | |||
| # 需要由 padder 自己决定是否能够 pad 。 | |||
| try: | |||
| if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] | |||
| if backend == 'raw': | |||
| return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| elif backend == 'numpy': | |||
| return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| elif backend == 'torch': | |||
| return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | |||
| if backend == 'raw': | |||
| return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| elif backend == 'numpy': | |||
| return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| elif backend == 'torch': | |||
| return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| if depth == 1 and shape_len != 0: | |||
| if backend == 'numpy': | |||
| return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| elif backend == 'torch': | |||
| return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
| if shape_len != 0 and depth>1: | |||
| msg = "Does not support pad tensor under nested list. If you need this, please report." | |||
| if must_pad: | |||
| raise RuntimeError(msg) | |||
| logger.debug(msg) | |||
| return NullPadder() | |||
| except DtypeError as e: | |||
| msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ | |||
| "information please set logger's level to DEBUG." | |||
| if must_pad: | |||
| raise type(e)(msg=msg) | |||
| logger.debug(msg) | |||
| return NullPadder() | |||
| except BaseException as e: | |||
| raise e | |||
| return NullPadder() | |||
| class HasShapeDtype(ABC): | |||
| """ | |||
| 检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。 | |||
| """ | |||
| @classmethod | |||
| def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
| if cls is HasShapeDtype: | |||
| if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'): | |||
| return True | |||
| return False | |||
| return NotImplemented | |||
| def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: | |||
| """ | |||
| 获取对象的中 element 的基本信息,用于判断是否可以 padding。 | |||
| :param content: | |||
| :param tuple parent: | |||
| :param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。 | |||
| 例如: [1, 2, 3] -> {(0,): ((), <class 'int'>), (1,): ((), <class 'int'>), (2,): ((), <class 'int'>)} | |||
| 例如: [1, [2, 3], 4] -> {(0,): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (1, 1): ((), <class 'int'>), (2,): ((), <class 'int'>)} | |||
| 例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), <class 'int'>), (0, 1): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (2, 0): ((), <class 'int'>), (2, 1): ((), <class 'int'>)} | |||
| 例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)] | |||
| -> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)} | |||
| :return: | |||
| """ | |||
| if catalog is None: | |||
| catalog = {} | |||
| if parent is None: | |||
| parent = () | |||
| if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray | |||
| shape = content.shape | |||
| dtype = content.dtype | |||
| catalog[parent] = (shape, dtype) | |||
| elif isinstance(content, (tuple, list)): | |||
| for i, c in enumerate(content): | |||
| _get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog) | |||
| else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 | |||
| catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 | |||
| return catalog | |||
| """ | |||
| from numbers import Number | |||
| issubclass(type(3), Number) # True | |||
| issubclass(type(3.1), Number) # True | |||
| issubclass(type('3'), Number) # False | |||
| issubclass(type(True), Number) # True | |||
| issubclass(type(np.zeros(3)[0]), Number) # True | |||
| isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True | |||
| isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True | |||
| isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 | |||
| is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) | |||
| """ | |||
| @@ -0,0 +1,72 @@ | |||
| __all__ = [ | |||
| 'NumpyNumberPadder', | |||
| 'NumpySequencePadder', | |||
| ] | |||
| from numbers import Number | |||
| from abc import ABC | |||
| from typing import Any, Union | |||
| import numpy as np | |||
| from .padder import Padder | |||
| from .utils import get_padded_numpy_array, is_number_or_numpy_number | |||
| from .exceptions import * | |||
| def _get_dtype(ele_dtype, dtype, class_name): | |||
| if not is_number_or_numpy_number(ele_dtype): | |||
| raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
| f"or numpy numbers but get `{ele_dtype}`.") | |||
| if dtype is None: | |||
| dtype = ele_dtype | |||
| else: | |||
| if not is_number_or_numpy_number(dtype): | |||
| raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||
| f"or numpy numbers but get `{dtype}`.") | |||
| dtype = dtype | |||
| return dtype | |||
| class NumpyNumberPadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| return np.array(batch_field, dtype=dtype) | |||
| class NumpySequencePadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) | |||
| class NumpyTensorPadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| """ | |||
| pad 类似于 [np.array([3, 4], np.array([1])] 的 field | |||
| :param ele_dtype: | |||
| :param pad_val: | |||
| :param dtype: | |||
| """ | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| shapes = [field.shape for field in batch_field] | |||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
| array = np.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
| for i, field in enumerate(batch_field): | |||
| slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
| array[slices] = field | |||
| return array | |||
| @@ -0,0 +1,21 @@ | |||
| class Padder: | |||
| def __init__(self, pad_val, dtype): | |||
| self.pad_val = pad_val | |||
| self.dtype = dtype | |||
| def __call__(self, batch_field): | |||
| return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| raise NotImplementedError() | |||
| class NullPadder(Padder): | |||
| def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| def __call__(self, batch_field): | |||
| # 直接返回,不调用 pad() 方法加快速度。 | |||
| return batch_field | |||
| @@ -0,0 +1,48 @@ | |||
| from .padder import Padder | |||
| from .utils import get_padded_nest_list, is_number, get_padded_numpy_array | |||
| from .exceptions import * | |||
| def _get_dtype(ele_dtype, dtype, class_name): | |||
| if is_number(ele_dtype): | |||
| if dtype is None: | |||
| dtype = ele_dtype | |||
| elif not is_number(dtype): | |||
| raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " | |||
| f"get `{dtype}`.") | |||
| else: | |||
| raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
| f"but get `{ele_dtype}`.") | |||
| return dtype | |||
| class RawNumberPadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| def __call__(self, batch_field): | |||
| return batch_field | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| raise NotImplementedError() | |||
| class RawSequencePadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| """ | |||
| :param batch_field: | |||
| :param pad_val: | |||
| :param dtype: 该参数无意义。 | |||
| :return: | |||
| """ | |||
| return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() | |||
| @@ -0,0 +1,157 @@ | |||
| from inspect import isclass | |||
| import numpy as np | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| numpy_to_torch_dtype_dict = { | |||
| np.bool_: torch.bool, | |||
| np.uint8: torch.uint8, | |||
| np.int8: torch.int8, | |||
| np.int16: torch.int16, | |||
| np.int32: torch.int32, | |||
| np.int64: torch.int64, | |||
| np.float16: torch.float16, | |||
| np.float32: torch.float32, | |||
| np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||
| np.complex64: torch.complex64, | |||
| np.complex128: torch.complex128 | |||
| } | |||
| number_to_torch_dtype_dict = { | |||
| float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64 | |||
| int: torch.int64, | |||
| bool: torch.bool | |||
| } | |||
| from .padder import Padder | |||
| from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||
| from .exceptions import * | |||
| def is_torch_tensor(dtype): | |||
| if not isclass(dtype) and isinstance(dtype, torch.dtype): | |||
| return True | |||
| return False | |||
| def _get_dtype(ele_dtype, dtype, class_name): | |||
| if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||
| raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
| f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | |||
| if dtype is not None: | |||
| if not (is_torch_tensor(dtype) or is_number(dtype)): | |||
| raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||
| f"or torch.dtype but get `{dtype}`.") | |||
| dtype = number_to_torch_dtype_dict.get(dtype, dtype) | |||
| else: | |||
| if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||
| ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||
| dtype = ele_dtype | |||
| elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||
| dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||
| elif is_numpy_generic_class(ele_dtype): | |||
| dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||
| return dtype | |||
| class TorchNumberPadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| return torch.tensor(batch_field, dtype=dtype) | |||
| class TorchSequencePadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||
| return tensor | |||
| class TorchTensorPadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| """ | |||
| 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | |||
| :param ele_dtype: | |||
| :param pad_val: | |||
| :param dtype: | |||
| """ | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| shapes = [field.shape for field in batch_field] | |||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
| if isinstance(dtype, np.dtype): | |||
| print(dtype) | |||
| tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
| for i, field in enumerate(batch_field): | |||
| slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
| if isinstance(field, np.ndarray): | |||
| field = torch.from_numpy(field) | |||
| tensor[slices] = field | |||
| return tensor | |||
| def fill_tensor(batch_field, padded_batch, dtype): | |||
| """ | |||
| 将 batch_field 中的值填入到 tensor 中。 | |||
| :param batch_field: 需要填充进入 array 中的内容 | |||
| :param padded_batch: 待填充的 tensor | |||
| :param dtype: 数据的类别 | |||
| :return: | |||
| """ | |||
| if padded_batch.ndim == 2: | |||
| for i, content_i in enumerate(batch_field): | |||
| padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype) | |||
| elif padded_batch.ndim == 3: | |||
| for i, content_i in enumerate(batch_field): | |||
| for j, content_ii in enumerate(content_i): | |||
| padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) | |||
| elif padded_batch.ndim == 4: | |||
| try: # 应该是图像,所以直接应该就 ok 了。 | |||
| padded_batch = np.array(batch_field) | |||
| except: | |||
| for i, content_i in enumerate(batch_field): | |||
| for j, content_ii in enumerate(content_i): | |||
| for k, content_iii in enumerate(content_ii): | |||
| padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype) | |||
| elif padded_batch.ndim == 1: | |||
| padded_batch[:] = torch.tensor(batch_field, dtype=dtype) | |||
| else: | |||
| raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||
| "report.") | |||
| return padded_batch | |||
| def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0): | |||
| """ | |||
| 例如: | |||
| [[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]]) | |||
| :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||
| /4d(多为图片)。 | |||
| :param dtype: 目标类别是什么 | |||
| :param pad_val: pad 的 value | |||
| :return: | |||
| """ | |||
| shapes = get_shape(batch_field) | |||
| tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val) | |||
| tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||
| return tensor | |||
| @@ -0,0 +1,20 @@ | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| def is_torch_tensor_dtype(dtype) -> bool: | |||
| """ | |||
| 返回当前 dtype 是否是 torch 的 dtype 类型 | |||
| :param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 | |||
| :return: | |||
| """ | |||
| try: | |||
| return isinstance(dtype, torch.dtype) | |||
| except: | |||
| return False | |||
| @@ -0,0 +1,173 @@ | |||
| from typing import Sequence, List | |||
| from numbers import Number | |||
| import re | |||
| from inspect import isclass | |||
| import numpy as np | |||
| np_str_obj_array_pattern = re.compile(r'[SaUO]') | |||
| def get_shape(batch_field:List, shape=None): | |||
| """ | |||
| 给定 field 返回这个 field pad 完成之后的 shape 。 | |||
| 例如: [[1, 2, 3], [3]] -> [2, 3] | |||
| [[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3] | |||
| :param batch_field: list,第 0 维一般为 batch 维度。 | |||
| :param shape: 无需传入。 | |||
| :return: | |||
| """ | |||
| if shape is None: | |||
| shape = [] | |||
| if isinstance(batch_field, Sequence): | |||
| num_ele = len(batch_field) | |||
| _shape = shape + [num_ele] | |||
| try: | |||
| shapes = [] | |||
| if isinstance(batch_field[0], Sequence): | |||
| for _field in batch_field: | |||
| shapes.append(get_shape(_field, _shape)) | |||
| max_shape = [max(_) for _ in zip(*shapes)] | |||
| return max_shape | |||
| except IndexError: # 空的shape | |||
| pass | |||
| return _shape # 说明是一个空的 sequence | |||
| else: | |||
| return shape | |||
| def fill_array(batch_field:List, padded_batch:np.ndarray): | |||
| """ | |||
| 将 batch_field 中的值填入到 array 中。 | |||
| :param batch_field: 需要填充进入 array 中的内容 | |||
| :param padded_batch: 待填充的 np.ndarray | |||
| :return: | |||
| """ | |||
| if padded_batch.ndim == 2: | |||
| for i, content_i in enumerate(batch_field): | |||
| padded_batch[i, :len(content_i)] = content_i | |||
| elif padded_batch.ndim == 3: | |||
| for i, content_i in enumerate(batch_field): | |||
| for j, content_ii in enumerate(content_i): | |||
| padded_batch[i, j, :len(content_ii)] = content_ii | |||
| elif padded_batch.ndim == 4: | |||
| try: # 应该是图像,所以直接应该就 ok 了。 | |||
| padded_batch = np.array(batch_field) | |||
| except: | |||
| for i, content_i in enumerate(batch_field): | |||
| for j, content_ii in enumerate(content_i): | |||
| for k, content_iii in enumerate(content_ii): | |||
| padded_batch[i, j, k, :len(content_iii)] = content_iii | |||
| elif padded_batch.ndim == 1: | |||
| padded_batch[:] = batch_field | |||
| else: | |||
| raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||
| "report.") | |||
| return padded_batch | |||
| def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: | |||
| """ | |||
| 例如: | |||
| [[1,2], [3]] -> np.array([[1, 2], [3, 0]]) | |||
| :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||
| /4d(多为图片)。 | |||
| :param dtype: 目标类别是什么 | |||
| :param pad_val: pad 的 value | |||
| :return: | |||
| """ | |||
| shapes = get_shape(batch_field) | |||
| array = np.full(shapes, dtype=dtype, fill_value=pad_val) | |||
| array = fill_array(batch_field, array) | |||
| return array | |||
| def get_padded_nest_list(batch_field: List, pad_val=0) -> List: | |||
| """ | |||
| 例如: | |||
| [[1,2], [3]] -> [[1, 2], [3, 0]] | |||
| :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||
| /4d(多为图片)。 | |||
| :param pad_val: pad 的 value | |||
| :return: | |||
| """ | |||
| array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist() | |||
| return array | |||
| def is_number_or_numpy_number(dtype): | |||
| """ | |||
| 判断 dtype 是否是数字类型,或者 numpy 的数字类型。 | |||
| is_number_or_numpy_number(type(3)) # True | |||
| is_number_or_numpy_number(type(3.1)) # True | |||
| is_number_or_numpy_number(type('3')) # False | |||
| is_number_or_numpy_number(type(True)) # True | |||
| is_number_or_numpy_number(type(np.zeros(3)[0])) # True | |||
| is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True | |||
| is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True | |||
| is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False | |||
| is_number_or_numpy_number(np.array([1, [2]]).dtype) # False | |||
| :param dtype: | |||
| :return: | |||
| """ | |||
| if is_number(dtype): | |||
| return True | |||
| else: | |||
| if isclass(dtype): | |||
| return is_numpy_generic_class(dtype) | |||
| elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: | |||
| return True | |||
| return False | |||
| def is_numpy_number_dtype(dtype): | |||
| if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: | |||
| return True | |||
| return False | |||
| def is_numpy_generic_class(dtype): | |||
| """ | |||
| 形如 np.int64,或者 np.zeros(1).dtype.type 的值 | |||
| :param dtype: | |||
| :return: | |||
| """ | |||
| if isclass(dtype) and issubclass(dtype, np.generic): | |||
| return True | |||
| return False | |||
| def is_number(dtype): | |||
| try: | |||
| if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \ | |||
| and not is_numpy_number_dtype(dtype): | |||
| return True | |||
| except: | |||
| return False | |||
| if __name__ == '__main__': | |||
| # a = [[[1]], [1, 2, 3], [3]] | |||
| # a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||
| # b = get_padded_nest_list(a) | |||
| # print(type(b[0])) | |||
| # print(b) | |||
| # import torch | |||
| print(is_number_or_numpy_number(type(3))) # True | |||
| print(is_number_or_numpy_number(type(3.1))) # True | |||
| print(is_number_or_numpy_number(type('3'))) # False | |||
| print(is_number_or_numpy_number(type(True))) # True | |||
| print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True | |||
| print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True | |||
| print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True | |||
| print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False | |||
| print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False | |||
| @@ -0,0 +1,103 @@ | |||
| from collections import defaultdict | |||
| from functools import reduce | |||
| from typing import Sequence, Mapping, Dict | |||
| NESTED_DICT_SEPARATOR = '@@' | |||
| def unpack_batch_mapping(batch:Sequence[Mapping])->Dict: | |||
| """ | |||
| 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} | |||
| :param batch: | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| for sample in batch: | |||
| for key, value in sample.items(): | |||
| dict_batch[key].append(value) | |||
| return dict_batch | |||
| def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict: | |||
| """ | |||
| 将 nested 的 dict 中的内容展开到一个 flat dict 中 | |||
| :param batch: | |||
| :param _parent: 内部使用 | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| if _parent != '': | |||
| _parent += NESTED_DICT_SEPARATOR | |||
| for sample in batch: | |||
| for key, value in sample.items(): | |||
| if isinstance(value, Mapping): | |||
| _dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key) | |||
| for key, value in _dict_batch.items(): | |||
| dict_batch[key].append(value) | |||
| else: | |||
| dict_batch[_parent + key].append(value) | |||
| return dict_batch | |||
| def _unpack_batch_nested_mapping(value, _parent)->Dict: | |||
| _dict = {} | |||
| _parent += NESTED_DICT_SEPARATOR | |||
| for k, v in value.items(): | |||
| if isinstance(v, Mapping): | |||
| __dict = _unpack_batch_nested_mapping(v, _parent=_parent + k) | |||
| _dict.update(__dict) | |||
| else: | |||
| _dict[_parent + k] = v | |||
| return _dict | |||
| def pack_batch_nested_mapping(batch:Mapping) -> Dict: | |||
| """ | |||
| 需要恢复出 nested 的 dict 原来的样式 | |||
| :param batch: | |||
| :return: | |||
| """ | |||
| dicts = [] | |||
| for key, value in batch.items(): | |||
| keys = key.split(NESTED_DICT_SEPARATOR) | |||
| d = {keys[-1]: value} | |||
| for key in keys[:-1:][::-1]: | |||
| d = {key: d} | |||
| dicts.append(d) | |||
| return reduce(_merge_dict, dicts) | |||
| def _merge_dict(a, b, path=None): | |||
| "merges b into a" | |||
| if path is None: path = [] | |||
| for key in b: | |||
| if key in a: | |||
| if isinstance(a[key], dict) and isinstance(b[key], dict): | |||
| _merge_dict(a[key], b[key], path + [str(key)]) | |||
| else: | |||
| raise Exception('Conflict at %s' % '.'.join(path + [str(key)])) | |||
| else: | |||
| a[key] = b[key] | |||
| return a | |||
| def unpack_batch_sequence(batch:Sequence[Sequence])->Dict: | |||
| """ | |||
| 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} | |||
| :param batch: | |||
| :return: | |||
| """ | |||
| dict_batch = defaultdict(list) | |||
| for sample in batch: | |||
| for i, content in enumerate(sample): | |||
| dict_batch[f'_{i}'].append(content) | |||
| return dict_batch | |||
| def pack_batch_sequence(batch:Mapping)->Sequence: | |||
| return list(batch.values()) | |||
| @@ -0,0 +1,139 @@ | |||
| import pytest | |||
| import numpy as np | |||
| from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \ | |||
| _get_element_shape_dtype | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||
| def test_get_element_shape_dtype(): | |||
| catalog = _get_element_shape_dtype([[1], [2, 3], [3], 2]) | |||
| catalog = _get_element_shape_dtype([['1'], [2, 3]]) | |||
| catalog = _get_element_shape_dtype([['1'], [2, 3]]) | |||
| catalog = _get_element_shape_dtype([['1'], ['2', '3']]) | |||
| catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))]) | |||
| @pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) | |||
| def test_get_padder_run(backend): | |||
| if not _NEED_IMPORT_TORCH and backend == 'torch': | |||
| pytest.skip("No torch") | |||
| if not _NEED_IMPORT_PADDLE and backend == 'paddle': | |||
| pytest.skip("No paddle") | |||
| if not _NEED_IMPORT_PADDLE and backend == 'jittor': | |||
| pytest.skip("No jittor") | |||
| batch_field = [1, 2, 3] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| if backend is not None: | |||
| # 不能 pad | |||
| batch_field = [[1], [2, 3], [3], 2] | |||
| with pytest.raises(InconsistencyError): | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') | |||
| # 不能 pad | |||
| batch_field = [['2'], ['2'], ['2', '2']] | |||
| with pytest.raises(DtypeError) as exec_info: | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') | |||
| batch_field = [np.zeros(3), np.zeros((3, 1))] | |||
| with pytest.raises(InconsistencyError) as exec_info: | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') # no pad | |||
| batch_field = [np.zeros((3, 1)), np.zeros((4, 1))] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| def test_raw_padder(): | |||
| backend = 'raw' | |||
| batch_field = [1, 2, 3] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert pad_batch == batch_field | |||
| batch_field = [[1], [2, 2], [3, 3, 3]] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert np.shape(pad_batch) == (3, 3) | |||
| batch_field = [[[1]], [[2, 2], [2]], [[3], [3], [3]]] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert np.shape(pad_batch) == (3, 3, 2) | |||
| def test_numpy_padder(): | |||
| backend = 'numpy' | |||
| target_type = np.ndarray | |||
| batch_field = [1, 2, 3] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert (pad_batch == np.array(batch_field)).sum()==len(batch_field) | |||
| batch_field = [[1], [2, 2], [3, 3, 3]] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert np.shape(pad_batch) == (3, 3) | |||
| assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==3 | |||
| batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,3))] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert np.shape(pad_batch) == (3, 3, 3) | |||
| assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==9 | |||
| batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert np.shape(pad_batch) == (3, 3, 3) | |||
| assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 | |||
| batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,))] | |||
| with pytest.raises(InconsistencyError): | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| def test_torch_padder(): | |||
| if not _NEED_IMPORT_TORCH: | |||
| pytest.skip("No torch.") | |||
| import torch | |||
| backend = 'torch' | |||
| target_type = torch.Tensor | |||
| batch_field = [1, 2, 3] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert (pad_batch == torch.LongTensor(batch_field)).sum()==len(batch_field) | |||
| batch_field = [[1], [2, 2], [3, 3, 3]] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert pad_batch.shape == (3, 3) | |||
| assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==3 | |||
| batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,3))] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert pad_batch.shape == (3, 3, 3) | |||
| assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==9 | |||
| batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert pad_batch.shape == (3, 3, 3) | |||
| assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12 | |||
| batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,))] | |||
| with pytest.raises(InconsistencyError): | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| @@ -0,0 +1,81 @@ | |||
| import numpy as np | |||
| import pytest | |||
| from fastNLP.core.collators.padders.numpy_padder import NumpyTensorPadder, NumpySequencePadder, NumpyNumberPadder | |||
| from fastNLP.core.collators.padders.exceptions import DtypeError | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| class TestNumpyNumberPadder: | |||
| def test_run(self): | |||
| padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
| a = [1, 2, 3] | |||
| assert isinstance(a, np.ndarray) | |||
| assert (padder(a) == np.array(a)).sum() == 3 | |||
| class TestNumpySequencePadder: | |||
| def test_run(self): | |||
| padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
| a = [[1, 2, 3], [3]] | |||
| a = padder(a) | |||
| shape = np.shape(a) | |||
| assert isinstance(a, np.ndarray) | |||
| assert shape == (2, 3) | |||
| b = np.array([[1, 2, 3], [3, -1, -1]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1] | |||
| def test_dtype_check(self): | |||
| padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
| with pytest.raises(DtypeError): | |||
| padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| with pytest.raises(DtypeError): | |||
| padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
| class TestNumpyTensorPadder: | |||
| def test_run(self): | |||
| padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) | |||
| a = [np.zeros(3), np.zeros(2), np.zeros(0)] | |||
| a = padder(a) | |||
| shape = np.shape(a) | |||
| assert isinstance(a, np.ndarray) | |||
| assert shape == (3, 3) | |||
| b = np.array([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1] | |||
| a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 1))] | |||
| a = padder(a) | |||
| shape = np.shape(a) | |||
| assert isinstance(a, np.ndarray) | |||
| assert shape == (3, 3, 2) | |||
| b = np.array([[[0, 0], [0, 0], [0, 0]], | |||
| [[0, 0], [0, 0], [-1, -1]], | |||
| [[0, -1], [-1, -1], [-1, -1]]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
| a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | |||
| a = padder(a) | |||
| shape = np.shape(a) | |||
| assert isinstance(a, np.ndarray) | |||
| assert shape == (3, 3, 2) | |||
| b = np.array([[[0, 0], [0, 0], [0, 0]], | |||
| [[0, 0], [0, 0], [-1, -1]], | |||
| [[-1, -1], [-1, -1], [-1, -1]]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
| def test_dtype_check(self): | |||
| padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
| with pytest.raises(DtypeError): | |||
| padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| with pytest.raises(DtypeError): | |||
| padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
| with pytest.raises(DtypeError): | |||
| padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||
| @@ -0,0 +1,29 @@ | |||
| import numpy as np | |||
| import pytest | |||
| from fastNLP.core.collators.padders.raw_padder import RawNumberPadder, RawSequencePadder | |||
| from fastNLP.core.collators.padders.exceptions import DtypeError | |||
| class TestRawNumberPadder: | |||
| def test_run(self): | |||
| padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
| a = [1, 2, 3] | |||
| assert padder(a) == a | |||
| class TestRawSequencePadder: | |||
| def test_run(self): | |||
| padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
| a = [[1, 2, 3], [3]] | |||
| a = padder(a) | |||
| shape = np.shape(a) | |||
| assert shape == (2, 3) | |||
| b = np.array([[1, 2, 3], [3, -1, -1]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1] | |||
| def test_dtype_check(self): | |||
| with pytest.raises(DtypeError): | |||
| padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
| with pytest.raises(DtypeError): | |||
| padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
| @@ -0,0 +1,105 @@ | |||
| import numpy as np | |||
| import pytest | |||
| from fastNLP.core.collators.padders.torch_padder import TorchTensorPadder, TorchSequencePadder, TorchNumberPadder | |||
| from fastNLP.core.collators.padders.exceptions import DtypeError | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| class TestTorchNumberPadder: | |||
| def test_run(self): | |||
| padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
| a = [1, 2, 3] | |||
| t_a = padder(a) | |||
| assert isinstance(t_a, torch.Tensor) | |||
| assert (t_a == torch.LongTensor(a)).sum() == 3 | |||
| class TestTorchSequencePadder: | |||
| def test_run(self): | |||
| padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
| a = [[1, 2, 3], [3]] | |||
| a = padder(a) | |||
| shape = a.shape | |||
| assert isinstance(a, torch.Tensor) | |||
| assert tuple(shape) == (2, 3) | |||
| b = torch.LongTensor([[1, 2, 3], [3, -1, -1]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1] | |||
| def test_dtype_check(self): | |||
| padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
| with pytest.raises(DtypeError): | |||
| padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
| padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
| padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) | |||
| a = padder([[1], [2, 322]]) | |||
| assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | |||
| padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||
| class TestTorchTensorPadder: | |||
| def test_run(self): | |||
| padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||
| a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] | |||
| a = padder(a) | |||
| shape = a.shape | |||
| assert isinstance(a, torch.Tensor) | |||
| assert tuple(shape) == (3, 3) | |||
| b = torch.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1] | |||
| a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 2))] | |||
| a = padder(a) | |||
| shape = a.shape | |||
| assert isinstance(a, torch.Tensor) | |||
| assert tuple(shape) == (3, 3, 2) | |||
| b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||
| [[0, 0], [0, 0], [-1, -1]], | |||
| [[0, 0], [-1, -1], [-1, -1]]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
| a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 1))] | |||
| a = padder(a) | |||
| shape = a.shape | |||
| assert isinstance(a, torch.Tensor) | |||
| assert tuple(shape) == (3, 3, 2) | |||
| b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||
| [[0, 0], [0, 0], [-1, -1]], | |||
| [[0, -1], [-1, -1], [-1, -1]]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
| padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||
| a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] | |||
| a = padder(a) | |||
| shape = a.shape | |||
| assert isinstance(a, torch.Tensor) | |||
| assert tuple(shape) == (3, 3, 2) | |||
| b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], | |||
| [[0, 0], [0, 0], [-1, -1]], | |||
| [[-1, -1], [-1, -1], [-1, -1]]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
| padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) | |||
| a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | |||
| a = padder(a) | |||
| shape = a.shape | |||
| assert isinstance(a, torch.Tensor) | |||
| assert tuple(shape) == (3, 3, 2) | |||
| b = torch.FloatTensor([[[0, 0], [0, 0], [0, 0]], | |||
| [[0, 0], [0, 0], [-1, -1]], | |||
| [[-1, -1], [-1, -1], [-1, -1]]]) | |||
| assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
| def test_dtype_check(self): | |||
| padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
| with pytest.raises(DtypeError): | |||
| padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
| padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
| padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||
| @@ -0,0 +1,90 @@ | |||
| import pytest | |||
| import numpy as np | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.core.collators.padders.utils import get_shape, get_padded_numpy_array, \ | |||
| get_padded_nest_list, is_number_or_numpy_number, is_numpy_number_dtype, is_number | |||
| def test_get_shape(): | |||
| a = [[1, 2, 3], [3]] | |||
| assert get_shape(a) == [2, 3] | |||
| a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||
| assert get_shape(a) == [2, 3, 3] | |||
| a = [[[1], [2], [3, 4]], [[]]] | |||
| assert get_shape(a) == [2, 3, 2] | |||
| def test_get_padded_numpy_array(): | |||
| a = [[1, 2, 3], [3]] | |||
| a = get_padded_numpy_array(a, dtype=int, pad_val=-1) | |||
| assert a.shape == (2, 3) | |||
| a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||
| a = get_padded_numpy_array(a, dtype=int, pad_val=-1) | |||
| assert a.shape == (2, 3, 3) | |||
| a = [[[1], [2], [3, 4]], [[]]] | |||
| a = get_padded_numpy_array(a, dtype=int, pad_val=-1) | |||
| assert a.shape == (2, 3, 2) | |||
| def test_get_padded_nest_list(): | |||
| a = [[1, 2, 3], [3]] | |||
| a = get_padded_nest_list(a, pad_val=-1) | |||
| assert np.shape(a) == (2, 3) | |||
| a = [[[1], [2], [3, 4]], [[2, 3, 4]]] | |||
| a = get_padded_nest_list(a, pad_val=-1) | |||
| assert np.shape(a) == (2, 3, 3) | |||
| a = [[[1], [2], [3, 4]], [[]]] | |||
| a = get_padded_nest_list(a, pad_val=-1) | |||
| assert np.shape(a) == (2, 3, 2) | |||
| def test_is_number_or_numpy_number(): | |||
| assert is_number_or_numpy_number(type(3)) is True | |||
| assert is_number_or_numpy_number(type(3.1)) is True | |||
| assert is_number_or_numpy_number(type(True)) is True | |||
| assert is_number_or_numpy_number(type('3')) is False | |||
| assert is_number_or_numpy_number(np.zeros(3).dtype) is True | |||
| assert is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) is True | |||
| assert is_number_or_numpy_number(np.zeros(3, dtype=object).dtype) is False | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| dtype = torch.ones(3).dtype | |||
| assert is_number_or_numpy_number(dtype) is False | |||
| def test_is_number(): | |||
| assert is_number(type(3)) is True | |||
| assert is_number(type(3.1)) is True | |||
| assert is_number(type(True)) is True | |||
| assert is_number(type('3')) is False | |||
| assert is_number(np.zeros(3).dtype) is False | |||
| assert is_number(np.zeros(3, dtype=int).dtype) is False | |||
| assert is_number(np.zeros(3, dtype=object).dtype) is False | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| dtype = torch.ones(3).dtype | |||
| assert is_number(dtype) is False | |||
| def test_is_numpy_number(): | |||
| assert is_numpy_number_dtype(type(3)) is False | |||
| assert is_numpy_number_dtype(type(3.1)) is False | |||
| assert is_numpy_number_dtype(type(True)) is False | |||
| assert is_numpy_number_dtype(type('3')) is False | |||
| assert is_numpy_number_dtype(np.zeros(3).dtype) is True | |||
| assert is_numpy_number_dtype(np.zeros(3, dtype=int).dtype) is True | |||
| assert is_numpy_number_dtype(np.zeros(3, dtype=object).dtype) is False | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| dtype = torch.ones(3).dtype | |||
| assert is_numpy_number_dtype(dtype) is False | |||
| @@ -0,0 +1,225 @@ | |||
| import numpy as np | |||
| import pytest | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR | |||
| from fastNLP.core.collators.new_collator import Collator | |||
| def _assert_equal(d1, d2): | |||
| try: | |||
| if 'torch' in str(type(d1)): | |||
| if 'float64' in str(d2.dtype): | |||
| print(d2.dtype) | |||
| assert (d1 == d2).all().item() | |||
| else: | |||
| assert all(d1 == d2) | |||
| except TypeError: | |||
| assert d1 == d2 | |||
| except ValueError: | |||
| assert (d1 == d2).all() | |||
| def findDictDiff(d1, d2, path=""): | |||
| for k in d1: | |||
| if k in d2: | |||
| if isinstance(d1[k], dict): | |||
| findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) | |||
| else: | |||
| _assert_equal(d1[k], d2[k]) | |||
| else: | |||
| raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) | |||
| def findListDiff(d1, d2): | |||
| assert len(d1)==len(d2) | |||
| for _d1, _d2 in zip(d1, d2): | |||
| if isinstance(_d1, list): | |||
| findListDiff(_d1, _d2) | |||
| else: | |||
| _assert_equal(_d1, _d2) | |||
| class TestCollator: | |||
| def test_run(self): | |||
| dict_batch = [{ | |||
| 'str': '1', | |||
| 'lst_str': ['1'], | |||
| 'int': 1, | |||
| 'lst_int': [1], | |||
| 'nest_lst_int': [[1]], | |||
| 'float': 1.1, | |||
| 'lst_float': [1.1], | |||
| 'bool': True, | |||
| 'numpy': np.ones(1), | |||
| 'dict': {'1': '1'}, | |||
| 'set': {'1'}, | |||
| 'nested_dict': {'a': 1, 'b':[1, 2]} | |||
| }, | |||
| { | |||
| 'str': '2', | |||
| 'lst_str': ['2', '2'], | |||
| 'int': 2, | |||
| 'lst_int': [1, 2], | |||
| 'nest_lst_int': [[1], [1, 2]], | |||
| 'float': 2.1, | |||
| 'lst_float': [2.1], | |||
| 'bool': False, | |||
| 'numpy': np.zeros(1), | |||
| 'dict': {'1': '2'}, | |||
| 'set': {'2'}, | |||
| 'nested_dict': {'a': 2, 'b': [1, 2]} | |||
| } | |||
| ] | |||
| list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||
| ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||
| raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||
| collator = Collator(backend='raw') | |||
| assert raw_pad_batch == collator(dict_batch) | |||
| collator = Collator(backend='raw') | |||
| raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||
| [{'1'}, {'2'}]] | |||
| findListDiff(raw_pad_lst, collator(list_batch)) | |||
| collator = Collator(backend='numpy') | |||
| numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), | |||
| 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), | |||
| 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), | |||
| 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), | |||
| 'b': np.array([[1, 2], [1, 2]])}} | |||
| findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||
| collator = Collator(backend='numpy') | |||
| numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), | |||
| np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||
| np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), | |||
| np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||
| [{'1'}, {'2'}]] | |||
| findListDiff(numpy_pad_lst, collator(list_batch)) | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| collator = Collator(backend='torch') | |||
| numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), | |||
| 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), | |||
| 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||
| 'float': torch.FloatTensor([1.1, 2.1]), | |||
| 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), | |||
| 'numpy': torch.FloatTensor([[1], [0]]), | |||
| 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), | |||
| 'b': torch.LongTensor( | |||
| [[1, 2], [1, 2]])}} | |||
| findDictDiff(numpy_pad_batch, collator(dict_batch)) | |||
| collator = Collator(backend='torch') | |||
| torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), | |||
| torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), | |||
| torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), | |||
| torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], | |||
| [{'1'}, {'2'}]] | |||
| findListDiff(torch_pad_lst, collator(list_batch)) | |||
| def test_pad(self): | |||
| dict_batch = [{ | |||
| 'str': '1', | |||
| 'lst_str': ['1'], | |||
| 'int': 1, | |||
| 'lst_int': [1], | |||
| 'nest_lst_int': [[1]], | |||
| 'float': 1.1, | |||
| 'lst_float': [1.1], | |||
| 'bool': True, | |||
| 'numpy': np.ones(1), | |||
| 'dict': {'1': '1'}, | |||
| 'set': {'1'}, | |||
| 'nested_dict': {'a': 1, 'b':[1, 2]} | |||
| }, | |||
| { | |||
| 'str': '2', | |||
| 'lst_str': ['2', '2'], | |||
| 'int': 2, | |||
| 'lst_int': [1, 2], | |||
| 'nest_lst_int': [[1], [1, 2]], | |||
| 'float': 2.1, | |||
| 'lst_float': [2.1], | |||
| 'bool': False, | |||
| 'numpy': np.zeros(1), | |||
| 'dict': {'1': '2'}, | |||
| 'set': {'2'}, | |||
| 'nested_dict': {'a': 2, 'b': [1, 2]} | |||
| } | |||
| ] | |||
| raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} | |||
| # 测试 ignore | |||
| collator = Collator(backend='raw') | |||
| collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||
| findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
| # 测试 set_pad | |||
| collator = Collator(backend='raw') | |||
| collator.set_pad('str', pad_val=1) | |||
| with pytest.raises(BaseException): | |||
| collator(dict_batch) | |||
| # 测试设置 pad 值 | |||
| collator = Collator(backend='raw') | |||
| collator.set_pad('nest_lst_int', pad_val=100) | |||
| collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||
| 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||
| findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
| # 设置 backend 和 type | |||
| collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) | |||
| raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], | |||
| 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} | |||
| findDictDiff(raw_pad_batch, collator(dict_batch)) | |||
| # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | |||
| # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||
| # [{'1'}, {'2'}]] | |||
| list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||
| ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||
| collator = Collator(backend='raw') | |||
| collator.set_ignore('_0', '_3', '_1') | |||
| collator.set_pad('_4', pad_val=None) | |||
| raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | |||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||
| [{'1'}, {'2'}]] | |||
| findListDiff(raw_pad_lst, collator(list_batch)) | |||
| collator = Collator(backend='raw') | |||
| collator.set_pad('_0', pad_val=1) | |||
| with pytest.raises(BaseException): | |||
| collator(dict_batch) | |||
| list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], | |||
| ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] | |||
| collator = Collator(backend='raw') | |||
| collator.set_ignore('_0', '_3', '_1') | |||
| collator.set_pad('_2', backend='numpy') | |||
| collator.set_pad('_4', backend='numpy', pad_val=100) | |||
| raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | |||
| [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||
| [{'1'}, {'2'}]] | |||
| findListDiff(raw_pad_lst, collator(list_batch)) | |||
| # _single | |||
| collator = Collator() | |||
| collator.set_pad('_single') | |||
| findListDiff(list_batch, collator(list_batch)) | |||
| @@ -0,0 +1,37 @@ | |||
| from fastNLP.core.collators.utils import * | |||
| def test_unpack_batch_mapping(): | |||
| batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] | |||
| assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]} | |||
| def test_unpack_batch_nested_mapping(): | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] | |||
| assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]} | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] | |||
| assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]} | |||
| batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, | |||
| {'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] | |||
| assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||
| 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||
| def test_pack_batch_nested_mapping(): | |||
| batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], | |||
| 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} | |||
| new_batch = pack_batch_nested_mapping(batch) | |||
| assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], | |||
| 'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} | |||
| def test_unpack_batch_sequence(): | |||
| batch = [[1, 2, 3], [2, 4, 6]] | |||
| new_batch = unpack_batch_sequence(batch) | |||
| assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]} | |||