Browse Source

统一不同位置的seq_len_to_mask, 现统一到core.utils.seq_len_to_mask

tags/v0.4.10
yh_cc 5 years ago
parent
commit
56ff4ac7a1
13 changed files with 97 additions and 97 deletions
  1. +1
    -1
      fastNLP/core/__init__.py
  2. +2
    -2
      fastNLP/core/metrics.py
  3. +31
    -38
      fastNLP/core/utils.py
  4. +4
    -5
      fastNLP/models/biaffine_parser.py
  5. +3
    -3
      fastNLP/models/sequence_labeling.py
  6. +3
    -3
      fastNLP/models/snli.py
  7. +5
    -5
      fastNLP/models/star_transformer.py
  8. +2
    -13
      fastNLP/modules/decoder/CRF.py
  9. +2
    -1
      fastNLP/modules/decoder/__init__.py
  10. +0
    -6
      fastNLP/modules/decoder/utils.py
  11. +0
    -16
      fastNLP/modules/utils.py
  12. +41
    -1
      test/core/test_utils.py
  13. +3
    -3
      test/modules/decoder/test_CRF.py

+ 1
- 1
fastNLP/core/__init__.py View File

@@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results
from .utils import cache_results, seq_len_to_mask
from .vocabulary import Vocabulary from .vocabulary import Vocabulary

+ 2
- 2
fastNLP/core/metrics.py View File

@@ -13,7 +13,7 @@ from .utils import _CheckRes
from .utils import _build_args from .utils import _build_args
from .utils import _check_arg_dict_list from .utils import _check_arg_dict_list
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import seq_lens_to_masks
from .utils import seq_len_to_mask
from .vocabulary import Vocabulary from .vocabulary import Vocabulary




@@ -305,7 +305,7 @@ class AccuracyMetric(MetricBase):
f"got {type(seq_len)}.") f"got {type(seq_len)}.")


if seq_len is not None: if seq_len is not None:
masks = seq_lens_to_masks(seq_lens=seq_len)
masks = seq_len_to_mask(seq_len=seq_len)
else: else:
masks = None masks = None




+ 31
- 38
fastNLP/core/utils.py View File

@@ -1,7 +1,7 @@
""" """
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
""" """
__all__ = ["cache_results"]
__all__ = ["cache_results", "seq_len_to_mask"]
import _pickle import _pickle
import inspect import inspect
import os import os
@@ -600,48 +600,41 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
warnings.warn(message=_unused_warn) warnings.warn(message=_unused_warn)




def seq_lens_to_masks(seq_lens, float=False):
def seq_len_to_mask(seq_len):
""" """


Convert seq_lens to masks.
:param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,)
:param float: if True, the return masks is in float type, otherwise it is byte.
:return: list, np.ndarray or torch.Tensor, shape will be (B, max_length)
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。
转变 1-d seq_len到2-d mask.

Example::
>>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.size())
torch.Size([14, 15])
>>> seq_len = np.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
>>> print(mask.shape)
(14, 15)

:param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
:return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8
""" """
if isinstance(seq_lens, np.ndarray):
assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}."
assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}."
raise NotImplemented
elif isinstance(seq_lens, torch.Tensor):
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long()
masks = indexes.lt(seq_lens.unsqueeze(1))

if float:
masks = masks.float()

return masks
elif isinstance(seq_lens, list):
raise NotImplemented
if isinstance(seq_len, np.ndarray):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
max_len = int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len<seq_len.reshape(-1, 1)

elif isinstance(seq_len, torch.Tensor):
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
batch_size = seq_len.size(0)
max_len = seq_len.max().long()
broad_cast_seq_len = torch.arange(max_len).expand(batch_size, -1).to(seq_len)
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
else: else:
raise NotImplemented

raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")


