Browse Source

统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask

tags/v0.4.10
yh_cc 5 years ago
parent
commit
56ff4ac7a1
13 changed files with 97 additions and 97 deletions
  1. +1
    -1
      fastNLP/core/__init__.py
  2. +2
    -2
      fastNLP/core/metrics.py
  3. +31
    -38
      fastNLP/core/utils.py
  4. +4
    -5
      fastNLP/models/biaffine_parser.py
  5. +3
    -3
      fastNLP/models/sequence_labeling.py
  6. +3
    -3
      fastNLP/models/snli.py
  7. +5
    -5
      fastNLP/models/star_transformer.py
  8. +2
    -13
      fastNLP/modules/decoder/CRF.py
  9. +2
    -1
      fastNLP/modules/decoder/__init__.py
  10. +0
    -6
      fastNLP/modules/decoder/utils.py
  11. +0
    -16
      fastNLP/modules/utils.py
  12. +41
    -1
      test/core/test_utils.py
  13. +3
    -3
      test/modules/decoder/test_CRF.py

+ 1
- 1
fastNLP/core/__init__.py View File

@@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester
from .trainer import Trainer
from .utils import cache_results
from .utils import cache_results, seq_len_to_mask
from .vocabulary import Vocabulary

+ 2
- 2
fastNLP/core/metrics.py View File

@@ -13,7 +13,7 @@ from .utils import _CheckRes
from .utils import _build_args
from .utils import _check_arg_dict_list
from .utils import _get_func_signature
from .utils import seq_lens_to_masks
from .utils import seq_len_to_mask
from .vocabulary import Vocabulary


@@ -305,7 +305,7 @@ class AccuracyMetric(MetricBase):
f"got {type(seq_len)}.")

if seq_len is not None:
masks = seq_lens_to_masks(seq_lens=seq_len)
masks = seq_len_to_mask(seq_len=seq_len)
else:
masks = None



+ 31
- 38
fastNLP/core/utils.py View File

@@ -1,7 +1,7 @@
"""
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
"""
__all__ = ["cache_results"]
__all__ = ["cache_results", "seq_len_to_mask"]
import _pickle
import inspect
import os
@@ -600,48 +600,41 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
warnings.warn(message=_unused_warn)


def seq_lens_to_masks(seq_lens, float=False):
def seq_len_to_mask(seq_len):
"""

Convert seq_lens to masks.
:param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,)
:param float: if True, the return masks is in float type, otherwise it is byte.
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length)
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。
转变 1-d seq_len到2-d mask.

Example::
>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.size())
torch.Size([14, 15])
>>> seq_len = np.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.shape)
(14, 15)

:param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
:return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8
"""
if isinstance(seq_lens, np.ndarray):
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}."
raise NotImplemented
elif isinstance(seq_lens, torch.Tensor):
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long()
masks = indexes.lt(seq_lens.unsqueeze(1))

if float:
masks = masks.float()

return masks
elif isinstance(seq_lens, list):
raise NotImplemented
if isinstance(seq_len, np.ndarray):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
max_len = int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len<seq_len.reshape(-1, 1)

elif isinstance(seq_len, torch.Tensor):
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
batch_size = seq_len.size(0)
max_len = seq_len.max().long()
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
else:
raise NotImplemented

raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")

def seq_mask(seq_len, max_len):
"""Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]
return mask


class _pseudo_tqdm:


+ 4
- 5
fastNLP/models/biaffine_parser.py View File

@@ -10,14 +10,13 @@ from torch.nn import functional as F
from ..core.const import Const as C
from ..core.losses import LossFunc
from ..core.metrics import MetricBase
from ..core.utils import seq_lens_to_masks
from ..modules.dropout import TimestepDropout
from ..modules.encoder.transformer import TransformerEncoder
from ..modules.encoder.variational_rnn import VarLSTM
from ..modules.utils import initial_parameter
from ..modules.utils import seq_mask
from ..modules.utils import get_embeddings
from .base_model import BaseModel
from ..core.utils import seq_len_to_mask

def _mst(scores):
"""
@@ -346,7 +345,7 @@ class BiaffineParser(GraphParser):
# print('forward {} {}'.format(batch_size, seq_len))

# get sequence mask
mask = seq_mask(seq_len, length).long()
mask = seq_len_to_mask(seq_len).long()

