@@ -1,10 +1,3 @@ | |||
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||
from fastNLP.envs.set_env_on_import import set_env_on_import | |||
set_env_on_import() | |||
# 再设置 backend 相关 | |||
from fastNLP.envs.set_backend import _set_backend | |||
_set_backend() | |||
from fastNLP.envs import * | |||
from fastNLP.core import Trainer, Evaluator |
@@ -84,25 +84,25 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||
try: | |||
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] | |||
if backend == 'raw': | |||
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return RawNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
elif backend == 'numpy': | |||
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
elif backend == 'torch': | |||
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | |||
if backend == 'raw': | |||
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return RawSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
elif backend == 'numpy': | |||
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
elif backend == 'torch': | |||
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
if depth == 1 and shape_len != 0: | |||
if backend == 'numpy': | |||
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
elif backend == 'torch': | |||
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||
if shape_len != 0 and depth>1: | |||
msg = "Does not support pad tensor under nested list. If you need this, please report." | |||
@@ -1,6 +1,7 @@ | |||
__all__ = [ | |||
'NumpyNumberPadder', | |||
'NumpySequencePadder', | |||
"NumpyTensorPadder" | |||
] | |||
from numbers import Number | |||
@@ -14,7 +15,7 @@ from .exceptions import * | |||
def _get_dtype(ele_dtype, dtype, class_name): | |||
if not is_number_or_numpy_number(ele_dtype): | |||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | |||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
f"or numpy numbers but get `{ele_dtype}`.") | |||
@@ -29,7 +30,14 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
class NumpyNumberPadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) | |||
:param pad_val: 该值无意义 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -39,7 +47,14 @@ class NumpyNumberPadder(Padder): | |||
class NumpySequencePadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
:param pad_val: pad 的值是多少。 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -49,13 +64,13 @@ class NumpySequencePadder(Padder): | |||
class NumpyTensorPadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
pad 类似于 [np.array([3, 4], np.array([1])] 的 field | |||
:param ele_dtype: | |||
:param pad_val: | |||
:param dtype: | |||
:param pad_val: pad 的值是多少。 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -14,6 +14,13 @@ class Padder: | |||
class NullPadder(Padder): | |||
def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | |||
""" | |||
不进行任何 检查 与 pad 的空 padder 。 | |||
:param ele_dtype: | |||
:param pad_val: | |||
:param dtype: | |||
""" | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
def __call__(self, batch_field): | |||
@@ -1,25 +1,35 @@ | |||
from .padder import Padder | |||
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array | |||
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number | |||
from .exceptions import * | |||
def _get_dtype(ele_dtype, dtype, class_name): | |||
if is_number(ele_dtype): | |||
if dtype is None: | |||
dtype = ele_dtype | |||
elif not is_number(dtype): | |||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " | |||
f"get `{dtype}`.") | |||
else: | |||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | |||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
f"but get `{ele_dtype}`.") | |||
f"or numpy numbers but get `{ele_dtype}`.") | |||
if dtype is None: | |||
dtype = ele_dtype | |||
else: | |||
if not is_number_or_numpy_number(dtype): | |||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||
f"or numpy numbers but get `{dtype}`.") | |||
dtype = dtype | |||
return dtype | |||
class RawNumberPadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 | |||
:param pad_val: 该值无意义 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -32,7 +42,14 @@ class RawNumberPadder(Padder): | |||
class RawSequencePadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||
:param pad_val: pad 的值 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -37,7 +37,7 @@ def is_torch_tensor(dtype): | |||
def _get_dtype(ele_dtype, dtype, class_name): | |||
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | |||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | |||
@@ -47,20 +47,27 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||
f"or torch.dtype but get `{dtype}`.") | |||
dtype = number_to_torch_dtype_dict.get(dtype, dtype) | |||
else: | |||
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||
dtype = ele_dtype | |||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||
elif is_numpy_generic_class(ele_dtype): | |||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||
if ele_dtype is not None: | |||
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||
dtype = ele_dtype | |||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||
elif is_numpy_generic_class(ele_dtype): | |||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||
return dtype | |||
class TorchNumberPadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) | |||
:param pad_val: 该值无意义 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -70,7 +77,14 @@ class TorchNumberPadder(Padder): | |||
class TorchSequencePadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||
:param pad_val: 需要 pad 的值。 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -81,13 +95,13 @@ class TorchSequencePadder(Padder): | |||
class TorchTensorPadder(Padder): | |||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||
""" | |||
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | |||
:param ele_dtype: | |||
:param pad_val: | |||
:param dtype: | |||
:param pad_val: 需要 pad 的值。 | |||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||
""" | |||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||
super().__init__(pad_val=pad_val, dtype=dtype) | |||
@@ -96,8 +110,6 @@ class TorchTensorPadder(Padder): | |||
def pad(batch_field, pad_val, dtype): | |||
shapes = [field.shape for field in batch_field] | |||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||
if isinstance(dtype, np.dtype): | |||
print(dtype) | |||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | |||
for i, field in enumerate(batch_field): | |||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | |||
@@ -10,8 +10,7 @@ def is_torch_tensor_dtype(dtype) -> bool: | |||
""" | |||
返回当前 dtype 是否是 torch 的 dtype 类型 | |||
:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 | |||
:param dtype: 类似与 torch.ones(3).dtype | |||
:return: | |||
""" | |||
try: | |||
@@ -64,7 +64,7 @@ class TorchDataLoader(DataLoader): | |||
:param sampler: sampler实例化对象 | |||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数。[None, auto, callable] | |||
:param pin_memory: | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param timeout: | |||
@@ -178,6 +178,16 @@ class TorchDataLoader(DataLoader): | |||
""" | |||
return self.cur_batch_indices | |||
def set_pad(self): | |||
pass | |||
def set_ignore(self): | |||
pass | |||
def set_backend(self): | |||
pass | |||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||
batch_size: int = 1, | |||
@@ -760,8 +760,7 @@ class DataSet: | |||
dict_ = {key: value.content for key, value in self.field_arrays.items()} | |||
return pd.DataFrame.from_dict(dict_) | |||
# TODO 应该有返回值的吧 | |||
def to_csv(self, path: str) -> None: | |||
def to_csv(self, path: str): | |||
""" | |||
将dataset保存为csv文件 | |||
@@ -770,7 +769,7 @@ class DataSet: | |||
""" | |||
df = self.to_pandas() | |||
df.to_csv(path, encoding="utf-8") | |||
return df.to_csv(path, encoding="utf-8") | |||
def add_collate_fn(self, collate_fn: Callable) -> None: | |||
""" | |||
@@ -831,4 +830,8 @@ class DataSet: | |||
""" | |||
self.collate_fns.set_input(*field_names) | |||
@property | |||
def collator(self): | |||
if self._collator is None: | |||
self._collator = Collator() | |||
return self._collator |
@@ -14,7 +14,11 @@ __all__ = [ | |||
from .env import * | |||
from .set_env_on_import import set_env_on_import | |||
from .set_backend import dump_fastnlp_backend | |||
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||
set_env_on_import() | |||
from .set_backend import dump_fastnlp_backend, _set_backend | |||
# 再设置 backend 相关 | |||
_set_backend() | |||
from .imports import * | |||
from .utils import _module_available, get_gpu_count | |||
from .distributed import * |
@@ -5,9 +5,9 @@ import operator | |||
from fastNLP.envs.env import FASTNLP_BACKEND | |||
from fastNLP.envs.utils import _module_available, _compare_version | |||
from fastNLP.envs.set_backend import SUPPORT_BACKENDS | |||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||
backend = os.environ.get(FASTNLP_BACKEND, 'all') | |||
if backend == 'all': | |||
need_import = SUPPORT_BACKENDS | |||
@@ -1,7 +1,3 @@ | |||
""" | |||
这个文件用于自动以及手动设置某些环境变量的,该文件中的set_env()函数会在 fastNLP 被 import 的时候在set_env_on_import之后运行。可以 | |||
用于设置某些必要的环境变量。同时用户在使用时set_env()修改环境变量时,也应该保证set_env()函数在所有其它代码之前被运行。 | |||
""" | |||
import os | |||
import json | |||
import sys | |||
@@ -9,9 +5,12 @@ from collections import defaultdict | |||
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | |||
from fastNLP.envs.imports import SUPPORT_BACKENDS | |||
from fastNLP.envs.utils import _module_available, get_gpu_count | |||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||
def _set_backend(): | |||
""" | |||
根据环境变量或者默认配置文件设置 backend 。 | |||
@@ -1,97 +0,0 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CWSLoader" | |||
] | |||
import glob | |||
import os | |||
import random | |||
import shutil | |||
import time | |||
from .loader import Loader | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class CWSLoader(Loader): | |||
r""" | |||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||
Example:: | |||
上海 浦东 开发 与 法制 建设 同步 | |||
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) | |||
... | |||
该Loader读取后的DataSet具有如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
""" | |||
def __init__(self, dataset_name: str = None): | |||
r""" | |||
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||
""" | |||
super().__init__() | |||
datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} | |||
if dataset_name in datanames: | |||
self.dataset_name = datanames[dataset_name] | |||
else: | |||
self.dataset_name = None | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self, dev_ratio=0.1, re_download=False) -> str: | |||
r""" | |||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str | |||
""" | |||
if self.dataset_name is None: | |||
return '' | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
modify_time = 0 | |||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||
modify_time = os.stat(filepath).st_mtime | |||
break | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
try: | |||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, 'train.txt')) | |||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||
return data_dir |
@@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
class TestNumpyNumberPadder: | |||
def test_run(self): | |||
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
padder = NumpyNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||
a = [1, 2, 3] | |||
assert isinstance(padder(a), np.ndarray) | |||
assert (padder(a) == np.array(a)).sum() == 3 | |||
@@ -17,7 +17,7 @@ class TestNumpyNumberPadder: | |||
@pytest.mark.torch | |||
class TestNumpySequencePadder: | |||
def test_run(self): | |||
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||
a = [[1, 2, 3], [3]] | |||
a = padder(a) | |||
shape = np.shape(a) | |||
@@ -27,18 +27,18 @@ class TestNumpySequencePadder: | |||
assert (a == b).sum().item() == shape[0]*shape[1] | |||
def test_dtype_check(self): | |||
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||
with pytest.raises(DtypeError): | |||
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
with pytest.raises(DtypeError): | |||
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||
class TestNumpyTensorPadder: | |||
def test_run(self): | |||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) | |||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3).dtype, dtype=int) | |||
a = [np.zeros(3), np.zeros(2), np.zeros(0)] | |||
a = padder(a) | |||
shape = np.shape(a) | |||
@@ -68,15 +68,15 @@ class TestNumpyTensorPadder: | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
def test_dtype_check(self): | |||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||
with pytest.raises(DtypeError): | |||
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
with pytest.raises(DtypeError): | |||
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||
with pytest.raises(DtypeError): | |||
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long) | |||
@@ -7,14 +7,14 @@ from fastNLP.core.collators.padders.exceptions import DtypeError | |||
class TestRawNumberPadder: | |||
def test_run(self): | |||
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
padder = RawNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||
a = [1, 2, 3] | |||
assert padder(a) == a | |||
class TestRawSequencePadder: | |||
def test_run(self): | |||
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
padder = RawSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||
a = [[1, 2, 3], [3]] | |||
a = padder(a) | |||
shape = np.shape(a) | |||
@@ -24,6 +24,6 @@ class TestRawSequencePadder: | |||
def test_dtype_check(self): | |||
with pytest.raises(DtypeError): | |||
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||
with pytest.raises(DtypeError): | |||
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) |
@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: | |||
@pytest.mark.torch | |||
class TestTorchNumberPadder: | |||
def test_run(self): | |||
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
padder = TorchNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||
a = [1, 2, 3] | |||
t_a = padder(a) | |||
assert isinstance(t_a, torch.Tensor) | |||
@@ -22,7 +22,7 @@ class TestTorchNumberPadder: | |||
@pytest.mark.torch | |||
class TestTorchSequencePadder: | |||
def test_run(self): | |||
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||
a = [[1, 2, 3], [3]] | |||
a = padder(a) | |||
shape = a.shape | |||
@@ -32,20 +32,20 @@ class TestTorchSequencePadder: | |||
assert (a == b).sum().item() == shape[0]*shape[1] | |||
def test_dtype_check(self): | |||
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||
with pytest.raises(DtypeError): | |||
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) | |||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | |||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.int8, dtype=None) | |||
a = padder([[1], [2, 322]]) | |||
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | |||
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(2).dtype, dtype=None) | |||
@pytest.mark.torch | |||
class TestTorchTensorPadder: | |||
def test_run(self): | |||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) | |||
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] | |||
a = padder(a) | |||
shape = a.shape | |||
@@ -74,7 +74,7 @@ class TestTorchTensorPadder: | |||
[[0, -1], [-1, -1], [-1, -1]]]) | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) | |||
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] | |||
a = padder(a) | |||
shape = a.shape | |||
@@ -85,7 +85,7 @@ class TestTorchTensorPadder: | |||
[[-1, -1], [-1, -1], [-1, -1]]]) | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) | |||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=None) | |||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | |||
a = padder(a) | |||
shape = a.shape | |||
@@ -97,11 +97,11 @@ class TestTorchTensorPadder: | |||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | |||
def test_dtype_check(self): | |||
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||
with pytest.raises(DtypeError): | |||
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) | |||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long) | |||