* add ConfusionMatrix, ConfusionMatrixMetric * add confusionmatrix to utils * add ConfusionMatrixmetric * add ConfusionMatrixMetric * init for test * begin test * test finish * doc finishtags/v0.5.5
@@ -7,7 +7,8 @@ __all__ = [ | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"CMRC2018Metric", | |||
"ClassifyFPreRecMetric" | |||
"ClassifyFPreRecMetric", | |||
"ConfusionMatrixMetric" | |||
] | |||
import inspect | |||
@@ -15,6 +16,7 @@ import warnings | |||
from abc import abstractmethod | |||
from collections import defaultdict | |||
from typing import Union | |||
from copy import deepcopy | |||
import re | |||
import numpy as np | |||
@@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
from .utils import ConfusionMatrix | |||
class MetricBase(object): | |||
@@ -276,6 +279,95 @@ class MetricBase(object): | |||
return | |||
class ConfusionMatrixMetric(MetricBase): | |||
r""" | |||
分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | |||
最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例} | |||
ConfusionMatrix实例的print()函数将输出矩阵字符串。 | |||
pred_dict = {"pred": torch.Tensor([2,1,3])} | |||
target_dict = {'target': torch.Tensor([2,2,1])} | |||
metric = ConfusionMatrixMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
print(metric.get_metric()) | |||
{'confusion_matrix': | |||
target 1.0 2.0 3.0 all | |||
pred | |||
1.0 0 1 0 1 | |||
2.0 0 1 0 1 | |||
3.0 1 0 0 1 | |||
all 1 2 0 3} | |||
""" | |||
def __init__(self, vocab=None, pred=None, target=None, seq_len=None): | |||
""" | |||
:param vocab: vocab词表类,要求有to_word()方法。 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||
""" | |||
super().__init__() | |||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
self.confusion_matrix = ConfusionMatrix(vocab=vocab) | |||
def evaluate(self, pred, target, seq_len=None): | |||
""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]). | |||
""" | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
if not isinstance(target, torch.Tensor): | |||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(target)}.") | |||
if seq_len is not None and not isinstance(seq_len, torch.Tensor): | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_len)}.") | |||
if pred.dim() == target.dim(): | |||
pass | |||
elif pred.dim() == target.dim() + 1: | |||
pred = pred.argmax(dim=-1) | |||
if seq_len is None and target.dim() > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad.") | |||
else: | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
f"{pred.size()[:-1]}, got {target.size()}.") | |||
target = target.to(pred) | |||
if seq_len is not None and target.dim() > 1: | |||
for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()): | |||
l=int(l) | |||
self.confusion_matrix.add_pred_target(p[:l], t[:l]) | |||
elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出 | |||
for p, t in zip(pred.tolist(), target.tolist()): | |||
self.confusion_matrix.add_pred_target(p, t) | |||
else: | |||
self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist()) | |||
def get_metric(self,reset=True): | |||
""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | |||
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} | |||
""" | |||
confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)} | |||
if reset: | |||
self.confusion_matrix.clear() | |||
return confusion | |||
class AccuracyMetric(MetricBase): | |||
""" | |||
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | |||
@@ -8,18 +8,22 @@ __all__ = [ | |||
"get_seq_len" | |||
] | |||
import _pickle | |||
import inspect | |||
import os | |||
import warnings | |||
from collections import Counter, namedtuple | |||
from copy import deepcopy | |||
from typing import List | |||
import _pickle | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from typing import List | |||
from ._logger import logger | |||
from prettytable import PrettyTable | |||
from ._logger import logger | |||
from ._parallel_utils import _model_contains_inner_module | |||
# from .vocabulary import Vocabulary | |||
try: | |||
from apex import amp | |||
@@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||
'varargs']) | |||
class ConfusionMatrix: | |||
"""a dict can provide Confusion Matrix""" | |||
def __init__(self, vocab=None): | |||
""" | |||
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 | |||
""" | |||
if vocab and not hasattr(vocab, 'to_word'): | |||
raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary," | |||
f"got {type(vocab)}.") | |||
self.confusiondict={} #key: pred index, value:target word ocunt | |||
self.predcount={} #key:pred index, value:count | |||
self.targetcount={} #key:target index, value:count | |||
self.vocab=vocab | |||
def add_pred_target(self, pred, target): #一组结果 | |||
""" | |||
通过这个函数向ConfusionMatrix加入一组预测结果 | |||
:param list pred: 预测的标签列表 | |||
:param list target: 真实值的标签列表 | |||
:return ConfusionMatrix | |||
confusion=ConfusionMatrix() | |||
pred = [2,1,3] | |||
target = [2,2,1] | |||
confusion.add_pred_target(pred, target) | |||
print(confusion) | |||
target 1 2 3 all | |||
pred | |||
1 0 1 0 1 | |||
2 0 1 0 1 | |||
3 1 0 0 1 | |||
all 1 2 0 3 | |||
""" | |||
for p,t in zip(pred,target): #<int, int> | |||
self.predcount[p]=self.predcount.get(p,0)+ 1 | |||
self.targetcount[t]=self.targetcount.get(t,0)+1 | |||
if p in self.confusiondict: | |||
self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1 | |||
else: | |||
self.confusiondict[p]={} | |||
self.confusiondict[p][t]= 1 | |||
return self.confusiondict | |||
def clear(self): | |||
""" | |||
清除一些值,等待再次新加入 | |||
:return: | |||
""" | |||
self.confusiondict={} | |||
self.targetcount={} | |||
self.predcount={} | |||
def __repr__(self): | |||
""" | |||
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 | |||
""" | |||
row2idx={} | |||
idx2row={} | |||
# 已知的所有键/label | |||
totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys())))) | |||
lenth=len(totallabel) | |||
# namedict key :idx value:word/idx | |||
namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel]) | |||
for label,idx in zip(totallabel,range(lenth)): | |||
idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | |||
row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... | |||
# 这里打印东西 | |||
#表头 | |||
head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"] | |||
output="\t".join(head) + "\n" + "pred" + "\n" | |||
#内容 | |||
for i in row2idx.keys(): #第i行 | |||
p=row2idx[i] | |||
h=namedict[p] | |||
l=[0 for _ in range(lenth)] | |||
if self.confusiondict.get(p,None): | |||
for t,c in self.confusiondict[p].items(): | |||
l[idx2row[t]] = c #完成一行 | |||
l=[h]+[str(n) for n in l]+[str(sum(l))] | |||
output+="\t".join(l) +"\n" | |||
#表尾 | |||
tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()] | |||
tail=["all"]+[str(n) for n in tail]+[str(sum(tail))] | |||
output+="\t".join(tail) | |||
return output | |||
class Option(dict): | |||
"""a dict can treat keys as attributes""" | |||
@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric | |||
from fastNLP.core.metrics import _pred_topk, _accuracy_topk | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from collections import Counter | |||
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric | |||
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric | |||
def _generate_tags(encoding_type, number_labels=4): | |||
@@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result): | |||
allen_result[key] = round(value, 6) | |||
return allen_result | |||
class TestConfusionMatrixMetric(unittest.TestCase): | |||
def test_ConfusionMatrixMetric1(self): | |||
pred_dict = {"pred": torch.zeros(4,3)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = ConfusionMatrixMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
def test_ConfusionMatrixMetric2(self): | |||
# (2) with corrupted size | |||
try: | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = ConfusionMatrixMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
print("No exception catches.") | |||
def test_ConfusionMatrixMetric3(self): | |||
# (3) the second batch is corrupted size | |||
try: | |||
metric = ConfusionMatrixMetric() | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
assert(True, False), "No exception catches." | |||
def test_ConfusionMatrixMetric4(self): | |||
# (4) check reset | |||
metric = ConfusionMatrixMetric() | |||
pred_dict = {"pred": torch.randn(4, 3, 2)} | |||
target_dict = {'target': torch.ones(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
res = metric.get_metric() | |||
self.assertTrue(isinstance(res, dict)) | |||
print(res) | |||
def test_ConfusionMatrixMetric5(self): | |||
# (5) check numpy array is not acceptable | |||
try: | |||
metric = ConfusionMatrixMetric() | |||
pred_dict = {"pred": np.zeros((4, 3, 2))} | |||
target_dict = {'target': np.zeros((4, 3))} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_ConfusionMatrixMetric6(self): | |||
# (6) check map, match | |||
metric = ConfusionMatrixMetric(pred='predictions', target='targets') | |||
pred_dict = {"predictions": torch.randn(4, 3, 2)} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
res = metric.get_metric() | |||
print(res) | |||
def test_ConfusionMatrixMetric7(self): | |||
# (7) check map, include unused | |||
try: | |||
metric = ConfusionMatrixMetric(pred='prediction', target='targets') | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_ConfusionMatrixMetric8(self): | |||
# (8) check _fast_metric | |||
try: | |||
metric = ConfusionMatrixMetric() | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_duplicate(self): | |||
# 0.4.1的潜在bug,不能出现形参重复的情况 | |||
metric = ConfusionMatrixMetric(pred='predictions', target='targets') | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} | |||
target_dict = {'targets':torch.zeros(4, 3), 'target': 0} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
def test_seq_len(self): | |||
N = 256 | |||
seq_len = torch.zeros(N).long() | |||
seq_len[0] = 2 | |||
pred = {'pred': torch.ones(N, 2)} | |||
target = {'target': torch.ones(N, 2), 'seq_len': seq_len} | |||
metric = ConfusionMatrixMetric() | |||
metric(pred_dict=pred, target_dict=target) | |||
metric.get_metric(reset=False) | |||
seq_len[1:] = 1 | |||
metric(pred_dict=pred, target_dict=target) | |||
metric.get_metric() | |||
def test_vocab(self): | |||
vocab = Vocabulary() | |||
word_list = "this is a word list".split() | |||
vocab.update(word_list) | |||
pred_dict = {"pred": torch.zeros(4,3)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = ConfusionMatrixMetric(vocab=vocab) | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
class TestAccuracyMetric(unittest.TestCase): | |||
def test_AccuracyMetric1(self): | |||
# (1) only input, targets passed | |||
@@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
def test_AccuaryMetric8(self): | |||
try: | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||