Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.6.0
yh_cc 4 years ago
parent
commit
fe7ce8e448
2 changed files with 23 additions and 4 deletions
  1. +4
    -4
      fastNLP/modules/decoder/crf.py
  2. +19
    -0
      test/modules/decoder/test_CRF.py

+ 4
- 4
fastNLP/modules/decoder/crf.py View File

@@ -293,6 +293,7 @@ class ConditionalRandomField(nn.Module):
batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.eq(True) # L, B
flip_mask = mask.eq(False)

# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
@@ -307,12 +308,11 @@ class ConditionalRandomField(nn.Module):
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = logits[i].view(batch_size, 1, n_tags)
score = prev_score + trans_score + cur_score
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0)
best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(False).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
vscore = best_score

if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags + 1].view(1, -1)


+ 19
- 0
test/modules/decoder/test_CRF.py View File

@@ -220,3 +220,22 @@ class TestCRF(unittest.TestCase):
if _%1000==0:
print(loss)
self.assertGreater(loss.item(), 0, "CRF loss cannot be less than 0.")

def test_masking(self):
# 测试crf的pad masking正常运行
import torch
from fastNLP.modules.decoder.crf import ConditionalRandomField
max_len = 5
n_tags = 5
pad_len = 5

torch.manual_seed(4)
logit = torch.rand(1, max_len+pad_len, n_tags)
# logit[0, -1, :] = 0.0
mask = torch.ones(1, max_len+pad_len)
mask[0,-pad_len] = 0
model = ConditionalRandomField(n_tags)
pred, score = model.viterbi_decode(logit[:,:-pad_len], mask[:,:-pad_len])
mask_pred, mask_score = model.viterbi_decode(logit, mask)
self.assertEqual(pred[0].tolist(), mask_pred[0,:-pad_len].tolist())


Loading…
Cancel
Save