|
|
@@ -31,7 +31,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
self.tag_size = tag_size |
|
|
|
|
|
|
|
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score |
|
|
|
self.transition_m = nn.Parameter(torch.randn(tag_size, tag_size)) |
|
|
|
self.trans_m = nn.Parameter(torch.randn(tag_size, tag_size)) |
|
|
|
if self.include_start_end_trans: |
|
|
|
self.start_scores = nn.Parameter(torch.randn(tag_size)) |
|
|
|
self.end_scores = nn.Parameter(torch.randn(tag_size)) |
|
|
@@ -39,137 +39,121 @@ class ConditionalRandomField(nn.Module): |
|
|
|
# self.reset_parameter() |
|
|
|
initial_parameter(self, initial_method) |
|
|
|
def reset_parameter(self): |
|
|
|
nn.init.xavier_normal_(self.transition_m) |
|
|
|
nn.init.xavier_normal_(self.trans_m) |
|
|
|
if self.include_start_end_trans: |
|
|
|
nn.init.normal_(self.start_scores) |
|
|
|
nn.init.normal_(self.end_scores) |
|
|
|
|
|
|
|
def _normalizer_likelihood(self, feats, masks): |
|
|
|
def _normalizer_likelihood(self, logits, mask): |
|
|
|
""" |
|
|
|
Computes the (batch_size,) denominator term for the log-likelihood, which is the |
|
|
|
sum of the likelihoods across all possible state sequences. |
|
|
|
:param feats:FloatTensor, batch_size x max_len x tag_size |
|
|
|
:param masks:ByteTensor, batch_size x max_len |
|
|
|
:param logits:FloatTensor, max_len x batch_size x tag_size |
|
|
|
:param mask:ByteTensor, max_len x batch_size |
|
|
|
:return:FloatTensor, batch_size |
|
|
|
""" |
|
|
|
batch_size, max_len, _ = feats.size() |
|
|
|
|
|
|
|
# alpha, batch_size x tag_size |
|
|
|
seq_len, batch_size, n_tags = logits.size() |
|
|
|
alpha = logits[0] |
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = self.start_scores.view(1, -1) + feats[:, 0] |
|
|
|
else: |
|
|
|
alpha = feats[:, 0] |
|
|
|
|
|
|
|
# broadcast_trans_m, the meaning of entry in this matrix is [batch_idx, to_tag_id, from_tag_id] |
|
|
|
broadcast_trans_m = self.transition_m.permute( |
|
|
|
1, 0).unsqueeze(0).repeat(batch_size, 1, 1) |
|
|
|
# loop |
|
|
|
for i in range(1, max_len): |
|
|
|
emit_score = feats[:, i].unsqueeze(2) |
|
|
|
new_alpha = broadcast_trans_m + alpha.unsqueeze(1) + emit_score |
|
|
|
|
|
|
|
new_alpha = log_sum_exp(new_alpha, dim=2) |
|
|
|
alpha += self.start_scores.view(1, -1) |
|
|
|
|
|
|
|
alpha = new_alpha * \ |
|
|
|
masks[:, i:i + 1].float() + alpha * \ |
|
|
|
(1 - masks[:, i:i + 1].float()) |
|
|
|
for i in range(1, seq_len): |
|
|
|
emit_score = logits[i].view(batch_size, 1, n_tags) |
|
|
|
trans_score = self.trans_m.view(1, n_tags, n_tags) |
|
|
|
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score |
|
|
|
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) |
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = alpha + self.end_scores.view(1, -1) |
|
|
|
alpha += self.end_scores.view(1, -1) |
|
|
|
|
|
|
|
return log_sum_exp(alpha) |
|
|
|
return log_sum_exp(alpha, 1) |
|
|
|
|
|
|
|
def _glod_score(self, feats, tags, masks): |
|
|
|
def _glod_score(self, logits, tags, mask): |
|
|
|
""" |
|
|
|
Compute the score for the gold path. |
|
|
|
:param feats: FloatTensor, batch_size x max_len x tag_size |
|
|
|
:param tags: LongTensor, batch_size x max_len |
|
|
|
:param masks: ByteTensor, batch_size x max_len |
|
|
|
:param logits: FloatTensor, max_len x batch_size x tag_size |
|
|
|
:param tags: LongTensor, max_len x batch_size |
|
|
|
:param mask: ByteTensor, max_len x batch_size |
|
|
|
:return:FloatTensor, batch_size |
|
|
|
""" |
|
|
|
batch_size, max_len, _ = feats.size() |
|
|
|
|
|
|
|
# alpha, B x 1 |
|
|
|
seq_len, batch_size, _ = logits.size() |
|
|
|
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device) |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
|
|
|
|
|
# trans_socre [L-1, B] |
|
|
|
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] |
|
|
|
# emit_score [L, B] |
|
|
|
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask |
|
|
|
# score [L-1, B] |
|
|
|
score = trans_score + emit_score[:seq_len-1, :] |
|
|
|
score = score.sum(0) + emit_score[-1] |
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = self.start_scores.view(1, -1).repeat(batch_size, 1).gather(dim=1, index=tags[:, :1]) + \ |
|
|
|
feats[:, 0].gather(dim=1, index=tags[:, :1]) |
|
|
|
else: |
|
|
|
alpha = feats[:, 0].gather(dim=1, index=tags[:, :1]) |
|
|
|
|
|
|
|
for i in range(1, max_len): |
|
|
|
trans_score = self.transition_m[( |
|
|
|
tags[:, i - 1], tags[:, i])].unsqueeze(1) |
|
|
|
emit_score = feats[:, i].gather(dim=1, index=tags[:, i:i + 1]) |
|
|
|
new_alpha = alpha + trans_score + emit_score |
|
|
|
|
|
|
|
alpha = new_alpha * \ |
|
|
|
masks[:, i:i + 1].float() + alpha * \ |
|
|
|
(1 - masks[:, i:i + 1].float()) |
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|
last_tag_index = masks.cumsum(dim=1, dtype=torch.long)[:, -1:] - 1 |
|
|
|
last_from_tag_id = tags.gather(dim=1, index=last_tag_index) |
|
|
|
trans_score = self.end_scores.view( |
|
|
|
1, -1).repeat(batch_size, 1).gather(dim=1, index=last_from_tag_id) |
|
|
|
alpha = alpha + trans_score |
|
|
|
|
|
|
|
return alpha.squeeze(1) |
|
|
|
|
|
|
|
def forward(self, feats, tags, masks): |
|
|
|
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] |
|
|
|
last_idx = masks.long().sum(0) |
|
|
|
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] |
|
|
|
score += st_scores + ed_scores |
|
|
|
# return [B,] |
|
|
|
return score |
|
|
|
|
|
|
|
def forward(self, feats, tags, mask): |
|
|
|
""" |
|
|
|
Calculate the neg log likelihood |
|
|
|
:param feats:FloatTensor, batch_size x max_len x tag_size |
|
|
|
:param tags:LongTensor, batch_size x max_len |
|
|
|
:param masks:ByteTensor batch_size x max_len |
|
|
|
:param mask:ByteTensor batch_size x max_len |
|
|
|
:return:FloatTensor, batch_size |
|
|
|
""" |
|
|
|
all_path_score = self._normalizer_likelihood(feats, masks) |
|
|
|
gold_path_score = self._glod_score(feats, tags, masks) |
|
|
|
feats = feats.transpose(0, 1) |
|
|
|
tags = tags.transpose(0, 1) |
|
|
|
mask = mask.transpose(0, 1) |
|
|
|
all_path_score = self._normalizer_likelihood(feats, mask) |
|
|
|
gold_path_score = self._glod_score(feats, tags, mask) |
|
|
|
|
|
|
|
return all_path_score - gold_path_score |
|
|
|
|
|
|
|
def viterbi_decode(self, feats, masks, get_score=False): |
|
|
|
def viterbi_decode(self, data, mask, get_score=False): |
|
|
|
""" |
|
|
|
Given a feats matrix, return best decode path and best score. |
|
|
|
:param feats: |
|
|
|
:param masks: |
|
|
|
:param data:FloatTensor, batch_size x max_len x tag_size |
|
|
|
:param mask:ByteTensor batch_size x max_len |
|
|
|
:param get_score: bool, whether to output the decode score. |
|
|
|
:return:List[Tuple(List, float)], |
|
|
|
:return: scores, paths |
|
|
|
""" |
|
|
|
batch_size, max_len, tag_size = feats.size() |
|
|
|
batch_size, seq_len, n_tags = data.size() |
|
|
|
data = data.transpose(0, 1).data # L, B, H |
|
|
|
mask = mask.transpose(0, 1).data.float() # L, B |
|
|
|
|
|
|
|
paths = torch.zeros(batch_size, max_len - 1, self.tag_size) |
|
|
|
# dp |
|
|
|
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
|
vscore = data[0] |
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = self.start_scores.repeat(batch_size, 1) + feats[:, 0] |
|
|
|
else: |
|
|
|
alpha = feats[:, 0] |
|
|
|
for i in range(1, max_len): |
|
|
|
new_alpha = alpha.clone() |
|
|
|
for t in range(self.tag_size): |
|
|
|
pre_scores = self.transition_m[:, t].view( |
|
|
|
1, self.tag_size) + alpha |
|
|
|
max_score, indices = pre_scores.max(dim=1) |
|
|
|
new_alpha[:, t] = max_score + feats[:, i, t] |
|
|
|
paths[:, i - 1, t] = indices |
|
|
|
alpha = new_alpha * masks[:, i:i + 1].float() + alpha * (1 - masks[:, i:i + 1].float()) |
|
|
|
vscore += self.start_scores.view(1. -1) |
|
|
|
for i in range(1, seq_len): |
|
|
|
prev_score = vscore.view(batch_size, n_tags, 1) |
|
|
|
cur_score = data[i].view(batch_size, 1, n_tags) |
|
|
|
trans_score = self.trans_m.view(1, n_tags, n_tags).data |
|
|
|
score = prev_score + trans_score + cur_score |
|
|
|
best_score, best_dst = score.max(1) |
|
|
|
vpath[i] = best_dst |
|
|
|
vscore = best_score * mask[i].view(batch_size, 1) + vscore * (1 - mask[i]).view(batch_size, 1) |
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha += self.end_scores.view(1, -1) |
|
|
|
|
|
|
|
max_scores, indices = alpha.max(dim=1) |
|
|
|
indices = indices.cpu().numpy() |
|
|
|
final_paths = [] |
|
|
|
paths = paths.cpu().numpy().astype(int) |
|
|
|
|
|
|
|
seq_lens = masks.cumsum(dim=1, dtype=torch.long)[:, -1] |
|
|
|
vscore += self.end_scores.view(1, -1) |
|
|
|
|
|
|
|
# backtrace |
|
|
|
batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) |
|
|
|
lens = (mask.long().sum(0) - 1) |
|
|
|
# idxes [L, B], batched idx from seq_len-1 to 0 |
|
|
|
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len |
|
|
|
|
|
|
|
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) |
|
|
|
ans_score, last_tags = vscore.max(1) |
|
|
|
ans[idxes[0], batch_idx] = last_tags |
|
|
|
for i in range(seq_len - 1): |
|
|
|
last_tags = vpath[idxes[i], batch_idx, last_tags] |
|
|
|
ans[idxes[i+1], batch_idx] = last_tags |
|
|
|
|
|
|
|
for b in range(batch_size): |
|
|
|
path = [indices[b]] |
|
|
|
for i in range(seq_lens[b] - 2, -1, -1): |
|
|
|
index = paths[b, i, path[-1]] |
|
|
|
path.append(index) |
|
|
|
final_paths.append(path[::-1]) |
|
|
|
if get_score: |
|
|
|
return list(zip(final_paths, max_scores.detach().cpu().numpy())) |
|
|
|
else: |
|
|
|
return final_paths |
|
|
|
return ans_score, ans.transpose(0, 1) |
|
|
|
return ans.transpose(0, 1) |