From dcdd484eb2ab7363df170cb5811d46f880d49344 Mon Sep 17 00:00:00 2001 From: yh Date: Thu, 26 May 2022 00:05:36 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0seq=5Flen=5Fto=5Fmask?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 2 +- .../core/dataloaders/jittor_dataloader/fdl.py | 4 +- .../core/dataloaders/paddle_dataloader/fdl.py | 2 +- .../core/dataloaders/prepare_dataloader.py | 2 +- fastNLP/core/metrics/accuracy.py | 2 +- .../metrics/classify_f1_pre_rec_metric.py | 2 +- fastNLP/core/utils/__init__.py | 6 +- fastNLP/core/utils/paddle_utils.py | 18 +++- fastNLP/core/utils/seq_len_to_mask.py | 84 +++++++++++++++ fastNLP/core/utils/utils.py | 40 ------- fastNLP/io/__init__.py | 3 + fastNLP/io/pipe/__init__.py | 7 +- tests/core/utils/test_seq_len_to_mask.py | 102 ++++++++++++++++++ .../embeddings/torch/test_static_embedding.py | 14 +++ 14 files changed, 234 insertions(+), 54 deletions(-) create mode 100644 fastNLP/core/utils/seq_len_to_mask.py create mode 100644 tests/core/utils/test_seq_len_to_mask.py diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 855c335f..4f5ee3d8 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -92,8 +92,8 @@ __all__ = [ "cache_results", "f_rich_progress", "auto_param_call", - "seq_len_to_mask", "f_tqdm_progress", + "seq_len_to_mask", # vocabulary.py 'Vocabulary' diff --git a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py index 349fb444..a93fb55c 100644 --- a/fastNLP/core/dataloaders/jittor_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/jittor_dataloader/fdl.py @@ -52,7 +52,7 @@ class JittorDataLoader: """ - def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, + def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto") -> None: @@ -194,7 +194,7 @@ class JittorDataLoader: return self.cur_batch_indices -def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = True, +def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, collate_fn: Union[None, str, Callable] = "auto", diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py index 02bf9bef..3f1b6acd 100644 --- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py @@ -81,7 +81,7 @@ class PaddleDataLoader(DataLoader): def __init__(self, dataset, feed_list=None, places=None, return_list: bool = True, batch_sampler=None, - batch_size: int = 16, shuffle: bool = True, + batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', num_workers: int = 0, use_buffer_reader: bool = True, use_shared_memory: bool = True, timeout: int = 0, diff --git a/fastNLP/core/dataloaders/prepare_dataloader.py b/fastNLP/core/dataloaders/prepare_dataloader.py index 81bd1bdb..358578fc 100644 --- a/fastNLP/core/dataloaders/prepare_dataloader.py +++ b/fastNLP/core/dataloaders/prepare_dataloader.py @@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS from ..log import logger -def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, +def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, seed: int = 0, backend: str = 'auto'): """ diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py index 9fa2152b..e78729b3 100644 --- a/fastNLP/core/metrics/accuracy.py +++ b/fastNLP/core/metrics/accuracy.py @@ -10,7 +10,7 @@ import numpy as np from fastNLP.core.metrics.metric import Metric from fastNLP.core.metrics.backend import Backend -from fastNLP.core.utils.utils import seq_len_to_mask +from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask class Accuracy(Metric): diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py index aa1e8440..abb68f38 100644 --- a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py @@ -10,7 +10,7 @@ import numpy as np from .metric import Metric from .backend import Backend from fastNLP.core.vocabulary import Vocabulary -from fastNLP.core.utils.utils import seq_len_to_mask +from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask from .utils import _compute_f_pre_rec diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 62b4cb7e..0857f450 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -21,9 +21,10 @@ __all__ = [ 'pretty_table_printer', 'Option', 'deprecated', - 'seq_len_to_mask', "flat_nest_dict", - "f_tqdm_progress" + "f_tqdm_progress", + + "seq_len_to_mask" ] from .cache_results import cache_results @@ -34,5 +35,6 @@ from .rich_progress import f_rich_progress from .torch_utils import torch_move_data_to_device from .utils import * from .tqdm_progress import f_tqdm_progress +from .seq_len_to_mask import seq_len_to_mask diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index 2d7b65cc..f14a2bce 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -20,6 +20,7 @@ if _NEED_IMPORT_PADDLE: from .utils import apply_to_collection + def _convert_data_device(device: Union[str, int]) -> str: """ 用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将 @@ -59,7 +60,9 @@ def _convert_data_device(device: Union[str, int]) -> str: raise ValueError(f"Can't convert device {device} when USER_CUDA_VISIBLE_DEVICES={user_visible_devices} " "and CUDA_VISIBLE_DEVICES={cuda_visible_devices}. If this situation happens, please report this bug to us.") -def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor": + +def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_avx.Place', + 'paddle.CPUPlace', 'paddle.CUDAPlace']) -> "paddle.Tensor": """ 将 ``data`` 迁移到指定的 ``device`` 上。``paddle.Tensor`` 没有类似 ``torch.Tensor`` 的 ``to`` 函数, 该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。 @@ -68,12 +71,21 @@ def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor" :param device: 目标设备,可以是 ``str`` 或 ``int`` 类型; :return: 迁移后的张量; """ - - if device == "cpu": + if isinstance(device, paddle.fluid.core_avx.Place): + if device.is_cpu_place(): + return data.cpu() + else: + return data.cuda(device.gpu_device_id()) + elif isinstance(device, paddle.CPUPlace): + return data.cpu() + elif isinstance(device, paddle.CUDAPlace): + return data.gpu(device.get_device_id()) + elif device == "cpu": return data.cpu() else: return data.cuda(get_paddle_device_id(device)) + def get_paddle_gpu_str(device: Union[str, int]) -> str: """ 获得 ``gpu:x`` 格式的设备名:: diff --git a/fastNLP/core/utils/seq_len_to_mask.py b/fastNLP/core/utils/seq_len_to_mask.py new file mode 100644 index 00000000..e244603c --- /dev/null +++ b/fastNLP/core/utils/seq_len_to_mask.py @@ -0,0 +1,84 @@ +from typing import Optional + +import numpy as np +from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE +from .paddle_utils import paddle_to + + +if _NEED_IMPORT_TORCH: + import torch + +if _NEED_IMPORT_PADDLE: + import paddle + +if _NEED_IMPORT_JITTOR: + import jittor + + +def seq_len_to_mask(seq_len, max_len: Optional[int]=None): + r""" + + 将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。 + + .. code-block:: + + >>> seq_len = torch.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len) + >>> print(mask.size()) + torch.Size([14, 15]) + >>> seq_len = np.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len) + >>> print(mask.shape) + (14, 15) + >>> seq_len = torch.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len, max_len=100) + >>>print(mask.size()) + torch.Size([14, 100]) + + :param seq_len: 大小为 ``(B,)`` 的长度序列; + :param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度; + 但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 + ``max_len`` 使得 ``mask`` 的补齐或截断到该长度。 + :return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` + """ + max_len = int(max_len) if max_len is not None else int(seq_len.max()) + + if isinstance(seq_len, np.ndarray): + assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim}." + broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) + mask = broad_cast_seq_len < seq_len.reshape(-1, 1) + return mask + + try: # 尝试是否是 torch + if isinstance(seq_len, torch.Tensor): + assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." + batch_size = seq_len.shape[0] + broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len) + mask = broad_cast_seq_len < seq_len.unsqueeze(1) + return mask + except NameError as e: + pass + + try: + if isinstance(seq_len, paddle.Tensor): + assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." + batch_size = seq_len.shape[0] + broad_cast_seq_len = paddle.arange(max_len).expand((batch_size, -1)) + broad_cast_seq_len = paddle_to(broad_cast_seq_len, device=seq_len.place) + mask = broad_cast_seq_len < seq_len.unsqueeze(1) + return mask + except NameError as e: + pass + + try: + if isinstance(seq_len, jittor.Var): + assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}." + batch_size = seq_len.shape[0] + broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1) + mask = broad_cast_seq_len < seq_len.unsqueeze(1) + return mask + except NameError as e: + pass + + raise TypeError("seq_len_to_mask function only supports numpy.ndarray, torch.Tensor, paddle.Tensor, " + f"and jittor.Var, but got {type(seq_len)}") \ No newline at end of file diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 0890f5ec..6864d984 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -14,7 +14,6 @@ import os from contextlib import contextmanager from functools import wraps from prettytable import PrettyTable -import numpy as np from pathlib import Path from fastNLP.core.log import logger @@ -31,7 +30,6 @@ __all__ = [ 'pretty_table_printer', 'Option', 'deprecated', - 'seq_len_to_mask', "flat_nest_dict" ] @@ -567,44 +565,6 @@ def deprecated(help_message: Optional[str] = None): return decorator -def seq_len_to_mask(seq_len, max_len: Optional[int]): - r""" - - 将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。 - - .. code-block:: - - >>> seq_len = torch.arange(2, 16) - >>> mask = seq_len_to_mask(seq_len) - >>> print(mask.size()) - torch.Size([14, 15]) - >>> seq_len = np.arange(2, 16) - >>> mask = seq_len_to_mask(seq_len) - >>> print(mask.shape) - (14, 15) - >>> seq_len = torch.arange(2, 16) - >>> mask = seq_len_to_mask(seq_len, max_len=100) - >>>print(mask.size()) - torch.Size([14, 100]) - - :param seq_len: 大小为 ``(B,)`` 的长度序列; - :param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度; - 但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入 - ``max_len`` 使得 ``mask`` 的补齐或截断到该长度。 - :return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8`` - """ - if isinstance(seq_len, np.ndarray): - assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." - max_len = int(max_len) if max_len else int(seq_len.max()) - broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) - mask = broad_cast_seq_len < seq_len.reshape(-1, 1) - - else: - raise TypeError("Only support 1-d numpy.ndarray.") - - return mask - - def wait_filepath(path, exist=True): """ 等待当 path 的存在状态为 {exist} 时返回 diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 290d8ffe..3897cb0d 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -109,6 +109,9 @@ __all__ = [ "CMRC2018BertPipe", + "iob2", + "iob2bioes" + ] from .data_bundle import DataBundle diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 35965ca3..5d269cc5 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -62,13 +62,16 @@ __all__ = [ "R8PmiGraphPipe", "OhsumedPmiGraphPipe", "NG20PmiGraphPipe", - "MRPmiGraphPipe" + "MRPmiGraphPipe", + + "iob2", + "iob2bioes" ] from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe -from .conll import Conll2003Pipe +from .conll import Conll2003Pipe, iob2, iob2bioes from .coreference import CoReferencePipe from .cws import CWSPipe from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ diff --git a/tests/core/utils/test_seq_len_to_mask.py b/tests/core/utils/test_seq_len_to_mask.py new file mode 100644 index 00000000..0a17bae6 --- /dev/null +++ b/tests/core/utils/test_seq_len_to_mask.py @@ -0,0 +1,102 @@ +import pytest +import numpy as np +from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask +from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + import torch + +if _NEED_IMPORT_PADDLE: + import paddle + +if _NEED_IMPORT_JITTOR: + import jittor + + +class TestSeqLenToMask: + + def evaluate_mask_seq_len(self, seq_len, mask): + max_len = int(max(seq_len)) + for i in range(len(seq_len)): + length = seq_len[i] + mask_i = mask[i] + for j in range(max_len): + assert mask_i[j] == (j