Browse Source

!1 修复一个 CRF 的 bug

Merge pull request !1 from WillQvQ/dev
tags/v0.6.0
WillQvQ Gitee 3 years ago
parent
commit
f58e10d094
5 changed files with 185 additions and 80 deletions
  1. +6
    -3
      .Jenkinsfile
  2. +1
    -1
      fastNLP/io/pipe/cws.py
  3. +35
    -16
      fastNLP/modules/decoder/crf.py
  4. +1
    -0
      test/data_for_tests/modules/decoder/crf.json
  5. +142
    -60
      test/modules/decoder/test_CRF.py

+ 6
- 3
.Jenkinsfile View File

@@ -36,10 +36,13 @@ pipeline {
}
}
post {
always {
sh 'post'
failure {
sh 'post 1'
}
success {
sh 'post 0'
sh 'post github'
}

}

}

+ 1
- 1
fastNLP/io/pipe/cws.py View File

@@ -122,7 +122,7 @@ def _find_and_replace_digit_spans(line):
otherwise unkdgt
"""
new_line = ''
pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])'
pattern = r'\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])'
prev_end = 0
for match in re.finditer(pattern, line):
start, end = match.span()


+ 35
- 16
fastNLP/modules/decoder/crf.py View File

@@ -198,8 +198,18 @@ class ConditionalRandomField(nn.Module):
constrain = torch.zeros(num_tags + 2, num_tags + 2)
else:
constrain = torch.full((num_tags + 2, num_tags + 2), fill_value=-10000.0, dtype=torch.float)
has_start = False
has_end = False
for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0
if from_tag_id==num_tags:
has_start = True
if to_tag_id==num_tags+1:
has_end = True
if not has_start:
constrain[num_tags, :].fill_(0)
if not has_end:
constrain[:, num_tags+1].fill_(0)
self._constrain = nn.Parameter(constrain, requires_grad=False)

initial_parameter(self, initial_method)
@@ -290,14 +300,15 @@ class ConditionalRandomField(nn.Module):
scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。

"""
batch_size, seq_len, n_tags = logits.size()
batch_size, max_len, n_tags = logits.size()
seq_len = mask.long().sum(1)
logits = logits.transpose(0, 1).data # L, B, H
mask = mask.transpose(0, 1).data.eq(True) # L, B
flip_mask = mask.eq(False)

# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
vpath = logits.new_zeros((max_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0] # bsz x n_tags
transitions = self._constrain.data.clone()
transitions[:n_tags, :n_tags] += self.trans_m.data
if self.include_start_end_trans:
@@ -305,36 +316,44 @@ class ConditionalRandomField(nn.Module):
transitions[:n_tags, n_tags + 1] += self.end_scores.data

vscore += transitions[n_tags, :n_tags]

trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
end_trans_score = transitions[:n_tags, n_tags+1].view(1, 1, n_tags).repeat(batch_size, 1, 1) # bsz, 1, n_tags

# 针对长度为1的句子
vscore += transitions[:n_tags, n_tags+1].view(1, n_tags).repeat(batch_size, 1) \
.masked_fill(seq_len.ne(1).view(-1, 1), 0)
for i in range(1, max_len):
prev_score = vscore.view(batch_size, n_tags, 1)
cur_score = logits[i].view(batch_size, 1, n_tags) + trans_score
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0)
score = prev_score + cur_score.masked_fill(flip_mask[i].view(batch_size, 1, 1), 0) # bsz x n_tag x n_tag
# 需要考虑当前位置是该序列的最后一个
score += end_trans_score.masked_fill(seq_len.ne(i+1).view(-1, 1, 1), 0)

best_score, best_dst = score.max(1)
vpath[i] = best_dst
vscore = best_score

if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags + 1].view(1, -1)
# 由于最终是通过last_tags回溯,需要保持每个位置的vscore情况
vscore = best_score.masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)

# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1)
seq_idx = torch.arange(max_len, dtype=torch.long, device=logits.device)
lens = (seq_len - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % max_len

ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans = logits.new_empty((max_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
for i in range(seq_len - 1):
for i in range(max_len - 1):
last_tags = vpath[idxes[i], batch_idx, last_tags]
ans[idxes[i + 1], batch_idx] = last_tags
ans = ans.transpose(0, 1)
if unpad:
paths = []
for idx, seq_len in enumerate(lens):
paths.append(ans[idx, :seq_len + 1].tolist())
for idx, max_len in enumerate(lens):
paths.append(ans[idx, :max_len + 1].tolist())
else:
paths = ans
return paths, ans_score

+ 1
- 0
test/data_for_tests/modules/decoder/crf.json
File diff suppressed because it is too large
View File


+ 142
- 60
test/modules/decoder/test_CRF.py View File

@@ -132,68 +132,150 @@ class TestCRF(unittest.TestCase):
self.assertSetEqual(expected_res, set(
allowed_transitions(vocab, include_start_end=True)))

# def test_case2(self):
# # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。
# pass
# import torch
# from fastNLP import seq_len_to_mask
#
# labels = ['O']
# for label in ['X', 'Y']:
# for tag in 'BI':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
# max_len = 10
# batch_size = 4
# bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log()
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label),
# include_start_end_transitions=False)
# bio_trans_m = allen_CRF.transitions
# bio_seq_lens = torch.randint(1, max_len, size=(batch_size,))
# bio_seq_lens[0] = 1
# bio_seq_lens[-1] = max_len
# mask = seq_len_to_mask(bio_seq_lens)
# allen_res = allen_CRF.viterbi_tags(bio_logits, mask)
#
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# include_start_end=True))
# fast_CRF.trans_m = bio_trans_m
# fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
# bio_scores = [round(score, 4) for _, score in allen_res]
# # score equal
# self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()])
# # seq equal
# bio_path = [_ for _, score in allen_res]
# self.assertListEqual(bio_path, fast_res[0])
#
# labels = []
# for label in ['X', 'Y']:
# for tag in 'BMES':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
#
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label),
# include_start_end_transitions=False)
# bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log()
# bmes_trans_m = allen_CRF.transitions
# bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,))
# bmes_seq_lens[0] = 1
# bmes_seq_lens[-1] = max_len
# mask = seq_len_to_mask(bmes_seq_lens)
# allen_res = allen_CRF.viterbi_tags(bmes_logits, mask)
#
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# encoding_type='BMES',
# include_start_end=True))
# fast_CRF.trans_m = bmes_trans_m
# fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
# # score equal
# bmes_scores = [round(score, 4) for _, score in allen_res]
# self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()])
# # seq equal
# bmes_path = [_ for _, score in allen_res]
# self.assertListEqual(bmes_path, fast_res[0])
#
# data = {
# 'bio_logits': bio_logits.tolist(),
# 'bio_scores': bio_scores,
# 'bio_path': bio_path,
# 'bio_trans_m': bio_trans_m.tolist(),
# 'bio_seq_lens': bio_seq_lens.tolist(),
# 'bmes_logits': bmes_logits.tolist(),
# 'bmes_scores': bmes_scores,
# 'bmes_path': bmes_path,
# 'bmes_trans_m': bmes_trans_m.tolist(),
# 'bmes_seq_lens': bmes_seq_lens.tolist(),
# }
#
# with open('weights.json', 'w') as f:
# import json
# json.dump(data, f)

def test_case2(self):
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。
pass
# import torch
# from fastNLP.modules.decoder.crf import seq_len_to_byte_mask
#
# labels = ['O']
# for label in ['X', 'Y']:
# for tag in 'BI':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
#
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label),
# include_start_end_transitions=False)
# batch_size = 3
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log()
# trans_m = allen_CRF.transitions
# seq_lens = torch.randint(1, 20, size=(batch_size,))
# seq_lens[-1] = 20
# mask = seq_len_to_byte_mask(seq_lens)
# allen_res = allen_CRF.viterbi_tags(logits, mask)
#
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal
# self.assertListEqual([_ for _, score in allen_res], fast_res[0])
#
#
# labels = []
# for label in ['X', 'Y']:
# for tag in 'BMES':
# labels.append('{}-{}'.format(tag, label))
# id2label = {idx: label for idx, label in enumerate(labels)}
# num_tags = len(id2label)
#
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label),
# include_start_end_transitions=False)
# batch_size = 3
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log()
# trans_m = allen_CRF.transitions
# seq_lens = torch.randint(1, 20, size=(batch_size,))
# seq_lens[-1] = 20
# mask = seq_len_to_byte_mask(seq_lens)
# allen_res = allen_CRF.viterbi_tags(logits, mask)
#
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
# encoding_type='BMES'))
# fast_CRF.trans_m = trans_m
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True)
# # score equal
# self.assertListEqual([score for _, score in allen_res], fast_res[1])
# # seq equal
# self.assertListEqual([_ for _, score in allen_res], fast_res[0])
# 测试CRF是否正常work。
import json
import torch
from fastNLP import seq_len_to_mask

with open('test/data_for_tests/modules/decoder/crf.json', 'r') as f:
data = json.load(f)

bio_logits = torch.FloatTensor(data['bio_logits'])
bio_scores = data['bio_scores']
bio_path = data['bio_path']
bio_trans_m = torch.FloatTensor(data['bio_trans_m'])
bio_seq_lens = torch.LongTensor(data['bio_seq_lens'])

bmes_logits = torch.FloatTensor(data['bmes_logits'])
bmes_scores = data['bmes_scores']
bmes_path = data['bmes_path']
bmes_trans_m = torch.FloatTensor(data['bmes_trans_m'])
bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens'])

labels = ['O']
for label in ['X', 'Y']:
for tag in 'BI':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
num_tags = len(id2label)

mask = seq_len_to_mask(bio_seq_lens)

from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
include_start_end=True))
fast_CRF.trans_m.data = bio_trans_m
fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True)
# score equal
self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()])
# seq equal
self.assertListEqual(bio_path, fast_res[0])

labels = []
for label in ['X', 'Y']:
for tag in 'BMES':
labels.append('{}-{}'.format(tag, label))
id2label = {idx: label for idx, label in enumerate(labels)}
num_tags = len(id2label)

mask = seq_len_to_mask(bmes_seq_lens)

from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label,
encoding_type='BMES',
include_start_end=True))
fast_CRF.trans_m.data = bmes_trans_m
fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True)
# score equal
self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()])
# seq equal
self.assertListEqual(bmes_path, fast_res[0])

def test_case3(self):
# 测试crf的loss不会出现负数


Loading…
Cancel
Save