Browse Source

[new] add ConfusionMatrix, ConfusionMatrixMetric (#272)

* add ConfusionMatrix, ConfusionMatrixMetric

* add confusionmatrix to utils

* add ConfusionMatrixmetric

* add ConfusionMatrixMetric

* init for test

* begin test

* test finish

* doc finish
tags/v0.5.5
ROGERDJQ GitHub 4 years ago
parent
commit
f3ee16a5f6
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 329 additions and 6 deletions
  1. +93
    -1
      fastNLP/core/metrics.py
  2. +99
    -3
      fastNLP/core/utils.py
  3. +137
    -2
      test/core/test_metrics.py

+ 93
- 1
fastNLP/core/metrics.py View File

@@ -7,7 +7,8 @@ __all__ = [
"AccuracyMetric",
"SpanFPreRecMetric",
"CMRC2018Metric",
"ClassifyFPreRecMetric"
"ClassifyFPreRecMetric",
"ConfusionMatrixMetric"
]

import inspect
@@ -15,6 +16,7 @@ import warnings
from abc import abstractmethod
from collections import defaultdict
from typing import Union
from copy import deepcopy
import re

import numpy as np
@@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list
from .utils import _get_func_signature
from .utils import seq_len_to_mask
from .vocabulary import Vocabulary
from .utils import ConfusionMatrix


class MetricBase(object):
@@ -276,6 +279,95 @@ class MetricBase(object):
return


class ConfusionMatrixMetric(MetricBase):
r"""
分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )

最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例}
ConfusionMatrix实例的print()函数将输出矩阵字符串。

pred_dict = {"pred": torch.Tensor([2,1,3])}
target_dict = {'target': torch.Tensor([2,2,1])}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())

{'confusion_matrix':
target 1.0 2.0 3.0 all
pred
1.0 0 1 0 1
2.0 0 1 0 1
3.0 1 0 0 1
all 1 2 0 3}
"""
def __init__(self, vocab=None, pred=None, target=None, seq_len=None):
"""
:param vocab: vocab词表类,要求有to_word()方法。
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
"""
super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.confusion_matrix = ConfusionMatrix(vocab=vocab)

def evaluate(self, pred, target, seq_len=None):
"""
evaluate函数将针对一个批次的预测结果做评价指标的累计

:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len])
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]).
"""
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")

if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")

if pred.dim() == target.dim():
pass
elif pred.dim() == target.dim() + 1:
pred = pred.argmax(dim=-1)
if seq_len is None and target.dim() > 1:
warnings.warn("You are not passing `seq_len` to exclude pad.")
else:
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")

target = target.to(pred)
if seq_len is not None and target.dim() > 1:
for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()):
l=int(l)
self.confusion_matrix.add_pred_target(p[:l], t[:l])
elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出
for p, t in zip(pred.tolist(), target.tolist()):
self.confusion_matrix.add_pred_target(p, t)
else:
self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist())

def get_metric(self,reset=True):
"""
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.

:param bool reset: 在调用完get_metric后是否清空评价指标统计量.
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix}
"""
confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)}
if reset:
self.confusion_matrix.clear()
return confusion


class AccuracyMetric(MetricBase):
"""
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )


+ 99
- 3
fastNLP/core/utils.py View File

@@ -8,18 +8,22 @@ __all__ = [
"get_seq_len"
]

import _pickle
import inspect
import os
import warnings
from collections import Counter, namedtuple
from copy import deepcopy
from typing import List

import _pickle
import numpy as np
import torch
import torch.nn as nn
from typing import List
from ._logger import logger
from prettytable import PrettyTable

from ._logger import logger
from ._parallel_utils import _model_contains_inner_module
# from .vocabulary import Vocabulary

try:
from apex import amp
@@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs'])




class ConfusionMatrix:
"""a dict can provide Confusion Matrix"""
def __init__(self, vocab=None):
"""
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。
"""
if vocab and not hasattr(vocab, 'to_word'):
raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary,"
f"got {type(vocab)}.")
self.confusiondict={} #key: pred index, value:target word ocunt
self.predcount={} #key:pred index, value:count
self.targetcount={} #key:target index, value:count
self.vocab=vocab
def add_pred_target(self, pred, target): #一组结果
"""
通过这个函数向ConfusionMatrix加入一组预测结果

:param list pred: 预测的标签列表
:param list target: 真实值的标签列表
:return ConfusionMatrix

confusion=ConfusionMatrix()
pred = [2,1,3]
target = [2,2,1]
confusion.add_pred_target(pred, target)
print(confusion)

