|
|
@@ -11,7 +11,7 @@ from ..utils import initial_parameter |
|
|
|
|
|
|
|
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions` |
|
|
|
别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions` |
|
|
|
|
|
|
|
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 |
|
|
|
|
|
|
@@ -31,7 +31,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) |
|
|
|
id_label_lst = list(id2target.items()) |
|
|
|
if include_start_end: |
|
|
|
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] |
|
|
|
|
|
|
|
|
|
|
|
def split_tag_label(from_label): |
|
|
|
from_label = from_label.lower() |
|
|
|
if from_label in ['start', 'end']: |
|
|
@@ -41,7 +41,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False) |
|
|
|
from_tag = from_label[:1] |
|
|
|
from_label = from_label[2:] |
|
|
|
return from_tag, from_label |
|
|
|
|
|
|
|
|
|
|
|
for from_id, from_label in id_label_lst: |
|
|
|
if from_label in ['<pad>', '<unk>']: |
|
|
|
continue |
|
|
@@ -93,7 +93,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label |
|
|
|
return to_tag in ['end', 'b', 'o'] |
|
|
|
else: |
|
|
|
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag)) |
|
|
|
|
|
|
|
|
|
|
|
elif encoding_type == 'bmes': |
|
|
|
""" |
|
|
|
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转 |
|
|
@@ -151,7 +151,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label |
|
|
|
|
|
|
|
class ConditionalRandomField(nn.Module): |
|
|
|
""" |
|
|
|
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.crf.ConditionalRandomField` |
|
|
|
别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.ConditionalRandomField` |
|
|
|
|
|
|
|
条件随机场。 |
|
|
|
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 |
|
|
@@ -163,21 +163,21 @@ class ConditionalRandomField(nn.Module): |
|
|
|
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 |
|
|
|
:param str initial_method: 初始化方法。见initial_parameter |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, |
|
|
|
initial_method=None): |
|
|
|
|
|
|
|
|
|
|
|
super(ConditionalRandomField, self).__init__() |
|
|
|
|
|
|
|
|
|
|
|
self.include_start_end_trans = include_start_end_trans |
|
|
|
self.num_tags = num_tags |
|
|
|
|
|
|
|
|
|
|
|
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score |
|
|
|
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags)) |
|
|
|
if self.include_start_end_trans: |
|
|
|
self.start_scores = nn.Parameter(torch.randn(num_tags)) |
|
|
|
self.end_scores = nn.Parameter(torch.randn(num_tags)) |
|
|
|
|
|
|
|
|
|
|
|
if allowed_transitions is None: |
|
|
|
constrain = torch.zeros(num_tags + 2, num_tags + 2) |
|
|
|
else: |
|
|
@@ -185,9 +185,9 @@ class ConditionalRandomField(nn.Module): |
|
|
|
for from_tag_id, to_tag_id in allowed_transitions: |
|
|
|
constrain[from_tag_id, to_tag_id] = 0 |
|
|
|
self._constrain = nn.Parameter(constrain, requires_grad=False) |
|
|
|
|
|
|
|
|
|
|
|
initial_parameter(self, initial_method) |
|
|
|
|
|
|
|
|
|
|
|
def _normalizer_likelihood(self, logits, mask): |
|
|
|
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the |
|
|
|
sum of the likelihoods across all possible state sequences. |
|
|
@@ -200,21 +200,21 @@ class ConditionalRandomField(nn.Module): |
|
|
|
alpha = logits[0] |
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = alpha + self.start_scores.view(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
flip_mask = mask.eq(0) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
emit_score = logits[i].view(batch_size, 1, n_tags) |
|
|
|
trans_score = self.trans_m.view(1, n_tags, n_tags) |
|
|
|
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score |
|
|
|
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ |
|
|
|
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) |
|
|
|
|
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = alpha + self.end_scores.view(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
return torch.logsumexp(alpha, 1) |
|
|
|
|
|
|
|
|
|
|
|
def _gold_score(self, logits, tags, mask): |
|
|
|
""" |
|
|
|
Compute the score for the gold path. |
|
|
@@ -226,7 +226,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
seq_len, batch_size, _ = logits.size() |
|
|
|
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
|
|
|
|
|
|
|
|
|
# trans_socre [L-1, B] |
|
|
|
mask = mask.byte() |
|
|
|
flip_mask = mask.eq(0) |
|
|
@@ -243,7 +243,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
score = score + st_scores + ed_scores |
|
|
|
# return [B,] |
|
|
|
return score |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, feats, tags, mask): |
|
|
|
""" |
|
|
|
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 |
|
|
@@ -258,9 +258,9 @@ class ConditionalRandomField(nn.Module): |
|
|
|
mask = mask.transpose(0, 1).float() |
|
|
|
all_path_score = self._normalizer_likelihood(feats, mask) |
|
|
|
gold_path_score = self._gold_score(feats, tags, mask) |
|
|
|
|
|
|
|
|
|
|
|
return all_path_score - gold_path_score |
|
|
|
|
|
|
|
|
|
|
|
def viterbi_decode(self, logits, mask, unpad=False): |
|
|
|
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 |
|
|
|
|
|
|
@@ -277,7 +277,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
batch_size, seq_len, n_tags = logits.size() |
|
|
|
logits = logits.transpose(0, 1).data # L, B, H |
|
|
|
mask = mask.transpose(0, 1).data.byte() # L, B |
|
|
|
|
|
|
|
|
|
|
|
# dp |
|
|
|
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
|
vscore = logits[0] |
|
|
@@ -286,7 +286,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
if self.include_start_end_trans: |
|
|
|
transitions[n_tags, :n_tags] += self.start_scores.data |
|
|
|
transitions[:n_tags, n_tags + 1] += self.end_scores.data |
|
|
|
|
|
|
|
|
|
|
|
vscore += transitions[n_tags, :n_tags] |
|
|
|
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data |
|
|
|
for i in range(1, seq_len): |
|
|
@@ -297,17 +297,17 @@ class ConditionalRandomField(nn.Module): |
|
|
|
vpath[i] = best_dst |
|
|
|
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ |
|
|
|
vscore.masked_fill(mask[i].view(batch_size, 1), 0) |
|
|
|
|
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|
vscore += transitions[:n_tags, n_tags + 1].view(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
# backtrace |
|
|
|
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
|
lens = (mask.long().sum(0) - 1) |
|
|
|
# idxes [L, B], batched idx from seq_len-1 to 0 |
|
|
|
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len |
|
|
|
|
|
|
|
|
|
|
|
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long) |
|
|
|
ans_score, last_tags = vscore.max(1) |
|
|
|
ans[idxes[0], batch_idx] = last_tags |
|
|
|