From f3ee16a5f6c2a679f9f12f490c6cb4109f5cae54 Mon Sep 17 00:00:00 2001 From: ROGERDJQ Date: Thu, 13 Feb 2020 20:38:10 +0800 Subject: [PATCH] [new] add ConfusionMatrix, ConfusionMatrixMetric (#272) * add ConfusionMatrix, ConfusionMatrixMetric * add confusionmatrix to utils * add ConfusionMatrixmetric * add ConfusionMatrixMetric * init for test * begin test * test finish * doc finish --- fastNLP/core/metrics.py | 94 +++++++++++++++++++++++++- fastNLP/core/utils.py | 102 +++++++++++++++++++++++++++- test/core/test_metrics.py | 139 +++++++++++++++++++++++++++++++++++++- 3 files changed, 329 insertions(+), 6 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 95a3331f..146b532f 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -7,7 +7,8 @@ __all__ = [ "AccuracyMetric", "SpanFPreRecMetric", "CMRC2018Metric", - "ClassifyFPreRecMetric" + "ClassifyFPreRecMetric", + "ConfusionMatrixMetric" ] import inspect @@ -15,6 +16,7 @@ import warnings from abc import abstractmethod from collections import defaultdict from typing import Union +from copy import deepcopy import re import numpy as np @@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list from .utils import _get_func_signature from .utils import seq_len_to_mask from .vocabulary import Vocabulary +from .utils import ConfusionMatrix class MetricBase(object): @@ -276,6 +279,95 @@ class MetricBase(object): return +class ConfusionMatrixMetric(MetricBase): + r""" + 分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) + + 最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例} + ConfusionMatrix实例的print()函数将输出矩阵字符串。 + + pred_dict = {"pred": torch.Tensor([2,1,3])} + target_dict = {'target': torch.Tensor([2,2,1])} + metric = ConfusionMatrixMetric() + metric(pred_dict=pred_dict, target_dict=target_dict, ) + print(metric.get_metric()) + + {'confusion_matrix': + target 1.0 2.0 3.0 all + pred + 1.0 0 1 0 1 + 2.0 0 1 0 1 + 3.0 1 0 0 1 + all 1 2 0 3} + """ + def __init__(self, vocab=None, pred=None, target=None, seq_len=None): + """ + :param vocab: vocab词表类,要求有to_word()方法。 + :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` + :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` + :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` + """ + super().__init__() + self._init_param_map(pred=pred, target=target, seq_len=seq_len) + self.confusion_matrix = ConfusionMatrix(vocab=vocab) + + def evaluate(self, pred, target, seq_len=None): + """ + evaluate函数将针对一个批次的预测结果做评价指标的累计 + + :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), + torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) + :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), + torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) + :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]). + + """ + if not isinstance(pred, torch.Tensor): + raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") + if not isinstance(target, torch.Tensor): + raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(target)}.") + + if seq_len is not None and not isinstance(seq_len, torch.Tensor): + raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(seq_len)}.") + + if pred.dim() == target.dim(): + pass + elif pred.dim() == target.dim() + 1: + pred = pred.argmax(dim=-1) + if seq_len is None and target.dim() > 1: + warnings.warn("You are not passing `seq_len` to exclude pad.") + else: + raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") + + target = target.to(pred) + if seq_len is not None and target.dim() > 1: + for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()): + l=int(l) + self.confusion_matrix.add_pred_target(p[:l], t[:l]) + elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出 + for p, t in zip(pred.tolist(), target.tolist()): + self.confusion_matrix.add_pred_target(p, t) + else: + self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist()) + + def get_metric(self,reset=True): + """ + get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. + + :param bool reset: 在调用完get_metric后是否清空评价指标统计量. + :return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} + """ + confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)} + if reset: + self.confusion_matrix.clear() + return confusion + + class AccuracyMetric(MetricBase): """ 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index ba9ec850..b1d5f4e2 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -8,18 +8,22 @@ __all__ = [ "get_seq_len" ] -import _pickle import inspect import os import warnings from collections import Counter, namedtuple +from copy import deepcopy +from typing import List + +import _pickle import numpy as np import torch import torch.nn as nn -from typing import List -from ._logger import logger from prettytable import PrettyTable + +from ._logger import logger from ._parallel_utils import _model_contains_inner_module +# from .vocabulary import Vocabulary try: from apex import amp @@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require 'varargs']) + + +class ConfusionMatrix: + """a dict can provide Confusion Matrix""" + def __init__(self, vocab=None): + """ + :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 + """ + if vocab and not hasattr(vocab, 'to_word'): + raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary," + f"got {type(vocab)}.") + self.confusiondict={} #key: pred index, value:target word ocunt + self.predcount={} #key:pred index, value:count + self.targetcount={} #key:target index, value:count + self.vocab=vocab + + def add_pred_target(self, pred, target): #一组结果 + """ + 通过这个函数向ConfusionMatrix加入一组预测结果 + + :param list pred: 预测的标签列表 + :param list target: 真实值的标签列表 + :return ConfusionMatrix + + confusion=ConfusionMatrix() + pred = [2,1,3] + target = [2,2,1] + confusion.add_pred_target(pred, target) + print(confusion) + + target 1 2 3 all + pred + 1 0 1 0 1 + 2 0 1 0 1 + 3 1 0 0 1 + all 1 2 0 3 + """ + for p,t in zip(pred,target): # + self.predcount[p]=self.predcount.get(p,0)+ 1 + self.targetcount[t]=self.targetcount.get(t,0)+1 + if p in self.confusiondict: + self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1 + else: + self.confusiondict[p]={} + self.confusiondict[p][t]= 1 + return self.confusiondict + + def clear(self): + """ + 清除一些值,等待再次新加入 + :return: + """ + self.confusiondict={} + self.targetcount={} + self.predcount={} + + def __repr__(self): + """ + :return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 + """ + row2idx={} + idx2row={} + # 已知的所有键/label + totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys())))) + lenth=len(totallabel) + # namedict key :idx value:word/idx + namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel]) + + for label,idx in zip(totallabel,range(lenth)): + idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... + row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... + # 这里打印东西 + #表头 + head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"] + output="\t".join(head) + "\n" + "pred" + "\n" + #内容 + for i in row2idx.keys(): #第i行 + p=row2idx[i] + h=namedict[p] + l=[0 for _ in range(lenth)] + if self.confusiondict.get(p,None): + for t,c in self.confusiondict[p].items(): + l[idx2row[t]] = c #完成一行 + l=[h]+[str(n) for n in l]+[str(sum(l))] + output+="\t".join(l) +"\n" + #表尾 + tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()] + tail=["all"]+[str(n) for n in tail]+[str(sum(tail))] + output+="\t".join(tail) + return output + + class Option(dict): """a dict can treat keys as attributes""" diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index 32581e23..f6cbbb4f 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric from fastNLP.core.metrics import _pred_topk, _accuracy_topk from fastNLP.core.vocabulary import Vocabulary from collections import Counter -from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric +from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric def _generate_tags(encoding_type, number_labels=4): @@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result): allen_result[key] = round(value, 6) return allen_result + + +class TestConfusionMatrixMetric(unittest.TestCase): + def test_ConfusionMatrixMetric1(self): + pred_dict = {"pred": torch.zeros(4,3)} + target_dict = {'target': torch.zeros(4)} + metric = ConfusionMatrixMetric() + + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + + def test_ConfusionMatrixMetric2(self): + # (2) with corrupted size + try: + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric = ConfusionMatrixMetric() + + metric(pred_dict=pred_dict, target_dict=target_dict, ) + print(metric.get_metric()) + except Exception as e: + print(e) + return + print("No exception catches.") + + def test_ConfusionMatrixMetric3(self): + # (3) the second batch is corrupted size + try: + metric = ConfusionMatrixMetric() + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + + pred_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric(pred_dict=pred_dict, target_dict=target_dict) + + print(metric.get_metric()) + except Exception as e: + print(e) + return + assert(True, False), "No exception catches." + + def test_ConfusionMatrixMetric4(self): + # (4) check reset + metric = ConfusionMatrixMetric() + pred_dict = {"pred": torch.randn(4, 3, 2)} + target_dict = {'target': torch.ones(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + res = metric.get_metric() + self.assertTrue(isinstance(res, dict)) + print(res) + + def test_ConfusionMatrixMetric5(self): + # (5) check numpy array is not acceptable + try: + metric = ConfusionMatrixMetric() + pred_dict = {"pred": np.zeros((4, 3, 2))} + target_dict = {'target': np.zeros((4, 3))} + metric(pred_dict=pred_dict, target_dict=target_dict) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_ConfusionMatrixMetric6(self): + # (6) check map, match + metric = ConfusionMatrixMetric(pred='predictions', target='targets') + pred_dict = {"predictions": torch.randn(4, 3, 2)} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + res = metric.get_metric() + print(res) + + def test_ConfusionMatrixMetric7(self): + # (7) check map, include unused + try: + metric = ConfusionMatrixMetric(pred='prediction', target='targets') + pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_ConfusionMatrixMetric8(self): +# (8) check _fast_metric + try: + metric = ConfusionMatrixMetric() + pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." + + def test_duplicate(self): + # 0.4.1的潜在bug,不能出现形参重复的情况 + metric = ConfusionMatrixMetric(pred='predictions', target='targets') + pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} + target_dict = {'targets':torch.zeros(4, 3), 'target': 0} + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + + + def test_seq_len(self): + N = 256 + seq_len = torch.zeros(N).long() + seq_len[0] = 2 + pred = {'pred': torch.ones(N, 2)} + target = {'target': torch.ones(N, 2), 'seq_len': seq_len} + metric = ConfusionMatrixMetric() + metric(pred_dict=pred, target_dict=target) + metric.get_metric(reset=False) + seq_len[1:] = 1 + metric(pred_dict=pred, target_dict=target) + metric.get_metric() + + def test_vocab(self): + vocab = Vocabulary() + word_list = "this is a word list".split() + vocab.update(word_list) + + pred_dict = {"pred": torch.zeros(4,3)} + target_dict = {'target': torch.zeros(4)} + metric = ConfusionMatrixMetric(vocab=vocab) + metric(pred_dict=pred_dict, target_dict=target_dict) + print(metric.get_metric()) + + + + class TestAccuracyMetric(unittest.TestCase): def test_AccuracyMetric1(self): # (1) only input, targets passed @@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase): def test_AccuaryMetric8(self): try: metric = AccuracyMetric(pred='predictions', target='targets') - pred_dict = {"prediction": torch.zeros(4, 3, 2)} + pred_dict = {"predictions": torch.zeros(4, 3, 2)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict, ) self.assertDictEqual(metric.get_metric(), {'acc': 1})