|
@@ -132,68 +132,150 @@ class TestCRF(unittest.TestCase): |
|
|
self.assertSetEqual(expected_res, set( |
|
|
self.assertSetEqual(expected_res, set( |
|
|
allowed_transitions(vocab, include_start_end=True))) |
|
|
allowed_transitions(vocab, include_start_end=True))) |
|
|
|
|
|
|
|
|
|
|
|
# def test_case2(self): |
|
|
|
|
|
# # 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 |
|
|
|
|
|
# pass |
|
|
|
|
|
# import torch |
|
|
|
|
|
# from fastNLP import seq_len_to_mask |
|
|
|
|
|
# |
|
|
|
|
|
# labels = ['O'] |
|
|
|
|
|
# for label in ['X', 'Y']: |
|
|
|
|
|
# for tag in 'BI': |
|
|
|
|
|
# labels.append('{}-{}'.format(tag, label)) |
|
|
|
|
|
# id2label = {idx: label for idx, label in enumerate(labels)} |
|
|
|
|
|
# num_tags = len(id2label) |
|
|
|
|
|
# max_len = 10 |
|
|
|
|
|
# batch_size = 4 |
|
|
|
|
|
# bio_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() |
|
|
|
|
|
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), |
|
|
|
|
|
# include_start_end_transitions=False) |
|
|
|
|
|
# bio_trans_m = allen_CRF.transitions |
|
|
|
|
|
# bio_seq_lens = torch.randint(1, max_len, size=(batch_size,)) |
|
|
|
|
|
# bio_seq_lens[0] = 1 |
|
|
|
|
|
# bio_seq_lens[-1] = max_len |
|
|
|
|
|
# mask = seq_len_to_mask(bio_seq_lens) |
|
|
|
|
|
# allen_res = allen_CRF.viterbi_tags(bio_logits, mask) |
|
|
|
|
|
# |
|
|
|
|
|
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, |
|
|
|
|
|
# include_start_end=True)) |
|
|
|
|
|
# fast_CRF.trans_m = bio_trans_m |
|
|
|
|
|
# fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) |
|
|
|
|
|
# bio_scores = [round(score, 4) for _, score in allen_res] |
|
|
|
|
|
# # score equal |
|
|
|
|
|
# self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) |
|
|
|
|
|
# # seq equal |
|
|
|
|
|
# bio_path = [_ for _, score in allen_res] |
|
|
|
|
|
# self.assertListEqual(bio_path, fast_res[0]) |
|
|
|
|
|
# |
|
|
|
|
|
# labels = [] |
|
|
|
|
|
# for label in ['X', 'Y']: |
|
|
|
|
|
# for tag in 'BMES': |
|
|
|
|
|
# labels.append('{}-{}'.format(tag, label)) |
|
|
|
|
|
# id2label = {idx: label for idx, label in enumerate(labels)} |
|
|
|
|
|
# num_tags = len(id2label) |
|
|
|
|
|
# |
|
|
|
|
|
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), |
|
|
|
|
|
# include_start_end_transitions=False) |
|
|
|
|
|
# bmes_logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, max_len, num_tags)), dim=-1).log() |
|
|
|
|
|
# bmes_trans_m = allen_CRF.transitions |
|
|
|
|
|
# bmes_seq_lens = torch.randint(1, max_len, size=(batch_size,)) |
|
|
|
|
|
# bmes_seq_lens[0] = 1 |
|
|
|
|
|
# bmes_seq_lens[-1] = max_len |
|
|
|
|
|
# mask = seq_len_to_mask(bmes_seq_lens) |
|
|
|
|
|
# allen_res = allen_CRF.viterbi_tags(bmes_logits, mask) |
|
|
|
|
|
# |
|
|
|
|
|
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, |
|
|
|
|
|
# encoding_type='BMES', |
|
|
|
|
|
# include_start_end=True)) |
|
|
|
|
|
# fast_CRF.trans_m = bmes_trans_m |
|
|
|
|
|
# fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) |
|
|
|
|
|
# # score equal |
|
|
|
|
|
# bmes_scores = [round(score, 4) for _, score in allen_res] |
|
|
|
|
|
# self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) |
|
|
|
|
|
# # seq equal |
|
|
|
|
|
# bmes_path = [_ for _, score in allen_res] |
|
|
|
|
|
# self.assertListEqual(bmes_path, fast_res[0]) |
|
|
|
|
|
# |
|
|
|
|
|
# data = { |
|
|
|
|
|
# 'bio_logits': bio_logits.tolist(), |
|
|
|
|
|
# 'bio_scores': bio_scores, |
|
|
|
|
|
# 'bio_path': bio_path, |
|
|
|
|
|
# 'bio_trans_m': bio_trans_m.tolist(), |
|
|
|
|
|
# 'bio_seq_lens': bio_seq_lens.tolist(), |
|
|
|
|
|
# 'bmes_logits': bmes_logits.tolist(), |
|
|
|
|
|
# 'bmes_scores': bmes_scores, |
|
|
|
|
|
# 'bmes_path': bmes_path, |
|
|
|
|
|
# 'bmes_trans_m': bmes_trans_m.tolist(), |
|
|
|
|
|
# 'bmes_seq_lens': bmes_seq_lens.tolist(), |
|
|
|
|
|
# } |
|
|
|
|
|
# |
|
|
|
|
|
# with open('weights.json', 'w') as f: |
|
|
|
|
|
# import json |
|
|
|
|
|
# json.dump(data, f) |
|
|
|
|
|
|
|
|
def test_case2(self): |
|
|
def test_case2(self): |
|
|
# 测试CRF能否避免解码出非法跃迁, 使用allennlp做了验证。 |
|
|
|
|
|
pass |
|
|
|
|
|
# import torch |
|
|
|
|
|
# from fastNLP.modules.decoder.crf import seq_len_to_byte_mask |
|
|
|
|
|
# |
|
|
|
|
|
# labels = ['O'] |
|
|
|
|
|
# for label in ['X', 'Y']: |
|
|
|
|
|
# for tag in 'BI': |
|
|
|
|
|
# labels.append('{}-{}'.format(tag, label)) |
|
|
|
|
|
# id2label = {idx: label for idx, label in enumerate(labels)} |
|
|
|
|
|
# num_tags = len(id2label) |
|
|
|
|
|
# |
|
|
|
|
|
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BIO', id2label), |
|
|
|
|
|
# include_start_end_transitions=False) |
|
|
|
|
|
# batch_size = 3 |
|
|
|
|
|
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() |
|
|
|
|
|
# trans_m = allen_CRF.transitions |
|
|
|
|
|
# seq_lens = torch.randint(1, 20, size=(batch_size,)) |
|
|
|
|
|
# seq_lens[-1] = 20 |
|
|
|
|
|
# mask = seq_len_to_byte_mask(seq_lens) |
|
|
|
|
|
# allen_res = allen_CRF.viterbi_tags(logits, mask) |
|
|
|
|
|
# |
|
|
|
|
|
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label)) |
|
|
|
|
|
# fast_CRF.trans_m = trans_m |
|
|
|
|
|
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) |
|
|
|
|
|
# # score equal |
|
|
|
|
|
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) |
|
|
|
|
|
# # seq equal |
|
|
|
|
|
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) |
|
|
|
|
|
# |
|
|
|
|
|
# |
|
|
|
|
|
# labels = [] |
|
|
|
|
|
# for label in ['X', 'Y']: |
|
|
|
|
|
# for tag in 'BMES': |
|
|
|
|
|
# labels.append('{}-{}'.format(tag, label)) |
|
|
|
|
|
# id2label = {idx: label for idx, label in enumerate(labels)} |
|
|
|
|
|
# num_tags = len(id2label) |
|
|
|
|
|
# |
|
|
|
|
|
# from allennlp.modules.conditional_random_field import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# allen_CRF = ConditionalRandomField(num_tags=num_tags, constraints=allowed_transitions('BMES', id2label), |
|
|
|
|
|
# include_start_end_transitions=False) |
|
|
|
|
|
# batch_size = 3 |
|
|
|
|
|
# logits = torch.nn.functional.softmax(torch.rand(size=(batch_size, 20, num_tags))).log() |
|
|
|
|
|
# trans_m = allen_CRF.transitions |
|
|
|
|
|
# seq_lens = torch.randint(1, 20, size=(batch_size,)) |
|
|
|
|
|
# seq_lens[-1] = 20 |
|
|
|
|
|
# mask = seq_len_to_byte_mask(seq_lens) |
|
|
|
|
|
# allen_res = allen_CRF.viterbi_tags(logits, mask) |
|
|
|
|
|
# |
|
|
|
|
|
# from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
# fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, |
|
|
|
|
|
# encoding_type='BMES')) |
|
|
|
|
|
# fast_CRF.trans_m = trans_m |
|
|
|
|
|
# fast_res = fast_CRF.viterbi_decode(logits, mask, get_score=True, unpad=True) |
|
|
|
|
|
# # score equal |
|
|
|
|
|
# self.assertListEqual([score for _, score in allen_res], fast_res[1]) |
|
|
|
|
|
# # seq equal |
|
|
|
|
|
# self.assertListEqual([_ for _, score in allen_res], fast_res[0]) |
|
|
|
|
|
|
|
|
# 测试CRF是否正常work。 |
|
|
|
|
|
import json |
|
|
|
|
|
import torch |
|
|
|
|
|
from fastNLP import seq_len_to_mask |
|
|
|
|
|
|
|
|
|
|
|
with open('test/data_for_tests/modules/decoder/crf.json', 'r') as f: |
|
|
|
|
|
data = json.load(f) |
|
|
|
|
|
|
|
|
|
|
|
bio_logits = torch.FloatTensor(data['bio_logits']) |
|
|
|
|
|
bio_scores = data['bio_scores'] |
|
|
|
|
|
bio_path = data['bio_path'] |
|
|
|
|
|
bio_trans_m = torch.FloatTensor(data['bio_trans_m']) |
|
|
|
|
|
bio_seq_lens = torch.LongTensor(data['bio_seq_lens']) |
|
|
|
|
|
|
|
|
|
|
|
bmes_logits = torch.FloatTensor(data['bmes_logits']) |
|
|
|
|
|
bmes_scores = data['bmes_scores'] |
|
|
|
|
|
bmes_path = data['bmes_path'] |
|
|
|
|
|
bmes_trans_m = torch.FloatTensor(data['bmes_trans_m']) |
|
|
|
|
|
bmes_seq_lens = torch.LongTensor(data['bmes_seq_lens']) |
|
|
|
|
|
|
|
|
|
|
|
labels = ['O'] |
|
|
|
|
|
for label in ['X', 'Y']: |
|
|
|
|
|
for tag in 'BI': |
|
|
|
|
|
labels.append('{}-{}'.format(tag, label)) |
|
|
|
|
|
id2label = {idx: label for idx, label in enumerate(labels)} |
|
|
|
|
|
num_tags = len(id2label) |
|
|
|
|
|
|
|
|
|
|
|
mask = seq_len_to_mask(bio_seq_lens) |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, |
|
|
|
|
|
include_start_end=True)) |
|
|
|
|
|
fast_CRF.trans_m.data = bio_trans_m |
|
|
|
|
|
fast_res = fast_CRF.viterbi_decode(bio_logits, mask, unpad=True) |
|
|
|
|
|
# score equal |
|
|
|
|
|
self.assertListEqual(bio_scores, [round(s, 4) for s in fast_res[1].tolist()]) |
|
|
|
|
|
# seq equal |
|
|
|
|
|
self.assertListEqual(bio_path, fast_res[0]) |
|
|
|
|
|
|
|
|
|
|
|
labels = [] |
|
|
|
|
|
for label in ['X', 'Y']: |
|
|
|
|
|
for tag in 'BMES': |
|
|
|
|
|
labels.append('{}-{}'.format(tag, label)) |
|
|
|
|
|
id2label = {idx: label for idx, label in enumerate(labels)} |
|
|
|
|
|
num_tags = len(id2label) |
|
|
|
|
|
|
|
|
|
|
|
mask = seq_len_to_mask(bmes_seq_lens) |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.modules.decoder.crf import ConditionalRandomField, allowed_transitions |
|
|
|
|
|
fast_CRF = ConditionalRandomField(num_tags=num_tags, allowed_transitions=allowed_transitions(id2label, |
|
|
|
|
|
encoding_type='BMES', |
|
|
|
|
|
include_start_end=True)) |
|
|
|
|
|
fast_CRF.trans_m.data = bmes_trans_m |
|
|
|
|
|
fast_res = fast_CRF.viterbi_decode(bmes_logits, mask, unpad=True) |
|
|
|
|
|
# score equal |
|
|
|
|
|
self.assertListEqual(bmes_scores, [round(s, 4) for s in fast_res[1].tolist()]) |
|
|
|
|
|
# seq equal |
|
|
|
|
|
self.assertListEqual(bmes_path, fast_res[0]) |
|
|
|
|
|
|
|
|
def test_case3(self): |
|
|
def test_case3(self): |
|
|
# 测试crf的loss不会出现负数 |
|
|
# 测试crf的loss不会出现负数 |
|
|