|
@@ -194,11 +194,14 @@ class ConditionalRandomField(nn.Module): |
|
|
if self.include_start_end_trans: |
|
|
if self.include_start_end_trans: |
|
|
alpha += self.start_scores.view(1, -1) |
|
|
alpha += self.start_scores.view(1, -1) |
|
|
|
|
|
|
|
|
|
|
|
flip_mask = mask.eq(0) |
|
|
|
|
|
|
|
|
for i in range(1, seq_len): |
|
|
for i in range(1, seq_len): |
|
|
emit_score = logits[i].view(batch_size, 1, n_tags) |
|
|
emit_score = logits[i].view(batch_size, 1, n_tags) |
|
|
trans_score = self.trans_m.view(1, n_tags, n_tags) |
|
|
trans_score = self.trans_m.view(1, n_tags, n_tags) |
|
|
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score |
|
|
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score |
|
|
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) |
|
|
|
|
|
|
|
|
alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ |
|
|
|
|
|
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) |
|
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
if self.include_start_end_trans: |
|
|
alpha += self.end_scores.view(1, -1) |
|
|
alpha += self.end_scores.view(1, -1) |
|
@@ -218,12 +221,14 @@ class ConditionalRandomField(nn.Module): |
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
|
|
|
|
|
|
# trans_socre [L-1, B] |
|
|
# trans_socre [L-1, B] |
|
|
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] |
|
|
|
|
|
|
|
|
mask = mask.byte() |
|
|
|
|
|
flip_mask = mask.eq(0) |
|
|
|
|
|
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0) |
|
|
# emit_score [L, B] |
|
|
# emit_score [L, B] |
|
|
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask |
|
|
|
|
|
|
|
|
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0) |
|
|
# score [L-1, B] |
|
|
# score [L-1, B] |
|
|
score = trans_score + emit_score[:seq_len-1, :] |
|
|
score = trans_score + emit_score[:seq_len-1, :] |
|
|
score = score.sum(0) + emit_score[-1] * mask[-1] |
|
|
|
|
|
|
|
|
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) |
|
|
if self.include_start_end_trans: |
|
|
if self.include_start_end_trans: |
|
|
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] |
|
|
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] |
|
|
last_idx = mask.long().sum(0) - 1 |
|
|
last_idx = mask.long().sum(0) - 1 |
|
@@ -265,7 +270,7 @@ class ConditionalRandomField(nn.Module): |
|
|
""" |
|
|
""" |
|
|
batch_size, seq_len, n_tags = data.size() |
|
|
batch_size, seq_len, n_tags = data.size() |
|
|
data = data.transpose(0, 1).data # L, B, H |
|
|
data = data.transpose(0, 1).data # L, B, H |
|
|
mask = mask.transpose(0, 1).data.float() # L, B |
|
|
|
|
|
|
|
|
mask = mask.transpose(0, 1).data.byte() # L, B |
|
|
|
|
|
|
|
|
# dp |
|
|
# dp |
|
|
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
@@ -284,7 +289,7 @@ class ConditionalRandomField(nn.Module): |
|
|
score = prev_score + trans_score + cur_score |
|
|
score = prev_score + trans_score + cur_score |
|
|
best_score, best_dst = score.max(1) |
|
|
best_score, best_dst = score.max(1) |
|
|
vpath[i] = best_dst |
|
|
vpath[i] = best_dst |
|
|
best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ |
|
|
|
|
|
|
|
|
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ |
|
|
vscore.masked_fill(mask[i].view(batch_size, 1), 0) |
|
|
vscore.masked_fill(mask[i].view(batch_size, 1), 0) |
|
|
|
|
|
|
|
|
vscore += transitions[:n_tags, n_tags+1].view(1, -1) |
|
|
vscore += transitions[:n_tags, n_tags+1].view(1, -1) |
|
|