From 885c74022cd6a508b9c57eba6a1e3c6791529be5 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 24 Mar 2020 16:09:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9ConfusionMatrix?= =?UTF-8?q?=E7=9A=84=E6=B5=8B=E8=AF=95=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/utils.py | 13 ++----------- test/core/test_utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index daf0b050..8941b9d8 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -34,8 +34,6 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require 'varargs']) - - class ConfusionMatrix: """a dict can provide Confusion Matrix""" def __init__(self, vocab=None, print_ratio=False): @@ -83,7 +81,7 @@ class ConfusionMatrix: def clear(self): """ - 清除一些值,等待再次新加入 + 清空ConfusionMatrix,等待再次新加入 :return: """ self.confusiondict = {} @@ -102,11 +100,6 @@ class ConfusionMatrix: set(self.targetcount.keys()).union(set( self.predcount.keys())))) lenth = len(totallabel) - # namedict key :idx value:word/idx - namedict = dict([ - (k, str(k if self.vocab == None else self.vocab.to_word(k))) - for k in totallabel - ]) for label, idx in zip(totallabel, range(lenth)): idx2row[ @@ -116,7 +109,6 @@ class ConfusionMatrix: output = [] for i in row2idx.keys(): # 第i行 p = row2idx[i] - h = namedict[p] l = [0 for _ in range(lenth)] if self.confusiondict.get(p, None): for t, c in self.confusiondict[p].items(): @@ -141,7 +133,7 @@ class ConfusionMatrix: tmp = tmp * 100 elif dim == 1: tmp = np.array(result).T - mp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12) + tmp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12) tmp = tmp.T * 100 tmp = np.around(tmp, decimals=2) return tmp.tolist() @@ -172,7 +164,6 @@ class ConfusionMatrix: row2idx[ idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... # 这里打印东西 - col_lenths = [] out = str() output = [] # 表头 diff --git a/test/core/test_utils.py b/test/core/test_utils.py index 0093c3e8..f4a29658 100644 --- a/test/core/test_utils.py +++ b/test/core/test_utils.py @@ -288,3 +288,28 @@ class TestUtils(unittest.TestCase): self.assertSequenceEqual(convert_tags, iob2bioes(tags)) +class TestConfusionMatrix(unittest.TestCase): + def test1(self): + # 测试能否正常打印 + from fastNLP import Vocabulary + from fastNLP.core.utils import ConfusionMatrix + import numpy as np + vocab = Vocabulary(unknown=None, padding=None) + vocab.add_word_lst(list('abcdef')) + confusion_matrix = ConfusionMatrix(vocab) + for _ in range(3): + length = np.random.randint(1, 5) + pred = np.random.randint(0, 3, size=(length,)) + target = np.random.randint(0, 3, size=(length,)) + confusion_matrix.add_pred_target(pred, target) + print(confusion_matrix) + + # 测试print_ratio + confusion_matrix = ConfusionMatrix(vocab, print_ratio=True) + for _ in range(3): + length = np.random.randint(1, 5) + pred = np.random.randint(0, 3, size=(length,)) + target = np.random.randint(0, 3, size=(length,)) + confusion_matrix.add_pred_target(pred, target) + print(confusion_matrix) +