+ 181
- 0
from typing import List, Union, Dict, Callable, Sequence, Mapping

from fastNLP.core.log import logger
from .padders.get_padder import get_padder

import re

from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \
pack_batch_sequence, NESTED_DICT_SEPARATOR

sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None]

class Collator:
def __init__(self, backend='torch'):
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None],
若为 None ,则不进行 padding 。
self.unpack_batch_func = None
self.pack_batch_func = None
self.ignore_fields = set()
self.padders = {}
self.input_fields = {}
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。

def __call__(self, batch)->Union[List, Dict]:
List[Dict], List[List], List[Sample]

第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。
第二步:使用每个 field 各自的 padder 进行 pad 。
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。

第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample
第一次调用会根据当前 field 决定对应的 Padder 。

if self.unpack_batch_func is None:
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型
if self.batch_data_type is None:
if isinstance(batch[0], Mapping):
self.batch_data_type = 'd'
elif isinstance(batch[0], Sequence): # 这里存在误判的风险
self.batch_data_type = 'l'
self.batch_data_type = 's'
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
f"is {self.batch_data_type}")
if self.batch_data_type == 's':
self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整
self.pack_batch_func = lambda x:x['_single']
elif self.batch_data_type == 'l':
self.unpack_batch_func = unpack_batch_sequence
self.pack_batch_func = pack_batch_sequence
elif self.batch_data_type == 'd':
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value}
self.unpack_batch_func = unpack_batch_nested_mapping
self.pack_batch_func = pack_batch_nested_mapping
self.unpack_batch_func = unpack_batch_mapping
self.pack_batch_func = lambda x:x

unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。

pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder
for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}

for field_name, setting in self.input_fields.items():
pad_fn = setting.get('pad_fn', None)
if callable(pad_fn):
padder = pad_fn
batch_field = unpack_batch.get(field_name)
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
dtype=setting['dtype'], backend=setting['backend'],
self.padders[field_name] = padder
if self.batch_data_type == 'l':
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序

for key, padder in self.padders.items():
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)

return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型

def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "Collator":
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
:return: 返回 Collator 自身
self.padders.clear() # 重新生成

if self.batch_data_type is not None:
if self.batch_data_type == 's':
logger.debug("Set as single field mode.")
elif self.batch_data_type == 'd':
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \
f"index, but other field is set as dict mode."
elif self.batch_data_type == 'l':
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
f"field name is {field_name}"

if field_name == '_single':
self.batch_data_type = 's'
elif sequence_idx_str.match(field_name):
self.batch_data_type = 'l'
self.batch_data_type = 'd'

if field_name in self.ignore_fields:
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.")
if backend is None:
backend = self.backend
assert backend in SUPPORTED_BACKENDS

self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn}

return self

def set_backend(self, backend:str):
设置可以 pad 的 field 默认 pad 为什么类型的 tensor

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None],
若为 None ,则不进行 padding 。
assert backend in SUPPORTED_BACKENDS
self.backend = backend

def set_ignore(self, *field_names) -> "Collator":
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
for field_name in field_names:
if field_name in self.input_fields:
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.")
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。

return self

+ 0
- 0
__all__ = [

class InconsistencyError(BaseException):
当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。

def __init__(self, msg, *args):
super(InconsistencyError, self).__init__(msg, *args)

class DtypeError(BaseException):
def __init__(self, msg, *args):
super(DtypeError, self).__init__(msg, *args)
self.msg = msg

class EleDtypeUnsupportedError(DtypeError):
当 batch 中的 element 的类别本身无法 pad 的时候报错。
例如要求 str 类型的数据进行 padding 。


class EleDtypeDtypeConversionError(DtypeError):
当 batch 中的 element 的类别无法转换为 dtype 类型时报错。


class DtypeUnsupportedError(DtypeError):
当当前 backend 不支持这种类型的 dtype 时报错。


from typing import Dict

from typing import Sequence, Any, Union, Dict
from abc import ABC

from fastNLP.core.log import logger

from .padder import Padder, NullPadder
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder
from .exceptions import *

def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder:
根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。

:param batch_field: 将某 field 的内容组合成一个 batch 传入。
:param pad_val:
:param backend:
:param dtype:
:param field_name: 方便报错的。
logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field))
if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
return NullPadder()
if backend is None:
logger.debug(f"The backend for field:{field_name} is None, not padding this field.")
return NullPadder()

# 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。
must_pad = False
if pad_val != 0 or dtype is not None:
must_pad = True

catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。

# 根据 catalog 来判定当前是否可以进行 pad 。
# 首先检查是否所有的 key 是一样长的,表明深度是一致的
depths = set(map(len, catalog.keys()))
num_depth = len(depths)
if num_depth != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
return NullPadder()

# 再检查所有的元素 shape 是否一致?
shape_lens = set([len(v[0]) for v in catalog.values()])
num_shape = len(shape_lens)
if num_shape != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
return NullPadder()

# 再检查所有的元素 type 是否一致
ele_dtypes = set([v[1] for v in catalog.values()])
num_eletypes = len(ele_dtypes)
if num_eletypes != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
return NullPadder()

depth = depths.pop()
shape_len = shape_lens.pop()
ele_dtype = ele_dtypes.pop()

# 需要由 padder 自己决定是否能够 pad 。
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True]
if backend == 'raw':
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'numpy':
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw':
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'numpy':
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report."
if must_pad:
raise RuntimeError(msg)
return NullPadder()

except DtypeError as e:
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \
"information please set logger's level to DEBUG."
if must_pad:
raise type(e)(msg=msg)
return NullPadder()

except BaseException as e:
raise e

return NullPadder()

class HasShapeDtype(ABC):
检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。


def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is HasShapeDtype:
if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'):
return True
return False
return NotImplemented

def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict:
获取对象的中 element 的基本信息,用于判断是否可以 padding。

:param content:
:param tuple parent:
:param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。
例如: [1, 2, 3] -> {(0,): ((), <class 'int'>), (1,): ((), <class 'int'>), (2,): ((), <class 'int'>)}
例如: [1, [2, 3], 4] -> {(0,): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (1, 1): ((), <class 'int'>), (2,): ((), <class 'int'>)}
例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), <class 'int'>), (0, 1): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (2, 0): ((), <class 'int'>), (2, 1): ((), <class 'int'>)}
例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)]
-> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)}

if catalog is None:
catalog = {}

if parent is None:
parent = ()

if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray
shape = content.shape
dtype = content.dtype
catalog[parent] = (shape, dtype)
elif isinstance(content, (tuple, list)):
for i, c in enumerate(content):
_get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog)
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别
return catalog

from numbers import Number

issubclass(type(3), Number) # True
issubclass(type(3.1), Number) # True
issubclass(type('3'), Number) # False
issubclass(type(True), Number) # True
issubclass(type(np.zeros(3)[0]), Number) # True
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype)

__all__ = [

from numbers import Number
from abc import ABC
from typing import Any, Union
import numpy as np

from .padder import Padder
from .utils import get_padded_numpy_array, is_number_or_numpy_number
from .exceptions import *

def _get_dtype(ele_dtype, dtype, class_name):
if not is_number_or_numpy_number(ele_dtype):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers but get `{ele_dtype}`.")

if dtype is None:
dtype = ele_dtype
if not is_number_or_numpy_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or numpy numbers but get `{dtype}`.")
dtype = dtype
return dtype

class NumpyNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def pad(batch_field, pad_val, dtype):
return np.array(batch_field, dtype=dtype)

class NumpySequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def pad(batch_field, pad_val, dtype):
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val)

class NumpyTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
pad 类似于 [np.array([3, 4], np.array([1])] 的 field

:param ele_dtype:
:param pad_val:
:param dtype:
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
array = np.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
array[slices] = field
return array

class Padder:
def __init__(self, pad_val, dtype):
self.pad_val = pad_val
self.dtype = dtype

def __call__(self, batch_field):
return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype)

def pad(batch_field, pad_val, dtype):
raise NotImplementedError()

class NullPadder(Padder):
def __init__(self, ele_dtype=None, pad_val=None, dtype=None):
super().__init__(pad_val=pad_val, dtype=dtype)

def __call__(self, batch_field):
# 直接返回,不调用 pad() 方法加快速度。
return batch_field

from .padder import Padder
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array
from .exceptions import *

def _get_dtype(ele_dtype, dtype, class_name):
if is_number(ele_dtype):
if dtype is None:
dtype = ele_dtype
elif not is_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but "
f"get `{dtype}`.")
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"but get `{ele_dtype}`.")
return dtype

class RawNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def __call__(self, batch_field):
return batch_field

def pad(batch_field, pad_val, dtype):
raise NotImplementedError()

class RawSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def pad(batch_field, pad_val, dtype):

