Browse Source

[update] fix padding masking in viterbi_decode() of CRF

tags/v0.6.0
yunfan 4 years ago
parent
commit
8d08626d67
2 changed files with 23 additions and 4 deletions
  1. +4
    -4
      fastNLP/modules/decoder/crf.py
  2. +19
    -0
      test/modules/decoder/test_CRF.py

+ 4
- 4
fastNLP/modules/decoder/crf.py View File

@@ -293,6 +293,7 @@ class ConditionalRandomField(nn.Module):
batch_size, seq_len, n_tags = logits.size() batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.eq(True) # L, B mask = mask.transpose(0, 1).data.eq(True) # L, B
flip_mask = mask.eq(False)


# dp # dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
@@ -307,12 +308,11 @@ class ConditionalRandomField(nn.Module):
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len): for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1) prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = logits[i].view(batch_size, 1, n_tags)
score = prev_score + trans_score + cur_score
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0)
best_score, best_dst = score.max(1) best_score, best_dst = score.max(1)
vpath[i] = best_dst vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
vscore = best_score


if self.include_start_end_trans: if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags + 1].view(1, -1) vscore += transitions[:n_tags, n_tags + 1].view(1, -1)


+ 19
- 0
test/modules/decoder/test_CRF.py View File

@@ -220,3 +220,22 @@ class TestCRF(unittest.TestCase):
if _%1000==0: if _%1000==0:
print(loss) print(loss)
self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.") self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")

def test_masking(self):
# 测试crf的pad masking正常运行
import torch
from fastNLP.modules.decoder.crf import ConditionalRandomField
max_len = 5
n_tags = 5
pad_len = 5

torch.manual_seed(4)
logit = torch.rand(1, max_len+pad_len, n_tags)
# logit[0, -1, :] = 0.0
mask = torch.ones(1, max_len+pad_len)
mask[0,-pad_len] = 0
model = ConditionalRandomField(n_tags)
pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len])
mask_pred, mask_score = model.viterbi_decode(logit, mask)
self.assertEqual(pred[0].tolist(), mask_pred[0,:-pad_len].tolist())


Loading…
Cancel
Save