Browse Source

修改Padder

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
b2e9507118
9 changed files with 143 additions and 23 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +0
    -1
      fastNLP/core/callbacks/has_monitor_callback.py
  3. +15
    -9
      fastNLP/core/collators/padders/get_padder.py
  4. +8
    -1
      fastNLP/core/collators/padders/numpy_padder.py
  5. +29
    -8
      fastNLP/core/collators/padders/paddle_padder.py
  6. +33
    -1
      fastNLP/core/collators/padders/raw_padder.py
  7. +9
    -2
      fastNLP/core/collators/padders/torch_padder.py
  8. +25
    -0
      fastNLP/core/log/print.py
  9. +23
    -1
      tests/core/collators/padders/test_get_padder.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -59,6 +59,7 @@ __all__ = [
# log
"logger"

#
]
from .callbacks import *
from .collators import *


+ 0
- 1
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -161,7 +161,6 @@ class MonitorUtility:
return monitor_name



class HasMonitorCallback(MonitorUtility, Callback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
"""


+ 15
- 9
fastNLP/core/collators/padders/get_padder.py View File

@@ -7,7 +7,7 @@ from fastNLP.core.log import logger
from .padder import Padder, NullPadder
from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder
from .raw_padder import RawNumberPadder, RawSequencePadder
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder
from .exceptions import *

@@ -23,7 +23,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
:param field_name: 方便报错的。
:return:
"""
assert len(batch_field)!=0, "Empty batch encountered."
logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field))
if pad_val is None:
logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.")
@@ -63,7 +63,10 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
return NullPadder()

# 再检查所有的元素 type 是否一致
ele_dtypes = set([v[1] for v in catalog.values()])
try:
ele_dtypes = set([v[1] for v in catalog.values()])
except TypeError:
ele_dtypes = set([str(v[1]) for v in catalog.values()])
num_eletypes = len(ele_dtypes)
if num_eletypes != 1:
msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \
@@ -75,7 +78,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->

depth = depths.pop()
shape_len = shape_lens.pop()
ele_dtype = ele_dtypes.pop()
ele_dtype = list(catalog.values())[0][1] # 因为上面有except的情况,所以这样处理了

# 需要由 padder 自己决定是否能够 pad 。
try:
@@ -103,13 +106,16 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
else:
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).")

if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
# 如果有有 shape 的话,只有当该对象拥有 tolist() 方法才行
if depth == 1 and shape_len != 0 and callable(getattr(batch_field[0], 'tolist', None)):
if backend == 'raw':
return RawTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'numpy':
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'torch':
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
elif backend == 'paddle':
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype)
else:
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).")



+ 8
- 1
fastNLP/core/collators/padders/numpy_padder.py View File

@@ -66,7 +66,7 @@ class NumpySequencePadder(Padder):
class NumpyTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
pad 类似于 [np.array([3, 4], np.array([1])] 的 field
pad 类似于 [np.array([3, 4], np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。

:param pad_val: pad 的值是多少。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
@@ -77,6 +77,13 @@ class NumpyTensorPadder(Padder):

@staticmethod
def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], np.ndarray):
batch_field = [np.array(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), "
f"it must have tolist() method.")

shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
array = np.full(max_shape, fill_value=pad_val, dtype=dtype)


+ 29
- 8
fastNLP/core/collators/padders/paddle_padder.py View File

@@ -56,7 +56,7 @@ def is_paddle_dtype_str(dtype):


def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
if not (ele_dtype is not None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.")

@@ -80,7 +80,14 @@ def _get_dtype(ele_dtype, dtype, class_name):


class PaddleNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3])

:param pad_val: 该值无意义
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等
"""
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@@ -91,7 +98,14 @@ class PaddleNumberPadder(Padder):


class PaddleSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, ele_dtype=None, pad_val=0, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。

:param pad_val: pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@@ -102,19 +116,26 @@ class PaddleSequencePadder(Padder):


class PaddleTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
目前支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的
目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。

:param ele_dtype:
:param pad_val:
:param dtype:
:param pad_val: pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], paddle.Tensor):
batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")

shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):


+ 33
- 1
fastNLP/core/collators/padders/raw_padder.py View File

@@ -1,6 +1,7 @@
__all__ = [
"RawNumberPadder",
"RawSequencePadder"
"RawSequencePadder",
"RawTensorPadder"
]

from .padder import Padder
@@ -66,3 +67,34 @@ class RawSequencePadder(Padder):
:return:
"""
return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()


class RawTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。

:param pad_val: pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@staticmethod
def pad(batch_field, pad_val, dtype):
"""

:param batch_field:
:param pad_val:
:param dtype: 该参数无意义。
:return:
"""
try:
if not isinstance(batch_field[0], (list, tuple)):
batch_field = [field.tolist() for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a list or tuple(it is {type(batch_field[0])}), "
f"it must have tolist() method.")

return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist()

+ 9
- 2
fastNLP/core/collators/padders/torch_padder.py View File

@@ -41,7 +41,7 @@ def is_torch_tensor(dtype):


def _get_dtype(ele_dtype, dtype, class_name):
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")

@@ -101,7 +101,7 @@ class TorchSequencePadder(Padder):
class TorchTensorPadder(Padder):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的
目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。

:param pad_val: 需要 pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。
@@ -112,6 +112,13 @@ class TorchTensorPadder(Padder):

@staticmethod
def pad(batch_field, pad_val, dtype):
try:
if not isinstance(batch_field[0], torch.Tensor):
batch_field = [torch.tensor(field.tolist()) for field in batch_field]
except AttributeError:
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), "
f"it must have tolist() method.")

shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)


+ 25
- 0
fastNLP/core/log/print.py View File

@@ -0,0 +1,25 @@
__all__ = [
'print'
]

from .logger import logger


def print(*args, sep=' ', end='\n', file=None, flush=False):
"""
用来重定向 print 函数至 logger.info 的函数。

Example:
from fastNLP import print

print("This is a test") # 等价于调用了 logger.info("This is a test")

:param args: 需要打印的内容
:param sep: 存在多个输入时,使用的间隔。
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。
:param file: 该参数无意义。
:param flush: 该参数无意义。
:return:
"""
line = sep.join(args)
logger.info(line)

+ 23
- 1
tests/core/collators/padders/test_get_padder.py View File

@@ -23,7 +23,7 @@ def test_get_padder_run(backend):
pytest.skip("No torch")
if not _NEED_IMPORT_PADDLE and backend == 'paddle':
pytest.skip("No paddle")
if not _NEED_IMPORT_PADDLE and backend == 'jittor':
if not _NEED_IMPORT_JITTOR and backend == 'jittor':
pytest.skip("No jittor")
batch_field = [1, 2, 3]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
@@ -67,6 +67,13 @@ def test_raw_padder():
pad_batch = padder(batch_field)
assert np.shape(pad_batch) == (3, 3, 2)

batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, list)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12


def test_numpy_padder():
backend = 'numpy'
@@ -141,3 +148,18 @@ def test_torch_padder():
with pytest.raises(InconsistencyError):
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')

# 可以是 numpy.ndarray
batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, target_type)
assert pad_batch.shape == (3, 3, 3)
assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12

# 测试 to numpy
batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))]
padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test')
pad_batch = padder(batch_field)
assert isinstance(pad_batch, np.ndarray)
assert np.shape(pad_batch) == (3, 3, 3)
assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12

Loading…
Cancel
Save