:param batch_field:
:param pad_val:
:param dtype: 该参数无意义。
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()

from inspect import isclass
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

import torch
numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
np.complex64: torch.complex64,
np.complex128: torch.complex128
number_to_torch_dtype_dict = {
float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64
int: torch.int64,
bool: torch.bool

from .padder import Padder
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
from .exceptions import *

def is_torch_tensor(dtype):
if not isclass(dtype) and isinstance(dtype, torch.dtype):
return True
return False

def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")

if dtype is not None:
if not (is_torch_tensor(dtype) or is_number(dtype)):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or torch.dtype but get `{dtype}`.")
dtype = number_to_torch_dtype_dict.get(dtype, dtype)
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)):
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype)
dtype = ele_dtype
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type)
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_torch_dtype_dict.get(ele_dtype)

return dtype

class TorchNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def pad(batch_field, pad_val, dtype):
return torch.tensor(batch_field, dtype=dtype)

class TorchSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def pad(batch_field, pad_val, dtype):
tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val)
return tensor

class TorchTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的

:param ele_dtype:
:param pad_val:
:param dtype:
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
if isinstance(field, np.ndarray):
field = torch.from_numpy(field)
tensor[slices] = field
return tensor

def fill_tensor(batch_field, padded_batch, dtype):
将 batch_field 中的值填入到 tensor 中。

:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 tensor
:param dtype: 数据的类别

if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype)
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype)
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype)
elif padded_batch.ndim == 1:
padded_batch[:] = torch.tensor(batch_field, dtype=dtype)
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
return padded_batch

def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0):
[[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]])

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
:param dtype: 目标类别是什么
:param pad_val: pad 的 value
shapes = get_shape(batch_field)
tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

import torch

def is_torch_tensor_dtype(dtype) -> bool:
返回当前 dtype 是否是 torch 的 dtype 类型

:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果
return isinstance(dtype, torch.dtype)
return False

from typing import Sequence, List
from numbers import Number
import re
from inspect import isclass

import numpy as np
np_str_obj_array_pattern = re.compile(r'[SaUO]')

def get_shape(batch_field:List, shape=None):
给定 field 返回这个 field pad 完成之后的 shape 。
例如: [[1, 2, 3], [3]] -> [2, 3]
[[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3]

:param batch_field: list,第 0 维一般为 batch 维度。
:param shape: 无需传入。
if shape is None:
shape = []
if isinstance(batch_field, Sequence):
num_ele = len(batch_field)
_shape = shape + [num_ele]
shapes = []
if isinstance(batch_field[0], Sequence):
for _field in batch_field:
shapes.append(get_shape(_field, _shape))
max_shape = [max(_) for _ in zip(*shapes)]
return max_shape
except IndexError: # 空的shape
return _shape # 说明是一个空的 sequence
return shape

def fill_array(batch_field:List, padded_batch:np.ndarray):
将 batch_field 中的值填入到 array 中。

:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 np.ndarray
if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = content_i
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = content_ii
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = content_iii
elif padded_batch.ndim == 1:
padded_batch[:] = batch_field
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
return padded_batch

def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray:
[[1,2], [3]] -> np.array([[1, 2], [3, 0]])

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
:param dtype: 目标类别是什么
:param pad_val: pad 的 value
shapes = get_shape(batch_field)
array = np.full(shapes, dtype=dtype, fill_value=pad_val)
array = fill_array(batch_field, array)
return array

def get_padded_nest_list(batch_field: List, pad_val=0) -> List:
[[1,2], [3]] -> [[1, 2], [3, 0]]

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
:param pad_val: pad 的 value

array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist()
return array

def is_number_or_numpy_number(dtype):
判断 dtype 是否是数字类型,或者 numpy 的数字类型。
is_number_or_numpy_number(type(3)) # True
is_number_or_numpy_number(type(3.1)) # True
is_number_or_numpy_number(type('3')) # False
is_number_or_numpy_number(type(True)) # True
is_number_or_numpy_number(type(np.zeros(3)[0])) # True
is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True
is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True
is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False
is_number_or_numpy_number(np.array([1, [2]]).dtype) # False

:param dtype:
if is_number(dtype):
return True
if isclass(dtype):
return is_numpy_generic_class(dtype)
elif isinstance(dtype, np.dtype) and is None:
return True
return False

def is_numpy_number_dtype(dtype):
if not isclass(dtype) and isinstance(dtype, np.dtype) and is None:
return True
return False

def is_numpy_generic_class(dtype):
形如 np.int64,或者 np.zeros(1).dtype.type 的值

:param dtype:
if isclass(dtype) and issubclass(dtype, np.generic):
return True
return False

def is_number(dtype):
if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \
and not is_numpy_number_dtype(dtype):
return True
return False

from collections import defaultdict
from functools import reduce
from typing import Sequence, Mapping, Dict


def unpack_batch_mapping(batch:Sequence[Mapping])->Dict:
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]}

