Browse Source

1.增加部分padder的注释;2.修改set_backend的位置

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
d645e88c7f
16 changed files with 152 additions and 190 deletions
  1. +1
    -8
      fastNLP/__init__.py
  2. +8
    -8
      fastNLP/core/collators/padders/get_padder.py
  3. +22
    -7
      fastNLP/core/collators/padders/numpy_padder.py
  4. +7
    -0
      fastNLP/core/collators/padders/padder.py
  5. +28
    -11
      fastNLP/core/collators/padders/raw_padder.py
  6. +29
    -17
      fastNLP/core/collators/padders/torch_padder.py
  7. +1
    -2
      fastNLP/core/collators/padders/torch_utils.py
  8. +11
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  9. +7
    -4
      fastNLP/core/dataset/dataset.py
  10. +5
    -1
      fastNLP/envs/__init__.py
  11. +1
    -1
      fastNLP/envs/imports.py
  12. +4
    -5
      fastNLP/envs/set_backend.py
  13. +0
    -97
      fastNLP/io/cws.py
  14. +10
    -10
      tests/core/collators/padders/test_numpy_padder.py
  15. +4
    -4
      tests/core/collators/padders/test_raw_padder.py
  16. +14
    -14
      tests/core/collators/padders/test_torch_padder.py

+ 1
- 8
fastNLP/__init__.py View File

@@ -1,10 +1,3 @@

# 首先保证 FASTNLP_GLOBAL_RANK 正确设置
from fastNLP.envs.set_env_on_import import set_env_on_import
set_env_on_import()

# 再设置 backend 相关
from fastNLP.envs.set_backend import _set_backend
_set_backend()

from fastNLP.envs import *
from fastNLP.core import Trainer, Evaluator

+ 8
- 8
fastNLP/core/collators/padders/get_padder.py View File

@@ -84,25 +84,25 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->
try:
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True]
if backend == 'raw':
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return RawNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'numpy':
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'torch':
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)

if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种
if backend == 'raw':
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return RawSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'numpy':
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'torch':
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)

if depth == 1 and shape_len != 0:
if backend == 'numpy':
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)
elif backend == 'torch':
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype)
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype)

if shape_len != 0 and depth>1:
msg = "Does not support pad tensor under nested list. If you need this, please report."


+ 22
- 7
fastNLP/core/collators/padders/numpy_padder.py View File

@@ -1,6 +1,7 @@
__all__ = [
'NumpyNumberPadder',
'NumpySequencePadder',
"NumpyTensorPadder"
]

from numbers import Number
@@ -14,7 +15,7 @@ from .exceptions import *


def _get_dtype(ele_dtype, dtype, class_name):
if not is_number_or_numpy_number(ele_dtype):
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers but get `{ele_dtype}`.")

@@ -29,7 +30,14 @@ def _get_dtype(ele_dtype, dtype, class_name):


class NumpyNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3])

:param pad_val: 该值无意义
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@@ -39,7 +47,14 @@ class NumpyNumberPadder(Padder):


class NumpySequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。

:param pad_val: pad 的值是多少。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@@ -49,13 +64,13 @@ class NumpySequencePadder(Padder):


class NumpyTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
pad 类似于 [np.array([3, 4], np.array([1])] 的 field

:param ele_dtype:
:param pad_val:
:param dtype:
:param pad_val: pad 的值是多少。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)


+ 7
- 0
fastNLP/core/collators/padders/padder.py View File

@@ -14,6 +14,13 @@ class Padder:

class NullPadder(Padder):
def __init__(self, ele_dtype=None, pad_val=None, dtype=None):
"""
不进行任何 检查 与 pad 的空 padder 。

:param ele_dtype:
:param pad_val:
:param dtype:
"""
super().__init__(pad_val=pad_val, dtype=dtype)

def __call__(self, batch_field):


+ 28
- 11
fastNLP/core/collators/padders/raw_padder.py View File

@@ -1,25 +1,35 @@


from .padder import Padder
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number
from .exceptions import *


