Browse Source

新设计collator

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
d4dd85ed40
20 changed files with 1718 additions and 0 deletions
  1. +181
    -0
      fastNLP/core/collators/new_collator.py
  2. +0
    -0
      fastNLP/core/collators/padders/__init__.py
  3. +44
    -0
      fastNLP/core/collators/padders/exceptions.py
  4. +193
    -0
      fastNLP/core/collators/padders/get_padder.py
  5. +72
    -0
      fastNLP/core/collators/padders/numpy_padder.py
  6. +21
    -0
      fastNLP/core/collators/padders/padder.py
  7. +48
    -0
      fastNLP/core/collators/padders/raw_padder.py
  8. +157
    -0
      fastNLP/core/collators/padders/torch_padder.py
  9. +20
    -0
      fastNLP/core/collators/padders/torch_utils.py
  10. +173
    -0
      fastNLP/core/collators/padders/utils.py
  11. +103
    -0
      fastNLP/core/collators/utils.py
  12. +0
    -0
      tests/core/collators/__init__.py
  13. +0
    -0
      tests/core/collators/padders/__init__.py
  14. +139
    -0
      tests/core/collators/padders/test_get_padder.py
  15. +81
    -0
      tests/core/collators/padders/test_numpy_padder.py
  16. +29
    -0
      tests/core/collators/padders/test_raw_padder.py
  17. +105
    -0
      tests/core/collators/padders/test_torch_padder.py
  18. +90
    -0
      tests/core/collators/padders/test_utils.py
  19. +225
    -0
      tests/core/collators/test_new_collator.py
  20. +37
    -0
      tests/core/collators/test_utils.py

+ 181
- 0
fastNLP/core/collators/new_collator.py View File

@@ -0,0 +1,181 @@
from typing import List, Union, Dict, Callable, Sequence, Mapping

from fastNLP.core.log import logger
from .padders.get_padder import get_padder

import re

from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \
pack_batch_sequence, NESTED_DICT_SEPARATOR

sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1
SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None]


class Collator:
def __init__(self, backend='torch'):
"""
用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。
可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None],
若为 None ,则不进行 padding 。
"""
self.unpack_batch_func = None
self.pack_batch_func = None
self.ignore_fields = set()
self.padders = {}
self.input_fields = {}
self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。
self.set_backend(backend)

def __call__(self, batch)->Union[List, Dict]:
"""
batch可能存在三种可能性
List[Dict], List[List], List[Sample]

第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。
第二步:使用每个 field 各自的 padder 进行 pad 。
第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。

第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个
list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample
的类别。
第一次调用会根据当前 field 决定对应的 Padder 。

"""
if self.unpack_batch_func is None:
# 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型
if self.batch_data_type is None:
if isinstance(batch[0], Mapping):
self.batch_data_type = 'd'
elif isinstance(batch[0], Sequence): # 这里存在误判的风险
self.batch_data_type = 'l'
else:
self.batch_data_type = 's'
logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type "
f"is {self.batch_data_type}")
if self.batch_data_type == 's':
self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整
self.pack_batch_func = lambda x:x['_single']
elif self.batch_data_type == 'l':
self.unpack_batch_func = unpack_batch_sequence
self.pack_batch_func = pack_batch_sequence
elif self.batch_data_type == 'd':
if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value}
self.unpack_batch_func = unpack_batch_nested_mapping
self.pack_batch_func = pack_batch_nested_mapping
else:
self.unpack_batch_func = unpack_batch_mapping
self.pack_batch_func = lambda x:x

unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。

pad_batch = {}
if len(self.padders)==0: # 第一次运行,准备 padder
for key in unpack_batch.keys():
if key not in self.input_fields and key not in self.ignore_fields:
self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend}

for field_name, setting in self.input_fields.items():
pad_fn = setting.get('pad_fn', None)
if callable(pad_fn):
padder = pad_fn
else:
batch_field = unpack_batch.get(field_name)
padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'],
dtype=setting['dtype'], backend=setting['backend'],
field_name=field_name)
self.padders[field_name] = padder
if self.batch_data_type == 'l':
self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序

for key, padder in self.padders.items():
batch = unpack_batch.get(key)
pad_batch[key] = padder(batch)

return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型

def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None,
pad_fn:Callable=None) -> "Collator":
"""
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。

:param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没
有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。
:param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的
field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。
:param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。
:param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor,
paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。
:param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的
batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch
形式,输出将被直接作为结果输出。
:return: 返回 Collator 自身
"""
self.padders.clear() # 重新生成

