From d645e88c7f236fbc28c4d45292d8b05a335f2c93 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 1 May 2022 23:17:56 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=A2=9E=E5=8A=A0=E9=83=A8=E5=88=86padder?= =?UTF-8?q?=E7=9A=84=E6=B3=A8=E9=87=8A;2.=E4=BF=AE=E6=94=B9set=5Fbackend?= =?UTF-8?q?=E7=9A=84=E4=BD=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/__init__.py | 9 +- fastNLP/core/collators/padders/get_padder.py | 16 +-- .../core/collators/padders/numpy_padder.py | 29 ++++-- fastNLP/core/collators/padders/padder.py | 7 ++ fastNLP/core/collators/padders/raw_padder.py | 39 +++++--- .../core/collators/padders/torch_padder.py | 46 +++++---- fastNLP/core/collators/padders/torch_utils.py | 3 +- .../core/dataloaders/torch_dataloader/fdl.py | 12 ++- fastNLP/core/dataset/dataset.py | 11 ++- fastNLP/envs/__init__.py | 6 +- fastNLP/envs/imports.py | 2 +- fastNLP/envs/set_backend.py | 9 +- fastNLP/io/cws.py | 97 ------------------- .../collators/padders/test_numpy_padder.py | 20 ++-- .../core/collators/padders/test_raw_padder.py | 8 +- .../collators/padders/test_torch_padder.py | 28 +++--- 16 files changed, 152 insertions(+), 190 deletions(-) delete mode 100644 fastNLP/io/cws.py diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 6007ed07..b856faff 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -1,10 +1,3 @@ -# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 -from fastNLP.envs.set_env_on_import import set_env_on_import -set_env_on_import() - -# 再设置 backend 相关 -from fastNLP.envs.set_backend import _set_backend -_set_backend() - +from fastNLP.envs import * from fastNLP.core import Trainer, Evaluator \ No newline at end of file diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py index 051a0ffc..ecef9fcf 100644 --- a/fastNLP/core/collators/padders/get_padder.py +++ b/fastNLP/core/collators/padders/get_padder.py @@ -84,25 +84,25 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> try: if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] if backend == 'raw': - return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return RawNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'numpy': - return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'torch': - return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 if backend == 'raw': - return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return RawSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'numpy': - return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'torch': - return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) if depth == 1 and shape_len != 0: if backend == 'numpy': - return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) elif backend == 'torch': - return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) if shape_len != 0 and depth>1: msg = "Does not support pad tensor under nested list. If you need this, please report." diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py index 0298fd86..40b70683 100644 --- a/fastNLP/core/collators/padders/numpy_padder.py +++ b/fastNLP/core/collators/padders/numpy_padder.py @@ -1,6 +1,7 @@ __all__ = [ 'NumpyNumberPadder', 'NumpySequencePadder', + "NumpyTensorPadder" ] from numbers import Number @@ -14,7 +15,7 @@ from .exceptions import * def _get_dtype(ele_dtype, dtype, class_name): - if not is_number_or_numpy_number(ele_dtype): + if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " f"or numpy numbers but get `{ele_dtype}`.") @@ -29,7 +30,14 @@ def _get_dtype(ele_dtype, dtype, class_name): class NumpyNumberPadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) + + :param pad_val: 该值无意义 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 + :param dtype: 输出的数据的 dtype 是什么 + """ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -39,7 +47,14 @@ class NumpyNumberPadder(Padder): class NumpySequencePadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 + + :param pad_val: pad 的值是多少。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 + :param dtype: 输出的数据的 dtype 是什么 + """ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -49,13 +64,13 @@ class NumpySequencePadder(Padder): class NumpyTensorPadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): """ pad 类似于 [np.array([3, 4], np.array([1])] 的 field - :param ele_dtype: - :param pad_val: - :param dtype: + :param pad_val: pad 的值是多少。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 + :param dtype: 输出的数据的 dtype 是什么 """ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) diff --git a/fastNLP/core/collators/padders/padder.py b/fastNLP/core/collators/padders/padder.py index 486574af..065ce39f 100644 --- a/fastNLP/core/collators/padders/padder.py +++ b/fastNLP/core/collators/padders/padder.py @@ -14,6 +14,13 @@ class Padder: class NullPadder(Padder): def __init__(self, ele_dtype=None, pad_val=None, dtype=None): + """ + 不进行任何 检查 与 pad 的空 padder 。 + + :param ele_dtype: + :param pad_val: + :param dtype: + """ super().__init__(pad_val=pad_val, dtype=dtype) def __call__(self, batch_field): diff --git a/fastNLP/core/collators/padders/raw_padder.py b/fastNLP/core/collators/padders/raw_padder.py index 66393b40..fe0c0b14 100644 --- a/fastNLP/core/collators/padders/raw_padder.py +++ b/fastNLP/core/collators/padders/raw_padder.py @@ -1,25 +1,35 @@ from .padder import Padder -from .utils import get_padded_nest_list, is_number, get_padded_numpy_array +from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number from .exceptions import * def _get_dtype(ele_dtype, dtype, class_name): - if is_number(ele_dtype): - if dtype is None: - dtype = ele_dtype - elif not is_number(dtype): - raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " - f"get `{dtype}`.") - else: + if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " - f"but get `{ele_dtype}`.") + f"or numpy numbers but get `{ele_dtype}`.") + + if dtype is None: + dtype = ele_dtype + else: + if not is_number_or_numpy_number(dtype): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " + f"or numpy numbers but get `{dtype}`.") + dtype = dtype + return dtype class RawNumberPadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 + + :param pad_val: 该值无意义 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 + :param dtype: 输出的数据的 dtype 是什么 + """ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -32,7 +42,14 @@ class RawNumberPadder(Padder): class RawSequencePadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 + + :param pad_val: pad 的值 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 + :param dtype: 输出的数据的 dtype 是什么 + """ dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index a6768435..7c0eaa33 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -37,7 +37,7 @@ def is_torch_tensor(dtype): def _get_dtype(ele_dtype, dtype, class_name): - if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): + if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") @@ -47,20 +47,27 @@ def _get_dtype(ele_dtype, dtype, class_name): f"or torch.dtype but get `{dtype}`.") dtype = number_to_torch_dtype_dict.get(dtype, dtype) else: - if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): - ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) - dtype = ele_dtype - elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 - dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) - elif is_numpy_generic_class(ele_dtype): - dtype = numpy_to_torch_dtype_dict.get(ele_dtype) + if ele_dtype is not None: + if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): + ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) + dtype = ele_dtype + elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 + dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) + elif is_numpy_generic_class(ele_dtype): + dtype = numpy_to_torch_dtype_dict.get(ele_dtype) return dtype class TorchNumberPadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): - # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) + + :param pad_val: 该值无意义 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 + :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 + """ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -70,7 +77,14 @@ class TorchNumberPadder(Padder): class TorchSequencePadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): + """ + 将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 + + :param pad_val: 需要 pad 的值。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 + :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 + """ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -81,13 +95,13 @@ class TorchSequencePadder(Padder): class TorchTensorPadder(Padder): - def __init__(self, ele_dtype, pad_val=0, dtype=None): + def __init__(self, pad_val=0, ele_dtype=None, dtype=None): """ 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 - :param ele_dtype: - :param pad_val: - :param dtype: + :param pad_val: 需要 pad 的值。 + :param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 + :param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 """ dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) super().__init__(pad_val=pad_val, dtype=dtype) @@ -96,8 +110,6 @@ class TorchTensorPadder(Padder): def pad(batch_field, pad_val, dtype): shapes = [field.shape for field in batch_field] max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] - if isinstance(dtype, np.dtype): - print(dtype) tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) diff --git a/fastNLP/core/collators/padders/torch_utils.py b/fastNLP/core/collators/padders/torch_utils.py index a47bea0e..3f21333b 100644 --- a/fastNLP/core/collators/padders/torch_utils.py +++ b/fastNLP/core/collators/padders/torch_utils.py @@ -10,8 +10,7 @@ def is_torch_tensor_dtype(dtype) -> bool: """ 返回当前 dtype 是否是 torch 的 dtype 类型 - - :param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 + :param dtype: 类似与 torch.ones(3).dtype :return: """ try: diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 02721aaf..a21a7e9d 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -64,7 +64,7 @@ class TorchDataLoader(DataLoader): :param sampler: sampler实例化对象 :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 :param num_workers: 进程的数量,当num_worker=0时不开启多进程 - :param collate_fn: 对取得到的数据进行打包的callable函数 + :param collate_fn: 对取得到的数据进行打包的callable函数。[None, auto, callable] :param pin_memory: :param drop_last: 是否去掉最后一个不符合batch_size的数据 :param timeout: @@ -178,6 +178,16 @@ class TorchDataLoader(DataLoader): """ return self.cur_batch_indices + def set_pad(self): + pass + + def set_ignore(self): + pass + + def set_backend(self): + pass + + def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], batch_size: int = 1, diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index cd887253..a2e0ddae 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -760,8 +760,7 @@ class DataSet: dict_ = {key: value.content for key, value in self.field_arrays.items()} return pd.DataFrame.from_dict(dict_) - # TODO 应该有返回值的吧 - def to_csv(self, path: str) -> None: + def to_csv(self, path: str): """ 将dataset保存为csv文件 @@ -770,7 +769,7 @@ class DataSet: """ df = self.to_pandas() - df.to_csv(path, encoding="utf-8") + return df.to_csv(path, encoding="utf-8") def add_collate_fn(self, collate_fn: Callable) -> None: """ @@ -831,4 +830,8 @@ class DataSet: """ self.collate_fns.set_input(*field_names) - + @property + def collator(self): + if self._collator is None: + self._collator = Collator() + return self._collator diff --git a/fastNLP/envs/__init__.py b/fastNLP/envs/__init__.py index 8ec8f3a4..bc09c33b 100644 --- a/fastNLP/envs/__init__.py +++ b/fastNLP/envs/__init__.py @@ -14,7 +14,11 @@ __all__ = [ from .env import * from .set_env_on_import import set_env_on_import -from .set_backend import dump_fastnlp_backend +# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 +set_env_on_import() +from .set_backend import dump_fastnlp_backend, _set_backend +# 再设置 backend 相关 +_set_backend() from .imports import * from .utils import _module_available, get_gpu_count from .distributed import * diff --git a/fastNLP/envs/imports.py b/fastNLP/envs/imports.py index 3d49afe4..a2b63953 100644 --- a/fastNLP/envs/imports.py +++ b/fastNLP/envs/imports.py @@ -5,9 +5,9 @@ import operator from fastNLP.envs.env import FASTNLP_BACKEND from fastNLP.envs.utils import _module_available, _compare_version +from fastNLP.envs.set_backend import SUPPORT_BACKENDS -SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] backend = os.environ.get(FASTNLP_BACKEND, 'all') if backend == 'all': need_import = SUPPORT_BACKENDS diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index 4d6fb915..553dd389 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -1,7 +1,3 @@ -""" -这个文件用于自动以及手动设置某些环境变量的,该文件中的set_env()函数会在 fastNLP 被 import 的时候在set_env_on_import之后运行。可以 - 用于设置某些必要的环境变量。同时用户在使用时set_env()修改环境变量时,也应该保证set_env()函数在所有其它代码之前被运行。 -""" import os import json import sys @@ -9,9 +5,12 @@ from collections import defaultdict from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED -from fastNLP.envs.imports import SUPPORT_BACKENDS from fastNLP.envs.utils import _module_available, get_gpu_count + +SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] + + def _set_backend(): """ 根据环境变量或者默认配置文件设置 backend 。 diff --git a/fastNLP/io/cws.py b/fastNLP/io/cws.py deleted file mode 100644 index d88d6a00..00000000 --- a/fastNLP/io/cws.py +++ /dev/null @@ -1,97 +0,0 @@ -r"""undocumented""" - -__all__ = [ - "CWSLoader" -] - -import glob -import os -import random -import shutil -import time - -from .loader import Loader -from fastNLP.core.dataset import DataSet, Instance - - -class CWSLoader(Loader): - r""" - CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: - - Example:: - - 上海 浦东 开发 与 法制 建设 同步 - 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) - ... - - 该Loader读取后的DataSet具有如下的结构 - - .. csv-table:: - :header: "raw_words" - - "上海 浦东 开发 与 法制 建设 同步" - "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" - "..." - - """ - - def __init__(self, dataset_name: str = None): - r""" - - :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None - """ - super().__init__() - datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} - if dataset_name in datanames: - self.dataset_name = datanames[dataset_name] - else: - self.dataset_name = None - - def _load(self, path: str): - ds = DataSet() - with open(path, 'r', encoding='utf-8') as f: - for line in f: - line = line.strip() - if line: - ds.append(Instance(raw_words=line)) - return ds - - def download(self, dev_ratio=0.1, re_download=False) -> str: - r""" - 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, - 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 - - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str - """ - if self.dataset_name is None: - return '' - data_dir = self._get_dataset_path(dataset_name=self.dataset_name) - modify_time = 0 - for filepath in glob.glob(os.path.join(data_dir, '*')): - modify_time = os.stat(filepath).st_mtime - break - if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 - shutil.rmtree(data_dir) - data_dir = self._get_dataset_path(dataset_name=self.dataset_name) - - if not os.path.exists(os.path.join(data_dir, 'dev.txt')): - if dev_ratio > 0: - assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." - try: - with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ - open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ - open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: - for line in f: - if random.random() < dev_ratio: - f2.write(line) - else: - f1.write(line) - os.remove(os.path.join(data_dir, 'train.txt')) - os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) - finally: - if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): - os.remove(os.path.join(data_dir, 'middle_file.txt')) - - return data_dir diff --git a/tests/core/collators/padders/test_numpy_padder.py b/tests/core/collators/padders/test_numpy_padder.py index 0e93aa21..6c05e00b 100644 --- a/tests/core/collators/padders/test_numpy_padder.py +++ b/tests/core/collators/padders/test_numpy_padder.py @@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH class TestNumpyNumberPadder: def test_run(self): - padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = NumpyNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) a = [1, 2, 3] assert isinstance(padder(a), np.ndarray) assert (padder(a) == np.array(a)).sum() == 3 @@ -17,7 +17,7 @@ class TestNumpyNumberPadder: @pytest.mark.torch class TestNumpySequencePadder: def test_run(self): - padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = NumpySequencePadder(pad_val=-1, ele_dtype=int, dtype=int) a = [[1, 2, 3], [3]] a = padder(a) shape = np.shape(a) @@ -27,18 +27,18 @@ class TestNumpySequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = NumpySequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): - padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = NumpySequencePadder(pad_val=-1, ele_dtype=str, dtype=int) if _NEED_IMPORT_TORCH: import torch with pytest.raises(DtypeError): - padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + padder = NumpySequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) class TestNumpyTensorPadder: def test_run(self): - padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) + padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3).dtype, dtype=int) a = [np.zeros(3), np.zeros(2), np.zeros(0)] a = padder(a) shape = np.shape(a) @@ -68,15 +68,15 @@ class TestNumpyTensorPadder: assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] def test_dtype_check(self): - padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): - padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = NumpyTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) if _NEED_IMPORT_TORCH: import torch with pytest.raises(DtypeError): - padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + padder = NumpyTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) with pytest.raises(DtypeError): - padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) + padder = NumpyTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long) diff --git a/tests/core/collators/padders/test_raw_padder.py b/tests/core/collators/padders/test_raw_padder.py index 41a9de64..9742bc9a 100644 --- a/tests/core/collators/padders/test_raw_padder.py +++ b/tests/core/collators/padders/test_raw_padder.py @@ -7,14 +7,14 @@ from fastNLP.core.collators.padders.exceptions import DtypeError class TestRawNumberPadder: def test_run(self): - padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = RawNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) a = [1, 2, 3] assert padder(a) == a class TestRawSequencePadder: def test_run(self): - padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = RawSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) a = [[1, 2, 3], [3]] a = padder(a) shape = np.shape(a) @@ -24,6 +24,6 @@ class TestRawSequencePadder: def test_dtype_check(self): with pytest.raises(DtypeError): - padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): - padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) \ No newline at end of file + padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) \ No newline at end of file diff --git a/tests/core/collators/padders/test_torch_padder.py b/tests/core/collators/padders/test_torch_padder.py index 0bb363ba..53bd31e3 100644 --- a/tests/core/collators/padders/test_torch_padder.py +++ b/tests/core/collators/padders/test_torch_padder.py @@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: @pytest.mark.torch class TestTorchNumberPadder: def test_run(self): - padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = TorchNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) a = [1, 2, 3] t_a = padder(a) assert isinstance(t_a, torch.Tensor) @@ -22,7 +22,7 @@ class TestTorchNumberPadder: @pytest.mark.torch class TestTorchSequencePadder: def test_run(self): - padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + padder = TorchSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) a = [[1, 2, 3], [3]] a = padder(a) shape = a.shape @@ -32,20 +32,20 @@ class TestTorchSequencePadder: assert (a == b).sum().item() == shape[0]*shape[1] def test_dtype_check(self): - padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): - padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) - padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) + padder = TorchSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) + padder = TorchSequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) + padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.int8, dtype=None) a = padder([[1], [2, 322]]) assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 - padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) + padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(2).dtype, dtype=None) @pytest.mark.torch class TestTorchTensorPadder: def test_run(self): - padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) + padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] a = padder(a) shape = a.shape @@ -74,7 +74,7 @@ class TestTorchTensorPadder: [[0, -1], [-1, -1], [-1, -1]]]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) + padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] a = padder(a) shape = a.shape @@ -85,7 +85,7 @@ class TestTorchTensorPadder: [[-1, -1], [-1, -1], [-1, -1]]]) assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] - padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) + padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=None) a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] a = padder(a) shape = a.shape @@ -97,11 +97,11 @@ class TestTorchTensorPadder: assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] def test_dtype_check(self): - padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + padder = TorchTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) with pytest.raises(DtypeError): - padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) - padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) - padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) + padder = TorchTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) + padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) + padder = TorchTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long)