Browse Source

增加seq_len_to_mask

tags/v1.0.0alpha
yh 2 years ago
parent
commit
dcdd484eb2
14 changed files with 234 additions and 54 deletions
  1. +1
    -1
      fastNLP/core/__init__.py
  2. +2
    -2
      fastNLP/core/dataloaders/jittor_dataloader/fdl.py
  3. +1
    -1
      fastNLP/core/dataloaders/paddle_dataloader/fdl.py
  4. +1
    -1
      fastNLP/core/dataloaders/prepare_dataloader.py
  5. +1
    -1
      fastNLP/core/metrics/accuracy.py
  6. +1
    -1
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  7. +4
    -2
      fastNLP/core/utils/__init__.py
  8. +15
    -3
      fastNLP/core/utils/paddle_utils.py
  9. +84
    -0
      fastNLP/core/utils/seq_len_to_mask.py
  10. +0
    -40
      fastNLP/core/utils/utils.py
  11. +3
    -0
      fastNLP/io/__init__.py
  12. +5
    -2
      fastNLP/io/pipe/__init__.py
  13. +102
    -0
      tests/core/utils/test_seq_len_to_mask.py
  14. +14
    -0
      tests/embeddings/torch/test_static_embedding.py

+ 1
- 1
fastNLP/core/__init__.py View File

@@ -92,8 +92,8 @@ __all__ = [
"cache_results",
"f_rich_progress",
"auto_param_call",
"seq_len_to_mask",
"f_tqdm_progress",
"seq_len_to_mask",

# vocabulary.py
'Vocabulary'


+ 2
- 2
fastNLP/core/dataloaders/jittor_dataloader/fdl.py View File

@@ -52,7 +52,7 @@ class JittorDataLoader:

"""

def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True,
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Union[None, str, Callable] = "auto") -> None:
@@ -194,7 +194,7 @@ class JittorDataLoader:
return self.cur_batch_indices


def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = True,
def prepare_jittor_dataloader(ds_or_db, batch_size: int = 16, shuffle: bool = False,
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024,
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False,
collate_fn: Union[None, str, Callable] = "auto",


+ 1
- 1
fastNLP/core/dataloaders/paddle_dataloader/fdl.py View File

@@ -81,7 +81,7 @@ class PaddleDataLoader(DataLoader):

def __init__(self, dataset, feed_list=None, places=None,
return_list: bool = True, batch_sampler=None,
batch_size: int = 16, shuffle: bool = True,
batch_size: int = 16, shuffle: bool = False,
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto',
num_workers: int = 0, use_buffer_reader: bool = True,
use_shared_memory: bool = True, timeout: int = 0,


+ 1
- 1
fastNLP/core/dataloaders/prepare_dataloader.py View File

@@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS
from ..log import logger


def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False,
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False,
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0,
seed: int = 0, backend: str = 'auto'):
"""


+ 1
- 1
fastNLP/core/metrics/accuracy.py View File

@@ -10,7 +10,7 @@ import numpy as np

from fastNLP.core.metrics.metric import Metric
from fastNLP.core.metrics.backend import Backend
from fastNLP.core.utils.utils import seq_len_to_mask
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask


class Accuracy(Metric):


+ 1
- 1
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -10,7 +10,7 @@ import numpy as np
from .metric import Metric
from .backend import Backend
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.utils.utils import seq_len_to_mask
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
from .utils import _compute_f_pre_rec




+ 4
- 2
fastNLP/core/utils/__init__.py View File

@@ -21,9 +21,10 @@ __all__ = [
'pretty_table_printer',
'Option',
'deprecated',
'seq_len_to_mask',
"flat_nest_dict",
"f_tqdm_progress"
"f_tqdm_progress",

"seq_len_to_mask"
]

from .cache_results import cache_results
@@ -34,5 +35,6 @@ from .rich_progress import f_rich_progress
from .torch_utils import torch_move_data_to_device
from .utils import *
from .tqdm_progress import f_tqdm_progress
from .seq_len_to_mask import seq_len_to_mask



+ 15
- 3
fastNLP/core/utils/paddle_utils.py View File

@@ -20,6 +20,7 @@ if _NEED_IMPORT_PADDLE:

from .utils import apply_to_collection


def _convert_data_device(device: Union[str, int]) -> str:
"""
用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将
@@ -59,7 +60,9 @@ def _convert_data_device(device: Union[str, int]) -> str:
raise ValueError(f"Can't convert device {device} when USER_CUDA_VISIBLE_DEVICES={user_visible_devices} "
"and CUDA_VISIBLE_DEVICES={cuda_visible_devices}. If this situation happens, please report this bug to us.")

def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor":

def paddle_to(data: "paddle.Tensor", device: Union[str, int, 'paddle.fluid.core_avx.Place',
'paddle.CPUPlace', 'paddle.CUDAPlace']) -> "paddle.Tensor":
"""
将 ``data`` 迁移到指定的 ``device`` 上。``paddle.Tensor`` 没有类似 ``torch.Tensor`` 的 ``to`` 函数,
该函数只是集成了 :func:`paddle.Tensor.cpu` 和 :func:`paddle.Tensor.cuda` 两个函数。
@@ -68,12 +71,21 @@ def paddle_to(data: "paddle.Tensor", device: Union[str, int]) -> "paddle.Tensor"
:param device: 目标设备,可以是 ``str`` 或 ``int`` 类型;
:return: 迁移后的张量;
"""

if device == "cpu":
if isinstance(device, paddle.fluid.core_avx.Place):
if device.is_cpu_place():
return data.cpu()
else:
return data.cuda(device.gpu_device_id())
elif isinstance(device, paddle.CPUPlace):
return data.cpu()
elif isinstance(device, paddle.CUDAPlace):
return data.gpu(device.get_device_id())
elif device == "cpu":
return data.cpu()
else:
return data.cuda(get_paddle_device_id(device))


def get_paddle_gpu_str(device: Union[str, int]) -> str:
"""
获得 ``gpu:x`` 格式的设备名::


+ 84
- 0
fastNLP/core/utils/seq_len_to_mask.py View File

@@ -0,0 +1,84 @@
from typing import Optional

import numpy as np
from ...envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE
from .paddle_utils import paddle_to


if _NEED_IMPORT_TORCH:
import torch

if _NEED_IMPORT_PADDLE:
import paddle

if _NEED_IMPORT_JITTOR:
import jittor


def seq_len_to_mask(seq_len, max_len: Optional[int]=None):
r"""

将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。

.. code-block::

>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.size())
torch.Size([14, 15])
>>> seq_len = np.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.shape)
(14, 15)
>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len, max_len=100)
>>>print(mask.size())
torch.Size([14, 100])

:param seq_len: 大小为 ``(B,)`` 的长度序列;
:param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度;
但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入
``max_len`` 使得 ``mask`` 的补齐或截断到该长度。
:return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8``
"""
max_len = int(max_len) if max_len is not None else int(seq_len.max())

if isinstance(seq_len, np.ndarray):
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim}."
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
return mask