:param batch:
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
return dict_batch

def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict:
将 nested 的 dict 中的内容展开到一个 flat dict 中

:param batch:
:param _parent: 内部使用
dict_batch = defaultdict(list)
if _parent != '':
for sample in batch:
for key, value in sample.items():
if isinstance(value, Mapping):
_dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key)
for key, value in _dict_batch.items():
dict_batch[_parent + key].append(value)
return dict_batch

def _unpack_batch_nested_mapping(value, _parent)->Dict:
_dict = {}
for k, v in value.items():
if isinstance(v, Mapping):
__dict = _unpack_batch_nested_mapping(v, _parent=_parent + k)
_dict[_parent + k] = v
return _dict

def pack_batch_nested_mapping(batch:Mapping) -> Dict:
需要恢复出 nested 的 dict 原来的样式

:param batch:
dicts = []

for key, value in batch.items():
keys = key.split(NESTED_DICT_SEPARATOR)
d = {keys[-1]: value}
for key in keys[:-1:][::-1]:
d = {key: d}
return reduce(_merge_dict, dicts)

def _merge_dict(a, b, path=None):
"merges b into a"
if path is None: path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
_merge_dict(a[key], b[key], path + [str(key)])
raise Exception('Conflict at %s' % '.'.join(path + [str(key)]))
a[key] = b[key]
return a

def unpack_batch_sequence(batch:Sequence[Sequence])->Dict:
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]}

:param batch:
dict_batch = defaultdict(list)
for sample in batch:
for i, content in enumerate(sample):
return dict_batch

def pack_batch_sequence(batch:Mapping)->Sequence:
return list(batch.values())

import pytest
import numpy as np

from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \

def test_get_element_shape_dtype():
catalog = _get_element_shape_dtype([[1], [2, 3], [3], 2])
catalog = _get_element_shape_dtype([['1'], [2, 3]])
catalog = _get_element_shape_dtype([['1'], [2, 3]])
catalog = _get_element_shape_dtype([['1'], ['2', '3']])
catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))])

@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch")
if not _NEED_IMPORT_PADDLE and backend == 'paddle':
pytest.skip("No paddle")
if not _NEED_IMPORT_PADDLE and backend == 'jittor':
pytest.skip("No jittor")
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')

if backend is not None:
# 不能 pad
batch_field = [[1], [2, 3], [3], 2]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test')

# 不能 pad
batch_field = [['2'], ['2'], ['2', '2']]
with pytest.raises(DtypeError) as exec_info:
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test')

batch_field = [np.zeros(3), np.zeros((3, 1))]
with pytest.raises(InconsistencyError) as exec_info:
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') # no pad

batch_field = [np.zeros((3, 1)), np.zeros((4, 1))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')

def test_raw_padder():
backend = 'raw'
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert pad_batch == batch_field

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3)

batch_field = [[[1]], [[2, 2], [2]], [[3], [3], [3]]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3, 2)

def test_numpy_padder():
backend = 'numpy'
target_type = np.ndarray
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert (pad_batch == np.array(batch_field)).sum()==len(batch_field)

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==3

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,3))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==9

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,))]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')

def test_torch_padder():
pytest.skip("No torch.")
import torch
backend = 'torch'
target_type = torch.Tensor
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert (pad_batch == torch.LongTensor(batch_field)).sum()==len(batch_field)

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==3

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,3))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==9

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,))]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')

import numpy as np
import pytest

from fastNLP.core.collators.padders.numpy_padder import NumpyTensorPadder, NumpySequencePadder, NumpyNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

class TestNumpyNumberPadder:
def test_run(self):
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
assert isinstance(a, np.ndarray)
assert (padder(a) == np.array(a)).sum() == 3