target 1 2 3 all
pred
1 0 1 0 1
2 0 1 0 1
3 1 0 0 1
all 1 2 0 3
"""
for p,t in zip(pred,target): #<int, int>
self.predcount[p]=self.predcount.get(p,0)+ 1
self.targetcount[t]=self.targetcount.get(t,0)+1
if p in self.confusiondict:
self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1
else:
self.confusiondict[p]={}
self.confusiondict[p][t]= 1
return self.confusiondict

def clear(self):
"""
清除一些值,等待再次新加入
:return:
"""
self.confusiondict={}
self.targetcount={}
self.predcount={}
def __repr__(self):
"""
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。
"""
row2idx={}
idx2row={}
# 已知的所有键/label
totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys()))))
lenth=len(totallabel)
# namedict key :idx value:word/idx
namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel])

for label,idx in zip(totallabel,range(lenth)):
idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
# 这里打印东西
#表头
head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"]
output="\t".join(head) + "\n" + "pred" + "\n"
#内容
for i in row2idx.keys(): #第i行
p=row2idx[i]
h=namedict[p]
l=[0 for _ in range(lenth)]
if self.confusiondict.get(p,None):
for t,c in self.confusiondict[p].items():
l[idx2row[t]] = c #完成一行
l=[h]+[str(n) for n in l]+[str(sum(l))]
output+="\t".join(l) +"\n"
#表尾
tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()]
tail=["all"]+[str(n) for n in tail]+[str(sum(tail))]
output+="\t".join(tail)
return output


class Option(dict):
"""a dict can treat keys as attributes"""



+ 137
- 2
test/core/test_metrics.py View File

@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric
from fastNLP.core.metrics import _pred_topk, _accuracy_topk
from fastNLP.core.vocabulary import Vocabulary
from collections import Counter
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric


def _generate_tags(encoding_type, number_labels=4):
@@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result):
allen_result[key] = round(value, 6)
return allen_result



class TestConfusionMatrixMetric(unittest.TestCase):
def test_ConfusionMatrixMetric1(self):
pred_dict = {"pred": torch.zeros(4,3)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric()

metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())

def test_ConfusionMatrixMetric2(self):
# (2) with corrupted size
try:
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric())
except Exception as e:
print(e)
return
print("No exception catches.")

def test_ConfusionMatrixMetric3(self):
# (3) the second batch is corrupted size
try:
metric = ConfusionMatrixMetric()
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
except Exception as e:
print(e)
return
assert(True, False), "No exception catches."

def test_ConfusionMatrixMetric4(self):
# (4) check reset
metric = ConfusionMatrixMetric()
pred_dict = {"pred": torch.randn(4, 3, 2)}
target_dict = {'target': torch.ones(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
res = metric.get_metric()
self.assertTrue(isinstance(res, dict))
print(res)

def test_ConfusionMatrixMetric5(self):
# (5) check numpy array is not acceptable
try:
metric = ConfusionMatrixMetric()
pred_dict = {"pred": np.zeros((4, 3, 2))}
target_dict = {'target': np.zeros((4, 3))}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
def test_ConfusionMatrixMetric6(self):
# (6) check map, match
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.randn(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
res = metric.get_metric()
print(res)

def test_ConfusionMatrixMetric7(self):
# (7) check map, include unused
try:
metric = ConfusionMatrixMetric(pred='prediction', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."
def test_ConfusionMatrixMetric8(self):
# (8) check _fast_metric
try:
metric = ConfusionMatrixMetric()
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

def test_duplicate(self):
# 0.4.1的潜在bug,不能出现形参重复的情况
metric = ConfusionMatrixMetric(pred='predictions', target='targets')
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0}
target_dict = {'targets':torch.zeros(4, 3), 'target': 0}
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())


def test_seq_len(self):
N = 256
seq_len = torch.zeros(N).long()
seq_len[0] = 2
pred = {'pred': torch.ones(N, 2)}
target = {'target': torch.ones(N, 2), 'seq_len': seq_len}
metric = ConfusionMatrixMetric()
metric(pred_dict=pred, target_dict=target)
metric.get_metric(reset=False)
seq_len[1:] = 1
metric(pred_dict=pred, target_dict=target)
metric.get_metric()

def test_vocab(self):
vocab = Vocabulary()
word_list = "this is a word list".split()
vocab.update(word_list)
pred_dict = {"pred": torch.zeros(4,3)}
target_dict = {'target': torch.zeros(4)}
metric = ConfusionMatrixMetric(vocab=vocab)
metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric())




class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
# (1) only input, targets passed
@@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase):
def test_AccuaryMetric8(self):
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2)}
pred_dict = {"predictions": torch.zeros(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict, )
self.assertDictEqual(metric.get_metric(), {'acc': 1})


Loading…
Cancel
Save