diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 0354e7cc..7a96020b 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -342,8 +342,8 @@ class AccuracyMetric(MetricBase): def _bmes_tag_to_spans(tags, ignore_labels=None): """ - 给定一个tags的lis,比如['S', 'B-singer', 'M-singer', 'E-singer', 'S', 'S']。 - 返回[('', (0, 1)), ('singer', (1, 4)), ('', (4, 5)), ('', (5, 6))] (左闭右开区间) + 给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 + 返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) :param tags: List[str], :param ignore_labels: List[str], 在该list中的label将被忽略 @@ -527,8 +527,8 @@ class SpanFPreRecMetric(MetricBase): if pred.size() == target.size() and len(target.size()) == 2: pass elif len(pred.size()) == len(target.size()) + 1 and len(target.size()) == 2: - pred = pred.argmax(dim=-1) num_classes = pred.size(-1) + pred = pred.argmax(dim=-1) if (target >= num_classes).any(): raise ValueError("A gold label passed to SpanBasedF1Metric contains an " "id >= {}, the number of classes.".format(num_classes)) @@ -538,9 +538,11 @@ class SpanFPreRecMetric(MetricBase): f"{pred.size()[:-1]}, got {target.size()}.") batch_size = pred.size(0) + pred = pred.tolist() + target = target.tolist() for i in range(batch_size): - pred_tags = pred[i, :int(seq_len[i])].tolist() - gold_tags = target[i, :int(seq_len[i])].tolist() + pred_tags = pred[i][:int(seq_len[i])] + gold_tags = target[i][:int(seq_len[i])] pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] @@ -592,15 +594,18 @@ class SpanFPreRecMetric(MetricBase): f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), sum(self._false_negatives.values()), sum(self._false_positives.values())) - evaluate_result['f'] = round(f, 6) - evaluate_result['pre'] = round(pre, 6) - evaluate_result['rec'] = round(rec, 6) + evaluate_result['f'] = f + evaluate_result['pre'] = pre + evaluate_result['rec'] = rec if reset: self._true_positives = defaultdict(int) self._false_positives = defaultdict(int) self._false_negatives = defaultdict(int) + for key, value in evaluate_result.items(): + evaluate_result[key] = round(value, 6) + return evaluate_result def _compute_f_pre_rec(self, tp, fn, fp): diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index a5f7c0c3..f3b0178c 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -5,8 +5,39 @@ import torch from fastNLP import AccuracyMetric from fastNLP.core.metrics import _pred_topk, _accuracy_topk +from fastNLP.core.vocabulary import Vocabulary +from collections import Counter +from fastNLP.core.metrics import SpanFPreRecMetric +def _generate_tags(encoding_type, number_labels=4): + vocab = {} + for i in range(number_labels): + label = str(i) + for tag in encoding_type: + if tag == 'O': + if tag not in vocab: + vocab['O'] = len(vocab) + 1 + continue + vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count + return vocab + + +def _convert_res_to_fastnlp_res(metric_result): + allen_result = {} + key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} + for key, value in metric_result.items(): + if key in key_map: + key = key_map[key] + else: + label = key.split('-')[-1] + if key.startswith('f1'): + key = 'f-{}'.format(label) + else: + key = '{}-{}'.format(key[:3], label) + allen_result[key] = round(value, 6) + return allen_result + class TestAccuracyMetric(unittest.TestCase): def test_AccuracyMetric1(self): # (1) only input, targets passed @@ -160,18 +191,7 @@ class SpanF1PreRecMetric(unittest.TestCase): ('6', (4, 5)), ('7', (6, 7))]) self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) - # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 - # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans - # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans - # for i in range(1000): - # strs = list(map(str, np.random.randint(100, size=1000))) - # bmes = list('bmes'.upper()) - # bmes_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bmes, size=len(strs)))] - # bio = list('bio'.upper()) - # bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] - # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) - # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) - + def test_case2(self): # 测试不带label的 from fastNLP.core.metrics import _bmes_tag_to_spans @@ -185,170 +205,130 @@ class SpanF1PreRecMetric(unittest.TestCase): expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) - # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 - # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans - # from allennlp.data.dataset_readers.dataset_utils import bmes_tags_to_spans as allen_bmes_tags_to_spans - # for i in range(1000): - # bmes = list('bmes'.upper()) - # bmes_strs = np.random.choice(bmes, size=1000) - # bio = list('bio'.upper()) - # bio_strs = np.random.choice(bio, size=100) - # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) - # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) - - def tese_case3(self): - from fastNLP.core.vocabulary import Vocabulary - from collections import Counter - from fastNLP.core.metrics import SpanFPreRecMetric - # 与allennlp测试能否正确计算f metric - # - def generate_allen_tags(encoding_type, number_labels=4): - vocab = {} - for i in range(number_labels): - label = str(i) - for tag in encoding_type: - if tag == 'O': - if tag not in vocab: - vocab['O'] = len(vocab) + 1 - continue - vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count - return vocab - + + def test_case3(self): number_labels = 4 # bio tag fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) - fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) + fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) - bio_sequence = torch.FloatTensor( - [[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, - 0.0470, 0.0971], - [-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, - 0.7987, -0.3970], - [0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, - 0.6880, 1.4348], - [-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, - -1.6876, -0.8917], - [-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, - 1.4217, 0.2622]], - - [[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, - 1.3592, -0.8973], - [0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, - -0.4025, -0.3417], - [-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, - 0.2861, -0.3966], - [-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, - 0.0213, 1.4777], - [-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, - 1.3024, 0.2001]]] - ) - bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], - [5., 6., 8., 6., 0.]]) - fastnlp_bio_metric({'pred': bio_sequence, 'seq_lens': torch.LongTensor([5, 5])}, {'target': bio_target}) - expect_bio_res = {'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-1': 0.33333333333327775, - 'pre-2': 0.0, 'rec-2': 0.0, 'f-2': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre-0': 0.0, - 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, - 'f': 0.12499999999994846} + bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, + -0.3782, 0.8240], + [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, + -0.3562, -1.4116], + [ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, + 2.0023, 0.7075], + [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, + 0.3832, -0.1540], + [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, + -1.3508, -0.9513], + [ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, + -0.0842, -0.4294]], + + [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, + -1.4138, -0.8853], + [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, + -1.0726, 0.0364], + [ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, + -0.8836, -0.9320], + [ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, + -1.6857, 1.1571], + [ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, + -0.5837, 1.0184], + [ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, + -0.9025, 0.0864]]]) + bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], + [4, 1, 7, 0, 4, 7]]) + fastnlp_bio_metric({'pred': bio_sequence, 'seq_len': torch.LongTensor([6, 6])}, {'target': bio_target}) + expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, + 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, + 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} + self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) - + + def test_case4(self): # bmes tag - bmes_sequence = torch.FloatTensor( - [[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, - -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, - -0.3505, -0.6002], - [0.3238, -1.2378, -1.3304, -0.4903, 1.4518, -0.1868, -0.7641, - 1.6199, -0.8877, 0.1449, 0.8995, -0.5810, 0.1041, 0.1002, - 0.4439, 0.2514], - [-0.8362, 2.9526, 0.8008, 0.1193, 1.0488, 0.6670, 1.1696, - -1.1006, -0.8540, -0.1600, -0.9519, -0.2749, -0.4948, -1.4753, - 0.5802, -0.0516], - [-0.8383, -1.7292, -1.4079, -1.5023, 0.5383, 0.6653, 0.3121, - 4.1249, -0.4173, -0.2043, 1.7755, 1.1110, -1.7069, -0.0390, - -0.9242, -0.0333], - [0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, - 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, - -0.3779, -0.3195]], - - [[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, - 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, - -0.1103, 0.4417], - [-0.2903, 0.9205, -1.5758, -1.0421, 0.2921, -0.2142, -0.3049, - -0.0879, -0.4412, -1.3195, -0.0657, -0.2986, 0.7214, 0.0631, - -0.6386, 0.2797], - [0.6440, -0.3748, 1.2912, -0.0170, 0.7447, 1.4075, -0.4947, - 0.4123, -0.8447, -0.5502, 0.3520, -0.2832, 0.5019, -0.1522, - 1.1237, -1.5385], - [0.2839, -0.7649, 0.9067, -0.1163, -1.3789, 0.2571, -1.3977, - -0.3680, -0.8902, -0.6983, -1.1583, 1.2779, 0.2197, 0.1376, - -0.0591, -0.2461], - [-0.2977, -1.8564, -0.5347, 1.0011, -1.1260, 0.4252, -2.0097, - 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, - -0.7344, -1.2046]]] - ) - bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.], - [6., 15., 6., 15., 5.]]) - - fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) - fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) - fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') - fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) - - expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, - 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, - 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, - 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, - 'pre': 0.499999999999995, 'rec': 0.499999999999995} - - self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) - - # 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 - # from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary - # from allennlp.training.metrics import SpanBasedF1Measure - # allen_bio_vocab = allen_Vocabulary({"tags": generate_allen_tags('BIO', number_labels)}, - # non_padded_namespaces=['tags']) - # allen_bio_metric = SpanBasedF1Measure(allen_bio_vocab, 'tags') - # bio_sequence = torch.randn(size=(2, 20, 2 * number_labels + 1)) - # bio_target = torch.randint(2 * number_labels + 1, size=(2, 20)) - # allen_bio_metric(bio_sequence, bio_target, torch.ones(2, 20)) - # fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) - # fastnlp_bio_vocab.word_count = Counter(generate_allen_tags('BIO', number_labels)) - # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) - # - # def convert_allen_res_to_fastnlp_res(metric_result): - # allen_result = {} - # key_map = {'f1-measure-overall': "f", "recall-overall": "rec", "precision-overall": "pre"} - # for key, value in metric_result.items(): - # if key in key_map: - # key = key_map[key] - # else: - # label = key.split('-')[-1] - # if key.startswith('f1'): - # key = 'f-{}'.format(label) - # else: - # key = '{}-{}'.format(key[:3], label) - # allen_result[key] = value - # return allen_result - # - # # print(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric())) - # # print(fastnlp_bio_metric.get_metric()) - # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bio_metric.get_metric()), - # fastnlp_bio_metric.get_metric()) - # - # allen_bmes_vocab = allen_Vocabulary({"tags": generate_allen_tags('BMES', number_labels)}) - # allen_bmes_metric = SpanBasedF1Measure(allen_bmes_vocab, 'tags', label_encoding='BMES') - # fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) - # fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) - # fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') - # bmes_sequence = torch.randn(size=(2, 20, 4 * number_labels)) - # bmes_target = torch.randint(4 * number_labels, size=(2, 20)) - # allen_bmes_metric(bmes_sequence, bmes_target, torch.ones(2, 20)) - # fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) - # - # # print(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric())) - # # print(fastnlp_bmes_metric.get_metric()) - # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), - # fastnlp_bmes_metric.get_metric()) + def _generate_samples(): + target = [] + seq_len = [] + vocab = Vocabulary(unknown=None, padding=None) + for i in range(3): + target_i = [] + seq_len_i = 0 + for j in range(1, 10): + word_len = np.random.randint(1, 5) + seq_len_i += word_len + if word_len==1: + target_i.append('S') + else: + target_i.append('B') + target_i.extend(['M']*(word_len-2)) + target_i.append('E') + vocab.add_word_lst(target_i) + target.append(target_i) + seq_len.append(seq_len_i) + target_ = np.zeros((3, max(seq_len))) + for i in range(3): + target_i = [vocab.to_index(t) for t in target[i]] + target_[i, :seq_len[i]] = target_i + return target_, target, seq_len, vocab + def get_eval(raw_target, pred, vocab, seq_len): + pred = pred.argmax(dim=-1).tolist() + tp = 0 + gold = 0 + seg = 0 + pred_target = [] + for i in range(len(seq_len)): + tags = [vocab.to_word(p) for p in pred[i][:seq_len[i]]] + spans = [] + prev_bmes_tag = None + for idx, tag in enumerate(tags): + if tag in ('B', 'S'): + spans.append([idx, idx]) + elif tag in ('M', 'E') and prev_bmes_tag in ('B', 'M'): + spans[-1][1] = idx + else: + spans.append([idx, idx]) + prev_bmes_tag = tag + tmp = [] + for span in spans: + if span[1]-span[0]>0: + tmp.extend(['B'] + ['M']*(span[1]-span[0]-1) + ['E']) + else: + tmp.append('S') + pred_target.append(tmp) + for i in range(len(seq_len)): + raw_pred = pred_target[i] + start = 0 + for j in range(seq_len[i]): + if raw_target[i][j] in ('E', 'S'): + flag = True + for k in range(start, j+1): + if raw_target[i][k]!=raw_pred[k]: + flag = False + break + if flag: + tp += 1 + start = j + 1 + gold += 1 + if raw_pred[j] in ('E', 'S'): + seg += 1 + + pre = round(tp/seg, 6) + rec = round(tp/gold, 6) + return {'f': round(2*pre*rec/(pre+rec), 6), 'pre': pre, 'rec':rec} + + target, raw_target, seq_len, vocab = _generate_samples() + pred = torch.randn(3, max(seq_len), 4) + expected_metric = get_eval(raw_target, pred, vocab, seq_len) + metric = SpanFPreRecMetric(vocab, encoding_type='bmes') + metric({'pred': pred, 'seq_len':torch.LongTensor(seq_len)}, {'target': torch.from_numpy(target)}) + # print(metric.get_metric(reset=False)) + # print(expected_metric) + metric_value = metric.get_metric() + for key, value in expected_metric.items(): + self.assertAlmostEqual(value, metric_value[key], places=5) class TestUsefulFunctions(unittest.TestCase):