def seq_mask(seq_len, max_len):
"""Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return mask: torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]
return mask




class _pseudo_tqdm: class _pseudo_tqdm:


+ 4
- 5
fastNLP/models/biaffine_parser.py View File

@@ -10,14 +10,13 @@ from torch.nn import functional as F
from ..core.const import Const as C from ..core.const import Const as C
from ..core.losses import LossFunc from ..core.losses import LossFunc
from ..core.metrics import MetricBase from ..core.metrics import MetricBase
from ..core.utils import seq_lens_to_masks
from ..modules.dropout import TimestepDropout from ..modules.dropout import TimestepDropout
from ..modules.encoder.transformer import TransformerEncoder from ..modules.encoder.transformer import TransformerEncoder
from ..modules.encoder.variational_rnn import VarLSTM from ..modules.encoder.variational_rnn import VarLSTM
from ..modules.utils import initial_parameter from ..modules.utils import initial_parameter
from ..modules.utils import seq_mask
from ..modules.utils import get_embeddings from ..modules.utils import get_embeddings
from .base_model import BaseModel from .base_model import BaseModel
from ..core.utils import seq_len_to_mask


def _mst(scores): def _mst(scores):
""" """
@@ -346,7 +345,7 @@ class BiaffineParser(GraphParser):
# print('forward {} {}'.format(batch_size, seq_len)) # print('forward {} {}'.format(batch_size, seq_len))


# get sequence mask # get sequence mask
mask = seq_mask(seq_len, length).long()
mask = seq_len_to_mask(seq_len).long()


word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] word = self.word_embedding(words1) # [N,L] -> [N,L,C_0]
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1]
@@ -418,7 +417,7 @@ class BiaffineParser(GraphParser):
""" """


batch_size, length, _ = pred1.shape batch_size, length, _ = pred1.shape
mask = seq_mask(seq_len, length)
mask = seq_len_to_mask(seq_len)
flip_mask = (mask == 0) flip_mask = (mask == 0)
_arc_pred = pred1.clone() _arc_pred = pred1.clone()
_arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))
@@ -514,7 +513,7 @@ class ParserMetric(MetricBase):
if seq_len is None: if seq_len is None:
seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long)
else: else:
seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long()
seq_mask = seq_len_to_mask(seq_len.long()).long()
# mask out <root> tag # mask out <root> tag
seq_mask[:,0] = 0 seq_mask[:,0] = 0
head_pred_correct = (pred1 == target1).long() * seq_mask head_pred_correct = (pred1 == target1).long() * seq_mask


+ 3
- 3
fastNLP/models/sequence_labeling.py View File

@@ -3,7 +3,7 @@ import torch
from .base_model import BaseModel from .base_model import BaseModel
from ..modules import decoder, encoder from ..modules import decoder, encoder
from ..modules.decoder.CRF import allowed_transitions from ..modules.decoder.CRF import allowed_transitions
from ..modules.utils import seq_mask
from ..core.utils import seq_len_to_mask
from ..core.const import Const as C from ..core.const import Const as C
from torch import nn from torch import nn


@@ -84,7 +84,7 @@ class SeqLabeling(BaseModel):
def _make_mask(self, x, seq_len): def _make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1) batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = seq_len_to_mask(seq_len)
mask = mask.view(batch_size, max_len) mask = mask.view(batch_size, max_len)
mask = mask.to(x).float() mask = mask.to(x).float()
return mask return mask
@@ -160,7 +160,7 @@ class AdvSeqLabel(nn.Module):
def _make_mask(self, x, seq_len): def _make_mask(self, x, seq_len):
batch_size, max_len = x.size(0), x.size(1) batch_size, max_len = x.size(0), x.size(1)
mask = seq_mask(seq_len, max_len)
mask = seq_len_to_mask(seq_len)
mask = mask.view(batch_size, max_len) mask = mask.view(batch_size, max_len)
mask = mask.to(x).float() mask = mask.to(x).float()
return mask return mask


+ 3
- 3
fastNLP/models/snli.py View File

@@ -6,7 +6,7 @@ from ..core.const import Const
from ..modules import decoder as Decoder from ..modules import decoder as Decoder
from ..modules import encoder as Encoder from ..modules import encoder as Encoder
from ..modules import aggregator as Aggregator from ..modules import aggregator as Aggregator
from ..modules.utils import seq_mask
from ..core.utils import seq_len_to_mask




my_inf = 10e12 my_inf = 10e12
@@ -75,12 +75,12 @@ class ESIM(BaseModel):
hypothesis0 = self.embedding_layer(self.embedding(words2)) hypothesis0 = self.embedding_layer(self.embedding(words2))


if seq_len1 is not None: if seq_len1 is not None:
seq_len1 = seq_mask(seq_len1, premise0.size(1))
seq_len1 = seq_len_to_mask(seq_len1)
else: else:
seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) seq_len1 = torch.ones(premise0.size(0), premise0.size(1))
seq_len1 = (seq_len1.long()).to(device=premise0.device) seq_len1 = (seq_len1.long()).to(device=premise0.device)
if seq_len2 is not None: if seq_len2 is not None:
seq_len2 = seq_mask(seq_len2, hypothesis0.size(1))
seq_len2 = seq_len_to_mask(seq_len2)
else: else:
seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1))
seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) seq_len2 = (seq_len2.long()).to(device=hypothesis0.device)


+ 5
- 5
fastNLP/models/star_transformer.py View File

@@ -1,7 +1,7 @@
"""Star-Transformer 的 一个 Pytorch 实现. """Star-Transformer 的 一个 Pytorch 实现.
""" """
from ..modules.encoder.star_transformer import StarTransformer from ..modules.encoder.star_transformer import StarTransformer
from ..core.utils import seq_lens_to_masks
from ..core.utils import seq_len_to_mask
from ..modules.utils import get_embeddings from ..modules.utils import get_embeddings
from ..core.const import Const from ..core.const import Const


@@ -134,7 +134,7 @@ class STSeqLabel(nn.Module):
:param seq_len: [batch,] 输入序列的长度 :param seq_len: [batch,] 输入序列的长度
:return output: [batch, num_cls, seq_len] 输出序列中每个元素的分类的概率 :return output: [batch, num_cls, seq_len] 输出序列中每个元素的分类的概率
""" """
mask = seq_lens_to_masks(seq_len)
mask = seq_len_to_mask(seq_len)
nodes, _ = self.enc(words, mask) nodes, _ = self.enc(words, mask)
output = self.cls(nodes) output = self.cls(nodes)
output = output.transpose(1,2) # make hidden to be dim 1 output = output.transpose(1,2) # make hidden to be dim 1
@@ -195,7 +195,7 @@ class STSeqCls(nn.Module):
:param seq_len: [batch,] 输入序列的长度 :param seq_len: [batch,] 输入序列的长度
:return output: [batch, num_cls] 输出序列的分类的概率 :return output: [batch, num_cls] 输出序列的分类的概率
""" """
mask = seq_lens_to_masks(seq_len)
mask = seq_len_to_mask(seq_len)
nodes, relay = self.enc(words, mask) nodes, relay = self.enc(words, mask)
y = 0.5 * (relay + nodes.max(1)[0]) y = 0.5 * (relay + nodes.max(1)[0])
output = self.cls(y) # [bsz, n_cls] output = self.cls(y) # [bsz, n_cls]
@@ -258,8 +258,8 @@ class STNLICls(nn.Module):
:param seq_len2: [batch,] 输入序列2的长度 :param seq_len2: [batch,] 输入序列2的长度
:return output: [batch, num_cls] 输出分类的概率 :return output: [batch, num_cls] 输出分类的概率
""" """
mask1 = seq_lens_to_masks(seq_len1)
mask2 = seq_lens_to_masks(seq_len2)
mask1 = seq_len_to_mask(seq_len1)
mask2 = seq_len_to_mask(seq_len2)
def enc(seq, mask): def enc(seq, mask):
nodes, relay = self.enc(seq, mask) nodes, relay = self.enc(seq, mask)
return 0.5 * (relay + nodes.max(1)[0]) return 0.5 * (relay + nodes.max(1)[0])


+ 2
- 13
fastNLP/modules/decoder/CRF.py View File

@@ -2,17 +2,6 @@ import torch
from torch import nn from torch import nn


from ..utils import initial_parameter from ..utils import initial_parameter
from ..decoder.utils import log_sum_exp


def seq_len_to_byte_mask(seq_lens):
# usually seq_lens: LongTensor, batch_size
# return value: ByteTensor, batch_size x max_len
batch_size = seq_lens.size(0)
max_len = seq_lens.max()
broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1))
return mask




def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): def allowed_transitions(id2label, encoding_type='bio', include_start_end=True):
@@ -197,13 +186,13 @@ class ConditionalRandomField(nn.Module):
emit_score = logits[i].view(batch_size, 1, n_tags) emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)


if self.include_start_end_trans: if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1) alpha = alpha + self.end_scores.view(1, -1)


return log_sum_exp(alpha, 1)
return torch.logsumexp(alpha, 1)


def _gold_score(self, logits, tags, mask): def _gold_score(self, logits, tags, mask):
""" """


+ 2
- 1
fastNLP/modules/decoder/__init__.py View File

@@ -1,4 +1,5 @@
__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode"]
__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode", "allowed_transitions"]
from .CRF import ConditionalRandomField from .CRF import ConditionalRandomField
from .MLP import MLP from .MLP import MLP
from .utils import viterbi_decode from .utils import viterbi_decode
from .CRF import allowed_transitions

+ 0
- 6
fastNLP/modules/decoder/utils.py View File

@@ -2,12 +2,6 @@ __all__ = ["viterbi_decode"]
import torch import torch




def log_sum_exp(x, dim=-1):
max_value, _ = x.max(dim=dim, keepdim=True)
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value
return res.squeeze(dim)


def viterbi_decode(logits, transitions, mask=None, unpad=False): def viterbi_decode(logits, transitions, mask=None, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数




+ 0
- 16
fastNLP/modules/utils.py View File

@@ -68,22 +68,6 @@ def initial_parameter(net, initial_method=None):
net.apply(weights_init) net.apply(weights_init)




def seq_mask(seq_len, max_len):
"""
Create sequence mask.