if self.batch_data_type is not None:
if self.batch_data_type == 's':
logger.debug("Set as single field mode.")
self.input_fields.clear()
elif self.batch_data_type == 'd':
assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \
f"index, but other field is set as dict mode."
elif self.batch_data_type == 'l':
assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \
f"field name is {field_name}"

if field_name == '_single':
self.batch_data_type = 's'
elif sequence_idx_str.match(field_name):
self.batch_data_type = 'l'
else:
self.batch_data_type = 'd'

if field_name in self.ignore_fields:
logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.")
if backend is None:
backend = self.backend
else:
assert backend in SUPPORTED_BACKENDS

self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn}

return self

def set_backend(self, backend:str):
"""
设置可以 pad 的 field 默认 pad 为什么类型的 tensor

:param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None],
若为 None ,则不进行 padding 。
:return:
"""
assert backend in SUPPORTED_BACKENDS
self.padders.clear()
self.backend = backend

def set_ignore(self, *field_names) -> "Collator":
"""
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。
Ex::
collator.set_ignore('field1', 'field2')

:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的
field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b;
如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。
:return: 返回 Collator 自身
"""
for field_name in field_names:
if field_name in self.input_fields:
self.input_fields.pop(field_name)
logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.")
self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。
self.ignore_fields.add(field_name)

return self



+ 0
- 0
fastNLP/core/collators/padders/__init__.py View File


+ 44
- 0
fastNLP/core/collators/padders/exceptions.py View File

@@ -0,0 +1,44 @@
__all__ = [
'InconsistencyError',
'EleDtypeUnsupportedError',
'EleDtypeDtypeConversionError',
'DtypeUnsupportedError',
"DtypeError"
]


class InconsistencyError(BaseException):
"""
当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。

"""
def __init__(self, msg, *args):
super(InconsistencyError, self).__init__(msg, *args)


class DtypeError(BaseException):
def __init__(self, msg, *args):
super(DtypeError, self).__init__(msg, *args)
self.msg = msg


class EleDtypeUnsupportedError(DtypeError):
"""
当 batch 中的 element 的类别本身无法 pad 的时候报错。
例如要求 str 类型的数据进行 padding 。

"""


class EleDtypeDtypeConversionError(DtypeError):
"""
当 batch 中的 element 的类别无法转换为 dtype 类型时报错。

"""


class DtypeUnsupportedError(DtypeError):
"""
当当前 backend 不支持这种类型的 dtype 时报错。

"""

+ 193
- 0
fastNLP/core/collators/padders/get_padder.py View File

@@ -0,0 +1,193 @@

from typing import Dict



from typing import Sequence, Any, Union, Dict
from abc import ABC

from fastNLP.core.log import logger


from .padder import Padder, NullPadder
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder
from .exceptions import *


def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder:
"""
根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。

:param batch_field: 将某 field 的内容组合成一个 batch 传入。
:param pad_val:
:param backend:
:param dtype:
:param field_name: 方便报错的。
:return:
"""
logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field))
if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
return NullPadder()
if backend is None:
logger.debug(f"The backend for field:{field_name} is None, not padding this field.")
return NullPadder()

# 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。
must_pad = False
if pad_val != 0 or dtype is not None:
must_pad = True

catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。

# 根据 catalog 来判定当前是否可以进行 pad 。
# 首先检查是否所有的 key 是一样长的,表明深度是一致的
depths = set(map(len, catalog.keys()))
num_depth = len(depths)
if num_depth != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
logger.debug(msg)
return NullPadder()

# 再检查所有的元素 shape 是否一致?
shape_lens = set([len(v[0]) for v in catalog.values()])
num_shape = len(shape_lens)
if num_shape != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
logger.debug(msg)
return NullPadder()

# 再检查所有的元素 type 是否一致
ele_dtypes = set([v[1] for v in catalog.values()])
num_eletypes = len(ele_dtypes)
if num_eletypes != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \
f"information please set logger's level to DEBUG."
if must_pad:
raise InconsistencyError(msg)
logger.debug(msg)
return NullPadder()

depth = depths.pop()
shape_len = shape_lens.pop()
ele_dtype = ele_dtypes.pop()

