|
|
@@ -3,7 +3,6 @@ from torch import nn |
|
|
|
|
|
|
|
from fastNLP.modules.utils import initial_parameter |
|
|
|
|
|
|
|
|
|
|
|
def log_sum_exp(x, dim=-1): |
|
|
|
max_value, _ = x.max(dim=dim, keepdim=True) |
|
|
|
res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value |
|
|
@@ -21,7 +20,7 @@ def seq_len_to_byte_mask(seq_lens): |
|
|
|
|
|
|
|
|
|
|
|
class ConditionalRandomField(nn.Module): |
|
|
|
def __init__(self, tag_size, include_start_end_trans=False, initial_method=None): |
|
|
|
def __init__(self, tag_size, include_start_end_trans=True ,initial_method = None): |
|
|
|
""" |
|
|
|
:param tag_size: int, num of tags |
|
|
|
:param include_start_end_trans: bool, whether to include start/end tag |
|
|
@@ -39,7 +38,6 @@ class ConditionalRandomField(nn.Module): |
|
|
|
|
|
|
|
# self.reset_parameter() |
|
|
|
initial_parameter(self, initial_method) |
|
|
|
|
|
|
|
def reset_parameter(self): |
|
|
|
nn.init.xavier_normal_(self.trans_m) |
|
|
|
if self.include_start_end_trans: |
|
|
@@ -83,16 +81,17 @@ class ConditionalRandomField(nn.Module): |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
|
|
|
|
|
# trans_socre [L-1, B] |
|
|
|
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]] * mask[1:, :] |
|
|
|
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] |
|
|
|
# emit_score [L, B] |
|
|
|
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags] * mask |
|
|
|
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask |
|
|
|
# score [L-1, B] |
|
|
|
score = trans_score + emit_score[:seq_len - 1, :] |
|
|
|
score = trans_score + emit_score[:seq_len-1, :] |
|
|
|
score = score.sum(0) + emit_score[-1] |
|
|
|
if self.include_start_end_trans: |
|
|
|
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] |
|
|
|
last_idx = mask.long().sum(0) |
|
|
|
last_idx = mask.long().sum(0) - 1 |
|
|
|
ed_scores = self.end_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[last_idx, batch_idx]] |
|
|
|
print(score.size(), st_scores.size(), ed_scores.size()) |
|
|
|
score += st_scores + ed_scores |
|
|
|
# return [B,] |
|
|
|
return score |
|
|
@@ -106,8 +105,8 @@ class ConditionalRandomField(nn.Module): |
|
|
|
:return:FloatTensor, batch_size |
|
|
|
""" |
|
|
|
feats = feats.transpose(0, 1) |
|
|
|
tags = tags.transpose(0, 1) |
|
|
|
mask = mask.transpose(0, 1) |
|
|
|
tags = tags.transpose(0, 1).long() |
|
|
|
mask = mask.transpose(0, 1).float() |
|
|
|
all_path_score = self._normalizer_likelihood(feats, mask) |
|
|
|
gold_path_score = self._glod_score(feats, tags, mask) |
|
|
|
|
|
|
@@ -122,14 +121,14 @@ class ConditionalRandomField(nn.Module): |
|
|
|
:return: scores, paths |
|
|
|
""" |
|
|
|
batch_size, seq_len, n_tags = data.size() |
|
|
|
data = data.transpose(0, 1).data # L, B, H |
|
|
|
mask = mask.transpose(0, 1).data.float() # L, B |
|
|
|
data = data.transpose(0, 1).data # L, B, H |
|
|
|
mask = mask.transpose(0, 1).data.float() # L, B |
|
|
|
|
|
|
|
# dp |
|
|
|
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
|
vscore = data[0] |
|
|
|
if self.include_start_end_trans: |
|
|
|
vscore += self.start_scores.view(1. - 1) |
|
|
|
vscore += self.start_scores.view(1. -1) |
|
|
|
for i in range(1, seq_len): |
|
|
|
prev_score = vscore.view(batch_size, n_tags, 1) |
|
|
|
cur_score = data[i].view(batch_size, 1, n_tags) |
|
|
@@ -147,14 +146,14 @@ class ConditionalRandomField(nn.Module): |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) |
|
|
|
lens = (mask.long().sum(0) - 1) |
|
|
|
# idxes [L, B], batched idx from seq_len-1 to 0 |
|
|
|
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len |
|
|
|
idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len |
|
|
|
|
|
|
|
ans = data.new_empty((seq_len, batch_size), dtype=torch.long) |
|
|
|
ans_score, last_tags = vscore.max(1) |
|
|
|
ans[idxes[0], batch_idx] = last_tags |
|
|
|
for i in range(seq_len - 1): |
|
|
|
last_tags = vpath[idxes[i], batch_idx, last_tags] |
|
|
|
ans[idxes[i + 1], batch_idx] = last_tags |
|
|
|
ans[idxes[i+1], batch_idx] = last_tags |
|
|
|
|
|
|
|
if get_score: |
|
|
|
return ans_score, ans.transpose(0, 1) |
|
|
|