|
@@ -11,7 +11,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): |
|
|
给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 |
|
|
给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 |
|
|
|
|
|
|
|
|
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 |
|
|
:param torch.FloatTensor logits: batch_size x max_len x num_tags,特征矩阵。 |
|
|
:param torch.FloatTensor transitions: n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 |
|
|
|
|
|
|
|
|
:param torch.FloatTensor transitions: n_tags x n_tags,[i, j]位置的值认为是从tag i到tag j的转换; 或者(n_tags+2) x |
|
|
|
|
|
(n_tags+2), 其中n_tag是start的index, n_tags+1是end的index; |
|
|
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 |
|
|
:param torch.ByteTensor mask: batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 |
|
|
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 |
|
|
:param bool unpad: 是否将结果删去padding。False, 返回的是batch_size x max_len的tensor; True,返回的是 |
|
|
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 |
|
|
List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是这 |
|
@@ -22,20 +23,27 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
batch_size, seq_len, n_tags = logits.size() |
|
|
batch_size, seq_len, n_tags = logits.size() |
|
|
assert n_tags == transitions.size(0) and n_tags == transitions.size( |
|
|
|
|
|
1), "The shapes of transitions and feats are not " \ |
|
|
|
|
|
"compatible." |
|
|
|
|
|
|
|
|
if transitions.size(0) == n_tags+2: |
|
|
|
|
|
include_start_end_trans = True |
|
|
|
|
|
elif transitions.size(0) == n_tags: |
|
|
|
|
|
include_start_end_trans = False |
|
|
|
|
|
else: |
|
|
|
|
|
raise RuntimeError("The shapes of transitions and feats are not " \ |
|
|
|
|
|
"compatible.") |
|
|
logits = logits.transpose(0, 1).data # L, B, H |
|
|
logits = logits.transpose(0, 1).data # L, B, H |
|
|
if mask is not None: |
|
|
if mask is not None: |
|
|
mask = mask.transpose(0, 1).data.eq(True) # L, B |
|
|
mask = mask.transpose(0, 1).data.eq(True) # L, B |
|
|
else: |
|
|
else: |
|
|
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8) |
|
|
|
|
|
|
|
|
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8).eq(1) |
|
|
|
|
|
|
|
|
|
|
|
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data |
|
|
|
|
|
|
|
|
# dp |
|
|
# dp |
|
|
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
vscore = logits[0] |
|
|
vscore = logits[0] |
|
|
|
|
|
if include_start_end_trans: |
|
|
|
|
|
vscore += transitions[n_tags, :n_tags] |
|
|
|
|
|
|
|
|
trans_score = transitions.view(1, n_tags, n_tags).data |
|
|
|
|
|
for i in range(1, seq_len): |
|
|
for i in range(1, seq_len): |
|
|
prev_score = vscore.view(batch_size, n_tags, 1) |
|
|
prev_score = vscore.view(batch_size, n_tags, 1) |
|
|
cur_score = logits[i].view(batch_size, 1, n_tags) |
|
|
cur_score = logits[i].view(batch_size, 1, n_tags) |
|
@@ -45,6 +53,8 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False): |
|
|
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ |
|
|
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ |
|
|
vscore.masked_fill(mask[i].view(batch_size, 1), 0) |
|
|
vscore.masked_fill(mask[i].view(batch_size, 1), 0) |
|
|
|
|
|
|
|
|
|
|
|
if include_start_end_trans: |
|
|
|
|
|
vscore += transitions[:n_tags, n_tags + 1].view(1, -1) |
|
|
# backtrace |
|
|
# backtrace |
|
|
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) |
|
|
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) |
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|