word = self.word_embedding(words1) # [N,L] -> [N,L,C_0]
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1]
@@ -418,7 +417,7 @@ class BiaffineParser(GraphParser):
"""

batch_size, length, _ = pred1.shape
mask = seq_mask(seq_len, length)
mask = seq_len_to_mask(seq_len)
flip_mask = (mask == 0)
_arc_pred = pred1.clone()
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))
@@ -514,7 +513,7 @@ class ParserMetric(MetricBase):
if seq_len is None:
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long)
else:
seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long()
seq_mask = seq_len_to_mask(seq_len.long()).long()
# mask out <root> tag
seq_mask[:,0] = 0
head_pred_correct = (pred1 == target1).long() * seq_mask


+ 3
- 3
fastNLP/models/sequence_labeling.py View File

@@ -3,7 +3,7 @@ import torch
from .base_model import BaseModel
from ..modules import decoder, encoder
from ..modules.decoder.CRF import allowed_transitions
from ..modules.utils import seq_mask
from ..core.utils import seq_len_to_mask
from ..core.const import Const as C
from torch import nn

@@ -84,7 +84,7 @@ class SeqLabeling(BaseModel):
def _make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = seq_len_to_mask(seq_len)
mask = mask.view(batch_size, max_len)
mask = mask.to(x).float()
return mask
@@ -160,7 +160,7 @@ class AdvSeqLabel(nn.Module):
def _make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = seq_len_to_mask(seq_len)
mask = mask.view(batch_size, max_len)
mask = mask.to(x).float()
return mask


+ 3
- 3
fastNLP/models/snli.py View File

@@ -6,7 +6,7 @@ from ..core.const import Const
from ..modules import decoder as Decoder
from ..modules import encoder as Encoder
from ..modules import aggregator as Aggregator
from ..modules.utils import seq_mask
from ..core.utils import seq_len_to_mask


my_inf = 10e12
@@ -75,12 +75,12 @@ class ESIM(BaseModel):
hypothesis0 = self.embedding_layer(self.embedding(words2))

if seq_len1 is not None:
seq_len1 = seq_mask(seq_len1, premise0.size(1))
seq_len1 = seq_len_to_mask(seq_len1)
else:
seq_len1 = torch.ones(premise0.size(0), premise0.size(1))
seq_len1 = (seq_len1.long()).to(device=premise0.device)
if seq_len2 is not None:
seq_len2 = seq_mask(seq_len2, hypothesis0.size(1))
seq_len2 = seq_len_to_mask(seq_len2)
else:
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1))
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device)


+ 5
- 5
fastNLP/models/star_transformer.py View File

@@ -1,7 +1,7 @@
"""Star-Transformer 的 一个 Pytorch 实现.
"""
from ..modules.encoder.star_transformer import StarTransformer
from ..core.utils import seq_lens_to_masks
from ..core.utils import seq_len_to_mask
from ..modules.utils import get_embeddings
from ..core.const import Const

@@ -134,7 +134,7 @@ class STSeqLabel(nn.Module):
:param seq_len: [batch,] 输入序列的长度
:return output: [batch, num_cls, seq_len] 输出序列中每个元素的分类的概率
"""
mask = seq_lens_to_masks(seq_len)
mask = seq_len_to_mask(seq_len)
nodes, _ = self.enc(words, mask)
output = self.cls(nodes)
output = output.transpose(1,2) # make hidden to be dim 1
@@ -195,7 +195,7 @@ class STSeqCls(nn.Module):
:param seq_len: [batch,] 输入序列的长度
:return output: [batch, num_cls] 输出序列的分类的概率
"""
mask = seq_lens_to_masks(seq_len)
mask = seq_len_to_mask(seq_len)
nodes, relay = self.enc(words, mask)
y = 0.5 * (relay + nodes.max(1)[0])
output = self.cls(y) # [bsz, n_cls]
@@ -258,8 +258,8 @@ class STNLICls(nn.Module):
:param seq_len2: [batch,] 输入序列2的长度
:return output: [batch, num_cls] 输出分类的概率
"""
mask1 = seq_lens_to_masks(seq_len1)
mask2 = seq_lens_to_masks(seq_len2)
mask1 = seq_len_to_mask(seq_len1)
mask2 = seq_len_to_mask(seq_len2)
def enc(seq, mask):
nodes, relay = self.enc(seq, mask)
return 0.5 * (relay + nodes.max(1)[0])


+ 2
- 13
fastNLP/modules/decoder/CRF.py View File

