| @@ -59,6 +59,7 @@ __all__ = [ | |||
| # log | |||
| "logger" | |||
| # | |||
| ] | |||
| from .callbacks import * | |||
| from .collators import * | |||
| @@ -161,7 +161,6 @@ class MonitorUtility: | |||
| return monitor_name | |||
| class HasMonitorCallback(MonitorUtility, Callback): | |||
| def __init__(self, monitor, larger_better, must_have_monitor=False): | |||
| """ | |||
| @@ -7,7 +7,7 @@ from fastNLP.core.log import logger | |||
| from .padder import Padder, NullPadder | |||
| from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder | |||
| from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | |||
| from .raw_padder import RawNumberPadder, RawSequencePadder | |||
| from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | |||
| from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | |||
| from .exceptions import * | |||
| @@ -23,7 +23,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
| :param field_name: 方便报错的。 | |||
| :return: | |||
| """ | |||
| assert len(batch_field)!=0, "Empty batch encountered." | |||
| logger.debug(f"The content in the field:`{field_name}` is:\n" + str(batch_field)) | |||
| if pad_val is None: | |||
| logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") | |||
| @@ -63,7 +63,10 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
| return NullPadder() | |||
| # 再检查所有的元素 type 是否一致 | |||
| ele_dtypes = set([v[1] for v in catalog.values()]) | |||
| try: | |||
| ele_dtypes = set([v[1] for v in catalog.values()]) | |||
| except TypeError: | |||
| ele_dtypes = set([str(v[1]) for v in catalog.values()]) | |||
| num_eletypes = len(ele_dtypes) | |||
| if num_eletypes != 1: | |||
| msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ | |||
| @@ -75,7 +78,7 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
| depth = depths.pop() | |||
| shape_len = shape_lens.pop() | |||
| ele_dtype = ele_dtypes.pop() | |||
| ele_dtype = list(catalog.values())[0][1] # 因为上面有except的情况,所以这样处理了 | |||
| # 需要由 padder 自己决定是否能够 pad 。 | |||
| try: | |||
| @@ -103,13 +106,16 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
| else: | |||
| raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | |||
| if depth == 1 and shape_len != 0: | |||
| if backend == 'numpy': | |||
| return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
| # 如果有有 shape 的话,只有当该对象拥有 tolist() 方法才行 | |||
| if depth == 1 and shape_len != 0 and callable(getattr(batch_field[0], 'tolist', None)): | |||
| if backend == 'raw': | |||
| return RawTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||
| elif backend == 'numpy': | |||
| return NumpyTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||
| elif backend == 'torch': | |||
| return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
| return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||
| elif backend == 'paddle': | |||
| return PaddleTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
| return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | |||
| else: | |||
| raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | |||
| @@ -66,7 +66,7 @@ class NumpySequencePadder(Padder): | |||
| class NumpyTensorPadder(Padder): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| pad 类似于 [np.array([3, 4], np.array([1])] 的 field | |||
| pad 类似于 [np.array([3, 4], np.array([1])] 的 field 。若内部元素不为 np.ndarray ,则必须含有 tolist() 方法。 | |||
| :param pad_val: pad 的值是多少。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| @@ -77,6 +77,13 @@ class NumpyTensorPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| try: | |||
| if not isinstance(batch_field[0], np.ndarray): | |||
| batch_field = [np.array(field.tolist()) for field in batch_field] | |||
| except AttributeError: | |||
| raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " | |||
| f"it must have tolist() method.") | |||
| shapes = [field.shape for field in batch_field] | |||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
| array = np.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
| @@ -56,7 +56,7 @@ def is_paddle_dtype_str(dtype): | |||
| def _get_dtype(ele_dtype, dtype, class_name): | |||
| if not (is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | |||
| if not (ele_dtype is not None or is_number_or_numpy_number(ele_dtype) or is_paddle_tensor(ele_dtype) or is_paddle_dtype_str(ele_dtype)): | |||
| raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
| f"or numpy numbers or paddle.Tensor but get `{ele_dtype}`.") | |||
| @@ -80,7 +80,14 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
| class PaddleNumberPadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| 可以将形如 [1, 2, 3] 这类的数据转为 paddle.Tensor([1, 2, 3]) | |||
| :param pad_val: 该值无意义 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||
| """ | |||
| # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @@ -91,7 +98,14 @@ class PaddleNumberPadder(Padder): | |||
| class PaddleSequencePadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| def __init__(self, ele_dtype=None, pad_val=0, dtype=None): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 paddle.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
| :param pad_val: pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||
| """ | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @@ -102,19 +116,26 @@ class PaddleSequencePadder(Padder): | |||
| class PaddleTensorPadder(Padder): | |||
| def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| 目前仅支持 [paddle.tensor([3, 2], paddle.tensor([1])] 类似的 | |||
| 目前支持 [paddle.tensor([3, 2], paddle.tensor([2, 1])] 类似的,若内部元素不为 paddle.tensor ,则必须含有 tolist() 方法。 | |||
| :param ele_dtype: | |||
| :param pad_val: | |||
| :param dtype: | |||
| :param pad_val: pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 paddle.tensor 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么。如 int, float, 'int32' 等 | |||
| """ | |||
| dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| try: | |||
| if not isinstance(batch_field[0], paddle.Tensor): | |||
| batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field] | |||
| except AttributeError: | |||
| raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | |||
| f"it must have tolist() method.") | |||
| shapes = [field.shape for field in batch_field] | |||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
| if isinstance(dtype, np.dtype): | |||
| @@ -1,6 +1,7 @@ | |||
| __all__ = [ | |||
| "RawNumberPadder", | |||
| "RawSequencePadder" | |||
| "RawSequencePadder", | |||
| "RawTensorPadder" | |||
| ] | |||
| from .padder import Padder | |||
| @@ -66,3 +67,34 @@ class RawSequencePadder(Padder): | |||
| :return: | |||
| """ | |||
| return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() | |||
| class RawTensorPadder(Padder): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| 将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||
| :param pad_val: pad 的值 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
| :param dtype: 输出的数据的 dtype 是什么 | |||
| """ | |||
| dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
| super().__init__(pad_val=pad_val, dtype=dtype) | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| """ | |||
| :param batch_field: | |||
| :param pad_val: | |||
| :param dtype: 该参数无意义。 | |||
| :return: | |||
| """ | |||
| try: | |||
| if not isinstance(batch_field[0], (list, tuple)): | |||
| batch_field = [field.tolist() for field in batch_field] | |||
| except AttributeError: | |||
| raise RuntimeError(f"If the field is not a list or tuple(it is {type(batch_field[0])}), " | |||
| f"it must have tolist() method.") | |||
| return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() | |||
| @@ -41,7 +41,7 @@ def is_torch_tensor(dtype): | |||
| def _get_dtype(ele_dtype, dtype, class_name): | |||
| if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | |||
| if not (ele_dtype is None or (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | |||
| raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
| f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | |||
| @@ -101,7 +101,7 @@ class TorchSequencePadder(Padder): | |||
| class TorchTensorPadder(Padder): | |||
| def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
| """ | |||
| 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | |||
| 目前支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的。若内部元素不为 torch.tensor ,则必须含有 tolist() 方法。 | |||
| :param pad_val: 需要 pad 的值。 | |||
| :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||
| @@ -112,6 +112,13 @@ class TorchTensorPadder(Padder): | |||
| @staticmethod | |||
| def pad(batch_field, pad_val, dtype): | |||
| try: | |||
| if not isinstance(batch_field[0], torch.Tensor): | |||
| batch_field = [torch.tensor(field.tolist()) for field in batch_field] | |||
| except AttributeError: | |||
| raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | |||
| f"it must have tolist() method.") | |||
| shapes = [field.shape for field in batch_field] | |||
| max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
| tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
| @@ -0,0 +1,25 @@ | |||
| __all__ = [ | |||
| 'print' | |||
| ] | |||
| from .logger import logger | |||
| def print(*args, sep=' ', end='\n', file=None, flush=False): | |||
| """ | |||
| 用来重定向 print 函数至 logger.info 的函数。 | |||
| Example: | |||
| from fastNLP import print | |||
| print("This is a test") # 等价于调用了 logger.info("This is a test") | |||
| :param args: 需要打印的内容 | |||
| :param sep: 存在多个输入时,使用的间隔。 | |||
| :param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。 | |||
| :param file: 该参数无意义。 | |||
| :param flush: 该参数无意义。 | |||
| :return: | |||
| """ | |||
| line = sep.join(args) | |||
| logger.info(line) | |||
| @@ -23,7 +23,7 @@ def test_get_padder_run(backend): | |||
| pytest.skip("No torch") | |||
| if not _NEED_IMPORT_PADDLE and backend == 'paddle': | |||
| pytest.skip("No paddle") | |||
| if not _NEED_IMPORT_PADDLE and backend == 'jittor': | |||
| if not _NEED_IMPORT_JITTOR and backend == 'jittor': | |||
| pytest.skip("No jittor") | |||
| batch_field = [1, 2, 3] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| @@ -67,6 +67,13 @@ def test_raw_padder(): | |||
| pad_batch = padder(batch_field) | |||
| assert np.shape(pad_batch) == (3, 3, 2) | |||
| batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, list) | |||
| assert np.shape(pad_batch) == (3, 3, 3) | |||
| assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 | |||
| def test_numpy_padder(): | |||
| backend = 'numpy' | |||
| @@ -141,3 +148,18 @@ def test_torch_padder(): | |||
| with pytest.raises(InconsistencyError): | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| # 可以是 numpy.ndarray | |||
| batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] | |||
| padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, target_type) | |||
| assert pad_batch.shape == (3, 3, 3) | |||
| assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12 | |||
| # 测试 to numpy | |||
| batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))] | |||
| padder = get_padder(batch_field, pad_val=0, backend='numpy', dtype=int, field_name='test') | |||
| pad_batch = padder(batch_field) | |||
| assert isinstance(pad_batch, np.ndarray) | |||
| assert np.shape(pad_batch) == (3, 3, 3) | |||
| assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 | |||