|
|
@@ -217,14 +217,14 @@ class ConditionalRandomField(nn.Module): |
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = alpha + self.start_scores.view(1, -1) |
|
|
|
|
|
|
|
flip_mask = mask.eq(0) |
|
|
|
flip_mask = mask.eq(False) |
|
|
|
|
|
|
|
for i in range(1, seq_len): |
|
|
|
emit_score = logits[i].view(batch_size, 1, n_tags) |
|
|
|
trans_score = self.trans_m.view(1, n_tags, n_tags) |
|
|
|
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score |
|
|
|
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ |
|
|
|
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0) |
|
|
|
alpha.masked_fill(mask[i].eq(True).view(batch_size, 1), 0) |
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = alpha + self.end_scores.view(1, -1) |
|
|
@@ -244,8 +244,8 @@ class ConditionalRandomField(nn.Module): |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
|
|
|
|
|
# trans_socre [L-1, B] |
|
|
|
mask = mask.eq(1) |
|
|
|
flip_mask = mask.eq(0) |
|
|
|
mask = mask.eq(True) |
|
|
|
flip_mask = mask.eq(False) |
|
|
|
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) |
|
|
|
# emit_score [L, B] |
|
|
|
emit_score = logits[seq_idx.view(-1, 1), batch_idx.view(1, -1), tags].masked_fill(flip_mask, 0) |
|
|
@@ -292,7 +292,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
""" |
|
|
|
batch_size, seq_len, n_tags = logits.size() |
|
|
|
logits = logits.transpose(0, 1).data # L, B, H |
|
|
|
mask = mask.transpose(0, 1).data.eq(1) # L, B |
|
|
|
mask = mask.transpose(0, 1).data.eq(True) # L, B |
|
|
|
|
|
|
|
# dp |
|
|
|
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
@@ -311,7 +311,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
score = prev_score + trans_score + cur_score |
|
|
|
best_score, best_dst = score.max(1) |
|
|
|
vpath[i] = best_dst |
|
|
|
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ |
|
|
|
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ |
|
|
|
vscore.masked_fill(mask[i].view(batch_size, 1), 0) |
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|