try: # 尝试是否是 torch
if isinstance(seq_len, torch.Tensor):
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
batch_size = seq_len.shape[0]
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
mask = broad_cast_seq_len < seq_len.unsqueeze(1)
return mask
except NameError as e:
pass

try:
if isinstance(seq_len, paddle.Tensor):
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
batch_size = seq_len.shape[0]
broad_cast_seq_len = paddle.arange(max_len).expand((batch_size, -1))
broad_cast_seq_len = paddle_to(broad_cast_seq_len, device=seq_len.place)
mask = broad_cast_seq_len < seq_len.unsqueeze(1)
return mask
except NameError as e:
pass

try:
if isinstance(seq_len, jittor.Var):
assert seq_len.ndim == 1, f"seq_len can only have one dimension, got {seq_len.ndim == 1}."
batch_size = seq_len.shape[0]
broad_cast_seq_len = jittor.arange(max_len).expand(batch_size, -1)
mask = broad_cast_seq_len < seq_len.unsqueeze(1)
return mask
except NameError as e:
pass

raise TypeError("seq_len_to_mask function only supports numpy.ndarray, torch.Tensor, paddle.Tensor, "
f"and jittor.Var, but got {type(seq_len)}")

+ 0
- 40
fastNLP/core/utils/utils.py View File

@@ -14,7 +14,6 @@ import os
from contextlib import contextmanager
from functools import wraps
from prettytable import PrettyTable
import numpy as np
from pathlib import Path

from fastNLP.core.log import logger
@@ -31,7 +30,6 @@ __all__ = [
'pretty_table_printer',
'Option',
'deprecated',
'seq_len_to_mask',
"flat_nest_dict"
]

@@ -567,44 +565,6 @@ def deprecated(help_message: Optional[str] = None):
return decorator


def seq_len_to_mask(seq_len, max_len: Optional[int]):
r"""

将一个表示 ``sequence length`` 的一维数组转换为二维的 ``mask`` ,不包含的位置为 **0**。

.. code-block::

>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.size())
torch.Size([14, 15])
>>> seq_len = np.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.shape)
(14, 15)
>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len, max_len=100)
>>>print(mask.size())
torch.Size([14, 100])

:param seq_len: 大小为 ``(B,)`` 的长度序列;
:param int max_len: 将长度补齐或截断到 ``max_len``。默认情况(为 ``None``)使用的是 ``seq_len`` 中最长的长度;
但在 :class:`torch.nn.DataParallel` 等分布式的场景下可能不同卡的 ``seq_len`` 会有区别,所以需要传入
``max_len`` 使得 ``mask`` 的补齐或截断到该长度。
:return: 大小为 ``(B, max_len)`` 的 ``mask``, 元素类型为 ``bool`` 或 ``uint8``
"""
if isinstance(seq_len, np.ndarray):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
max_len = int(max_len) if max_len else int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)