# 需要由 padder 自己决定是否能够 pad 。
try:
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True]
if backend == 'raw':
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'numpy':
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw':
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'numpy':
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
elif backend == 'torch':
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)

if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report."
if must_pad:
raise RuntimeError(msg)
logger.debug(msg)
return NullPadder()

except DtypeError as e:
msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \
"information please set logger's level to DEBUG."
if must_pad:
raise type(e)(msg=msg)
logger.debug(msg)
return NullPadder()

except BaseException as e:
raise e

return NullPadder()


class HasShapeDtype(ABC):
"""
检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。

"""

@classmethod
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]:
if cls is HasShapeDtype:
if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'):
return True
return False
return NotImplemented


def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict:
"""
获取对象的中 element 的基本信息,用于判断是否可以 padding。

:param content:
:param tuple parent:
:param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。
例如: [1, 2, 3] -> {(0,): ((), <class 'int'>), (1,): ((), <class 'int'>), (2,): ((), <class 'int'>)}
例如: [1, [2, 3], 4] -> {(0,): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (1, 1): ((), <class 'int'>), (2,): ((), <class 'int'>)}
例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), <class 'int'>), (0, 1): ((), <class 'int'>), (1, 0): ((), <class 'int'>), (2, 0): ((), <class 'int'>), (2, 1): ((), <class 'int'>)}
例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)]
-> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)}

:return:
"""
if catalog is None:
catalog = {}

if parent is None:
parent = ()

if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray
shape = content.shape
dtype = content.dtype
catalog[parent] = (shape, dtype)
elif isinstance(content, (tuple, list)):
for i, c in enumerate(content):
_get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog)
else: # 包括 int/float/bool/dict 以及 其它无法pad 的等
catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别
return catalog




"""
from numbers import Number

issubclass(type(3), Number) # True
issubclass(type(3.1), Number) # True
issubclass(type('3'), Number) # False
issubclass(type(True), Number) # True
issubclass(type(np.zeros(3)[0]), Number) # True
isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True
isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定
is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype)
"""




+ 72
- 0
fastNLP/core/collators/padders/numpy_padder.py View File

@@ -0,0 +1,72 @@
__all__ = [
'NumpyNumberPadder',
'NumpySequencePadder',
]

from numbers import Number
from abc import ABC
from typing import Any, Union
import numpy as np

from .padder import Padder
from .utils import get_padded_numpy_array, is_number_or_numpy_number
from .exceptions import *


def _get_dtype(ele_dtype, dtype, class_name):
if not is_number_or_numpy_number(ele_dtype):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers but get `{ele_dtype}`.")

if dtype is None:
dtype = ele_dtype
else:
if not is_number_or_numpy_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or numpy numbers but get `{dtype}`.")
dtype = dtype
return dtype


class NumpyNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
return np.array(batch_field, dtype=dtype)


class NumpySequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val)


class NumpyTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
"""
pad 类似于 [np.array([3, 4], np.array([1])] 的 field

:param ele_dtype:
:param pad_val:
:param dtype:
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
array = np.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
array[slices] = field
return array


+ 21
- 0
fastNLP/core/collators/padders/padder.py View File

@@ -0,0 +1,21 @@

class Padder:
def __init__(self, pad_val, dtype):
self.pad_val = pad_val
self.dtype = dtype

def __call__(self, batch_field):
return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
raise NotImplementedError()


class NullPadder(Padder):
def __init__(self, ele_dtype=None, pad_val=None, dtype=None):
super().__init__(pad_val=pad_val, dtype=dtype)

def __call__(self, batch_field):
# 直接返回,不调用 pad() 方法加快速度。
return batch_field

+ 48
- 0
fastNLP/core/collators/padders/raw_padder.py View File

@@ -0,0 +1,48 @@


from .padder import Padder
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array
from .exceptions import *


def _get_dtype(ele_dtype, dtype, class_name):
if is_number(ele_dtype):
if dtype is None:
dtype = ele_dtype
elif not is_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but "
f"get `{dtype}`.")
else:
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"but get `{ele_dtype}`.")
return dtype


class RawNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

def __call__(self, batch_field):
return batch_field

@staticmethod
def pad(batch_field, pad_val, dtype):
raise NotImplementedError()


class RawSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
"""

