Browse Source

1. 修复metric中的bug; 2.增加metric测试

tags/v0.4.10
yh_cc 5 years ago
parent
commit
0a8f7c0e69
2 changed files with 162 additions and 177 deletions
  1. +13
    -8
      fastNLP/core/metrics.py
  2. +149
    -169
      test/core/test_metrics.py

+ 13
- 8
fastNLP/core/metrics.py View File

@@ -342,8 +342,8 @@ class AccuracyMetric(MetricBase):

def _bmes_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。
返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间)
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间)

:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
@@ -527,8 +527,8 @@ class SpanFPreRecMetric(MetricBase):
if pred.size() == target.size() and len(target.size()) == 2:
pass
elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2:
pred = pred.argmax(dim=-1)
num_classes = pred.size(-1)
pred = pred.argmax(dim=-1)
if (target >= num_classes).any():
raise ValueError("A gold label passed to SpanBasedF1Metric contains an "
"id >= {}, the number of classes.".format(num_classes))
@@ -538,9 +538,11 @@ class SpanFPreRecMetric(MetricBase):
f"{pred.size()[:-1]}, got {target.size()}.")

batch_size = pred.size(0)
pred = pred.tolist()
target = target.tolist()
for i in range(batch_size):
pred_tags = pred[i, :int(seq_len[i])].tolist()
gold_tags = target[i, :int(seq_len[i])].tolist()
pred_tags = pred[i][:int(seq_len[i])]
gold_tags = target[i][:int(seq_len[i])]

pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags]
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags]
@@ -592,15 +594,18 @@ class SpanFPreRecMetric(MetricBase):
f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()),
sum(self._false_negatives.values()),
sum(self._false_positives.values()))
evaluate_result['f'] = round(f, 6)
evaluate_result['pre'] = round(pre, 6)
evaluate_result['rec'] = round(rec, 6)
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec

if reset:
self._true_positives = defaultdict(int)
self._false_positives = defaultdict(int)
self._false_negatives = defaultdict(int)

for key, value in evaluate_result.items():
evaluate_result[key] = round(value, 6)

return evaluate_result

def _compute_f_pre_rec(self, tp, fn, fp):


+ 149
- 169
test/core/test_metrics.py View File

@@ -5,8 +5,39 @@ import torch

from fastNLP import AccuracyMetric
from fastNLP.core.metrics import _pred_topk, _accuracy_topk
from fastNLP.core.vocabulary import Vocabulary
from collections import Counter
from fastNLP.core.metrics import SpanFPreRecMetric


def _generate_tags(encoding_type, number_labels=4):
vocab = {}
for i in range(number_labels):
label = str(i)
for tag in encoding_type:
if tag == 'O':
if tag not in vocab:
vocab['O'] = len(vocab) + 1
continue
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count
return vocab


def _convert_res_to_fastnlp_res(metric_result):
allen_result = {}
key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"}
for key, value in metric_result.items():
if key in key_map:
key = key_map[key]
else:
label = key.split('-')[-1]
if key.startswith('f1'):
key = 'f-{}'.format(label)
else:
key = '{}-{}'.format(key[:3], label)
allen_result[key] = round(value, 6)
return allen_result

class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
# (1) only input, targets passed
@@ -160,18 +191,7 @@ class SpanF1PreRecMetric(unittest.TestCase):
('6', (4, 5)), ('7', (6, 7))])
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
# for i in range(1000):
# strs = list(map(str, np.random.randint(100, size=1000)))
# bmes = list('bmes'.upper())
# bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))]
# bio = list('bio'.upper())
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))]
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))

def test_case2(self):
# 测试不带label的
from fastNLP.core.metrics import _bmes_tag_to_spans
@@ -185,170 +205,130 @@ class SpanF1PreRecMetric(unittest.TestCase):
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))])
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
# from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans
# for i in range(1000):
# bmes = list('bmes'.upper())
# bmes_strs = np.random.choice(bmes, size=1000)
# bio = list('bio'.upper())
# bio_strs = np.random.choice(bio, size=100)
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))
def tese_case3(self):
from fastNLP.core.vocabulary import Vocabulary
from collections import Counter
from fastNLP.core.metrics import SpanFPreRecMetric
# 与allennlp测试能否正确计算f metric
#
def generate_allen_tags(encoding_type, number_labels=4):
vocab = {}
for i in range(number_labels):
label = str(i)
for tag in encoding_type:
if tag == 'O':
if tag not in vocab:
vocab['O'] = len(vocab) + 1
continue
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count
return vocab

def test_case3(self):
number_labels = 4
# bio tag
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels))
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
bio_sequence = torch.FloatTensor(
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011,
0.0470, 0.0971],
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523,
0.7987, -0.3970],
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898,
0.6880, 1.4348],
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793,
-1.6876, -0.8917],
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824,
1.4217, 0.2622]],
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136,
1.3592, -0.8973],
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887,
-0.4025, -0.3417],
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698,
0.2861, -0.3966],
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275,
0.0213, 1.4777],
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566,
1.3024, 0.2001]]]
)
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.],
[5., 6., 8., 6., 0.]])
fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target})
expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775,
'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0,
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845,
'f': 0.12499999999994846}
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
-0.3782, 0.8240],
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,
-0.3562, -1.4116],
[ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
2.0023, 0.7075],
[-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
0.3832, -0.1540],
[-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
-1.3508, -0.9513],
[ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
-0.0842, -0.4294]],

