| @@ -92,8 +92,8 @@ __all__ = [ | |||
| "cache_results", | |||
| "f_rich_progress", | |||
| "auto_param_call", | |||
| "seq_len_to_mask", | |||
| "f_tqdm_progress", | |||
| "seq_len_to_mask", | |||
| # vocabulary.py | |||
| 'Vocabulary' | |||
| @@ -52,7 +52,7 @@ class JittorDataLoader: | |||
| """ | |||
| def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, | |||
| def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||
| drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | |||
| stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||
| collate_fn: Union[None, str, Callable] = "auto") -> None: | |||
| @@ -194,7 +194,7 @@ class JittorDataLoader: | |||
| return self.cur_batch_indices | |||
| def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = True, | |||
| def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False, | |||
| drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | |||
| stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||
| collate_fn: Union[None, str, Callable] = "auto", | |||
| @@ -81,7 +81,7 @@ class PaddleDataLoader(DataLoader): | |||
| def __init__(self, dataset, feed_list=None, places=None, | |||
| return_list: bool = True, batch_sampler=None, | |||
| batch_size: int = 16, shuffle: bool = True, | |||
| batch_size: int = 16, shuffle: bool = False, | |||
| drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | |||
| num_workers: int = 0, use_buffer_reader: bool = True, | |||
| use_shared_memory: bool = True, timeout: int = 0, | |||
| @@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | |||
| from ..log import logger | |||
| def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, | |||
| def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | |||
| collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | |||
| seed: int = 0, backend: str = 'auto'): | |||
| """ | |||
| @@ -10,7 +10,7 @@ import numpy as np | |||
| from fastNLP.core.metrics.metric import Metric | |||
| from fastNLP.core.metrics.backend import Backend | |||
| from fastNLP.core.utils.utils import seq_len_to_mask | |||
| from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask | |||
| class Accuracy(Metric): | |||
| @@ -10,7 +10,7 @@ import numpy as np | |||
| from .metric import Metric | |||
| from .backend import Backend | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.core.utils.utils import seq_len_to_mask | |||
| from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask | |||
| from .utils import _compute_f_pre_rec | |||
| @@ -21,9 +21,10 @@ __all__ = [ | |||
| 'pretty_table_printer', | |||
| 'Option', | |||
| 'deprecated', | |||
| 'seq_len_to_mask', | |||
| "flat_nest_dict", | |||
| "f_tqdm_progress" | |||
| "f_tqdm_progress", | |||
| "seq_len_to_mask" | |||
| ] | |||
| from .cache_results import cache_results | |||
| @@ -34,5 +35,6 @@ from .rich_progress import f_rich_progress | |||
| from .torch_utils import torch_move_data_to_device | |||
| from .utils import * | |||
| from .tqdm_progress import f_tqdm_progress | |||
| from .seq_len_to_mask import seq_len_to_mask | |||
| @@ -20,6 +20,7 @@ if _NEED_IMPORT_PADDLE: | |||
| from .utils import apply_to_collection | |||
| def _convert_data_device(device: Union[str, int]) -> str: | |||
| """ | |||
| 用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将 | |||
| @@ -59,7 +60,9 @@ def _convert_data_device(device: Union[str, int]) -> str: | |||
| raise ValueError(f"Can't convert device {device} when USER_CUDA_VISIBLE_DEVICES={user_visible_devices} " | |||
| "and CUDA_VISIBLE_DEVICES={cuda_visible_devices}. If this situation happens, please report this bug to us.") | |||
| def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor": | |||
| def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_avx.Place', | |||
| 'paddle.CPUPlace', 'paddle.CUDAPlace']) -> "paddle.Tensor": | |||
| """ | |||
| 将 ``data`` 迁移到指定的 ``device`` 上。``paddle.Tensor`` 没有类似 ``torch.Tensor`` 的 ``to`` 函数, | |||
| 该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 | |||
| @@ -68,12 +71,21 @@ def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor" | |||
| :param device: 目标设备,可以是 ``str`` 或 ``int`` 类型; | |||
| :return: 迁移后的张量; | |||
| """ | |||
| if device == "cpu": | |||
| if isinstance(device, paddle.fluid.core_avx.Place): | |||
| if device.is_cpu_place(): | |||
| return data.cpu() | |||
| else: | |||
| return data.cuda(device.gpu_device_id()) | |||
| elif isinstance(device, paddle.CPUPlace): | |||
| return data.cpu() | |||
| elif isinstance(device, paddle.CUDAPlace): | |||
| return data.gpu(device.get_device_id()) | |||
| elif device == "cpu": | |||
| return data.cpu() | |||
| else: | |||
| return data.cuda(get_paddle_device_id(device)) | |||
| def get_paddle_gpu_str(device: Union[str, int]) -> str: | |||
| """ | |||
| 获得 ``gpu:x`` 格式的设备名:: | |||
| @@ -0,0 +1,84 @@ | |||
| from typing import Optional | |||
| import numpy as np | |||
| from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE | |||
| from .paddle_utils import paddle_to | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor | |||
| def seq_len_to_mask(seq_len, max_len: Optional[int]=None): | |||
| r""" | |||
| 将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。 | |||
| .. code-block:: | |||
| >>> seq_len = torch.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len) | |||
| >>> print(mask.size()) | |||
| torch.Size([14, 15]) | |||
| >>> seq_len = np.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len) | |||
| >>> print(mask.shape) | |||
| (14, 15) | |||
| >>> seq_len = torch.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len, max_len=100) | |||
| >>>print(mask.size()) | |||
| torch.Size([14, 100]) | |||
| :param seq_len: 大小为 ``(B,)`` 的长度序列; | |||
| :param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度; | |||
| 但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 | |||
| ``max_len`` 使得 ``mask`` 的补齐或截断到该长度。 | |||
| :return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` | |||
| """ | |||
| max_len = int(max_len) if max_len is not None else int(seq_len.max()) | |||
| if isinstance(seq_len, np.ndarray): | |||
| assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim}." | |||
| broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
| mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
| return mask | |||
| try: # 尝试是否是 torch | |||
| if isinstance(seq_len, torch.Tensor): | |||
| assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." | |||
| batch_size = seq_len.shape[0] | |||
| broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) | |||
| mask = broad_cast_seq_len < seq_len.unsqueeze(1) | |||
| return mask | |||
| except NameError as e: | |||
| pass | |||
| try: | |||
| if isinstance(seq_len, paddle.Tensor): | |||
| assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." | |||
| batch_size = seq_len.shape[0] | |||
| broad_cast_seq_len = paddle.arange(max_len).expand((batch_size, -1)) | |||
| broad_cast_seq_len = paddle_to(broad_cast_seq_len, device=seq_len.place) | |||
| mask = broad_cast_seq_len < seq_len.unsqueeze(1) | |||
| return mask | |||
| except NameError as e: | |||
| pass | |||
| try: | |||
| if isinstance(seq_len, jittor.Var): | |||
| assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." | |||
| batch_size = seq_len.shape[0] | |||
| broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) | |||
| mask = broad_cast_seq_len < seq_len.unsqueeze(1) | |||
| return mask | |||
| except NameError as e: | |||
| pass | |||
| raise TypeError("seq_len_to_mask function only supports numpy.ndarray, torch.Tensor, paddle.Tensor, " | |||
| f"and jittor.Var, but got {type(seq_len)}") | |||
| @@ -14,7 +14,6 @@ import os | |||
| from contextlib import contextmanager | |||
| from functools import wraps | |||
| from prettytable import PrettyTable | |||
| import numpy as np | |||
| from pathlib import Path | |||
| from fastNLP.core.log import logger | |||
| @@ -31,7 +30,6 @@ __all__ = [ | |||
| 'pretty_table_printer', | |||
| 'Option', | |||
| 'deprecated', | |||
| 'seq_len_to_mask', | |||
| "flat_nest_dict" | |||
| ] | |||
| @@ -567,44 +565,6 @@ def deprecated(help_message: Optional[str] = None): | |||
| return decorator | |||
| def seq_len_to_mask(seq_len, max_len: Optional[int]): | |||
| r""" | |||
| 将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。 | |||
| .. code-block:: | |||
| >>> seq_len = torch.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len) | |||
| >>> print(mask.size()) | |||
| torch.Size([14, 15]) | |||
| >>> seq_len = np.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len) | |||
| >>> print(mask.shape) | |||
| (14, 15) | |||
| >>> seq_len = torch.arange(2, 16) | |||
| >>> mask = seq_len_to_mask(seq_len, max_len=100) | |||
| >>>print(mask.size()) | |||
| torch.Size([14, 100]) | |||
| :param seq_len: 大小为 ``(B,)`` 的长度序列; | |||
| :param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度; | |||
| 但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 | |||
| ``max_len`` 使得 ``mask`` 的补齐或截断到该长度。 | |||
| :return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` | |||
| """ | |||
| if isinstance(seq_len, np.ndarray): | |||
| assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
| max_len = int(max_len) if max_len else int(seq_len.max()) | |||
| broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
| mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
| else: | |||
| raise TypeError("Only support 1-d numpy.ndarray.") | |||
| return mask | |||
| def wait_filepath(path, exist=True): | |||
| """ | |||
| 等待当 path 的存在状态为 {exist} 时返回 | |||
| @@ -109,6 +109,9 @@ __all__ = [ | |||
| "CMRC2018BertPipe", | |||
| "iob2", | |||
| "iob2bioes" | |||
| ] | |||
| from .data_bundle import DataBundle | |||
| @@ -62,13 +62,16 @@ __all__ = [ | |||
| "R8PmiGraphPipe", | |||
| "OhsumedPmiGraphPipe", | |||
| "NG20PmiGraphPipe", | |||
| "MRPmiGraphPipe" | |||
| "MRPmiGraphPipe", | |||
| "iob2", | |||
| "iob2bioes" | |||
| ] | |||
| from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ | |||
| WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe | |||
| from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | |||
| from .conll import Conll2003Pipe | |||
| from .conll import Conll2003Pipe, iob2, iob2bioes | |||
| from .coreference import CoReferencePipe | |||
| from .cws import CWSPipe | |||
| from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
| @@ -0,0 +1,102 @@ | |||
| import pytest | |||
| import numpy as np | |||
| from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask | |||
| from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| if _NEED_IMPORT_PADDLE: | |||
| import paddle | |||
| if _NEED_IMPORT_JITTOR: | |||
| import jittor | |||
| class TestSeqLenToMask: | |||
| def evaluate_mask_seq_len(self, seq_len, mask): | |||
| max_len = int(max(seq_len)) | |||
| for i in range(len(seq_len)): | |||
| length = seq_len[i] | |||
| mask_i = mask[i] | |||
| for j in range(max_len): | |||
| assert mask_i[j] == (j<length), (i, j, length) | |||
| def test_numpy_seq_len(self): | |||
| # 测试能否转换numpy类型的seq_len | |||
| # 1. 随机测试 | |||
| seq_len = np.random.randint(1, 10, size=(2, )) | |||
| mask = seq_len_to_mask(seq_len) | |||
| max_len = seq_len.max() | |||
| assert max_len == mask.shape[1] | |||
| print(mask) | |||
| print(seq_len) | |||
| self.evaluate_mask_seq_len(seq_len, mask) | |||
| # 2. 异常检测 | |||
| seq_len = np.random.randint(10, size=(10, 1)) | |||
| with pytest.raises(AssertionError): | |||
| mask = seq_len_to_mask(seq_len) | |||
| # 3. pad到指定长度 | |||
| seq_len = np.random.randint(1, 10, size=(10,)) | |||
| mask = seq_len_to_mask(seq_len, 100) | |||
| assert 100 == mask.shape[1] | |||
| @pytest.mark.torch | |||
| def test_pytorch_seq_len(self): | |||
| # 1. 随机测试 | |||
| seq_len = torch.randint(1, 10, size=(10, )) | |||
| max_len = seq_len.max() | |||
| mask = seq_len_to_mask(seq_len) | |||
| assert max_len == mask.shape[1] | |||
| self.evaluate_mask_seq_len(seq_len.tolist(), mask) | |||
| # 2. 异常检测 | |||
| seq_len = torch.randn(3, 4) | |||
| with pytest.raises(AssertionError): | |||
| mask = seq_len_to_mask(seq_len) | |||
| # 3. pad到指定长度 | |||
| seq_len = torch.randint(1, 10, size=(10, )) | |||
| mask = seq_len_to_mask(seq_len, 100) | |||
| assert 100 == mask.size(1) | |||
| @pytest.mark.paddle | |||
| def test_paddle_seq_len(self): | |||
| # 1. 随机测试 | |||
| seq_len = paddle.randint(1, 10, shape=(10,)) | |||
| max_len = seq_len.max() | |||
| mask = seq_len_to_mask(seq_len) | |||
| assert max_len == mask.shape[1] | |||
| self.evaluate_mask_seq_len(seq_len.tolist(), mask) | |||
| # 2. 异常检测 | |||
| seq_len = paddle.randn((3, 4)) | |||
| with pytest.raises(AssertionError): | |||
| mask = seq_len_to_mask(seq_len) | |||
| # 3. pad到指定长度 | |||
| seq_len = paddle.randint(1, 10, size=(10,)) | |||
| mask = seq_len_to_mask(seq_len, 100) | |||
| assert 100 == mask.shape[1] | |||
| @pytest.mark.jittor | |||
| def test_jittor_seq_len(self): | |||
| # 1. 随机测试 | |||
| seq_len = jittor.randint(1, 10, shape=(10,)) | |||
| max_len = seq_len.max() | |||
| mask = seq_len_to_mask(seq_len) | |||
| assert max_len == mask.shape[1] | |||
| self.evaluate_mask_seq_len(seq_len.tolist(), mask) | |||
| # 2. 异常检测 | |||
| seq_len = jittor.randn(3, 4) | |||
| with pytest.raises(AssertionError): | |||
| mask = seq_len_to_mask(seq_len) | |||
| # 3. pad到指定长度 | |||
| seq_len = jittor.randint(1, 10, shape=(10,)) | |||
| mask = seq_len_to_mask(seq_len, 100) | |||
| assert 100 == mask.shape[1] | |||
| @@ -193,3 +193,17 @@ class TestRandomSameEntry: | |||
| embed_0 = words[0, 3] | |||
| for i in range(3, 5): | |||
| assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) | |||
| def test_dropout_close(self): | |||
| vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"]) | |||
| embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True, | |||
| dropout=0.5, word_dropout=0.9) | |||
| words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]]) | |||
| embed.eval() | |||
| words = embed(words) | |||
| embed_0 = words[0, 0] | |||
| for i in range(1, 3): | |||
| assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) | |||
| embed_0 = words[0, 3] | |||
| for i in range(3, 5): | |||
| assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0)) | |||