@@ -1,10 +1,3 @@ | |||||
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||||
from fastNLP.envs.set_env_on_import import set_env_on_import | |||||
set_env_on_import() | |||||
# 再设置 backend 相关 | |||||
from fastNLP.envs.set_backend import _set_backend | |||||
_set_backend() | |||||
from fastNLP.envs import * | |||||
from fastNLP.core import Trainer, Evaluator | from fastNLP.core import Trainer, Evaluator |
@@ -84,25 +84,25 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
try: | try: | ||||
if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] | if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] | ||||
if backend == 'raw': | if backend == 'raw': | ||||
return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return RawNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'numpy': | elif backend == 'numpy': | ||||
return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return NumpyNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 | ||||
if backend == 'raw': | if backend == 'raw': | ||||
return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return RawSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'numpy': | elif backend == 'numpy': | ||||
return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return NumpySequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if depth == 1 and shape_len != 0: | if depth == 1 and shape_len != 0: | ||||
if backend == 'numpy': | if backend == 'numpy': | ||||
return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return NumpyTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
elif backend == 'torch': | elif backend == 'torch': | ||||
return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) | |||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
if shape_len != 0 and depth>1: | if shape_len != 0 and depth>1: | ||||
msg = "Does not support pad tensor under nested list. If you need this, please report." | msg = "Does not support pad tensor under nested list. If you need this, please report." | ||||
@@ -1,6 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'NumpyNumberPadder', | 'NumpyNumberPadder', | ||||
'NumpySequencePadder', | 'NumpySequencePadder', | ||||
"NumpyTensorPadder" | |||||
] | ] | ||||
from numbers import Number | from numbers import Number | ||||
@@ -14,7 +15,7 @@ from .exceptions import * | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if not is_number_or_numpy_number(ele_dtype): | |||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers but get `{ele_dtype}`.") | f"or numpy numbers but get `{ele_dtype}`.") | ||||
@@ -29,7 +30,14 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
class NumpyNumberPadder(Padder): | class NumpyNumberPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 np.array([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -39,7 +47,14 @@ class NumpyNumberPadder(Padder): | |||||
class NumpySequencePadder(Padder): | class NumpySequencePadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 np.array([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -49,13 +64,13 @@ class NumpySequencePadder(Padder): | |||||
class NumpyTensorPadder(Padder): | class NumpyTensorPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | """ | ||||
pad 类似于 [np.array([3, 4], np.array([1])] 的 field | pad 类似于 [np.array([3, 4], np.array([1])] 的 field | ||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
:param pad_val: pad 的值是多少。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | """ | ||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -14,6 +14,13 @@ class Padder: | |||||
class NullPadder(Padder): | class NullPadder(Padder): | ||||
def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | def __init__(self, ele_dtype=None, pad_val=None, dtype=None): | ||||
""" | |||||
不进行任何 检查 与 pad 的空 padder 。 | |||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
""" | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
def __call__(self, batch_field): | def __call__(self, batch_field): | ||||
@@ -1,25 +1,35 @@ | |||||
from .padder import Padder | from .padder import Padder | ||||
from .utils import get_padded_nest_list, is_number, get_padded_numpy_array | |||||
from .utils import is_number, get_padded_numpy_array, is_number_or_numpy_number | |||||
from .exceptions import * | from .exceptions import * | ||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if is_number(ele_dtype): | |||||
if dtype is None: | |||||
dtype = ele_dtype | |||||
elif not is_number(dtype): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " | |||||
f"get `{dtype}`.") | |||||
else: | |||||
if ele_dtype is not None and not is_number_or_numpy_number(ele_dtype): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"but get `{ele_dtype}`.") | |||||
f"or numpy numbers but get `{ele_dtype}`.") | |||||
if dtype is None: | |||||
dtype = ele_dtype | |||||
else: | |||||
if not is_number_or_numpy_number(dtype): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or numpy numbers but get `{dtype}`.") | |||||
dtype = dtype | |||||
return dtype | return dtype | ||||
class RawNumberPadder(Padder): | class RawNumberPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 [1, 2, 3] 。实际上该 padder 无意义。 | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -32,7 +42,14 @@ class RawNumberPadder(Padder): | |||||
class RawSequencePadder(Padder): | class RawSequencePadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 [[1, 0], [1, 2]] 。可以 pad 多重嵌套的数据。 | |||||
:param pad_val: pad 的值 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 np.array 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -37,7 +37,7 @@ def is_torch_tensor(dtype): | |||||
def _get_dtype(ele_dtype, dtype, class_name): | def _get_dtype(ele_dtype, dtype, class_name): | ||||
if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
if not (ele_dtype is not None and (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype))): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | ||||
f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") | ||||
@@ -47,20 +47,27 @@ def _get_dtype(ele_dtype, dtype, class_name): | |||||
f"or torch.dtype but get `{dtype}`.") | f"or torch.dtype but get `{dtype}`.") | ||||
dtype = number_to_torch_dtype_dict.get(dtype, dtype) | dtype = number_to_torch_dtype_dict.get(dtype, dtype) | ||||
else: | else: | ||||
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||||
if ele_dtype is not None: | |||||
if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): | |||||
ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) | |||||
dtype = ele_dtype | |||||
elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) | |||||
elif is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_torch_dtype_dict.get(ele_dtype) | |||||
return dtype | return dtype | ||||
class TorchNumberPadder(Padder): | class TorchNumberPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
# 仅当 ele_dtype 是 python number/ numpy number 或者 tensor | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 torch.Tensor([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -70,7 +77,14 @@ class TorchNumberPadder(Padder): | |||||
class TorchSequencePadder(Padder): | class TorchSequencePadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 torch.Tensor([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -81,13 +95,13 @@ class TorchSequencePadder(Padder): | |||||
class TorchTensorPadder(Padder): | class TorchTensorPadder(Padder): | ||||
def __init__(self, ele_dtype, pad_val=0, dtype=None): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | """ | ||||
目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 | ||||
:param ele_dtype: | |||||
:param pad_val: | |||||
:param dtype: | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 torch.tensor 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 torch.long, torch.float32, int, float 等 | |||||
""" | """ | ||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | ||||
super().__init__(pad_val=pad_val, dtype=dtype) | super().__init__(pad_val=pad_val, dtype=dtype) | ||||
@@ -96,8 +110,6 @@ class TorchTensorPadder(Padder): | |||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
shapes = [field.shape for field in batch_field] | shapes = [field.shape for field in batch_field] | ||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | ||||
if isinstance(dtype, np.dtype): | |||||
print(dtype) | |||||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
@@ -10,8 +10,7 @@ def is_torch_tensor_dtype(dtype) -> bool: | |||||
""" | """ | ||||
返回当前 dtype 是否是 torch 的 dtype 类型 | 返回当前 dtype 是否是 torch 的 dtype 类型 | ||||
:param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 | |||||
:param dtype: 类似与 torch.ones(3).dtype | |||||
:return: | :return: | ||||
""" | """ | ||||
try: | try: | ||||
@@ -64,7 +64,7 @@ class TorchDataLoader(DataLoader): | |||||
:param sampler: sampler实例化对象 | :param sampler: sampler实例化对象 | ||||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | :param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | ||||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | :param num_workers: 进程的数量,当num_worker=0时不开启多进程 | ||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数。[None, auto, callable] | |||||
:param pin_memory: | :param pin_memory: | ||||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | :param drop_last: 是否去掉最后一个不符合batch_size的数据 | ||||
:param timeout: | :param timeout: | ||||
@@ -178,6 +178,16 @@ class TorchDataLoader(DataLoader): | |||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def set_pad(self): | |||||
pass | |||||
def set_ignore(self): | |||||
pass | |||||
def set_backend(self): | |||||
pass | |||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 1, | batch_size: int = 1, | ||||
@@ -760,8 +760,7 @@ class DataSet: | |||||
dict_ = {key: value.content for key, value in self.field_arrays.items()} | dict_ = {key: value.content for key, value in self.field_arrays.items()} | ||||
return pd.DataFrame.from_dict(dict_) | return pd.DataFrame.from_dict(dict_) | ||||
# TODO 应该有返回值的吧 | |||||
def to_csv(self, path: str) -> None: | |||||
def to_csv(self, path: str): | |||||
""" | """ | ||||
将dataset保存为csv文件 | 将dataset保存为csv文件 | ||||
@@ -770,7 +769,7 @@ class DataSet: | |||||
""" | """ | ||||
df = self.to_pandas() | df = self.to_pandas() | ||||
df.to_csv(path, encoding="utf-8") | |||||
return df.to_csv(path, encoding="utf-8") | |||||
def add_collate_fn(self, collate_fn: Callable) -> None: | def add_collate_fn(self, collate_fn: Callable) -> None: | ||||
""" | """ | ||||
@@ -831,4 +830,8 @@ class DataSet: | |||||
""" | """ | ||||
self.collate_fns.set_input(*field_names) | self.collate_fns.set_input(*field_names) | ||||
@property | |||||
def collator(self): | |||||
if self._collator is None: | |||||
self._collator = Collator() | |||||
return self._collator |
@@ -14,7 +14,11 @@ __all__ = [ | |||||
from .env import * | from .env import * | ||||
from .set_env_on_import import set_env_on_import | from .set_env_on_import import set_env_on_import | ||||
from .set_backend import dump_fastnlp_backend | |||||
# 首先保证 FASTNLP_GLOBAL_RANK 正确设置 | |||||
set_env_on_import() | |||||
from .set_backend import dump_fastnlp_backend, _set_backend | |||||
# 再设置 backend 相关 | |||||
_set_backend() | |||||
from .imports import * | from .imports import * | ||||
from .utils import _module_available, get_gpu_count | from .utils import _module_available, get_gpu_count | ||||
from .distributed import * | from .distributed import * |
@@ -5,9 +5,9 @@ import operator | |||||
from fastNLP.envs.env import FASTNLP_BACKEND | from fastNLP.envs.env import FASTNLP_BACKEND | ||||
from fastNLP.envs.utils import _module_available, _compare_version | from fastNLP.envs.utils import _module_available, _compare_version | ||||
from fastNLP.envs.set_backend import SUPPORT_BACKENDS | |||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||||
backend = os.environ.get(FASTNLP_BACKEND, 'all') | backend = os.environ.get(FASTNLP_BACKEND, 'all') | ||||
if backend == 'all': | if backend == 'all': | ||||
need_import = SUPPORT_BACKENDS | need_import = SUPPORT_BACKENDS | ||||
@@ -1,7 +1,3 @@ | |||||
""" | |||||
这个文件用于自动以及手动设置某些环境变量的,该文件中的set_env()函数会在 fastNLP 被 import 的时候在set_env_on_import之后运行。可以 | |||||
用于设置某些必要的环境变量。同时用户在使用时set_env()修改环境变量时,也应该保证set_env()函数在所有其它代码之前被运行。 | |||||
""" | |||||
import os | import os | ||||
import json | import json | ||||
import sys | import sys | ||||
@@ -9,9 +5,12 @@ from collections import defaultdict | |||||
from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | from fastNLP.envs.env import FASTNLP_BACKEND, FASTNLP_GLOBAL_RANK, USER_CUDA_VISIBLE_DEVICES, FASTNLP_GLOBAL_SEED | ||||
from fastNLP.envs.imports import SUPPORT_BACKENDS | |||||
from fastNLP.envs.utils import _module_available, get_gpu_count | from fastNLP.envs.utils import _module_available, get_gpu_count | ||||
SUPPORT_BACKENDS = ['torch', 'paddle', 'jittor'] | |||||
def _set_backend(): | def _set_backend(): | ||||
""" | """ | ||||
根据环境变量或者默认配置文件设置 backend 。 | 根据环境变量或者默认配置文件设置 backend 。 | ||||
@@ -1,97 +0,0 @@ | |||||
r"""undocumented""" | |||||
__all__ = [ | |||||
"CWSLoader" | |||||
] | |||||
import glob | |||||
import os | |||||
import random | |||||
import shutil | |||||
import time | |||||
from .loader import Loader | |||||
from fastNLP.core.dataset import DataSet, Instance | |||||
class CWSLoader(Loader): | |||||
r""" | |||||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||||
Example:: | |||||
上海 浦东 开发 与 法制 建设 同步 | |||||
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) | |||||
... | |||||
该Loader读取后的DataSet具有如下的结构 | |||||
.. csv-table:: | |||||
:header: "raw_words" | |||||
"上海 浦东 开发 与 法制 建设 同步" | |||||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||||
"..." | |||||
""" | |||||
def __init__(self, dataset_name: str = None): | |||||
r""" | |||||
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||||
""" | |||||
super().__init__() | |||||
datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} | |||||
if dataset_name in datanames: | |||||
self.dataset_name = datanames[dataset_name] | |||||
else: | |||||
self.dataset_name = None | |||||
def _load(self, path: str): | |||||
ds = DataSet() | |||||
with open(path, 'r', encoding='utf-8') as f: | |||||
for line in f: | |||||
line = line.strip() | |||||
if line: | |||||
ds.append(Instance(raw_words=line)) | |||||
return ds | |||||
def download(self, dev_ratio=0.1, re_download=False) -> str: | |||||
r""" | |||||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||||
:return: str | |||||
""" | |||||
if self.dataset_name is None: | |||||
return '' | |||||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||||
modify_time = 0 | |||||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||||
modify_time = os.stat(filepath).st_mtime | |||||
break | |||||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||||
shutil.rmtree(data_dir) | |||||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||||
if dev_ratio > 0: | |||||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||||
try: | |||||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||||
for line in f: | |||||
if random.random() < dev_ratio: | |||||
f2.write(line) | |||||
else: | |||||
f1.write(line) | |||||
os.remove(os.path.join(data_dir, 'train.txt')) | |||||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||||
finally: | |||||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||||
return data_dir |
@@ -8,7 +8,7 @@ from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
class TestNumpyNumberPadder: | class TestNumpyNumberPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = NumpyNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
assert isinstance(padder(a), np.ndarray) | assert isinstance(padder(a), np.ndarray) | ||||
assert (padder(a) == np.array(a)).sum() == 3 | assert (padder(a) == np.array(a)).sum() == 3 | ||||
@@ -17,7 +17,7 @@ class TestNumpyNumberPadder: | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestNumpySequencePadder: | class TestNumpySequencePadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = np.shape(a) | shape = np.shape(a) | ||||
@@ -27,18 +27,18 @@ class TestNumpySequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = NumpySequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
class TestNumpyTensorPadder: | class TestNumpyTensorPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3).dtype, dtype=int) | |||||
a = [np.zeros(3), np.zeros(2), np.zeros(0)] | a = [np.zeros(3), np.zeros(2), np.zeros(0)] | ||||
a = padder(a) | a = padder(a) | ||||
shape = np.shape(a) | shape = np.shape(a) | ||||
@@ -68,15 +68,15 @@ class TestNumpyTensorPadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||||
padder = NumpyTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long) | |||||
@@ -7,14 +7,14 @@ from fastNLP.core.collators.padders.exceptions import DtypeError | |||||
class TestRawNumberPadder: | class TestRawNumberPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = RawNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
assert padder(a) == a | assert padder(a) == a | ||||
class TestRawSequencePadder: | class TestRawSequencePadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = np.shape(a) | shape = np.shape(a) | ||||
@@ -24,6 +24,6 @@ class TestRawSequencePadder: | |||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = RawSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) |
@@ -12,7 +12,7 @@ if _NEED_IMPORT_TORCH: | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestTorchNumberPadder: | class TestTorchNumberPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = TorchNumberPadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [1, 2, 3] | a = [1, 2, 3] | ||||
t_a = padder(a) | t_a = padder(a) | ||||
assert isinstance(t_a, torch.Tensor) | assert isinstance(t_a, torch.Tensor) | ||||
@@ -22,7 +22,7 @@ class TestTorchNumberPadder: | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestTorchSequencePadder: | class TestTorchSequencePadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=int, dtype=int) | |||||
a = [[1, 2, 3], [3]] | a = [[1, 2, 3], [3]] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -32,20 +32,20 @@ class TestTorchSequencePadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1] | assert (a == b).sum().item() == shape[0]*shape[1] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.int8, dtype=None) | |||||
a = padder([[1], [2, 322]]) | a = padder([[1], [2, 322]]) | ||||
assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 | ||||
padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) | |||||
padder = TorchSequencePadder(pad_val=-1, ele_dtype=np.zeros(2).dtype, dtype=None) | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
class TestTorchTensorPadder: | class TestTorchTensorPadder: | ||||
def test_run(self): | def test_run(self): | ||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) | |||||
a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] | a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -74,7 +74,7 @@ class TestTorchTensorPadder: | |||||
[[0, -1], [-1, -1], [-1, -1]]]) | [[0, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=int) | |||||
a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] | a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -85,7 +85,7 @@ class TestTorchTensorPadder: | |||||
[[-1, -1], [-1, -1], [-1, -1]]]) | [[-1, -1], [-1, -1], [-1, -1]]]) | ||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.zeros(3).dtype, dtype=None) | |||||
a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] | ||||
a = padder(a) | a = padder(a) | ||||
shape = a.shape | shape = a.shape | ||||
@@ -97,11 +97,11 @@ class TestTorchTensorPadder: | |||||
assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] | ||||
def test_dtype_check(self): | def test_dtype_check(self): | ||||
padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int) | |||||
with pytest.raises(DtypeError): | with pytest.raises(DtypeError): | ||||
padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) | |||||
padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=str, dtype=int) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=torch.long, dtype=int) | |||||
padder = TorchTensorPadder(pad_val=-1, ele_dtype=int, dtype=torch.long) | |||||