From b9b688d23311ee655180cafc84d0650f24de67d6 Mon Sep 17 00:00:00 2001 From: roger Date: Thu, 30 Jul 2020 04:24:28 +0000 Subject: [PATCH] f1 fix --- fastNLP/core/metrics.py | 30 ++++++++++---- test/core/test_metrics.py | 87 ++++++++++++++------------------------- 2 files changed, 53 insertions(+), 64 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index cf5b82b7..31f69cb9 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -313,11 +313,9 @@ class ConfusionMatrixMetric(MetricBase): pred=None, target=None, seq_len=None, - show_result=None, print_ratio=False ): r""" - :param vocab: vocab词表类,要求有to_word()方法。 :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` @@ -327,7 +325,6 @@ class ConfusionMatrixMetric(MetricBase): super().__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.confusion_matrix = ConfusionMatrix( - show_result=show_result, vocab=vocab, print_ratio=print_ratio, ) @@ -335,6 +332,7 @@ class ConfusionMatrixMetric(MetricBase): def evaluate(self, pred, target, seq_len=None): r""" evaluate函数将针对一个批次的预测结果做评价指标的累计 + :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), @@ -356,6 +354,10 @@ class ConfusionMatrixMetric(MetricBase): f"got {type(seq_len)}.") if pred.dim() == target.dim(): + if torch.numel(pred) !=torch.numel(target): + raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have " + f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}") + pass elif pred.dim() == target.dim() + 1: pred = pred.argmax(dim=-1) @@ -446,6 +448,10 @@ class AccuracyMetric(MetricBase): masks = None if pred.dim() == target.dim(): + if torch.numel(pred) !=torch.numel(target): + raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have " + f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}") + pass elif pred.dim() == target.dim() + 1: pred = pred.argmax(dim=-1) @@ -477,7 +483,6 @@ class AccuracyMetric(MetricBase): self.total = 0 return evaluate_result - class ClassifyFPreRecMetric(MetricBase): r""" 分类问题计算FPR值的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) @@ -567,9 +572,14 @@ class ClassifyFPreRecMetric(MetricBase): masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) else: masks = torch.ones_like(target).long().to(target.device) - masks = masks.eq(False) + + masks = masks.eq(1) if pred.dim() == target.dim(): + if torch.numel(pred) !=torch.numel(target): + raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have " + f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}") + pass elif pred.dim() == target.dim() + 1: pred = pred.argmax(dim=-1) @@ -580,12 +590,14 @@ class ClassifyFPreRecMetric(MetricBase): f"size:{pred.size()}, target should have size: {pred.size()} or " f"{pred.size()[:-1]}, got {target.size()}.") - target_idxes = set(target.reshape(-1).tolist()) target = target.to(pred) + target = target.masked_select(masks) + pred = pred.masked_select(masks) + target_idxes = set(target.reshape(-1).tolist()) for target_idx in target_idxes: - self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() - self._fp[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() - self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item() + self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0)).item() + self._fp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0)).item() + self._fn[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0)).item() def get_metric(self, reset=True): r""" diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 27799c54..14096ff5 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -1,13 +1,14 @@ import unittest +from collections import Counter import numpy as np import torch - from fastNLP import AccuracyMetric -from fastNLP.core.metrics import _pred_topk, _accuracy_topk +from fastNLP.core.metrics import (ClassifyFPreRecMetric, CMRC2018Metric, + ConfusionMatrixMetric, SpanFPreRecMetric, + _accuracy_topk, _pred_topk) from fastNLP.core.vocabulary import Vocabulary -from collections import Counter -from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric +from sklearn import metrics as m def _generate_tags(encoding_type, number_labels=4): @@ -563,69 +564,45 @@ class TestUsefulFunctions(unittest.TestCase): # 跑通即可 + class TestClassfiyFPreRecMetric(unittest.TestCase): def test_case_1(self): - pred = torch.FloatTensor([[-0.1603, -1.3247, 0.2010, 0.9240, -0.6396], - [-0.7316, -1.6028, 0.2281, 0.3558, 1.2500], - [-1.2943, -1.7350, -0.7085, 1.1269, 1.0782], - [ 0.1314, -0.2578, 0.7200, 1.0920, -1.0819], - [-0.6787, -0.9081, -0.2752, -1.5818, 0.5538], - [-0.2925, 1.1320, 2.8709, -0.6225, -0.6279], - [-0.3320, -0.9009, -1.5762, 0.3810, -0.1220], - [ 0.4601, -1.0509, 1.4242, 0.3427, 2.7014], - [-0.5558, 1.0899, -1.9045, 0.3377, 1.3192], - [-0.8251, -0.1558, -0.0871, -0.6755, -0.5905], - [ 0.1019, 1.2504, -1.1627, -0.7062, 1.8654], - [ 0.9016, -0.1984, -0.0831, -0.7646, 1.5309], - [ 0.2073, 0.2250, -0.0879, 0.1608, -0.8915], - [ 0.3624, 0.3806, 0.3159, -0.3603, -0.6672], - [ 0.2714, 2.5086, -0.1053, -0.5188, 0.9229], - [ 0.3258, -0.0303, 1.1439, -0.9123, 1.5180], - [ 1.2496, -1.0298, -0.4463, 0.1186, -1.7089], - [ 0.0788, 0.6300, -1.3336, -0.7122, 1.0164], - [-1.1900, -0.9620, -0.3839, 0.1159, -1.2045], - [-0.9037, -0.1447, 1.1834, -0.2617, 2.6112], - [ 0.1507, 0.1686, -0.1535, -0.3669, -0.8425], - [ 1.0537, 1.1958, -1.2309, 1.0405, 1.3018], - [-0.9823, -0.9712, 1.1560, -0.6473, 1.0361], - [ 0.8659, -0.2166, -0.8335, -0.3557, -0.5660], - [-1.4742, -0.8773, -2.5237, 0.7410, 0.1506], - [-1.3032, -1.7157, 0.7479, 1.0755, 1.0817], - [-0.2988, 2.3745, 1.2072, 0.0054, 1.1877], - [-0.0123, 1.6513, 0.2741, -0.7791, 0.6161], - [ 1.6339, -1.0365, 0.3961, -0.9683, 0.2684], - [-0.0278, -2.0856, -0.5376, 0.5129, -0.3169], - [ 0.9386, 0.8317, 0.9518, -0.5050, -0.2808], - [-0.6907, 0.5020, -0.9039, -1.1061, 0.1656]]) - - arg_max_pred = torch.Tensor([3, 2, 3, 3, 4, 2, 3, 4, 4, 2, 4, 4, 1, 1, - 1, 4, 0, 4, 3, 4, 1, 4, 2, 0, - 3, 4, 1, 1, 0, 3, 2, 1]) - target = torch.Tensor([3, 3, 3, 3, 4, 1, 0, 2, 1, 2, 4, 4, 1, 1, - 1, 4, 0, 4, 3, 4, 1, 4, 2, 0, - 3, 4, 1, 1, 0, 3, 2, 1]) + pred= torch.randn(32,5) + arg_max_pred = torch.argmax(pred,dim=-1) + target=np.random.randint(0, high=5, size=(32,1)) + target = torch.from_numpy(target) + metric = ClassifyFPreRecMetric(f_type='macro') metric.evaluate(pred, target) - result_dict = metric.get_metric(reset=True) - ground_truth = {'f': 0.8362782, 'pre': 0.8269841, 'rec': 0.8668831} + result_dict = metric.get_metric() + f1_score = m.f1_score(target.tolist(), arg_max_pred.tolist(), average="macro") + recall = m.recall_score(target.tolist(), arg_max_pred.tolist(), average="macro") + pre = m.precision_score(target.tolist(), arg_max_pred.tolist(), average="macro") + + ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} for keys in ['f', 'pre', 'rec']: self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) metric = ClassifyFPreRecMetric(f_type='micro') metric.evaluate(pred, target) - result_dict = metric.get_metric(reset=True) - ground_truth = {'f': 0.84375, 'pre': 0.84375, 'rec': 0.84375} + result_dict = metric.get_metric() + f1_score = m.f1_score(target.tolist(), arg_max_pred.tolist(), average="micro") + recall = m.recall_score(target.tolist(), arg_max_pred.tolist(), average="micro") + pre = m.precision_score(target.tolist(), arg_max_pred.tolist(), average="micro") + + ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} for keys in ['f', 'pre', 'rec']: self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) - metric = ClassifyFPreRecMetric(only_gross=False, f_type='micro') + metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') metric.evaluate(pred, target) result_dict = metric.get_metric(reset=True) - ground_truth = {'f-0': 0.857143, 'pre-0': 0.75, 'rec-0': 1.0, 'f-1': 0.875, 'pre-1': 0.777778, 'rec-1': 1.0, - 'f-2': 0.75, 'pre-2': 0.75, 'rec-2': 0.75, 'f-3': 0.857143, 'pre-3': 0.857143, - 'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.84375, - 'pre': 0.84375, 'rec': 0.84375} - for keys in ground_truth.keys(): - self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) - + ground_truth = m.classification_report(target.tolist(), arg_max_pred.tolist(),output_dict=True) + for keys in result_dict.keys(): + if keys=="f" or "pre" or "rec": + continue + gl=str(keys[-1]) + tmp_d={"p":"precision","r":"recall","f":"f1-score"} + gk=tmp_d[keys[0]] + self.assertAlmostEqual(result_dict[keys], ground_truth[gl][gk], delta=0.0001)