Browse Source

修复CRF中可能存在的bug

tags/v0.4.10
yh 5 years ago
parent
commit
f2d7d01bb7
2 changed files with 13 additions and 8 deletions
  1. +11
    -6
      fastNLP/modules/decoder/CRF.py
  2. +2
    -2
      test/modules/decoder/test_CRF.py

+ 11
- 6
fastNLP/modules/decoder/CRF.py View File

@@ -194,11 +194,14 @@ class ConditionalRandomField(nn.Module):
if self.include_start_end_trans:
alpha += self.start_scores.view(1, -1)

flip_mask = mask.eq(0)

for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = log_sum_exp(tmp, 1) * mask[i].view(batch_size, 1) + alpha * (1 - mask[i]).view(batch_size, 1)
alpha = log_sum_exp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)

if self.include_start_end_trans:
alpha += self.end_scores.view(1, -1)
@@ -218,12 +221,14 @@ class ConditionalRandomField(nn.Module):
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)

# trans_socre [L-1, B]
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]] * mask[1:, :]
mask = mask.byte()
flip_mask = mask.eq(0)
trans_score = self.trans_m[tags[:seq_len-1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B]
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags] * mask
emit_score = logits[seq_idx.view(-1,1), batch_idx.view(1,-1), tags].masked_fill(flip_mask, 0)
# score [L-1, B]
score = trans_score + emit_score[:seq_len-1, :]
score = score.sum(0) + emit_score[-1] * mask[-1]
score = score.sum(0) + emit_score[-1].masked_fill(flip_mask[-1], 0)
if self.include_start_end_trans:
st_scores = self.start_scores.view(1, -1).repeat(batch_size, 1)[batch_idx, tags[0]]
last_idx = mask.long().sum(0) - 1
@@ -265,7 +270,7 @@ class ConditionalRandomField(nn.Module):
"""
batch_size, seq_len, n_tags = data.size()
data = data.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.float() # L, B
mask = mask.transpose(0, 1).data.byte() # L, B

# dp
vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
@@ -284,7 +289,7 @@ class ConditionalRandomField(nn.Module):
score = prev_score + trans_score + cur_score
best_score, best_dst = score.max(1)
vpath[i] = best_dst
best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

vscore += transitions[:n_tags, n_tags+1].view(1, -1)


+ 2
- 2
test/modules/decoder/test_CRF.py View File

@@ -66,7 +66,7 @@ class TestCRF(unittest.TestCase):
# from fastNLP.modules.decoder.CRF import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal
@@ -95,7 +95,7 @@ class TestCRF(unittest.TestCase):
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# encoding_type='BMES'))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True)
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal


Loading…
Cancel
Save