class TestNumpySequencePadder:
def test_run(self):
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (2, 3)
b = np.array([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
import torch
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)

class TestNumpyTensorPadder:
def test_run(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1)
a = [np.zeros(3), np.zeros(2), np.zeros(0)]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3)
b = np.array([[0, 0, 0], [0, 0, -1], [-1, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 1))]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3, 2)
b = np.array([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3, 2)
b = np.array([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
import torch
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)

import numpy as np
import pytest

from fastNLP.core.collators.padders.raw_padder import RawNumberPadder, RawSequencePadder
from fastNLP.core.collators.padders.exceptions import DtypeError

class TestRawNumberPadder:
def test_run(self):
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
assert padder(a) == a

class TestRawSequencePadder:
def test_run(self):
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
assert shape == (2, 3)
b = np.array([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)

import numpy as np
import pytest

from fastNLP.core.collators.padders.torch_padder import TorchTensorPadder, TorchSequencePadder, TorchNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

import torch

class TestTorchNumberPadder:
def test_run(self):
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
t_a = padder(a)
assert isinstance(t_a, torch.Tensor)
assert (t_a == torch.LongTensor(a)).sum() == 3

class TestTorchSequencePadder:
def test_run(self):
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (2, 3)
b = torch.LongTensor([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]])
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)

class TestTorchTensorPadder:
def test_run(self):
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3)
b = torch.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 2))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, 0], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 1))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.FloatTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)

import pytest
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.collators.padders.utils import get_shape, get_padded_numpy_array, \
get_padded_nest_list, is_number_or_numpy_number, is_numpy_number_dtype, is_number

def test_get_shape():
a = [[1, 2, 3], [3]]
assert get_shape(a) == [2, 3]

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
assert get_shape(a) == [2, 3, 3]

a = [[[1], [2], [3, 4]], [[]]]
assert get_shape(a) == [2, 3, 2]

def test_get_padded_numpy_array():
a = [[1, 2, 3], [3]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3)

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3, 3)

a = [[[1], [2], [3, 4]], [[]]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3, 2)

def test_get_padded_nest_list():
a = [[1, 2, 3], [3]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3)

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3, 3)

a = [[[1], [2], [3, 4]], [[]]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3, 2)

def test_is_number_or_numpy_number():
assert is_number_or_numpy_number(type(3)) is True
assert is_number_or_numpy_number(type(3.1)) is True
assert is_number_or_numpy_number(type(True)) is True
assert is_number_or_numpy_number(type('3')) is False
assert is_number_or_numpy_number(np.zeros(3).dtype) is True
assert is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) is True
assert is_number_or_numpy_number(np.zeros(3, dtype=object).dtype) is False

import torch
dtype = torch.ones(3).dtype
assert is_number_or_numpy_number(dtype) is False

def test_is_number():
assert is_number(type(3)) is True
assert is_number(type(3.1)) is True
assert is_number(type(True)) is True
assert is_number(type('3')) is False
assert is_number(np.zeros(3).dtype) is False
assert is_number(np.zeros(3, dtype=int).dtype) is False
assert is_number(np.zeros(3, dtype=object).dtype) is False

import torch
dtype = torch.ones(3).dtype
assert is_number(dtype) is False

def test_is_numpy_number():
assert is_numpy_number_dtype(type(3)) is False
assert is_numpy_number_dtype(type(3.1)) is False
assert is_numpy_number_dtype(type(True)) is False
assert is_numpy_number_dtype(type('3')) is False
assert is_numpy_number_dtype(np.zeros(3).dtype) is True
assert is_numpy_number_dtype(np.zeros(3, dtype=int).dtype) is True
assert is_numpy_number_dtype(np.zeros(3, dtype=object).dtype) is False

import torch
dtype = torch.ones(3).dtype
assert is_numpy_number_dtype(dtype) is False

import numpy as np
import pytest


from fastNLP.core.collators.new_collator import Collator

def _assert_equal(d1, d2):
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
assert (d1 == d2).all().item()
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()

def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
_assert_equal(d1[k], d2[k])
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))

def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
_assert_equal(_d1, _d2)

class TestCollator:
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))

import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))

def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):

# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

# _single
collator = Collator()
findListDiff(list_batch, collator(list_batch))

from fastNLP.core.collators.utils import *

def test_unpack_batch_mapping():
batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}]
assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]}

def test_unpack_batch_nested_mapping():
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}},
{'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}

def test_pack_batch_nested_mapping():
batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}
new_batch = pack_batch_nested_mapping(batch)
assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2],
'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}}

def test_unpack_batch_sequence():
batch = [[1, 2, 3], [2, 4, 6]]
new_batch = unpack_batch_sequence(batch)
assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]}
