* add ConfusionMatrix, ConfusionMatrixMetric * add confusionmatrix to utils * add ConfusionMatrixmetric * add ConfusionMatrixMetric * init for test * begin test * test finish * doc finishtags/v0.5.5
@@ -7,7 +7,8 @@ __all__ = [ | |||||
"AccuracyMetric", | "AccuracyMetric", | ||||
"SpanFPreRecMetric", | "SpanFPreRecMetric", | ||||
"CMRC2018Metric", | "CMRC2018Metric", | ||||
"ClassifyFPreRecMetric" | |||||
"ClassifyFPreRecMetric", | |||||
"ConfusionMatrixMetric" | |||||
] | ] | ||||
import inspect | import inspect | ||||
@@ -15,6 +16,7 @@ import warnings | |||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from collections import defaultdict | from collections import defaultdict | ||||
from typing import Union | from typing import Union | ||||
from copy import deepcopy | |||||
import re | import re | ||||
import numpy as np | import numpy as np | ||||
@@ -27,6 +29,7 @@ from .utils import _check_arg_dict_list | |||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import seq_len_to_mask | from .utils import seq_len_to_mask | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
from .utils import ConfusionMatrix | |||||
class MetricBase(object): | class MetricBase(object): | ||||
@@ -276,6 +279,95 @@ class MetricBase(object): | |||||
return | return | ||||
class ConfusionMatrixMetric(MetricBase): | |||||
r""" | |||||
分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | |||||
最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例} | |||||
ConfusionMatrix实例的print()函数将输出矩阵字符串。 | |||||
pred_dict = {"pred": torch.Tensor([2,1,3])} | |||||
target_dict = {'target': torch.Tensor([2,2,1])} | |||||
metric = ConfusionMatrixMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
{'confusion_matrix': | |||||
target 1.0 2.0 3.0 all | |||||
pred | |||||
1.0 0 1 0 1 | |||||
2.0 0 1 0 1 | |||||
3.0 1 0 0 1 | |||||
all 1 2 0 3} | |||||
""" | |||||
def __init__(self, vocab=None, pred=None, target=None, seq_len=None): | |||||
""" | |||||
:param vocab: vocab词表类,要求有to_word()方法。 | |||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||||
""" | |||||
super().__init__() | |||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||||
self.confusion_matrix = ConfusionMatrix(vocab=vocab) | |||||
def evaluate(self, pred, target, seq_len=None): | |||||
""" | |||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, torch.Size([B]), 或者torch.Size([B]). | |||||
""" | |||||
if not isinstance(pred, torch.Tensor): | |||||
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(pred)}.") | |||||
if not isinstance(target, torch.Tensor): | |||||
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(target)}.") | |||||
if seq_len is not None and not isinstance(seq_len, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(seq_len)}.") | |||||
if pred.dim() == target.dim(): | |||||
pass | |||||
elif pred.dim() == target.dim() + 1: | |||||
pred = pred.argmax(dim=-1) | |||||
if seq_len is None and target.dim() > 1: | |||||
warnings.warn("You are not passing `seq_len` to exclude pad.") | |||||
else: | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||||
f"{pred.size()[:-1]}, got {target.size()}.") | |||||
target = target.to(pred) | |||||
if seq_len is not None and target.dim() > 1: | |||||
for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()): | |||||
l=int(l) | |||||
self.confusion_matrix.add_pred_target(p[:l], t[:l]) | |||||
elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出 | |||||
for p, t in zip(pred.tolist(), target.tolist()): | |||||
self.confusion_matrix.add_pred_target(p, t) | |||||
else: | |||||
self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist()) | |||||
def get_metric(self,reset=True): | |||||
""" | |||||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||||
:param bool reset: 在调用完get_metric后是否清空评价指标统计量. | |||||
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} | |||||
""" | |||||
confusion = {'confusion_matrix': deepcopy(self.confusion_matrix)} | |||||
if reset: | |||||
self.confusion_matrix.clear() | |||||
return confusion | |||||
class AccuracyMetric(MetricBase): | class AccuracyMetric(MetricBase): | ||||
""" | """ | ||||
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | ||||
@@ -8,18 +8,22 @@ __all__ = [ | |||||
"get_seq_len" | "get_seq_len" | ||||
] | ] | ||||
import _pickle | |||||
import inspect | import inspect | ||||
import os | import os | ||||
import warnings | import warnings | ||||
from collections import Counter, namedtuple | from collections import Counter, namedtuple | ||||
from copy import deepcopy | |||||
from typing import List | |||||
import _pickle | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from typing import List | |||||
from ._logger import logger | |||||
from prettytable import PrettyTable | from prettytable import PrettyTable | ||||
from ._logger import logger | |||||
from ._parallel_utils import _model_contains_inner_module | from ._parallel_utils import _model_contains_inner_module | ||||
# from .vocabulary import Vocabulary | |||||
try: | try: | ||||
from apex import amp | from apex import amp | ||||
@@ -30,6 +34,98 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||||
'varargs']) | 'varargs']) | ||||
class ConfusionMatrix: | |||||
"""a dict can provide Confusion Matrix""" | |||||
def __init__(self, vocab=None): | |||||
""" | |||||
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 | |||||
""" | |||||
if vocab and not hasattr(vocab, 'to_word'): | |||||
raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary," | |||||
f"got {type(vocab)}.") | |||||
self.confusiondict={} #key: pred index, value:target word ocunt | |||||
self.predcount={} #key:pred index, value:count | |||||
self.targetcount={} #key:target index, value:count | |||||
self.vocab=vocab | |||||
def add_pred_target(self, pred, target): #一组结果 | |||||
""" | |||||
通过这个函数向ConfusionMatrix加入一组预测结果 | |||||
:param list pred: 预测的标签列表 | |||||
:param list target: 真实值的标签列表 | |||||
:return ConfusionMatrix | |||||
confusion=ConfusionMatrix() | |||||
pred = [2,1,3] | |||||
target = [2,2,1] | |||||
confusion.add_pred_target(pred, target) | |||||
print(confusion) | |||||
target 1 2 3 all | |||||
pred | |||||
1 0 1 0 1 | |||||
2 0 1 0 1 | |||||
3 1 0 0 1 | |||||
all 1 2 0 3 | |||||
""" | |||||
for p,t in zip(pred,target): #<int, int> | |||||
self.predcount[p]=self.predcount.get(p,0)+ 1 | |||||
self.targetcount[t]=self.targetcount.get(t,0)+1 | |||||
if p in self.confusiondict: | |||||
self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1 | |||||
else: | |||||
self.confusiondict[p]={} | |||||
self.confusiondict[p][t]= 1 | |||||
return self.confusiondict | |||||
def clear(self): | |||||
""" | |||||
清除一些值,等待再次新加入 | |||||
:return: | |||||
""" | |||||
self.confusiondict={} | |||||
self.targetcount={} | |||||
self.predcount={} | |||||
def __repr__(self): | |||||
""" | |||||
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 | |||||
""" | |||||
row2idx={} | |||||
idx2row={} | |||||
# 已知的所有键/label | |||||
totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys())))) | |||||
lenth=len(totallabel) | |||||
# namedict key :idx value:word/idx | |||||
namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel]) | |||||
for label,idx in zip(totallabel,range(lenth)): | |||||
idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... | |||||
row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... | |||||
# 这里打印东西 | |||||
#表头 | |||||
head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"] | |||||
output="\t".join(head) + "\n" + "pred" + "\n" | |||||
#内容 | |||||
for i in row2idx.keys(): #第i行 | |||||
p=row2idx[i] | |||||
h=namedict[p] | |||||
l=[0 for _ in range(lenth)] | |||||
if self.confusiondict.get(p,None): | |||||
for t,c in self.confusiondict[p].items(): | |||||
l[idx2row[t]] = c #完成一行 | |||||
l=[h]+[str(n) for n in l]+[str(sum(l))] | |||||
output+="\t".join(l) +"\n" | |||||
#表尾 | |||||
tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()] | |||||
tail=["all"]+[str(n) for n in tail]+[str(sum(tail))] | |||||
output+="\t".join(tail) | |||||
return output | |||||
class Option(dict): | class Option(dict): | ||||
"""a dict can treat keys as attributes""" | """a dict can treat keys as attributes""" | ||||
@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric | |||||
from fastNLP.core.metrics import _pred_topk, _accuracy_topk | from fastNLP.core.metrics import _pred_topk, _accuracy_topk | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric | |||||
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric | |||||
def _generate_tags(encoding_type, number_labels=4): | def _generate_tags(encoding_type, number_labels=4): | ||||
@@ -44,6 +44,141 @@ def _convert_res_to_fastnlp_res(metric_result): | |||||
allen_result[key] = round(value, 6) | allen_result[key] = round(value, 6) | ||||
return allen_result | return allen_result | ||||
class TestConfusionMatrixMetric(unittest.TestCase): | |||||
def test_ConfusionMatrixMetric1(self): | |||||
pred_dict = {"pred": torch.zeros(4,3)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = ConfusionMatrixMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
def test_ConfusionMatrixMetric2(self): | |||||
# (2) with corrupted size | |||||
try: | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = ConfusionMatrixMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
print("No exception catches.") | |||||
def test_ConfusionMatrixMetric3(self): | |||||
# (3) the second batch is corrupted size | |||||
try: | |||||
metric = ConfusionMatrixMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
assert(True, False), "No exception catches." | |||||
def test_ConfusionMatrixMetric4(self): | |||||
# (4) check reset | |||||
metric = ConfusionMatrixMetric() | |||||
pred_dict = {"pred": torch.randn(4, 3, 2)} | |||||
target_dict = {'target': torch.ones(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
res = metric.get_metric() | |||||
self.assertTrue(isinstance(res, dict)) | |||||
print(res) | |||||
def test_ConfusionMatrixMetric5(self): | |||||
# (5) check numpy array is not acceptable | |||||
try: | |||||
metric = ConfusionMatrixMetric() | |||||
pred_dict = {"pred": np.zeros((4, 3, 2))} | |||||
target_dict = {'target': np.zeros((4, 3))} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_ConfusionMatrixMetric6(self): | |||||
# (6) check map, match | |||||
metric = ConfusionMatrixMetric(pred='predictions', target='targets') | |||||
pred_dict = {"predictions": torch.randn(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
res = metric.get_metric() | |||||
print(res) | |||||
def test_ConfusionMatrixMetric7(self): | |||||
# (7) check map, include unused | |||||
try: | |||||
metric = ConfusionMatrixMetric(pred='prediction', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused': 1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_ConfusionMatrixMetric8(self): | |||||
# (8) check _fast_metric | |||||
try: | |||||
metric = ConfusionMatrixMetric() | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_duplicate(self): | |||||
# 0.4.1的潜在bug,不能出现形参重复的情况 | |||||
metric = ConfusionMatrixMetric(pred='predictions', target='targets') | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(4) * 3, 'pred':0} | |||||
target_dict = {'targets':torch.zeros(4, 3), 'target': 0} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
def test_seq_len(self): | |||||
N = 256 | |||||
seq_len = torch.zeros(N).long() | |||||
seq_len[0] = 2 | |||||
pred = {'pred': torch.ones(N, 2)} | |||||
target = {'target': torch.ones(N, 2), 'seq_len': seq_len} | |||||
metric = ConfusionMatrixMetric() | |||||
metric(pred_dict=pred, target_dict=target) | |||||
metric.get_metric(reset=False) | |||||
seq_len[1:] = 1 | |||||
metric(pred_dict=pred, target_dict=target) | |||||
metric.get_metric() | |||||
def test_vocab(self): | |||||
vocab = Vocabulary() | |||||
word_list = "this is a word list".split() | |||||
vocab.update(word_list) | |||||
pred_dict = {"pred": torch.zeros(4,3)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = ConfusionMatrixMetric(vocab=vocab) | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
class TestAccuracyMetric(unittest.TestCase): | class TestAccuracyMetric(unittest.TestCase): | ||||
def test_AccuracyMetric1(self): | def test_AccuracyMetric1(self): | ||||
# (1) only input, targets passed | # (1) only input, targets passed | ||||
@@ -133,7 +268,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
def test_AccuaryMetric8(self): | def test_AccuaryMetric8(self): | ||||
try: | try: | ||||
metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | metric(pred_dict=pred_dict, target_dict=target_dict, ) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||