:param seq_len: list or torch.Tensor, the lengths of sequences in a batch.
:param max_len: int, the maximum sequence length in a batch.
:return: mask, torch.LongTensor, [batch_size, max_len]

"""
if not isinstance(seq_len, torch.Tensor):
seq_len = torch.LongTensor(seq_len)
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1]
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len]
return torch.gt(seq_len, seq_range) # [batch_size, max_len]


def get_embeddings(init_embed): def get_embeddings(init_embed):
""" """
得到词嵌入 TODO 得到词嵌入 TODO


+ 41
- 1
test/core/test_utils.py View File

@@ -9,7 +9,8 @@ import os
import torch import torch
from torch import nn from torch import nn
from fastNLP.core.utils import _move_model_to_device, _get_model_device from fastNLP.core.utils import _move_model_to_device, _get_model_device

import numpy as np
from fastNLP.core.utils import seq_len_to_mask


class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
@@ -210,3 +211,42 @@ class TestCache(unittest.TestCase):
finally: finally:
os.remove('test/demo1/demo.pkl') os.remove('test/demo1/demo.pkl')
os.rmdir('test/demo1') os.rmdir('test/demo1')


class TestSeqLenToMask(unittest.TestCase):

def evaluate_mask_seq_len(self, seq_len, mask):
max_len = int(max(seq_len))
for i in range(len(seq_len)):
length = seq_len[i]
mask_i = mask[i]
for j in range(max_len):
self.assertEqual(mask_i[j], j<length)

def test_numpy_seq_len(self):
# 测试能否转换numpy类型的seq_len
# 1. 随机测试
seq_len = np.random.randint(1, 10, size=(10, ))
mask = seq_len_to_mask(seq_len)
max_len = seq_len.max()
self.assertEqual(max_len, mask.shape[1])
self.evaluate_mask_seq_len(seq_len, mask)

# 2. 异常检测
seq_len = np.random.randint(10, size=(10, 1))
with self.assertRaises(AssertionError):
mask = seq_len_to_mask(seq_len)


def test_pytorch_seq_len(self):
# 1. 随机测试
seq_len = torch.randint(1, 10, size=(10, ))
max_len = seq_len.max()
mask = seq_len_to_mask(seq_len)
self.assertEqual(max_len, mask.shape[1])
self.evaluate_mask_seq_len(seq_len.tolist(), mask)

# 2. 异常检测
seq_len = torch.randn(3, 4)
with self.assertRaises(AssertionError):
mask = seq_len_to_mask(seq_len)

+ 3
- 3
test/modules/decoder/test_CRF.py View File

@@ -105,7 +105,7 @@ class TestCRF(unittest.TestCase):
# 测试crf的loss不会出现负数 # 测试crf的loss不会出现负数
import torch import torch
from fastNLP.modules.decoder.CRF import ConditionalRandomField from fastNLP.modules.decoder.CRF import ConditionalRandomField
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.utils import seq_len_to_mask
from torch import optim from torch import optim
from torch import nn from torch import nn


@@ -114,7 +114,7 @@ class TestCRF(unittest.TestCase):
lengths = torch.randint(3, 50, size=(num_samples, )).long() lengths = torch.randint(3, 50, size=(num_samples, )).long()
max_len = lengths.max() max_len = lengths.max()
tags = torch.randint(num_tags, size=(num_samples, max_len)) tags = torch.randint(num_tags, size=(num_samples, max_len))
masks = seq_lens_to_masks(lengths)
masks = seq_len_to_mask(lengths)
feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags)) feats = nn.Parameter(torch.randn(num_samples, max_len, num_tags))
crf = ConditionalRandomField(num_tags, include_start_end_trans) crf = ConditionalRandomField(num_tags, include_start_end_trans)
optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1) optimizer = optim.SGD([param for param in crf.parameters() if param.requires_grad] + [feats], lr=0.1)
@@ -125,4 +125,4 @@ class TestCRF(unittest.TestCase):
optimizer.step() optimizer.step()
if _%1000==0: if _%1000==0:
print(loss) print(loss)
assert loss.item()>0, "CRF loss cannot be less than 0."
self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")

Loading…
Cancel
Save