Browse Source

1.增强viterbi解码使得可以加入start, end; 2.修改BertEnocder中注释的typo

tags/v0.5.5
yh_cc 4 years ago
parent
commit
699a0ef74d
3 changed files with 17 additions and 9 deletions
  1. +0
    -1
      fastNLP/core/metrics.py
  2. +1
    -2
      fastNLP/embeddings/bert_embedding.py
  3. +16
    -6
      fastNLP/modules/decoder/utils.py

+ 0
- 1
fastNLP/core/metrics.py View File

@@ -216,7 +216,6 @@ class MetricBase(object):
:return:
"""


if not self._checked:
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")


+ 1
- 2
fastNLP/embeddings/bert_embedding.py View File

@@ -140,8 +140,7 @@ class BertWordPieceEncoder(nn.Module):
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取
[CLS]做预测,一般该值为True。
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下。如果下游任务取[CLS]做预测,一般该值为True。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param bool requires_grad: 是否需要gradient。


+ 16
- 6
fastNLP/modules/decoder/utils.py View File

@@ -11,7 +11,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数

:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。
:param torch.FloatTensor transitions: n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。
:param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index;
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这
@@ -22,20 +23,27 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):

"""
batch_size, seq_len, n_tags = logits.size()
assert n_tags == transitions.size(0) and n_tags == transitions.size(
1), "The shapes of transitions and feats are not " \
"compatible."
if transitions.size(0) == n_tags+2:
include_start_end_trans = True
elif transitions.size(0) == n_tags:
include_start_end_trans = False
else:
raise RuntimeError("The shapes of transitions and feats are not " \
"compatible.")
logits = logits.transpose(0, 1).data # L, B, H
if mask is not None:
mask = mask.transpose(0, 1).data.eq(True) # L, B
else:
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8).eq(1)

trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data

# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
if include_start_end_trans:
vscore += transitions[n_tags, :n_tags]

trans_score = transitions.view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = logits[i].view(batch_size, 1, n_tags)
@@ -45,6 +53,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

if include_start_end_trans:
vscore += transitions[:n_tags, n_tags + 1].view(1, -1)
# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)


Loading…
Cancel
Save