:param batch_field:
:param pad_val:
:param dtype: 该参数无意义。
:return:
"""
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()

+ 157
- 0
fastNLP/core/collators/padders/torch_padder.py View File

@@ -0,0 +1,157 @@

from inspect import isclass
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch
numpy_to_torch_dtype_dict = {
np.bool_: torch.bool,
np.uint8: torch.uint8,
np.int8: torch.int8,
np.int16: torch.int16,
np.int32: torch.int32,
np.int64: torch.int64,
np.float16: torch.float16,
np.float32: torch.float32,
np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了
np.complex64: torch.complex64,
np.complex128: torch.complex128
}
number_to_torch_dtype_dict = {
float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64
int: torch.int64,
bool: torch.bool
}

from .padder import Padder
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class
from .exceptions import *


def is_torch_tensor(dtype):
if not isclass(dtype) and isinstance(dtype, torch.dtype):
return True
return False


def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")

if dtype is not None:
if not (is_torch_tensor(dtype) or is_number(dtype)):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or torch.dtype but get `{dtype}`.")
dtype = number_to_torch_dtype_dict.get(dtype, dtype)
else:
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)):
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype)
dtype = ele_dtype
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type)
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_torch_dtype_dict.get(ele_dtype)

return dtype


class TorchNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
return torch.tensor(batch_field, dtype=dtype)


class TorchSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val)
return tensor


class TorchTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
"""
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的

:param ele_dtype:
:param pad_val:
:param dtype:
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
print(dtype)
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])
if isinstance(field, np.ndarray):
field = torch.from_numpy(field)
tensor[slices] = field
return tensor


def fill_tensor(batch_field, padded_batch, dtype):
"""
将 batch_field 中的值填入到 tensor 中。

:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 tensor
:param dtype: 数据的类别

:return:
"""
if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype)
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype)
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
except:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype)
elif padded_batch.ndim == 1:
padded_batch[:] = torch.tensor(batch_field, dtype=dtype)
else:
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
"report.")
return padded_batch


def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0):
"""
例如:
[[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]])

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
/4d(多为图片)。
:param dtype: 目标类别是什么
:param pad_val: pad 的 value
:return:
"""
shapes = get_shape(batch_field)
tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val)
tensor = fill_tensor(batch_field, tensor, dtype=dtype)
return tensor

+ 20
- 0
fastNLP/core/collators/padders/torch_utils.py View File

@@ -0,0 +1,20 @@


from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch


def is_torch_tensor_dtype(dtype) -> bool:
"""
返回当前 dtype 是否是 torch 的 dtype 类型


:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果
:return:
"""
try:
return isinstance(dtype, torch.dtype)
except:
return False

+ 173
- 0
fastNLP/core/collators/padders/utils.py View File

@@ -0,0 +1,173 @@

from typing import Sequence, List
from numbers import Number
import re
from inspect import isclass

import numpy as np
np_str_obj_array_pattern = re.compile(r'[SaUO]')


def get_shape(batch_field:List, shape=None):
"""
给定 field 返回这个 field pad 完成之后的 shape 。
例如: [[1, 2, 3], [3]] -> [2, 3]
[[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3]

:param batch_field: list,第 0 维一般为 batch 维度。
:param shape: 无需传入。
:return:
"""
if shape is None:
shape = []
if isinstance(batch_field, Sequence):
num_ele = len(batch_field)
_shape = shape + [num_ele]
try:
shapes = []
if isinstance(batch_field[0], Sequence):
for _field in batch_field:
shapes.append(get_shape(_field, _shape))
max_shape = [max(_) for _ in zip(*shapes)]
return max_shape
except IndexError: # 空的shape
pass
return _shape # 说明是一个空的 sequence
else:
return shape


def fill_array(batch_field:List, padded_batch:np.ndarray):
"""
将 batch_field 中的值填入到 array 中。

:param batch_field: 需要填充进入 array 中的内容
:param padded_batch: 待填充的 np.ndarray
:return:
"""
if padded_batch.ndim == 2:
for i, content_i in enumerate(batch_field):
padded_batch[i, :len(content_i)] = content_i
elif padded_batch.ndim == 3:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
padded_batch[i, j, :len(content_ii)] = content_ii
elif padded_batch.ndim == 4:
try: # 应该是图像,所以直接应该就 ok 了。
padded_batch = np.array(batch_field)
except:
for i, content_i in enumerate(batch_field):
for j, content_ii in enumerate(content_i):
for k, content_iii in enumerate(content_ii):
padded_batch[i, j, k, :len(content_iii)] = content_iii
elif padded_batch.ndim == 1:
padded_batch[:] = batch_field
else:
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please "
"report.")
return padded_batch