else:
raise TypeError("Only support 1-d numpy.ndarray.")

return mask


def wait_filepath(path, exist=True):
"""
等待当 path 的存在状态为 {exist} 时返回


+ 3
- 0
fastNLP/io/__init__.py View File

@@ -109,6 +109,9 @@ __all__ = [

"CMRC2018BertPipe",

"iob2",
"iob2bioes"

]

from .data_bundle import DataBundle


+ 5
- 2
fastNLP/io/pipe/__init__.py View File

@@ -62,13 +62,16 @@ __all__ = [
"R8PmiGraphPipe",
"OhsumedPmiGraphPipe",
"NG20PmiGraphPipe",
"MRPmiGraphPipe"
"MRPmiGraphPipe",

"iob2",
"iob2bioes"
]

from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
from .conll import Conll2003Pipe
from .conll import Conll2003Pipe, iob2, iob2bioes
from .coreference import CoReferencePipe
from .cws import CWSPipe
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \


+ 102
- 0
tests/core/utils/test_seq_len_to_mask.py View File

@@ -0,0 +1,102 @@
import pytest
import numpy as np
from fastNLP.core.utils.seq_len_to_mask import seq_len_to_mask
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch

if _NEED_IMPORT_PADDLE:
import paddle

if _NEED_IMPORT_JITTOR:
import jittor


class TestSeqLenToMask:

def evaluate_mask_seq_len(self, seq_len, mask):
max_len = int(max(seq_len))
for i in range(len(seq_len)):
length = seq_len[i]
mask_i = mask[i]
for j in range(max_len):
assert mask_i[j] == (j<length), (i, j, length)

def test_numpy_seq_len(self):
# 测试能否转换numpy类型的seq_len
# 1. 随机测试
seq_len = np.random.randint(1, 10, size=(2, ))
mask = seq_len_to_mask(seq_len)
max_len = seq_len.max()
assert max_len == mask.shape[1]
print(mask)
print(seq_len)
self.evaluate_mask_seq_len(seq_len, mask)

# 2. 异常检测
seq_len = np.random.randint(10, size=(10, 1))

with pytest.raises(AssertionError):
mask = seq_len_to_mask(seq_len)

# 3. pad到指定长度
seq_len = np.random.randint(1, 10, size=(10,))
mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.shape[1]

@pytest.mark.torch
def test_pytorch_seq_len(self):
# 1. 随机测试
seq_len = torch.randint(1, 10, size=(10, ))
max_len = seq_len.max()
mask = seq_len_to_mask(seq_len)
assert max_len == mask.shape[1]
self.evaluate_mask_seq_len(seq_len.tolist(), mask)

# 2. 异常检测
seq_len = torch.randn(3, 4)
with pytest.raises(AssertionError):
mask = seq_len_to_mask(seq_len)

# 3. pad到指定长度
seq_len = torch.randint(1, 10, size=(10, ))
mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.size(1)

@pytest.mark.paddle
def test_paddle_seq_len(self):
# 1. 随机测试
seq_len = paddle.randint(1, 10, shape=(10,))
max_len = seq_len.max()
mask = seq_len_to_mask(seq_len)
assert max_len == mask.shape[1]
self.evaluate_mask_seq_len(seq_len.tolist(), mask)

# 2. 异常检测
seq_len = paddle.randn((3, 4))
with pytest.raises(AssertionError):
mask = seq_len_to_mask(seq_len)

# 3. pad到指定长度
seq_len = paddle.randint(1, 10, size=(10,))
mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.shape[1]

@pytest.mark.jittor
def test_jittor_seq_len(self):
# 1. 随机测试
seq_len = jittor.randint(1, 10, shape=(10,))
max_len = seq_len.max()
mask = seq_len_to_mask(seq_len)
assert max_len == mask.shape[1]
self.evaluate_mask_seq_len(seq_len.tolist(), mask)

# 2. 异常检测
seq_len = jittor.randn(3, 4)
with pytest.raises(AssertionError):
mask = seq_len_to_mask(seq_len)

# 3. pad到指定长度
seq_len = jittor.randint(1, 10, shape=(10,))
mask = seq_len_to_mask(seq_len, 100)
assert 100 == mask.shape[1]

+ 14
- 0
tests/embeddings/torch/test_static_embedding.py View File

@@ -193,3 +193,17 @@ class TestRandomSameEntry:
embed_0 = words[0, 3]
for i in range(3, 5):
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))

def test_dropout_close(self):
vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True,
dropout=0.5, word_dropout=0.9)
words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
embed.eval()
words = embed(words)
embed_0 = words[0, 0]
for i in range(1, 3):
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
embed_0 = words[0, 3]
for i in range(3, 5):
assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))

Loading…
Cancel
Save