[[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
-1.4138, -0.8853],
[-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
-1.0726, 0.0364],
[ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
-0.8836, -0.9320],
[ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
-1.6857, 1.1571],
[ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
-0.5837, 1.0184],
[ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
-0.9025, 0.0864]]])
bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4],
[4, 1, 7, 0, 4, 7]])
fastnlp_bio_metric({'pred': bio_sequence, 'seq_len': torch.LongTensor([6, 6])}, {'target': bio_target})
expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5,
'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0,
'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2}

self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric())

def test_case4(self):
# bmes tag
bmes_sequence = torch.FloatTensor(
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352,
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332,
-0.3505, -0.6002],
[0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641,
1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002,
0.4439, 0.2514],
[-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696,
-1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753,
0.5802, -0.0516],
[-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121,
4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390,
-0.9242, -0.0333],
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393,
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809,
-0.3779, -0.3195]],
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753,
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957,
-0.1103, 0.4417],
[-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049,
-0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631,
-0.6386, 0.2797],
[0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947,
0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522,
1.1237, -1.5385],
[0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977,
-0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376,
-0.0591, -0.2461],
[-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097,
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142,
-0.7344, -1.2046]]]
)
bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.],
[6., 15., 6., 15., 5.]])
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels))
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes')
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target})
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001,
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775,
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314,
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504,
'pre': 0.499999999999995, 'rec': 0.499999999999995}
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res)
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary
# from allennlp.training.metrics import SpanBasedF1Measure
# allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)},
# non_padded_namespaces=['tags'])
# allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags')
# bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1))
# bio_target = torch.randint(2 * number_labels + 1, size=(2, 20))
# allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20))
# fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
# fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels))
# fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
#
# def convert_allen_res_to_fastnlp_res(metric_result):
# allen_result = {}
# key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"}
# for key, value in metric_result.items():
# if key in key_map:
# key = key_map[key]
# else:
# label = key.split('-')[-1]
# if key.startswith('f1'):
# key = 'f-{}'.format(label)
# else:
# key = '{}-{}'.format(key[:3], label)
# allen_result[key] = value
# return allen_result
#
# # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()))
# # print(fastnlp_bio_metric.get_metric())
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()),
# fastnlp_bio_metric.get_metric())
#
# allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)})
# allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES')
# fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
# fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels))
# fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes')
# bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels))
# bmes_target = torch.randint(4 * number_labels, size=(2, 20))
# allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20))
# fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target})
#
# # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()))
# # print(fastnlp_bmes_metric.get_metric())
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()),
# fastnlp_bmes_metric.get_metric())
def _generate_samples():
target = []
seq_len = []
vocab = Vocabulary(unknown=None, padding=None)
for i in range(3):
target_i = []
seq_len_i = 0
for j in range(1, 10):
word_len = np.random.randint(1, 5)
seq_len_i += word_len
if word_len==1:
target_i.append('S')
else:
target_i.append('B')
target_i.extend(['M']*(word_len-2))
target_i.append('E')
vocab.add_word_lst(target_i)
target.append(target_i)
seq_len.append(seq_len_i)
target_ = np.zeros((3, max(seq_len)))
for i in range(3):
target_i = [vocab.to_index(t) for t in target[i]]
target_[i, :seq_len[i]] = target_i
return target_, target, seq_len, vocab
def get_eval(raw_target, pred, vocab, seq_len):
pred = pred.argmax(dim=-1).tolist()
tp = 0
gold = 0
seg = 0
pred_target = []
for i in range(len(seq_len)):
tags = [vocab.to_word(p) for p in pred[i][:seq_len[i]]]
spans = []
prev_bmes_tag = None
for idx, tag in enumerate(tags):
if tag in ('B', 'S'):
spans.append([idx, idx])
elif tag in ('M', 'E') and prev_bmes_tag in ('B', 'M'):
spans[-1][1] = idx
else:
spans.append([idx, idx])
prev_bmes_tag = tag
tmp = []
for span in spans:
if span[1]-span[0]>0:
tmp.extend(['B'] + ['M']*(span[1]-span[0]-1) + ['E'])
else:
tmp.append('S')
pred_target.append(tmp)
for i in range(len(seq_len)):
raw_pred = pred_target[i]
start = 0
for j in range(seq_len[i]):
if raw_target[i][j] in ('E', 'S'):
flag = True
for k in range(start, j+1):
if raw_target[i][k]!=raw_pred[k]:
flag = False
break
if flag:
tp += 1
start = j + 1
gold += 1
if raw_pred[j] in ('E', 'S'):
seg += 1

pre = round(tp/seg, 6)
rec = round(tp/gold, 6)
return {'f': round(2*pre*rec/(pre+rec), 6), 'pre': pre, 'rec':rec}

target, raw_target, seq_len, vocab = _generate_samples()
pred = torch.randn(3, max(seq_len), 4)

expected_metric = get_eval(raw_target, pred, vocab, seq_len)
metric = SpanFPreRecMetric(vocab, encoding_type='bmes')
metric({'pred': pred, 'seq_len':torch.LongTensor(seq_len)}, {'target': torch.from_numpy(target)})
# print(metric.get_metric(reset=False))
# print(expected_metric)
metric_value = metric.get_metric()
for key, value in expected_metric.items():
self.assertAlmostEqual(value, metric_value[key], places=5)


class TestUsefulFunctions(unittest.TestCase):


Loading…
Cancel
Save