diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py index c83ef8be..b5ffa35d 100644 --- a/fastNLP/modules/decoder/crf.py +++ b/fastNLP/modules/decoder/crf.py @@ -293,6 +293,7 @@ class ConditionalRandomField(nn.Module): batch_size, seq_len, n_tags = logits.size() logits = logits.transpose(0, 1).data # L, B, H mask = mask.transpose(0, 1).data.eq(True) # L, B + flip_mask = mask.eq(False) # dp vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) @@ -307,12 +308,11 @@ class ConditionalRandomField(nn.Module): trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) - cur_score = logits[i].view(batch_size, 1, n_tags) - score = prev_score + trans_score + cur_score + cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score + score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) best_score, best_dst = score.max(1) vpath[i] = best_dst - vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \ - vscore.masked_fill(mask[i].view(batch_size, 1), 0) + vscore = best_score if self.include_start_end_trans: vscore += transitions[:n_tags, n_tags + 1].view(1, -1) diff --git a/test/modules/decoder/test_CRF.py b/test/modules/decoder/test_CRF.py index 94b4ab7a..85173669 100644 --- a/test/modules/decoder/test_CRF.py +++ b/test/modules/decoder/test_CRF.py @@ -220,3 +220,22 @@ class TestCRF(unittest.TestCase): if _%1000==0: print(loss) self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.") + + def test_masking(self): + # 测试crf的pad masking正常运行 + import torch + from fastNLP.modules.decoder.crf import ConditionalRandomField + max_len = 5 + n_tags = 5 + pad_len = 5 + + torch.manual_seed(4) + logit = torch.rand(1, max_len+pad_len, n_tags) + # logit[0, -1, :] = 0.0 + mask = torch.ones(1, max_len+pad_len) + mask[0,-pad_len] = 0 + model = ConditionalRandomField(n_tags) + pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len]) + mask_pred, mask_score = model.viterbi_decode(logit, mask) + self.assertEqual(pred[0].tolist(), mask_pred[0,:-pad_len].tolist()) +