From 56ff4ac7a1bd3afb0fa5ba89b3726c46b8580a63 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 10 May 2019 16:34:35 +0800 Subject: [PATCH] =?UTF-8?q?=E7=BB=9F=E4=B8=80=E4=B8=8D=E5=90=8C=E4=BD=8D?= =?UTF-8?q?=E7=BD=AE=E7=9A=84seq=5Flen=5Fto=5Fmask,=20=E7=8E=B0=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E5=88=B0core.utils.seq=5Flen=5Fto=5Fmask?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 2 +- fastNLP/core/metrics.py | 4 +- fastNLP/core/utils.py | 69 +++++++++++++---------------- fastNLP/models/biaffine_parser.py | 9 ++-- fastNLP/models/sequence_labeling.py | 6 +-- fastNLP/models/snli.py | 6 +-- fastNLP/models/star_transformer.py | 10 ++--- fastNLP/modules/decoder/CRF.py | 15 +------ fastNLP/modules/decoder/__init__.py | 3 +- fastNLP/modules/decoder/utils.py | 6 --- fastNLP/modules/utils.py | 16 ------- test/core/test_utils.py | 42 +++++++++++++++++- test/modules/decoder/test_CRF.py | 6 +-- 13 files changed, 97 insertions(+), 97 deletions(-) diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 97afa364..ed1bd0c9 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -24,5 +24,5 @@ from .optimizer import Optimizer, SGD, Adam from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .tester import Tester from .trainer import Trainer -from .utils import cache_results +from .utils import cache_results, seq_len_to_mask from .vocabulary import Vocabulary diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index c9ba7f35..5cdfdc44 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -13,7 +13,7 @@ from .utils import _CheckRes from .utils import _build_args from .utils import _check_arg_dict_list from .utils import _get_func_signature -from .utils import seq_lens_to_masks +from .utils import seq_len_to_mask from .vocabulary import Vocabulary @@ -305,7 +305,7 @@ class AccuracyMetric(MetricBase): f"got {type(seq_len)}.") if seq_len is not None: - masks = seq_lens_to_masks(seq_lens=seq_len) + masks = seq_len_to_mask(seq_len=seq_len) else: masks = None diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 09a3a4c5..f7539fd7 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -1,7 +1,7 @@ """ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 """ -__all__ = ["cache_results"] +__all__ = ["cache_results", "seq_len_to_mask"] import _pickle import inspect import os @@ -600,48 +600,41 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): warnings.warn(message=_unused_warn) -def seq_lens_to_masks(seq_lens, float=False): +def seq_len_to_mask(seq_len): """ - Convert seq_lens to masks. - :param seq_lens: list, np.ndarray, or torch.LongTensor, shape should all be (B,) - :param float: if True, the return masks is in float type, otherwise it is byte. - :return: list, np.ndarray or torch.Tensor, shape will be (B, max_length) + 将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 + 转变 1-d seq_len到2-d mask. + + Example:: + >>> seq_len = torch.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len) + >>> print(mask.size()) + torch.Size([14, 15]) + >>> seq_len = np.arange(2, 16) + >>> mask = seq_len_to_mask(seq_len) + >>> print(mask.shape) + (14, 15) + + :param np.ndarray,torch.LongTensor seq_len: shape将是(B,) + :return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8 """ - if isinstance(seq_lens, np.ndarray): - assert len(np.shape(seq_lens)) == 1, f"seq_lens can only have one dimension, got {len(np.shape(seq_lens))}." - assert seq_lens.dtype in (int, np.int32, np.int64), f"seq_lens can only be integer, not {seq_lens.dtype}." - raise NotImplemented - elif isinstance(seq_lens, torch.Tensor): - assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." - batch_size = seq_lens.size(0) - max_len = seq_lens.max() - indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long() - masks = indexes.lt(seq_lens.unsqueeze(1)) - - if float: - masks = masks.float() - - return masks - elif isinstance(seq_lens, list): - raise NotImplemented + if isinstance(seq_len, np.ndarray): + assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." + max_len = int(seq_len.max()) + broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) + mask = broad_cast_seq_len [N,L,C_0] pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] @@ -418,7 +417,7 @@ class BiaffineParser(GraphParser): """ batch_size, length, _ = pred1.shape - mask = seq_mask(seq_len, length) + mask = seq_len_to_mask(seq_len) flip_mask = (mask == 0) _arc_pred = pred1.clone() _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) @@ -514,7 +513,7 @@ class ParserMetric(MetricBase): if seq_len is None: seq_mask = pred1.new_ones(pred1.size(), dtype=torch.long) else: - seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() + seq_mask = seq_len_to_mask(seq_len.long()).long() # mask out tag seq_mask[:,0] = 0 head_pred_correct = (pred1 == target1).long() * seq_mask diff --git a/fastNLP/models/sequence_labeling.py b/fastNLP/models/sequence_labeling.py index 015ae24a..98badd56 100644 --- a/fastNLP/models/sequence_labeling.py +++ b/fastNLP/models/sequence_labeling.py @@ -3,7 +3,7 @@ import torch from .base_model import BaseModel from ..modules import decoder, encoder from ..modules.decoder.CRF import allowed_transitions -from ..modules.utils import seq_mask +from ..core.utils import seq_len_to_mask from ..core.const import Const as C from torch import nn @@ -84,7 +84,7 @@ class SeqLabeling(BaseModel): def _make_mask(self, x, seq_len): batch_size, max_len = x.size(0), x.size(1) - mask = seq_mask(seq_len, max_len) + mask = seq_len_to_mask(seq_len) mask = mask.view(batch_size, max_len) mask = mask.to(x).float() return mask @@ -160,7 +160,7 @@ class AdvSeqLabel(nn.Module): def _make_mask(self, x, seq_len): batch_size, max_len = x.size(0), x.size(1) - mask = seq_mask(seq_len, max_len) + mask = seq_len_to_mask(seq_len) mask = mask.view(batch_size, max_len) mask = mask.to(x).float() return mask diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 1224c4b3..ac0a2e47 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -6,7 +6,7 @@ from ..core.const import Const from ..modules import decoder as Decoder from ..modules import encoder as Encoder from ..modules import aggregator as Aggregator -from ..modules.utils import seq_mask +from ..core.utils import seq_len_to_mask my_inf = 10e12 @@ -75,12 +75,12 @@ class ESIM(BaseModel): hypothesis0 = self.embedding_layer(self.embedding(words2)) if seq_len1 is not None: - seq_len1 = seq_mask(seq_len1, premise0.size(1)) + seq_len1 = seq_len_to_mask(seq_len1) else: seq_len1 = torch.ones(premise0.size(0), premise0.size(1)) seq_len1 = (seq_len1.long()).to(device=premise0.device) if seq_len2 is not None: - seq_len2 = seq_mask(seq_len2, hypothesis0.size(1)) + seq_len2 = seq_len_to_mask(seq_len2) else: seq_len2 = torch.ones(hypothesis0.size(0), hypothesis0.size(1)) seq_len2 = (seq_len2.long()).to(device=hypothesis0.device) diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index 93ee72f6..f7b9028e 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -1,7 +1,7 @@ """Star-Transformer 的 一个 Pytorch 实现. """ from ..modules.encoder.star_transformer import StarTransformer -from ..core.utils import seq_lens_to_masks +from ..core.utils import seq_len_to_mask from ..modules.utils import get_embeddings from ..core.const import Const @@ -134,7 +134,7 @@ class STSeqLabel(nn.Module): :param seq_len: [batch,] 输入序列的长度 :return output: [batch, num_cls, seq_len] 输出序列中每个元素的分类的概率 """ - mask = seq_lens_to_masks(seq_len) + mask = seq_len_to_mask(seq_len) nodes, _ = self.enc(words, mask) output = self.cls(nodes) output = output.transpose(1,2) # make hidden to be dim 1 @@ -195,7 +195,7 @@ class STSeqCls(nn.Module): :param seq_len: [batch,] 输入序列的长度 :return output: [batch, num_cls] 输出序列的分类的概率 """ - mask = seq_lens_to_masks(seq_len) + mask = seq_len_to_mask(seq_len) nodes, relay = self.enc(words, mask) y = 0.5 * (relay + nodes.max(1)[0]) output = self.cls(y) # [bsz, n_cls] @@ -258,8 +258,8 @@ class STNLICls(nn.Module): :param seq_len2: [batch,] 输入序列2的长度 :return output: [batch, num_cls] 输出分类的概率 """ - mask1 = seq_lens_to_masks(seq_len1) - mask2 = seq_lens_to_masks(seq_len2) + mask1 = seq_len_to_mask(seq_len1) + mask2 = seq_len_to_mask(seq_len2) def enc(seq, mask): nodes, relay = self.enc(seq, mask) return 0.5 * (relay + nodes.max(1)[0]) diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index 59efbd53..0d8ec25a 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -2,17 +2,6 @@ import torch from torch import nn from ..utils import initial_parameter -from ..decoder.utils import log_sum_exp - - -def seq_len_to_byte_mask(seq_lens): - # usually seq_lens: LongTensor, batch_size - # return value: ByteTensor, batch_size x max_len - batch_size = seq_lens.size(0) - max_len = seq_lens.max() - broadcast_arange = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) - mask = broadcast_arange.float().lt(seq_lens.float().view(-1, 1)) - return mask def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): @@ -197,13 +186,13 @@ class ConditionalRandomField(nn.Module): emit_score = logits[i].view(batch_size, 1, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score - alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ + alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) if self.include_start_end_trans: alpha = alpha + self.end_scores.view(1, -1) - return log_sum_exp(alpha, 1) + return torch.logsumexp(alpha, 1) def _gold_score(self, logits, tags, mask): """ diff --git a/fastNLP/modules/decoder/__init__.py b/fastNLP/modules/decoder/__init__.py index dc05fce2..84763e03 100644 --- a/fastNLP/modules/decoder/__init__.py +++ b/fastNLP/modules/decoder/__init__.py @@ -1,4 +1,5 @@ -__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode"] +__all__ = ["MLP", "ConditionalRandomField", "viterbi_decode", "allowed_transitions"] from .CRF import ConditionalRandomField from .MLP import MLP from .utils import viterbi_decode +from .CRF import allowed_transitions diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index 12e6893b..95b25767 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -2,12 +2,6 @@ __all__ = ["viterbi_decode"] import torch -def log_sum_exp(x, dim=-1): - max_value, _ = x.max(dim=dim, keepdim=True) - res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value - return res.squeeze(dim) - - def viterbi_decode(logits, transitions, mask=None, unpad=False): """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 337be64d..78851587 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -68,22 +68,6 @@ def initial_parameter(net, initial_method=None): net.apply(weights_init) -def seq_mask(seq_len, max_len): - """ - Create sequence mask. - - :param seq_len: list or torch.Tensor, the lengths of sequences in a batch. - :param max_len: int, the maximum sequence length in a batch. - :return: mask, torch.LongTensor, [batch_size, max_len] - - """ - if not isinstance(seq_len, torch.Tensor): - seq_len = torch.LongTensor(seq_len) - seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] - seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] - return torch.gt(seq_len, seq_range) # [batch_size, max_len] - - def get_embeddings(init_embed): """ 得到词嵌入 TODO diff --git a/test/core/test_utils.py b/test/core/test_utils.py index 7f218db0..b846a32d 100644 --- a/test/core/test_utils.py +++ b/test/core/test_utils.py @@ -9,7 +9,8 @@ import os import torch from torch import nn from fastNLP.core.utils import _move_model_to_device, _get_model_device - +import numpy as np +from fastNLP.core.utils import seq_len_to_mask class Model(nn.Module): def __init__(self): @@ -210,3 +211,42 @@ class TestCache(unittest.TestCase): finally: os.remove('test/demo1/demo.pkl') os.rmdir('test/demo1') + + +class TestSeqLenToMask(unittest.TestCase): + + def evaluate_mask_seq_len(self, seq_len, mask): + max_len = int(max(seq_len)) + for i in range(len(seq_len)): + length = seq_len[i] + mask_i = mask[i] + for j in range(max_len): + self.assertEqual(mask_i[j], j0, "CRF loss cannot be less than 0." + self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")