|
@@ -293,6 +293,7 @@ class ConditionalRandomField(nn.Module): |
|
|
batch_size, seq_len, n_tags = logits.size() |
|
|
batch_size, seq_len, n_tags = logits.size() |
|
|
logits = logits.transpose(0, 1).data # L, B, H |
|
|
logits = logits.transpose(0, 1).data # L, B, H |
|
|
mask = mask.transpose(0, 1).data.eq(True) # L, B |
|
|
mask = mask.transpose(0, 1).data.eq(True) # L, B |
|
|
|
|
|
flip_mask = mask.eq(False) |
|
|
|
|
|
|
|
|
# dp |
|
|
# dp |
|
|
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
@@ -307,12 +308,11 @@ class ConditionalRandomField(nn.Module): |
|
|
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data |
|
|
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data |
|
|
for i in range(1, seq_len): |
|
|
for i in range(1, seq_len): |
|
|
prev_score = vscore.view(batch_size, n_tags, 1) |
|
|
prev_score = vscore.view(batch_size, n_tags, 1) |
|
|
cur_score = logits[i].view(batch_size, 1, n_tags) |
|
|
|
|
|
score = prev_score + trans_score + cur_score |
|
|
|
|
|
|
|
|
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score |
|
|
|
|
|
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) |
|
|
best_score, best_dst = score.max(1) |
|
|
best_score, best_dst = score.max(1) |
|
|
vpath[i] = best_dst |
|
|
vpath[i] = best_dst |
|
|
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ |
|
|
|
|
|
vscore.masked_fill(mask[i].view(batch_size, 1), 0) |
|
|
|
|
|
|
|
|
vscore = best_score |
|
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
if self.include_start_end_trans: |
|
|
vscore += transitions[:n_tags, n_tags + 1].view(1, -1) |
|
|
vscore += transitions[:n_tags, n_tags + 1].view(1, -1) |
|
|