def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray:
"""
例如:
[[1,2], [3]] -> np.array([[1, 2], [3, 0]])

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
/4d(多为图片)。
:param dtype: 目标类别是什么
:param pad_val: pad 的 value
:return:
"""
shapes = get_shape(batch_field)
array = np.full(shapes, dtype=dtype, fill_value=pad_val)
array = fill_array(batch_field, array)
return array


def get_padded_nest_list(batch_field: List, pad_val=0) -> List:
"""
例如:
[[1,2], [3]] -> [[1, 2], [3, 0]]

:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列)
/4d(多为图片)。
:param pad_val: pad 的 value
:return:
"""

array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist()
return array


def is_number_or_numpy_number(dtype):
"""
判断 dtype 是否是数字类型,或者 numpy 的数字类型。
is_number_or_numpy_number(type(3)) # True
is_number_or_numpy_number(type(3.1)) # True
is_number_or_numpy_number(type('3')) # False
is_number_or_numpy_number(type(True)) # True
is_number_or_numpy_number(type(np.zeros(3)[0])) # True
is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True
is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True
is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False
is_number_or_numpy_number(np.array([1, [2]]).dtype) # False

:param dtype:
:return:
"""
if is_number(dtype):
return True
else:
if isclass(dtype):
return is_numpy_generic_class(dtype)
elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
return True
return False


def is_numpy_number_dtype(dtype):
if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None:
return True
return False


def is_numpy_generic_class(dtype):
"""
形如 np.int64,或者 np.zeros(1).dtype.type 的值

:param dtype:
:return:
"""
if isclass(dtype) and issubclass(dtype, np.generic):
return True
return False


def is_number(dtype):
try:
if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \
and not is_numpy_number_dtype(dtype):
return True
except:
return False



if __name__ == '__main__':
# a = [[[1]], [1, 2, 3], [3]]
# a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
# b = get_padded_nest_list(a)
# print(type(b[0]))
# print(b)
# import torch
print(is_number_or_numpy_number(type(3))) # True
print(is_number_or_numpy_number(type(3.1))) # True
print(is_number_or_numpy_number(type('3'))) # False
print(is_number_or_numpy_number(type(True))) # True
print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True
print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True
print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True
print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False
print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False


+ 103
- 0
fastNLP/core/collators/utils.py View File

@@ -0,0 +1,103 @@
from collections import defaultdict
from functools import reduce
from typing import Sequence, Mapping, Dict

NESTED_DICT_SEPARATOR = '@@'


def unpack_batch_mapping(batch:Sequence[Mapping])->Dict:
"""
将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]}

:param batch:
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for key, value in sample.items():
dict_batch[key].append(value)
return dict_batch


def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict:
"""
将 nested 的 dict 中的内容展开到一个 flat dict 中

:param batch:
:param _parent: 内部使用
:return:
"""
dict_batch = defaultdict(list)
if _parent != '':
_parent += NESTED_DICT_SEPARATOR
for sample in batch:
for key, value in sample.items():
if isinstance(value, Mapping):
_dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key)
for key, value in _dict_batch.items():
dict_batch[key].append(value)
else:
dict_batch[_parent + key].append(value)
return dict_batch


def _unpack_batch_nested_mapping(value, _parent)->Dict:
_dict = {}
_parent += NESTED_DICT_SEPARATOR
for k, v in value.items():
if isinstance(v, Mapping):
__dict = _unpack_batch_nested_mapping(v, _parent=_parent + k)
_dict.update(__dict)
else:
_dict[_parent + k] = v
return _dict


def pack_batch_nested_mapping(batch:Mapping) -> Dict:
"""
需要恢复出 nested 的 dict 原来的样式

:param batch:
:return:
"""
dicts = []

for key, value in batch.items():
keys = key.split(NESTED_DICT_SEPARATOR)
d = {keys[-1]: value}
for key in keys[:-1:][::-1]:
d = {key: d}
dicts.append(d)
return reduce(_merge_dict, dicts)


def _merge_dict(a, b, path=None):
"merges b into a"
if path is None: path = []
for key in b:
if key in a:
if isinstance(a[key], dict) and isinstance(b[key], dict):
_merge_dict(a[key], b[key], path + [str(key)])
else:
raise Exception('Conflict at %s' % '.'.join(path + [str(key)]))
else:
a[key] = b[key]
return a


def unpack_batch_sequence(batch:Sequence[Sequence])->Dict:
"""
将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]}

