diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index e17b04f3..46350945 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -194,11 +194,14 @@ class ConditionalRandomField(nn.Module): if self.include_start_end_trans: alpha += self.start_scores.view(1, -1) + flip_mask = mask.eq(0) + for i in range(1, seq_len): emit_score = logits[i].view(batch_size, 1, n_tags) trans_score = self.trans_m.view(1, n_tags, n_tags) tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score - alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1) + alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \ + alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0) if self.include_start_end_trans: alpha += self.end_scores.view(1, -1) @@ -218,12 +221,14 @@ class ConditionalRandomField(nn.Module): seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device) # trans_socre [L-1, B] - trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :] + mask = mask.byte() + flip_mask = mask.eq(0) + trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0) # emit_score [L, B] - emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask + emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0) # score [L-1, B] score = trans_score + emit_score[:seq_len-1, :] - score = score.sum(0) + emit_score[-1] * mask[-1] + score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0) if self.include_start_end_trans: st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]] last_idx = mask.long().sum(0) - 1 @@ -265,7 +270,7 @@ class ConditionalRandomField(nn.Module): """ batch_size, seq_len, n_tags = data.size() data = data.transpose(0, 1).data # L, B, H - mask = mask.transpose(0, 1).data.float() # L, B + mask = mask.transpose(0, 1).data.byte() # L, B # dp vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) @@ -284,7 +289,7 @@ class ConditionalRandomField(nn.Module): score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst - best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ + vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) vscore += transitions[:n_tags, n_tags+1].view(1, -1) diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index 0fc331dc..4576d518 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -66,7 +66,7 @@ class TestCRF(unittest.TestCase): # from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) # fast_CRF.trans_m = trans_m - # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) + # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) # # score equal # self.assertListEqual([score for _, score in allen_res], fast_res[1]) # # seq equal @@ -95,7 +95,7 @@ class TestCRF(unittest.TestCase): # fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, # encoding_type='BMES')) # fast_CRF.trans_m = trans_m - # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True) + # fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) # # score equal # self.assertListEqual([score for _, score in allen_res], fast_res[1]) # # seq equal