From 699a0ef74d1f9faa3300c60e40d9ccd10f0416f1 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 26 Dec 2019 13:05:13 +0800 Subject: [PATCH] =?UTF-8?q?1.=E5=A2=9E=E5=BC=BAviterbi=E8=A7=A3=E7=A0=81?= =?UTF-8?q?=E4=BD=BF=E5=BE=97=E5=8F=AF=E4=BB=A5=E5=8A=A0=E5=85=A5start,=20?= =?UTF-8?q?end;=202.=E4=BF=AE=E6=94=B9BertEnocder=E4=B8=AD=E6=B3=A8?= =?UTF-8?q?=E9=87=8A=E7=9A=84typo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 1 - fastNLP/embeddings/bert_embedding.py | 3 +-- fastNLP/modules/decoder/utils.py | 22 ++++++++++++++++------ 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index fa46df24..95a3331f 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -216,7 +216,6 @@ class MetricBase(object): :return: """ - if not self._checked: if not callable(self.evaluate): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 3e2b98be..44824dc0 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -140,8 +140,7 @@ class BertWordPieceEncoder(nn.Module): :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 - :param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取 - [CLS]做预测,一般该值为True。 + :param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。 :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 :param bool requires_grad: 是否需要gradient。 diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py index 9b019abe..997b3453 100644 --- a/fastNLP/modules/decoder/utils.py +++ b/fastNLP/modules/decoder/utils.py @@ -11,7 +11,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 :param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 - :param torch.FloatTensor transitions: n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 + :param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x + (n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; :param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 :param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 @@ -22,20 +23,27 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): """ batch_size, seq_len, n_tags = logits.size() - assert n_tags == transitions.size(0) and n_tags == transitions.size( - 1), "The shapes of transitions and feats are not " \ - "compatible." + if transitions.size(0) == n_tags+2: + include_start_end_trans = True + elif transitions.size(0) == n_tags: + include_start_end_trans = False + else: + raise RuntimeError("The shapes of transitions and feats are not " \ + "compatible.") logits = logits.transpose(0, 1).data # L, B, H if mask is not None: mask = mask.transpose(0, 1).data.eq(True) # L, B else: - mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) + mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8).eq(1) + + trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data # dp vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vscore = logits[0] + if include_start_end_trans: + vscore += transitions[n_tags, :n_tags] - trans_score = transitions.view(1, n_tags, n_tags).data for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) cur_score = logits[i].view(batch_size, 1, n_tags) @@ -45,6 +53,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) + if include_start_end_trans: + vscore += transitions[:n_tags, n_tags + 1].view(1, -1) # backtrace batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)