@@ -2,17 +2,6 @@ import torch
from torch import nn

from ..utils import initial_parameter
from ..decoder.utils import log_sum_exp


def seq_len_to_byte_mask(seq_lens):
# usually seq_lens: LongTensor, batch_size
# return value: ByteTensor, batch_size x max_len
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1))
return mask


def allowed_transitions(id2label, encoding_type='bio', include_start_end=True):
@@ -197,13 +186,13 @@ class ConditionalRandomField(nn.Module):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)

if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)

return log_sum_exp(alpha, 1)
return torch.logsumexp(alpha, 1)

def _gold_score(self, logits, tags, mask):
"""


+ 2
- 1
fastNLP/modules/decoder/__init__.py View File

@@ -1,4 +1,5 @@
__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode"]
__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode", "allowed_transitions"]
from .CRF import ConditionalRandomField
from .MLP import MLP
from .utils import viterbi_decode
from .CRF import allowed_transitions

+ 0
- 6
fastNLP/modules/decoder/utils.py View File

@@ -2,12 +2,6 @@ __all__ = ["viterbi_decode"]
import torch


def log_sum_exp(x, dim=-1):
max_value, _ = x.max(dim=dim, keepdim=True)
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value
return res.squeeze(dim)


def viterbi_decode(logits, transitions, mask=None, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数



+ 0
- 16
fastNLP/modules/utils.py View File

@@ -68,22 +68,6 @@ def initial_parameter(net, initial_method=None):
net.apply(weights_init)


def seq_mask(seq_len, max_len):
"""
Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return: mask, torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]


def get_embeddings(init_embed):
"""
得到词嵌入 TODO


+ 41
- 1
test/core/test_utils.py View File

@@ -9,7 +9,8 @@ import os
import torch
from torch import nn
from fastNLP.core.utils import _move_model_to_device, _get_model_device

import numpy as np
from fastNLP.core.utils import seq_len_to_mask

class Model(nn.Module):
def __init__(self):
@@ -210,3 +211,42 @@ class TestCache(unittest.TestCase):
finally:
os.remove('test/demo1/demo.pkl')
os.rmdir('test/demo1')


class TestSeqLenToMask(unittest.TestCase):

def evaluate_mask_seq_len(self, seq_len, mask):
max_len = int(max(seq_len))
for i in range(len(seq_len)):
length = seq_len[i]
mask_i = mask[i]
for j in range(max_len):
self.assertEqual(mask_i[j], j<length)

def test_numpy_seq_len(self):
# 测试能否转换numpy类型的seq_len
# 1. 随机测试
seq_len = np.random.randint(1, 10, size=(10, ))
mask = seq_len_to_mask(seq_len)
max_len = seq_len.max()
self.assertEqual(max_len, mask.shape[1])
self.evaluate_mask_seq_len(seq_len, mask)

# 2. 异常检测
seq_len = np.random.randint(10, size=(10, 1))
with self.assertRaises(AssertionError):
mask = seq_len_to_mask(seq_len)


def test_pytorch_seq_len(self):
# 1. 随机测试
seq_len = torch.randint(1, 10, size=(10, ))
max_len = seq_len.max()
mask = seq_len_to_mask(seq_len)
self.assertEqual(max_len, mask.shape[1])
self.evaluate_mask_seq_len(seq_len.tolist(), mask)

# 2. 异常检测
seq_len = torch.randn(3, 4)
with self.assertRaises(AssertionError):
mask = seq_len_to_mask(seq_len)

+ 3
- 3
test/modules/decoder/test_CRF.py View File

@@ -105,7 +105,7 @@ class TestCRF(unittest.TestCase):
# 测试crf的loss不会出现负数
import torch
from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.utils import seq_len_to_mask
from torch import optim
from torch import nn

@@ -114,7 +114,7 @@ class TestCRF(unittest.TestCase):
lengths = torch.randint(3, 50, size=(num_samples, )).long()
max_len = lengths.max()
tags = torch.randint(num_tags, size=(num_samples, max_len))
masks = seq_lens_to_masks(lengths)
masks = seq_len_to_mask(lengths)
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
crf = ConditionalRandomField(num_tags, include_start_end_trans)
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1)
@@ -125,4 +125,4 @@ class TestCRF(unittest.TestCase):
optimizer.step()
if _%1000==0:
print(loss)
assert loss.item()>0, "CRF loss cannot be less than 0."
self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")

Loading…
Cancel
Save