def _get_dtype(ele_dtype, dtype, class_name):
if is_number(ele_dtype):
if dtype is None:
dtype = ele_dtype
elif not is_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but "
f"get `{dtype}`.")
else:
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"but get `{ele_dtype}`.")
f"or numpy numbers but get `{ele_dtype}`.")

if dtype is None:
dtype = ele_dtype
else:
if not is_number_or_numpy_number(dtype):
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers "
f"or numpy numbers but get `{dtype}`.")
dtype = dtype

return dtype


class RawNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。

:param pad_val: 该值无意义
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@@ -32,7 +42,14 @@ class RawNumberPadder(Padder):


class RawSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。

:param pad_val: pad 的值
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。
:param dtype: 输出的数据的 dtype 是什么
"""
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)



+ 29
- 17
fastNLP/core/collators/padders/torch_padder.py View File

@@ -37,7 +37,7 @@ def is_torch_tensor(dtype):


def _get_dtype(ele_dtype, dtype, class_name):
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)):
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))):
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers "
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.")

@@ -47,20 +47,27 @@ def _get_dtype(ele_dtype, dtype, class_name):
f"or torch.dtype but get `{dtype}`.")
dtype = number_to_torch_dtype_dict.get(dtype, dtype)
else:
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)):
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype)
dtype = ele_dtype
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type)
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_torch_dtype_dict.get(ele_dtype)
if ele_dtype is not None:
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)):
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype)
dtype = ele_dtype
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type)
elif is_numpy_generic_class(ele_dtype):
dtype = numpy_to_torch_dtype_dict.get(ele_dtype)

return dtype


class TorchNumberPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3])

:param pad_val: 该值无意义
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@@ -70,7 +77,14 @@ class TorchNumberPadder(Padder):


class TorchSequencePadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。

:param pad_val: 需要 pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)

@@ -81,13 +95,13 @@ class TorchSequencePadder(Padder):


class TorchTensorPadder(Padder):
def __init__(self, ele_dtype, pad_val=0, dtype=None):
def __init__(self, pad_val=0, ele_dtype=None, dtype=None):
"""
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的

:param ele_dtype:
:param pad_val:
:param dtype:
:param pad_val: 需要 pad 的值。
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等
"""
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__)
super().__init__(pad_val=pad_val, dtype=dtype)
@@ -96,8 +110,6 @@ class TorchTensorPadder(Padder):
def pad(batch_field, pad_val, dtype):
shapes = [field.shape for field in batch_field]
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)]
if isinstance(dtype, np.dtype):
print(dtype)
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype)
for i, field in enumerate(batch_field):
slices = (i, ) + tuple(slice(0, s) for s in shapes[i])


+ 1
- 2
fastNLP/core/collators/padders/torch_utils.py View File

@@ -10,8 +10,7 @@ def is_torch_tensor_dtype(dtype) -> bool:
"""
返回当前 dtype 是否是 torch 的 dtype 类型


:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果
:param dtype: 类似与 torch.ones(3).dtype
:return:
"""
try:


+ 11
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -64,7 +64,7 @@ class TorchDataLoader(DataLoader):
:param sampler: sampler实例化对象
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据
:param num_workers: 进程的数量,当num_worker=0时不开启多进程
:param collate_fn: 对取得到的数据进行打包的callable函数
:param collate_fn: 对取得到的数据进行打包的callable函数。[None, auto, callable]
:param pin_memory:
:param drop_last: 是否去掉最后一个不符合batch_size的数据
:param timeout:
@@ -178,6 +178,16 @@ class TorchDataLoader(DataLoader):
"""
return self.cur_batch_indices

def set_pad(self):
pass

def set_ignore(self):
pass

def set_backend(self):
pass



def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,


+ 7
- 4
fastNLP/core/dataset/dataset.py View File

@@ -760,8 +760,7 @@ class DataSet:
dict_ = {key: value.content for key, value in self.field_arrays.items()}
return pd.DataFrame.from_dict(dict_)