:param batch:
:return:
"""
dict_batch = defaultdict(list)
for sample in batch:
for i, content in enumerate(sample):
dict_batch[f'_{i}'].append(content)
return dict_batch


def pack_batch_sequence(batch:Mapping)->Sequence:
return list(batch.values())

+ 0
- 0
tests/core/collators/__init__.py View File


+ 0
- 0
tests/core/collators/padders/__init__.py View File


+ 139
- 0
tests/core/collators/padders/test_get_padder.py View File

@@ -0,0 +1,139 @@
import pytest
import numpy as np

from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \
_get_element_shape_dtype
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR


def test_get_element_shape_dtype():
catalog = _get_element_shape_dtype([[1], [2, 3], [3], 2])
catalog = _get_element_shape_dtype([['1'], [2, 3]])
catalog = _get_element_shape_dtype([['1'], [2, 3]])
catalog = _get_element_shape_dtype([['1'], ['2', '3']])
catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))])


@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle'])
def test_get_padder_run(backend):
if not _NEED_IMPORT_TORCH and backend == 'torch':
pytest.skip("No torch")
if not _NEED_IMPORT_PADDLE and backend == 'paddle':
pytest.skip("No paddle")
if not _NEED_IMPORT_PADDLE and backend == 'jittor':
pytest.skip("No jittor")
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')

if backend is not None:
# 不能 pad
batch_field = [[1], [2, 3], [3], 2]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test')

# 不能 pad
batch_field = [['2'], ['2'], ['2', '2']]
with pytest.raises(DtypeError) as exec_info:
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test')

batch_field = [np.zeros(3), np.zeros((3, 1))]
with pytest.raises(InconsistencyError) as exec_info:
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') # no pad

batch_field = [np.zeros((3, 1)), np.zeros((4, 1))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')


def test_raw_padder():
backend = 'raw'
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert pad_batch == batch_field

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3)

batch_field = [[[1]], [[2, 2], [2]], [[3], [3], [3]]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3, 2)


def test_numpy_padder():
backend = 'numpy'
target_type = np.ndarray
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert (pad_batch == np.array(batch_field)).sum()==len(batch_field)

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==3

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,3))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==9

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,))]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')


def test_torch_padder():
if not _NEED_IMPORT_TORCH:
pytest.skip("No torch.")
import torch
backend = 'torch'
target_type = torch.Tensor
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert (pad_batch == torch.LongTensor(batch_field)).sum()==len(batch_field)

batch_field = [[1], [2, 2], [3, 3, 3]]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==3

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,3))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==9

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12

batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,))]
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')


+ 81
- 0
tests/core/collators/padders/test_numpy_padder.py View File

@@ -0,0 +1,81 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.numpy_padder import NumpyTensorPadder, NumpySequencePadder, NumpyNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH


class TestNumpyNumberPadder:
def test_run(self):
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
assert isinstance(a, np.ndarray)
assert (padder(a) == np.array(a)).sum() == 3


class TestNumpySequencePadder:
def test_run(self):
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (2, 3)
b = np.array([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
if _NEED_IMPORT_TORCH:
import torch
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)


class TestNumpyTensorPadder:
def test_run(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1)
a = [np.zeros(3), np.zeros(2), np.zeros(0)]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3)
b = np.array([[0, 0, 0], [0, 0, -1], [-1, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 1))]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3, 2)
b = np.array([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))]
a = padder(a)
shape = np.shape(a)
assert isinstance(a, np.ndarray)
assert shape == (3, 3, 2)
b = np.array([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
if _NEED_IMPORT_TORCH:
import torch
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)




+ 29
- 0
tests/core/collators/padders/test_raw_padder.py View File

@@ -0,0 +1,29 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.raw_padder import RawNumberPadder, RawSequencePadder
from fastNLP.core.collators.padders.exceptions import DtypeError


class TestRawNumberPadder:
def test_run(self):
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
assert padder(a) == a


class TestRawSequencePadder:
def test_run(self):
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
assert shape == (2, 3)
b = np.array([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)

+ 105
- 0
tests/core/collators/padders/test_torch_padder.py View File

@@ -0,0 +1,105 @@
import numpy as np
import pytest

from fastNLP.core.collators.padders.torch_padder import TorchTensorPadder, TorchSequencePadder, TorchNumberPadder
from fastNLP.core.collators.padders.exceptions import DtypeError
from fastNLP.envs.imports import _NEED_IMPORT_TORCH

if _NEED_IMPORT_TORCH:
import torch


class TestTorchNumberPadder:
def test_run(self):
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [1, 2, 3]
t_a = padder(a)
assert isinstance(t_a, torch.Tensor)
assert (t_a == torch.LongTensor(a)).sum() == 3


class TestTorchSequencePadder:
def test_run(self):
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (2, 3)
b = torch.LongTensor([[1, 2, 3], [3, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1)
a = padder([[1], [2, 322]])
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)



class TestTorchTensorPadder:
def test_run(self):
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3)
b = torch.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]])
assert (a == b).sum().item() == shape[0]*shape[1]

a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 2))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, 0], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 1))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1)
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))]
a = padder(a)
shape = a.shape
assert isinstance(a, torch.Tensor)
assert tuple(shape) == (3, 3, 2)
b = torch.FloatTensor([[[0, 0], [0, 0], [0, 0]],
[[0, 0], [0, 0], [-1, -1]],
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
with pytest.raises(DtypeError):
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)




+ 90
- 0
tests/core/collators/padders/test_utils.py View File

@@ -0,0 +1,90 @@
import pytest
import numpy as np

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.collators.padders.utils import get_shape, get_padded_numpy_array, \
get_padded_nest_list, is_number_or_numpy_number, is_numpy_number_dtype, is_number


def test_get_shape():
a = [[1, 2, 3], [3]]
assert get_shape(a) == [2, 3]

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
assert get_shape(a) == [2, 3, 3]

a = [[[1], [2], [3, 4]], [[]]]
assert get_shape(a) == [2, 3, 2]


def test_get_padded_numpy_array():
a = [[1, 2, 3], [3]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3)

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3, 3)

a = [[[1], [2], [3, 4]], [[]]]
a = get_padded_numpy_array(a, dtype=int, pad_val=-1)
assert a.shape == (2, 3, 2)


def test_get_padded_nest_list():
a = [[1, 2, 3], [3]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3)

a = [[[1], [2], [3, 4]], [[2, 3, 4]]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3, 3)

a = [[[1], [2], [3, 4]], [[]]]
a = get_padded_nest_list(a, pad_val=-1)
assert np.shape(a) == (2, 3, 2)


def test_is_number_or_numpy_number():
assert is_number_or_numpy_number(type(3)) is True
assert is_number_or_numpy_number(type(3.1)) is True
assert is_number_or_numpy_number(type(True)) is True
assert is_number_or_numpy_number(type('3')) is False
assert is_number_or_numpy_number(np.zeros(3).dtype) is True
assert is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) is True
assert is_number_or_numpy_number(np.zeros(3, dtype=object).dtype) is False

if _NEED_IMPORT_TORCH:
import torch
dtype = torch.ones(3).dtype
assert is_number_or_numpy_number(dtype) is False


def test_is_number():
assert is_number(type(3)) is True
assert is_number(type(3.1)) is True
assert is_number(type(True)) is True
assert is_number(type('3')) is False
assert is_number(np.zeros(3).dtype) is False
assert is_number(np.zeros(3, dtype=int).dtype) is False
assert is_number(np.zeros(3, dtype=object).dtype) is False

if _NEED_IMPORT_TORCH:
import torch
dtype = torch.ones(3).dtype
assert is_number(dtype) is False


def test_is_numpy_number():
assert is_numpy_number_dtype(type(3)) is False
assert is_numpy_number_dtype(type(3.1)) is False
assert is_numpy_number_dtype(type(True)) is False
assert is_numpy_number_dtype(type('3')) is False
assert is_numpy_number_dtype(np.zeros(3).dtype) is True
assert is_numpy_number_dtype(np.zeros(3, dtype=int).dtype) is True
assert is_numpy_number_dtype(np.zeros(3, dtype=object).dtype) is False

if _NEED_IMPORT_TORCH:
import torch
dtype = torch.ones(3).dtype
assert is_numpy_number_dtype(dtype) is False

+ 225
- 0
tests/core/collators/test_new_collator.py View File

@@ -0,0 +1,225 @@

import numpy as np
import pytest

from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR

from fastNLP.core.collators.new_collator import Collator


def _assert_equal(d1, d2):
try:
if 'torch' in str(type(d1)):
if 'float64' in str(d2.dtype):
print(d2.dtype)
assert (d1 == d2).all().item()
else:
assert all(d1 == d2)
except TypeError:
assert d1 == d2
except ValueError:
assert (d1 == d2).all()


def findDictDiff(d1, d2, path=""):
for k in d1:
if k in d2:
if isinstance(d1[k], dict):
findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k)
else:
_assert_equal(d1[k], d2[k])
else:
raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k))


def findListDiff(d1, d2):
assert len(d1)==len(d2)
for _d1, _d2 in zip(d1, d2):
if isinstance(_d1, list):
findListDiff(_d1, _d2)
else:
_assert_equal(_d1, _d2)


class TestCollator:
def test_run(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}
collator = Collator(backend='raw')
assert raw_pad_batch == collator(dict_batch)
collator = Collator(backend='raw')
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='numpy')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]),
'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]),
'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]),
'b': np.array([[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='numpy')
numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]),
np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]),
np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(numpy_pad_lst, collator(list_batch))

if _NEED_IMPORT_TORCH:
import torch
collator = Collator(backend='torch')
numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]),
'lst_int': torch.LongTensor([[1, 0], [1, 2]]),
'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
'float': torch.FloatTensor([1.1, 2.1]),
'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]),
'numpy': torch.FloatTensor([[1], [0]]),
'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]),
'b': torch.LongTensor(
[[1, 2], [1, 2]])}}

findDictDiff(numpy_pad_batch, collator(dict_batch))
collator = Collator(backend='torch')
torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]),
torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]),
torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]),
torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(torch_pad_lst, collator(list_batch))

def test_pad(self):
dict_batch = [{
'str': '1',
'lst_str': ['1'],
'int': 1,
'lst_int': [1],
'nest_lst_int': [[1]],
'float': 1.1,
'lst_float': [1.1],
'bool': True,
'numpy': np.ones(1),
'dict': {'1': '1'},
'set': {'1'},
'nested_dict': {'a': 1, 'b':[1, 2]}
},
{
'str': '2',
'lst_str': ['2', '2'],
'int': 2,
'lst_int': [1, 2],
'nest_lst_int': [[1], [1, 2]],
'float': 2.1,
'lst_float': [2.1],
'bool': False,
'numpy': np.zeros(1),
'dict': {'1': '2'},
'set': {'2'},
'nested_dict': {'a': 2, 'b': [1, 2]}
}
]

raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}}

# 测试 ignore
collator = Collator(backend='raw')
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 测试 set_pad
collator = Collator(backend='raw')
collator.set_pad('str', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

# 测试设置 pad 值
collator = Collator(backend='raw')
collator.set_pad('nest_lst_int', pad_val=100)
collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a')
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))

# 设置 backend 和 type
collator.set_pad('float', pad_val=100, backend='numpy', dtype=int)
raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]],
'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}}
findDictDiff(raw_pad_batch, collator(dict_batch))


# raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]],
# [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
# [{'1'}, {'2'}]]
list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_4', pad_val=None)
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]],
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

collator = Collator(backend='raw')
collator.set_pad('_0', pad_val=1)
with pytest.raises(BaseException):
collator(dict_batch)

list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}],
['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]]
collator = Collator(backend='raw')
collator.set_ignore('_0', '_3', '_1')
collator.set_pad('_2', backend='numpy')
collator.set_pad('_4', backend='numpy', pad_val=100)
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]),
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}],
[{'1'}, {'2'}]]
findListDiff(raw_pad_lst, collator(list_batch))

# _single
collator = Collator()
collator.set_pad('_single')
findListDiff(list_batch, collator(list_batch))








+ 37
- 0
tests/core/collators/test_utils.py View File

@@ -0,0 +1,37 @@

from fastNLP.core.collators.utils import *


def test_unpack_batch_mapping():
batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}]
assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]}


def test_unpack_batch_nested_mapping():
batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]}

batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}},
{'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}]
assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}


def test_pack_batch_nested_mapping():
batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2],
'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]}
new_batch = pack_batch_nested_mapping(batch)
assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2],
'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}}


def test_unpack_batch_sequence():
batch = [[1, 2, 3], [2, 4, 6]]
new_batch = unpack_batch_sequence(batch)
assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]}




Loading…
Cancel
Save