|
@@ -5,8 +5,39 @@ import torch |
|
|
|
|
|
|
|
|
from fastNLP import AccuracyMetric |
|
|
from fastNLP import AccuracyMetric |
|
|
from fastNLP.core.metrics import _pred_topk, _accuracy_topk |
|
|
from fastNLP.core.metrics import _pred_topk, _accuracy_topk |
|
|
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
|
|
from collections import Counter |
|
|
|
|
|
from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _generate_tags(encoding_type, number_labels=4): |
|
|
|
|
|
vocab = {} |
|
|
|
|
|
for i in range(number_labels): |
|
|
|
|
|
label = str(i) |
|
|
|
|
|
for tag in encoding_type: |
|
|
|
|
|
if tag == 'O': |
|
|
|
|
|
if tag not in vocab: |
|
|
|
|
|
vocab['O'] = len(vocab) + 1 |
|
|
|
|
|
continue |
|
|
|
|
|
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count |
|
|
|
|
|
return vocab |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_res_to_fastnlp_res(metric_result): |
|
|
|
|
|
allen_result = {} |
|
|
|
|
|
key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} |
|
|
|
|
|
for key, value in metric_result.items(): |
|
|
|
|
|
if key in key_map: |
|
|
|
|
|
key = key_map[key] |
|
|
|
|
|
else: |
|
|
|
|
|
label = key.split('-')[-1] |
|
|
|
|
|
if key.startswith('f1'): |
|
|
|
|
|
key = 'f-{}'.format(label) |
|
|
|
|
|
else: |
|
|
|
|
|
key = '{}-{}'.format(key[:3], label) |
|
|
|
|
|
allen_result[key] = round(value, 6) |
|
|
|
|
|
return allen_result |
|
|
|
|
|
|
|
|
class TestAccuracyMetric(unittest.TestCase): |
|
|
class TestAccuracyMetric(unittest.TestCase): |
|
|
def test_AccuracyMetric1(self): |
|
|
def test_AccuracyMetric1(self): |
|
|
# (1) only input, targets passed |
|
|
# (1) only input, targets passed |
|
@@ -160,18 +191,7 @@ class SpanF1PreRecMetric(unittest.TestCase): |
|
|
('6', (4, 5)), ('7', (6, 7))]) |
|
|
('6', (4, 5)), ('7', (6, 7))]) |
|
|
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) |
|
|
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) |
|
|
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) |
|
|
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) |
|
|
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 |
|
|
|
|
|
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans |
|
|
|
|
|
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans |
|
|
|
|
|
# for i in range(1000): |
|
|
|
|
|
# strs = list(map(str, np.random.randint(100, size=1000))) |
|
|
|
|
|
# bmes = list('bmes'.upper()) |
|
|
|
|
|
# bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))] |
|
|
|
|
|
# bio = list('bio'.upper()) |
|
|
|
|
|
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] |
|
|
|
|
|
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) |
|
|
|
|
|
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case2(self): |
|
|
def test_case2(self): |
|
|
# 测试不带label的 |
|
|
# 测试不带label的 |
|
|
from fastNLP.core.metrics import _bmes_tag_to_spans |
|
|
from fastNLP.core.metrics import _bmes_tag_to_spans |
|
@@ -185,170 +205,130 @@ class SpanF1PreRecMetric(unittest.TestCase): |
|
|
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) |
|
|
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) |
|
|
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) |
|
|
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) |
|
|
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) |
|
|
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) |
|
|
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 |
|
|
|
|
|
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans |
|
|
|
|
|
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans |
|
|
|
|
|
# for i in range(1000): |
|
|
|
|
|
# bmes = list('bmes'.upper()) |
|
|
|
|
|
# bmes_strs = np.random.choice(bmes, size=1000) |
|
|
|
|
|
# bio = list('bio'.upper()) |
|
|
|
|
|
# bio_strs = np.random.choice(bio, size=100) |
|
|
|
|
|
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) |
|
|
|
|
|
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) |
|
|
|
|
|
|
|
|
|
|
|
def tese_case3(self): |
|
|
|
|
|
from fastNLP.core.vocabulary import Vocabulary |
|
|
|
|
|
from collections import Counter |
|
|
|
|
|
from fastNLP.core.metrics import SpanFPreRecMetric |
|
|
|
|
|
# 与allennlp测试能否正确计算f metric |
|
|
|
|
|
# |
|
|
|
|
|
def generate_allen_tags(encoding_type, number_labels=4): |
|
|
|
|
|
vocab = {} |
|
|
|
|
|
for i in range(number_labels): |
|
|
|
|
|
label = str(i) |
|
|
|
|
|
for tag in encoding_type: |
|
|
|
|
|
if tag == 'O': |
|
|
|
|
|
if tag not in vocab: |
|
|
|
|
|
vocab['O'] = len(vocab) + 1 |
|
|
|
|
|
continue |
|
|
|
|
|
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count |
|
|
|
|
|
return vocab |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case3(self): |
|
|
number_labels = 4 |
|
|
number_labels = 4 |
|
|
# bio tag |
|
|
# bio tag |
|
|
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) |
|
|
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) |
|
|
fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) |
|
|
|
|
|
|
|
|
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) |
|
|
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) |
|
|
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) |
|
|
bio_sequence = torch.FloatTensor( |
|
|
|
|
|
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, |
|
|
|
|
|
0.0470, 0.0971], |
|
|
|
|
|
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, |
|
|
|
|
|
0.7987, -0.3970], |
|
|
|
|
|
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, |
|
|
|
|
|
0.6880, 1.4348], |
|
|
|
|
|
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, |
|
|
|
|
|
-1.6876, -0.8917], |
|
|
|
|
|
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, |
|
|
|
|
|
1.4217, 0.2622]], |
|
|
|
|
|
|
|
|
|
|
|
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, |
|
|
|
|
|
1.3592, -0.8973], |
|
|
|
|
|
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, |
|
|
|
|
|
-0.4025, -0.3417], |
|
|
|
|
|
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, |
|
|
|
|
|
0.2861, -0.3966], |
|
|
|
|
|
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, |
|
|
|
|
|
0.0213, 1.4777], |
|
|
|
|
|
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, |
|
|
|
|
|
1.3024, 0.2001]]] |
|
|
|
|
|
) |
|
|
|
|
|
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], |
|
|
|
|
|
[5., 6., 8., 6., 0.]]) |
|
|
|
|
|
fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target}) |
|
|
|
|
|
expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, |
|
|
|
|
|
'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, |
|
|
|
|
|
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, |
|
|
|
|
|
'f': 0.12499999999994846} |
|
|
|
|
|
|
|
|
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, |
|
|
|
|
|
-0.3782, 0.8240], |
|
|
|
|
|
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, |
|
|
|
|
|
-0.3562, -1.4116], |
|
|
|
|
|
[ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, |
|
|
|
|
|
2.0023, 0.7075], |
|
|
|
|
|
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, |
|
|
|
|
|
0.3832, -0.1540], |
|
|
|
|
|
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, |
|
|
|
|
|
-1.3508, -0.9513], |
|
|
|
|
|
[ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, |
|
|
|
|
|
-0.0842, -0.4294]], |
|
|
|
|
|
|
|
|
|
|
|
[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, |
|
|
|
|
|
-1.4138, -0.8853], |
|
|
|
|
|
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, |
|
|
|
|
|
-1.0726, 0.0364], |
|
|
|
|
|
[ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, |
|
|
|
|
|
-0.8836, -0.9320], |
|
|
|
|
|
[ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, |
|
|
|
|
|
-1.6857, 1.1571], |
|
|
|
|
|
[ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, |
|
|
|
|
|
-0.5837, 1.0184], |
|
|
|
|
|
[ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, |
|
|
|
|
|
-0.9025, 0.0864]]]) |
|
|
|
|
|
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], |
|
|
|
|
|
[4, 1, 7, 0, 4, 7]]) |
|
|
|
|
|
fastnlp_bio_metric({'pred': bio_sequence, 'seq_len': torch.LongTensor([6, 6])}, {'target': bio_target}) |
|
|
|
|
|
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, |
|
|
|
|
|
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, |
|
|
|
|
|
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} |
|
|
|
|
|
|
|
|
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) |
|
|
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_case4(self): |
|
|
# bmes tag |
|
|
# bmes tag |
|
|
bmes_sequence = torch.FloatTensor( |
|
|
|
|
|
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, |
|
|
|
|
|
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, |
|
|
|
|
|
-0.3505, -0.6002], |
|
|
|
|
|
[0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, |
|
|
|
|
|
1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, |
|
|
|
|
|
0.4439, 0.2514], |
|
|
|
|
|
[-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, |
|
|
|
|
|
-1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, |
|
|
|
|
|
0.5802, -0.0516], |
|
|
|
|
|
[-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, |
|
|
|
|
|
4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, |
|
|
|
|
|
-0.9242, -0.0333], |
|
|
|
|
|
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, |
|
|
|
|
|
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, |
|
|
|
|
|
-0.3779, -0.3195]], |
|
|
|
|
|
|
|
|
|
|
|
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, |
|
|
|
|
|
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, |
|
|
|
|
|
-0.1103, 0.4417], |
|
|
|
|
|
[-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, |
|
|
|
|
|
-0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, |
|
|
|
|
|
-0.6386, 0.2797], |
|
|
|
|
|
[0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, |
|
|
|
|
|
0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, |
|
|
|
|
|
1.1237, -1.5385], |
|
|
|
|
|
[0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, |
|
|
|
|
|
-0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, |
|
|
|
|
|
-0.0591, -0.2461], |
|
|
|
|
|
[-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, |
|
|
|
|
|
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, |
|
|
|
|
|
-0.7344, -1.2046]]] |
|
|
|
|
|
) |
|
|
|
|
|
bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.], |
|
|
|
|
|
[6., 15., 6., 15., 5.]]) |
|
|
|
|
|
|
|
|
|
|
|
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) |
|
|
|
|
|
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) |
|
|
|
|
|
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') |
|
|
|
|
|
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) |
|
|
|
|
|
|
|
|
|
|
|
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, |
|
|
|
|
|
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, |
|
|
|
|
|
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, |
|
|
|
|
|
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, |
|
|
|
|
|
'pre': 0.499999999999995, 'rec': 0.499999999999995} |
|
|
|
|
|
|
|
|
|
|
|
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) |
|
|
|
|
|
|
|
|
|
|
|
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 |
|
|
|
|
|
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary |
|
|
|
|
|
# from allennlp.training.metrics import SpanBasedF1Measure |
|
|
|
|
|
# allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)}, |
|
|
|
|
|
# non_padded_namespaces=['tags']) |
|
|
|
|
|
# allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags') |
|
|
|
|
|
# bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1)) |
|
|
|
|
|
# bio_target = torch.randint(2 * number_labels + 1, size=(2, 20)) |
|
|
|
|
|
# allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20)) |
|
|
|
|
|
# fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) |
|
|
|
|
|
# fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) |
|
|
|
|
|
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) |
|
|
|
|
|
# |
|
|
|
|
|
# def convert_allen_res_to_fastnlp_res(metric_result): |
|
|
|
|
|
# allen_result = {} |
|
|
|
|
|
# key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} |
|
|
|
|
|
# for key, value in metric_result.items(): |
|
|
|
|
|
# if key in key_map: |
|
|
|
|
|
# key = key_map[key] |
|
|
|
|
|
# else: |
|
|
|
|
|
# label = key.split('-')[-1] |
|
|
|
|
|
# if key.startswith('f1'): |
|
|
|
|
|
# key = 'f-{}'.format(label) |
|
|
|
|
|
# else: |
|
|
|
|
|
# key = '{}-{}'.format(key[:3], label) |
|
|
|
|
|
# allen_result[key] = value |
|
|
|
|
|
# return allen_result |
|
|
|
|
|
# |
|
|
|
|
|
# # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric())) |
|
|
|
|
|
# # print(fastnlp_bio_metric.get_metric()) |
|
|
|
|
|
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()), |
|
|
|
|
|
# fastnlp_bio_metric.get_metric()) |
|
|
|
|
|
# |
|
|
|
|
|
# allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)}) |
|
|
|
|
|
# allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES') |
|
|
|
|
|
# fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) |
|
|
|
|
|
# fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) |
|
|
|
|
|
# fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') |
|
|
|
|
|
# bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels)) |
|
|
|
|
|
# bmes_target = torch.randint(4 * number_labels, size=(2, 20)) |
|
|
|
|
|
# allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20)) |
|
|
|
|
|
# fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) |
|
|
|
|
|
# |
|
|
|
|
|
# # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric())) |
|
|
|
|
|
# # print(fastnlp_bmes_metric.get_metric()) |
|
|
|
|
|
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), |
|
|
|
|
|
# fastnlp_bmes_metric.get_metric()) |
|
|
|
|
|
|
|
|
def _generate_samples(): |
|
|
|
|
|
target = [] |
|
|
|
|
|
seq_len = [] |
|
|
|
|
|
vocab = Vocabulary(unknown=None, padding=None) |
|
|
|
|
|
for i in range(3): |
|
|
|
|
|
target_i = [] |
|
|
|
|
|
seq_len_i = 0 |
|
|
|
|
|
for j in range(1, 10): |
|
|
|
|
|
word_len = np.random.randint(1, 5) |
|
|
|
|
|
seq_len_i += word_len |
|
|
|
|
|
if word_len==1: |
|
|
|
|
|
target_i.append('S') |
|
|
|
|
|
else: |
|
|
|
|
|
target_i.append('B') |
|
|
|
|
|
target_i.extend(['M']*(word_len-2)) |
|
|
|
|
|
target_i.append('E') |
|
|
|
|
|
vocab.add_word_lst(target_i) |
|
|
|
|
|
target.append(target_i) |
|
|
|
|
|
seq_len.append(seq_len_i) |
|
|
|
|
|
target_ = np.zeros((3, max(seq_len))) |
|
|
|
|
|
for i in range(3): |
|
|
|
|
|
target_i = [vocab.to_index(t) for t in target[i]] |
|
|
|
|
|
target_[i, :seq_len[i]] = target_i |
|
|
|
|
|
return target_, target, seq_len, vocab |
|
|
|
|
|
def get_eval(raw_target, pred, vocab, seq_len): |
|
|
|
|
|
pred = pred.argmax(dim=-1).tolist() |
|
|
|
|
|
tp = 0 |
|
|
|
|
|
gold = 0 |
|
|
|
|
|
seg = 0 |
|
|
|
|
|
pred_target = [] |
|
|
|
|
|
for i in range(len(seq_len)): |
|
|
|
|
|
tags = [vocab.to_word(p) for p in pred[i][:seq_len[i]]] |
|
|
|
|
|
spans = [] |
|
|
|
|
|
prev_bmes_tag = None |
|
|
|
|
|
for idx, tag in enumerate(tags): |
|
|
|
|
|
if tag in ('B', 'S'): |
|
|
|
|
|
spans.append([idx, idx]) |
|
|
|
|
|
elif tag in ('M', 'E') and prev_bmes_tag in ('B', 'M'): |
|
|
|
|
|
spans[-1][1] = idx |
|
|
|
|
|
else: |
|
|
|
|
|
spans.append([idx, idx]) |
|
|
|
|
|
prev_bmes_tag = tag |
|
|
|
|
|
tmp = [] |
|
|
|
|
|
for span in spans: |
|
|
|
|
|
if span[1]-span[0]>0: |
|
|
|
|
|
tmp.extend(['B'] + ['M']*(span[1]-span[0]-1) + ['E']) |
|
|
|
|
|
else: |
|
|
|
|
|
tmp.append('S') |
|
|
|
|
|
pred_target.append(tmp) |
|
|
|
|
|
for i in range(len(seq_len)): |
|
|
|
|
|
raw_pred = pred_target[i] |
|
|
|
|
|
start = 0 |
|
|
|
|
|
for j in range(seq_len[i]): |
|
|
|
|
|
if raw_target[i][j] in ('E', 'S'): |
|
|
|
|
|
flag = True |
|
|
|
|
|
for k in range(start, j+1): |
|
|
|
|
|
if raw_target[i][k]!=raw_pred[k]: |
|
|
|
|
|
flag = False |
|
|
|
|
|
break |
|
|
|
|
|
if flag: |
|
|
|
|
|
tp += 1 |
|
|
|
|
|
start = j + 1 |
|
|
|
|
|
gold += 1 |
|
|
|
|
|
if raw_pred[j] in ('E', 'S'): |
|
|
|
|
|
seg += 1 |
|
|
|
|
|
|
|
|
|
|
|
pre = round(tp/seg, 6) |
|
|
|
|
|
rec = round(tp/gold, 6) |
|
|
|
|
|
return {'f': round(2*pre*rec/(pre+rec), 6), 'pre': pre, 'rec':rec} |
|
|
|
|
|
|
|
|
|
|
|
target, raw_target, seq_len, vocab = _generate_samples() |
|
|
|
|
|
pred = torch.randn(3, max(seq_len), 4) |
|
|
|
|
|
|
|
|
|
|
|
expected_metric = get_eval(raw_target, pred, vocab, seq_len) |
|
|
|
|
|
metric = SpanFPreRecMetric(vocab, encoding_type='bmes') |
|
|
|
|
|
metric({'pred': pred, 'seq_len':torch.LongTensor(seq_len)}, {'target': torch.from_numpy(target)}) |
|
|
|
|
|
# print(metric.get_metric(reset=False)) |
|
|
|
|
|
# print(expected_metric) |
|
|
|
|
|
metric_value = metric.get_metric() |
|
|
|
|
|
for key, value in expected_metric.items(): |
|
|
|
|
|
self.assertAlmostEqual(value, metric_value[key], places=5) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TestUsefulFunctions(unittest.TestCase): |
|
|
class TestUsefulFunctions(unittest.TestCase): |
|
|