Browse Source

[new] add ConfusionMatrix, ConfusionMatrixMetric (#272)

* add ConfusionMatrix, ConfusionMatrixMetric

* add confusionmatrix to utils

* add ConfusionMatrixmetric

* add ConfusionMatrixMetric

* init for test

* begin test

* test finish

* doc finish
tags/v0.5.5
ROGERDJQ GitHub 5 years ago
parent
commit
f3ee16a5f6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 329 additions and 6 deletions
  1. +93
    -1
      fastNLP/core/metrics.py
  2. +99
    -3
      fastNLP/core/utils.py
  3. +137
    -2
      test/core/test_metrics.py

+ 93
- 1
fastNLP/core/metrics.py View File

@@ -7,7 +7,8 @@ __all__ = [
"AccuracyMetric", "AccuracyMetric",
"SpanFPreRecMetric", "SpanFPreRecMetric",
"CMRC2018Metric", "CMRC2018Metric",
"ClassifyFPreRecMetric"
"ClassifyFPreRecMetric",
"ConfusionMatrixMetric"
] ]


import inspect import inspect
@@ -15,6 +16,7 @@ import warnings
from abc import abstractmethod from abc import abstractmethod
from collections import defaultdict from collections import defaultdict
from typing import Union from typing import Union
from copy import deepcopy
import re import re


import numpy as np import numpy as np
@@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import seq_len_to_mask from .utils import seq_len_to_mask
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from .utils import ConfusionMatrix




class MetricBase(object): class MetricBase(object):
@@ -276,6 +279,95 @@ class MetricBase(object):
return return




class ConfusionMatrixMetric(MetricBase):
r"""
分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )

最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例}
ConfusionMatrix实例的print()函数将输出矩阵字符串。

pred_dict = {"pred": torch.Tensor([2,1,3])}
target_dict = {'target': torch.Tensor([2,2,1])}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())

{'confusion_matrix':
target 1.0 2.0 3.0 all
pred
1.0 0 1 0 1
2.0 0 1 0 1
3.0 1 0 0 1
all 1 2 0 3}
"""
def __init__(self, vocab=None, pred=None, target=None, seq_len=None):
"""
:param vocab: vocab词表类,要求有to_word()方法。
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
"""
super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.confusion_matrix = ConfusionMatrix(vocab=vocab)

def evaluate(self, pred, target, seq_len=None):
"""
evaluate函数将针对一个批次的预测结果做评价指标的累计

:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]).
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")

if pred.dim() == target.dim():
pass
elif pred.dim() == target.dim() + 1:
pred = pred.argmax(dim=-1)
if seq_len is None and target.dim() > 1:
warnings.warn("You are not passing `seq_len` to exclude pad.")
else:
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

target = target.to(pred)
if seq_len is not None and target.dim() > 1:
for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()):
l=int(l)
self.confusion_matrix.add_pred_target(p[:l], t[:l])
elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出
for p, t in zip(pred.tolist(), target.tolist()):
self.confusion_matrix.add_pred_target(p, t)
else:
self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist())

def get_metric(self,reset=True):
"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.

:param bool reset: 在调用完get_metric后是否清空评价指标统计量.
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix}
"""
confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)}
if reset:
self.confusion_matrix.clear()
return confusion


class AccuracyMetric(MetricBase): class AccuracyMetric(MetricBase):
""" """
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )


+ 99
- 3
fastNLP/core/utils.py View File

@@ -8,18 +8,22 @@ __all__ = [
"get_seq_len" "get_seq_len"
] ]


import _pickle
import inspect import inspect
import os import os
import warnings import warnings
from collections import Counter, namedtuple from collections import Counter, namedtuple
from copy import deepcopy
from typing import List

import _pickle
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List
from ._logger import logger
from prettytable import PrettyTable from prettytable import PrettyTable

from ._logger import logger
from ._parallel_utils import _model_contains_inner_module from ._parallel_utils import _model_contains_inner_module
# from .vocabulary import Vocabulary


try: try:
from apex import amp from apex import amp
@@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs']) 'varargs'])






class ConfusionMatrix:
"""a dict can provide Confusion Matrix"""
def __init__(self, vocab=None):
"""
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。
"""
if vocab and not hasattr(vocab, 'to_word'):
raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary,"
f"got {type(vocab)}.")
self.confusiondict={} #key: pred index, value:target word ocunt
self.predcount={} #key:pred index, value:count
self.targetcount={} #key:target index, value:count
self.vocab=vocab
def add_pred_target(self, pred, target): #一组结果
"""
通过这个函数向ConfusionMatrix加入一组预测结果

:param list pred: 预测的标签列表
:param list target: 真实值的标签列表
:return ConfusionMatrix

confusion=ConfusionMatrix()
pred = [2,1,3]
target = [2,2,1]
confusion.add_pred_target(pred, target)
print(confusion)

target 1 2 3 all
pred
1 0 1 0 1
2 0 1 0 1
3 1 0 0 1
all 1 2 0 3
"""
for p,t in zip(pred,target): #<int, int>
self.predcount[p]=self.predcount.get(p,0)+ 1
self.targetcount[t]=self.targetcount.get(t,0)+1
if p in self.confusiondict:
self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1
else:
self.confusiondict[p]={}
self.confusiondict[p][t]= 1
return self.confusiondict

