From 0547572d58e0bc2e2a36494600e94ddba0906dce Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Mon, 21 Oct 2019 01:37:33 +0800 Subject: [PATCH] add ClassifyFPRMetric --- fastNLP/core/metrics.py | 198 +++++++++++++++++++++++++++++++++++--- test/core/test_metrics.py | 67 ++++++++++++- 2 files changed, 249 insertions(+), 16 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6ef1aea5..e06c5650 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -378,6 +378,172 @@ class AccuracyMetric(MetricBase): return evaluate_result +class ClassifyFPreRecMetric(MetricBase): + """ + 分类问题计算FPR值的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) + + 最后得到的metric结果为:: + + { + 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 + 'pre': xxx, + 'rec':xxx + } + + 若only_gross=False, 即还会返回各个label的metric统计值:: + + { + 'f': xxx, + 'pre': xxx, + 'rec':xxx, + 'f-label': xxx, + 'pre-label': xxx, + 'rec-label':xxx, + ... + } + """ + + def __init__(self, tag_vocab=None, pred=None, target=None, seq_len=None, ignore_labels=None, + only_gross=True, f_type='micro', beta=1): + """ + + :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` . 默认值为None。若为None则使用数字来作为标签内容,否则使用vocab来作为标签内容。 + :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 + :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 + :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 + :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label + :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec + :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) + :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + """ + if tag_vocab: + if not isinstance(tag_vocab, Vocabulary): + raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) + if f_type not in ('micro', 'macro'): + raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) + + self.ignore_labels = ignore_labels + self.f_type = f_type + self.beta = beta + self.beta_square = self.beta ** 2 + self.only_gross = only_gross + + super().__init__() + self._init_param_map(pred=pred, target=target, seq_len=seq_len) + + self.tag_vocab = tag_vocab + + self._tp, self._fp, self._fn = defaultdict(int), defaultdict(int), defaultdict(int) + # tp: truth=T, classify=T; fp: truth=T, classify=F; fn: truth=F, classify=T + + def evaluate(self, pred, target, seq_len=None): + """ + evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), + torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) + :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), + torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) + :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). + 如果mask也被传进来的话seq_len会被忽略. + + """ + # TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value + if not isinstance(pred, torch.Tensor): + raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") + if not isinstance(target, torch.Tensor): + raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(target)}.") + + if seq_len is not None and not isinstance(seq_len, torch.Tensor): + raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(seq_len)}.") + + if seq_len is not None and target.dim() > 1: + max_len = target.size(1) + masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) + else: + masks = torch.ones_like(target).long().to(target.device) + + if pred.dim() == target.dim(): + pass + elif pred.dim() == target.dim() + 1: + pred = pred.argmax(dim=-1) + if seq_len is None and target.dim() > 1: + warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") + else: + raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") + + target_list = target.tolist() + target = target.to(pred) + for target_num in target_list: + self._tp[target_num] += torch.sum((pred == target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item() + self._fp[target_num] += torch.sum((pred != target_num).long().masked_fill(target != target_num, 0).masked_fill(masks.eq(0), 0)).item() + self._fn[target_num] += torch.sum((pred == target_num).long().masked_fill(target == target_num, 0).masked_fill(masks.eq(0), 0)).item() + + def get_metric(self, reset=True): + """ + get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. + + :param bool reset: 在调用完get_metric后是否清空评价指标统计量. + :return dict evaluate_result: {"acc": float} + """ + evaluate_result = {} + if not self.only_gross or self.f_type == 'macro': + tags = set(self._fn.keys()) + tags.update(set(self._fp.keys())) + tags.update(set(self._tp.keys())) + f_sum = 0 + pre_sum = 0 + rec_sum = 0 + for tag in tags: + if self.tag_vocab is not None: + tag_name = self.tag_vocab.to_word(tag) + else: + tag_name = int(tag) + tp = self._tp[tag] + fn = self._fn[tag] + fp = self._fp[tag] + f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) + f_sum += f + pre_sum += pre + rec_sum += rec + if not self.only_gross and tag != '': # tag!=''防止无tag的情况 + f_key = 'f-{}'.format(tag_name) + pre_key = 'pre-{}'.format(tag_name) + rec_key = 'rec-{}'.format(tag_name) + evaluate_result[f_key] = f + evaluate_result[pre_key] = pre + evaluate_result[rec_key] = rec + + if self.f_type == 'macro': + evaluate_result['f'] = f_sum / len(tags) + evaluate_result['pre'] = pre_sum / len(tags) + evaluate_result['rec'] = rec_sum / len(tags) + + if self.f_type == 'micro': + f, pre, rec = _compute_f_pre_rec(self.beta_square, + sum(self._tp.values()), + sum(self._fn.values()), + sum(self._fp.values())) + evaluate_result['f'] = f + evaluate_result['pre'] = pre + evaluate_result['rec'] = rec + + if reset: + self._tp = defaultdict(int) + self._fp = defaultdict(int) + self._fn = defaultdict(int) + + for key, value in evaluate_result.items(): + evaluate_result[key] = round(value, 6) + + return evaluate_result + + def _bmes_tag_to_spans(tags, ignore_labels=None): """ 给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 @@ -713,7 +879,7 @@ class SpanFPreRecMetric(MetricBase): tp = self._true_positives[tag] fn = self._false_negatives[tag] fp = self._false_positives[tag] - f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) + f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) f_sum += f pre_sum += pre rec_sum += rec @@ -731,9 +897,10 @@ class SpanFPreRecMetric(MetricBase): evaluate_result['rec'] = rec_sum / len(tags) if self.f_type == 'micro': - f, pre, rec = self._compute_f_pre_rec(sum(self._true_positives.values()), - sum(self._false_negatives.values()), - sum(self._false_positives.values())) + f, pre, rec = _compute_f_pre_rec(self.beta_square, + sum(self._true_positives.values()), + sum(self._false_negatives.values()), + sum(self._false_positives.values())) evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec @@ -748,19 +915,20 @@ class SpanFPreRecMetric(MetricBase): return evaluate_result - def _compute_f_pre_rec(self, tp, fn, fp): - """ - :param tp: int, true positive - :param fn: int, false negative - :param fp: int, false positive - :return: (f, pre, rec) - """ - pre = tp / (fp + tp + 1e-13) - rec = tp / (fn + tp + 1e-13) - f = (1 + self.beta_square) * pre * rec / (self.beta_square * pre + rec + 1e-13) +def _compute_f_pre_rec(beta_square, tp, fn, fp): + """ + + :param tp: int, true positive + :param fn: int, false negative + :param fp: int, false positive + :return: (f, pre, rec) + """ + pre = tp / (fp + tp + 1e-13) + rec = tp / (fn + tp + 1e-13) + f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) - return f, pre, rec + return f, pre, rec def _prepare_metrics(metrics): diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 16711064..a11bd90b 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric from fastNLP.core.metrics import _pred_topk, _accuracy_topk from fastNLP.core.vocabulary import Vocabulary from collections import Counter -from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric +from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric def _generate_tags(encoding_type, number_labels=4): @@ -446,4 +446,69 @@ class TestUsefulFunctions(unittest.TestCase): # 跑通即可 +class TestClassfiyFPreRecMetric(unittest.TestCase): + def test_case_1(self): + pred = torch.FloatTensor([[-0.1603, -1.3247, 0.2010, 0.9240, -0.6396], + [-0.7316, -1.6028, 0.2281, 0.3558, 1.2500], + [-1.2943, -1.7350, -0.7085, 1.1269, 1.0782], + [ 0.1314, -0.2578, 0.7200, 1.0920, -1.0819], + [-0.6787, -0.9081, -0.2752, -1.5818, 0.5538], + [-0.2925, 1.1320, 2.8709, -0.6225, -0.6279], + [-0.3320, -0.9009, -1.5762, 0.3810, -0.1220], + [ 0.4601, -1.0509, 1.4242, 0.3427, 2.7014], + [-0.5558, 1.0899, -1.9045, 0.3377, 1.3192], + [-0.8251, -0.1558, -0.0871, -0.6755, -0.5905], + [ 0.1019, 1.2504, -1.1627, -0.7062, 1.8654], + [ 0.9016, -0.1984, -0.0831, -0.7646, 1.5309], + [ 0.2073, 0.2250, -0.0879, 0.1608, -0.8915], + [ 0.3624, 0.3806, 0.3159, -0.3603, -0.6672], + [ 0.2714, 2.5086, -0.1053, -0.5188, 0.9229], + [ 0.3258, -0.0303, 1.1439, -0.9123, 1.5180], + [ 1.2496, -1.0298, -0.4463, 0.1186, -1.7089], + [ 0.0788, 0.6300, -1.3336, -0.7122, 1.0164], + [-1.1900, -0.9620, -0.3839, 0.1159, -1.2045], + [-0.9037, -0.1447, 1.1834, -0.2617, 2.6112], + [ 0.1507, 0.1686, -0.1535, -0.3669, -0.8425], + [ 1.0537, 1.1958, -1.2309, 1.0405, 1.3018], + [-0.9823, -0.9712, 1.1560, -0.6473, 1.0361], + [ 0.8659, -0.2166, -0.8335, -0.3557, -0.5660], + [-1.4742, -0.8773, -2.5237, 0.7410, 0.1506], + [-1.3032, -1.7157, 0.7479, 1.0755, 1.0817], + [-0.2988, 2.3745, 1.2072, 0.0054, 1.1877], + [-0.0123, 1.6513, 0.2741, -0.7791, 0.6161], + [ 1.6339, -1.0365, 0.3961, -0.9683, 0.2684], + [-0.0278, -2.0856, -0.5376, 0.5129, -0.3169], + [ 0.9386, 0.8317, 0.9518, -0.5050, -0.2808], + [-0.6907, 0.5020, -0.9039, -1.1061, 0.1656]]) + + arg_max_pred = torch.Tensor([3, 2, 3, 3, 4, 2, 3, 4, 4, 2, 4, 4, 1, 1, + 1, 4, 0, 4, 3, 4, 1, 4, 2, 0, + 3, 4, 1, 1, 0, 3, 2, 1]) + target = torch.Tensor([3, 3, 3, 3, 4, 1, 0, 2, 1, 2, 4, 4, 1, 1, + 1, 4, 0, 4, 3, 4, 1, 4, 2, 0, + 3, 4, 1, 1, 0, 3, 2, 1]) + + metric = ClassifyFPreRecMetric(f_type='macro') + metric.evaluate(pred, target) + result_dict = metric.get_metric(reset=True) + ground_truth = {'f': 0.8362782, 'pre': 0.8269841, 'rec': 0.8668831} + for keys in ['f', 'pre', 'rec']: + self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) + + metric = ClassifyFPreRecMetric(f_type='micro') + metric.evaluate(pred, target) + result_dict = metric.get_metric(reset=True) + ground_truth = {'f': 0.85022, 'pre': 0.853982, 'rec': 0.846491} + for keys in ['f', 'pre', 'rec']: + self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) + + metric = ClassifyFPreRecMetric(only_gross=False, f_type='micro') + metric.evaluate(pred, target) + result_dict = metric.get_metric(reset=True) + ground_truth = {'f-0': 0.857143, 'pre-0': 0.75, 'rec-0': 1.0, 'f-1': 0.875, 'pre-1': 0.777778, 'rec-1': 1.0, + 'f-2': 0.75, 'pre-2': 0.75, 'rec-2': 0.75, 'f-3': 0.857143, 'pre-3': 0.857143, + 'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.85022, + 'pre': 0.853982, 'rec': 0.846491} + for keys in ground_truth.keys(): + self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001)