|
|
@@ -210,7 +210,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
trans_score = self.trans_m.view(1, n_tags, n_tags) |
|
|
|
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score |
|
|
|
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ |
|
|
|
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) |
|
|
|
alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0) |
|
|
|
|
|
|
|
if self.include_start_end_trans: |
|
|
|
alpha = alpha + self.end_scores.view(1, -1) |
|
|
@@ -230,7 +230,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) |
|
|
|
|
|
|
|
# trans_socre [L-1, B] |
|
|
|
mask = mask.byte() |
|
|
|
mask = mask.eq(1) |
|
|
|
flip_mask = mask.eq(0) |
|
|
|
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0) |
|
|
|
# emit_score [L, B] |
|
|
@@ -278,7 +278,7 @@ class ConditionalRandomField(nn.Module): |
|
|
|
""" |
|
|
|
batch_size, seq_len, n_tags = logits.size() |
|
|
|
logits = logits.transpose(0, 1).data # L, B, H |
|
|
|
mask = mask.transpose(0, 1).data.byte() # L, B |
|
|
|
mask = mask.transpose(0, 1).data.eq(1) # L, B |
|
|
|
|
|
|
|
# dp |
|
|
|
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) |
|
|
|