# TODO 应该有返回值的吧
def to_csv(self, path: str) -> None:
def to_csv(self, path: str):
"""
将dataset保存为csv文件

@@ -770,7 +769,7 @@ class DataSet:
"""

df = self.to_pandas()
df.to_csv(path, encoding="utf-8")
return df.to_csv(path, encoding="utf-8")

def add_collate_fn(self, collate_fn: Callable) -> None:
"""
@@ -831,4 +830,8 @@ class DataSet:
"""
self.collate_fns.set_input(*field_names)


@property
def collator(self):
if self._collator is None:
self._collator = Collator()
return self._collator

+ 5
- 1
fastNLP/envs/__init__.py View File

@@ -14,7 +14,11 @@ __all__ = [

from .env import *
from .set_env_on_import import set_env_on_import
from .set_backend import dump_fastnlp_backend
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置
set_env_on_import()
from .set_backend import dump_fastnlp_backend, _set_backend
# 再设置 backend 相关
_set_backend()
from .imports import *
from .utils import _module_available, get_gpu_count
from .distributed import *

+ 1
- 1
fastNLP/envs/imports.py View File

@@ -5,9 +5,9 @@ import operator

from fastNLP.envs.env import FASTNLP_BACKEND
from fastNLP.envs.utils import _module_available, _compare_version
from fastNLP.envs.set_backend import SUPPORT_BACKENDS


SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor']
backend = os.environ.get(FASTNLP_BACKEND, 'all')
if backend == 'all':
need_import = SUPPORT_BACKENDS


+ 4
- 5
fastNLP/envs/set_backend.py View File

@@ -1,7 +1,3 @@
"""
这个文件用于自动以及手动设置某些环境变量的,该文件中的set_env()函数会在 fastNLP 被 import 的时候在set_env_on_import之后运行。可以
用于设置某些必要的环境变量。同时用户在使用时set_env()修改环境变量时,也应该保证set_env()函数在所有其它代码之前被运行。
"""
import os
import json
import sys
@@ -9,9 +5,12 @@ from collections import defaultdict


from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED
from fastNLP.envs.imports import SUPPORT_BACKENDS
from fastNLP.envs.utils import _module_available, get_gpu_count


SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor']


def _set_backend():
"""
根据环境变量或者默认配置文件设置 backend 。


+ 0
- 97
fastNLP/io/cws.py View File

@@ -1,97 +0,0 @@
r"""undocumented"""

__all__ = [
"CWSLoader"
]

import glob
import os
import random
import shutil
import time

from .loader import Loader
from fastNLP.core.dataset import DataSet, Instance


class CWSLoader(Loader):
r"""
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如:

Example::

上海 浦东 开发 与 法制 建设 同步
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )
...

该Loader读取后的DataSet具有如下的结构

.. csv-table::
:header: "raw_words"

"上海 浦东 开发 与 法制 建设 同步"
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )"
"..."
"""

def __init__(self, dataset_name: str = None):
r"""
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None
"""
super().__init__()
datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'}
if dataset_name in datanames:
self.dataset_name = datanames[dataset_name]
else:
self.dataset_name = None

def _load(self, path: str):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if line:
ds.append(Instance(raw_words=line))
return ds

def download(self, dev_ratio=0.1, re_download=False) -> str:
r"""
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff,
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看

:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
:param bool re_download: 是否重新下载数据,以重新切分数据。
:return: str
"""
if self.dataset_name is None:
return ''
data_dir = self._get_dataset_path(dataset_name=self.dataset_name)
modify_time = 0
for filepath in glob.glob(os.path.join(data_dir, '*')):
modify_time = os.stat(filepath).st_mtime
break
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=self.dataset_name)

if not os.path.exists(os.path.join(data_dir, 'dev.txt')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
try:
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2:
for line in f:
if random.random() < dev_ratio:
f2.write(line)
else:
f1.write(line)
os.remove(os.path.join(data_dir, 'train.txt'))
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt'))
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
os.remove(os.path.join(data_dir, 'middle_file.txt'))

return data_dir

+ 10
- 10
tests/core/collators/padders/test_numpy_padder.py View File

@@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH

class TestNumpyNumberPadder:
def test_run(self):
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = NumpyNumberPadder(pad_val=-1, ele_dtype=int, dtype=int)
a = [1, 2, 3]
assert isinstance(padder(a), np.ndarray)
assert (padder(a) == np.array(a)).sum() == 3
@@ -17,7 +17,7 @@ class TestNumpyNumberPadder:
@pytest.mark.torch
class TestNumpySequencePadder:
def test_run(self):
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = NumpySequencePadder(pad_val=-1, ele_dtype=int, dtype=int)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
@@ -27,18 +27,18 @@ class TestNumpySequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = NumpySequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = NumpySequencePadder(pad_val=-1, ele_dtype=str, dtype=int)
if _NEED_IMPORT_TORCH:
import torch
with pytest.raises(DtypeError):
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = NumpySequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int)


class TestNumpyTensorPadder:
def test_run(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1)
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3).dtype, dtype=int)
a = [np.zeros(3), np.zeros(2), np.zeros(0)]
a = padder(a)
shape = np.shape(a)
@@ -68,15 +68,15 @@ class TestNumpyTensorPadder:
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=str, dtype=int)
if _NEED_IMPORT_TORCH:
import torch
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int)
with pytest.raises(DtypeError):
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long)




+ 4
- 4
tests/core/collators/padders/test_raw_padder.py View File

@@ -7,14 +7,14 @@ from fastNLP.core.collators.padders.exceptions import DtypeError

class TestRawNumberPadder:
def test_run(self):
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = RawNumberPadder(pad_val=-1, ele_dtype=int, dtype=int)
a = [1, 2, 3]
assert padder(a) == a


class TestRawSequencePadder:
def test_run(self):
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = RawSequencePadder(pad_val=-1, ele_dtype=int, dtype=int)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = np.shape(a)
@@ -24,6 +24,6 @@ class TestRawSequencePadder:

def test_dtype_check(self):
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError):
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int)

+ 14
- 14
tests/core/collators/padders/test_torch_padder.py View File

@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH:
@pytest.mark.torch
class TestTorchNumberPadder:
def test_run(self):
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = TorchNumberPadder(pad_val=-1, ele_dtype=int, dtype=int)
a = [1, 2, 3]
t_a = padder(a)
assert isinstance(t_a, torch.Tensor)
@@ -22,7 +22,7 @@ class TestTorchNumberPadder:
@pytest.mark.torch
class TestTorchSequencePadder:
def test_run(self):
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1)
padder = TorchSequencePadder(pad_val=-1, ele_dtype=int, dtype=int)
a = [[1, 2, 3], [3]]
a = padder(a)
shape = a.shape
@@ -32,20 +32,20 @@ class TestTorchSequencePadder:
assert (a == b).sum().item() == shape[0]*shape[1]

def test_dtype_check(self):
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError):
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1)
padder = TorchSequencePadder(pad_val=-1, ele_dtype=str, dtype=int)
padder = TorchSequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int)
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.int8, dtype=None)
a = padder([[1], [2, 322]])
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1)
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(2).dtype, dtype=None)


@pytest.mark.torch
class TestTorchTensorPadder:
def test_run(self):
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int)
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)]
a = padder(a)
shape = a.shape
@@ -74,7 +74,7 @@ class TestTorchTensorPadder:
[[0, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1)
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int)
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))]
a = padder(a)
shape = a.shape
@@ -85,7 +85,7 @@ class TestTorchTensorPadder:
[[-1, -1], [-1, -1], [-1, -1]]])
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1)
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=None)
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))]
a = padder(a)
shape = a.shape
@@ -97,11 +97,11 @@ class TestTorchTensorPadder:
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2]

def test_dtype_check(self):
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1)
padder = TorchTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int)
with pytest.raises(DtypeError):
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1)
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1)
padder = TorchTensorPadder(pad_val=-1, ele_dtype=str, dtype=int)
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int)
padder = TorchTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long)




Loading…
Cancel
Save