def clear(self):
"""
清除一些值,等待再次新加入
:return:
"""
self.confusiondict={}
self.targetcount={}
self.predcount={}
def __repr__(self):
"""
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。
"""
row2idx={}
idx2row={}
# 已知的所有键/label
totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys()))))
lenth=len(totallabel)
# namedict key :idx value:word/idx
namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel])

for label,idx in zip(totallabel,range(lenth)):
idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
# 这里打印东西
#表头
head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"]
output="\t".join(head) + "\n" + "pred" + "\n"
#内容
for i in row2idx.keys(): #第i行
p=row2idx[i]
h=namedict[p]
l=[0 for _ in range(lenth)]
if self.confusiondict.get(p,None):
for t,c in self.confusiondict[p].items():
l[idx2row[t]] = c #完成一行
l=[h]+[str(n) for n in l]+[str(sum(l))]
output+="\t".join(l) +"\n"
#表尾
tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()]
tail=["all"]+[str(n) for n in tail]+[str(sum(tail))]
output+="\t".join(tail)
return output


class Option(dict): class Option(dict):
"""a dict can treat keys as attributes""" """a dict can treat keys as attributes"""




+ 137
- 2
test/core/test_metrics.py View File

@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric
from fastNLP.core.metrics import _pred_topk, _accuracy_topk from fastNLP.core.metrics import _pred_topk, _accuracy_topk
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from collections import Counter from collections import Counter
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric




def _generate_tags(encoding_type, number_labels=4): def _generate_tags(encoding_type, number_labels=4):
@@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result):
allen_result[key] = round(value, 6) allen_result[key] = round(value, 6)
return allen_result return allen_result




class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric1(self):
pred_dict = {"pred": torch.zeros(4,3)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric()

metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())

def test_ConfusionMatrixMetric2(self):
# (2) with corrupted size
try:
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
except Exception as e:
print(e)
return
print("No exception catches.")

def test_ConfusionMatrixMetric3(self):
# (3) the second batch is corrupted size
try:
metric = ConfusionMatrixMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
except Exception as e:
print(e)
return
assert(True, False), "No exception catches."

def test_ConfusionMatrixMetric4(self):
# (4) check reset
metric = ConfusionMatrixMetric()
pred_dict = {"pred": torch.randn(4, 3, 2)}
target_dict = {'target': torch.ones(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
res = metric.get_metric()
self.assertTrue(isinstance(res, dict))
print(res)

def test_ConfusionMatrixMetric5(self):
# (5) check numpy array is not acceptable
try:
metric = ConfusionMatrixMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))}
target_dict = {'target': np.zeros((4, 3))}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
def test_ConfusionMatrixMetric6(self):
# (6) check map, match
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.randn(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
res = metric.get_metric()
print(res)

def test_ConfusionMatrixMetric7(self):
# (7) check map, include unused
try:
metric = ConfusionMatrixMetric(pred='prediction', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
def test_ConfusionMatrixMetric8(self):
# (8) check _fast_metric
try:
metric = ConfusionMatrixMetric()
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_duplicate(self):
# 0.4.1的潜在bug,不能出现形参重复的情况
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0}
target_dict = {'targets':torch.zeros(4, 3), 'target': 0}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())


def test_seq_len(self):
N = 256
seq_len = torch.zeros(N).long()
seq_len[0] = 2
pred = {'pred': torch.ones(N, 2)}
target = {'target': torch.ones(N, 2), 'seq_len': seq_len}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred, target_dict=target)
metric.get_metric(reset=False)
seq_len[1:] = 1
metric(pred_dict=pred, target_dict=target)
metric.get_metric()

def test_vocab(self):
vocab = Vocabulary()
word_list = "this is a word list".split()
vocab.update(word_list)
pred_dict = {"pred": torch.zeros(4,3)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric(vocab=vocab)
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())




class TestAccuracyMetric(unittest.TestCase): class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self): def test_AccuracyMetric1(self):
# (1) only input, targets passed # (1) only input, targets passed
@@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase):
def test_AccuaryMetric8(self): def test_AccuaryMetric8(self):
try: try:
metric = AccuracyMetric(pred='predictions', target='targets') metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2)}
pred_dict = {"predictions": torch.zeros(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)} target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict, ) metric(pred_dict=pred_dict, target_dict=target_dict, )
self.assertDictEqual(metric.get_metric(), {'acc': 1}) self.assertDictEqual(metric.get_metric(), {'acc': 1})


Loading…
Cancel
Save