Browse Source

增加对ConfusionMatrix的测试用例

tags/v0.5.5
yh_cc 4 years ago
parent
commit
885c74022c
2 changed files with 27 additions and 11 deletions
  1. +2
    -11
      fastNLP/core/utils.py
  2. +25
    -0
      test/core/test_utils.py

+ 2
- 11
fastNLP/core/utils.py View File

@@ -34,8 +34,6 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require
'varargs'])




class ConfusionMatrix:
"""a dict can provide Confusion Matrix"""
def __init__(self, vocab=None, print_ratio=False):
@@ -83,7 +81,7 @@ class ConfusionMatrix:

def clear(self):
"""
除一些值,等待再次新加入
空ConfusionMatrix,等待再次新加入
:return:
"""
self.confusiondict = {}
@@ -102,11 +100,6 @@ class ConfusionMatrix:
set(self.targetcount.keys()).union(set(
self.predcount.keys()))))
lenth = len(totallabel)
# namedict key :idx value:word/idx
namedict = dict([
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel
])

for label, idx in zip(totallabel, range(lenth)):
idx2row[
@@ -116,7 +109,6 @@ class ConfusionMatrix:
output = []
for i in row2idx.keys(): # 第i行
p = row2idx[i]
h = namedict[p]
l = [0 for _ in range(lenth)]
if self.confusiondict.get(p, None):
for t, c in self.confusiondict[p].items():
@@ -141,7 +133,7 @@ class ConfusionMatrix:
tmp = tmp * 100
elif dim == 1:
tmp = np.array(result).T
mp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
tmp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
tmp = tmp.T * 100
tmp = np.around(tmp, decimals=2)
return tmp.tolist()
@@ -172,7 +164,6 @@ class ConfusionMatrix:
row2idx[
idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
# 这里打印东西
col_lenths = []
out = str()
output = []
# 表头


+ 25
- 0
test/core/test_utils.py View File

@@ -288,3 +288,28 @@ class TestUtils(unittest.TestCase):

self.assertSequenceEqual(convert_tags, iob2bioes(tags))

class TestConfusionMatrix(unittest.TestCase):
def test1(self):
# 测试能否正常打印
from fastNLP import Vocabulary
from fastNLP.core.utils import ConfusionMatrix
import numpy as np
vocab = Vocabulary(unknown=None, padding=None)
vocab.add_word_lst(list('abcdef'))
confusion_matrix = ConfusionMatrix(vocab)
for _ in range(3):
length = np.random.randint(1, 5)
pred = np.random.randint(0, 3, size=(length,))
target = np.random.randint(0, 3, size=(length,))
confusion_matrix.add_pred_target(pred, target)
print(confusion_matrix)

# 测试print_ratio
confusion_matrix = ConfusionMatrix(vocab, print_ratio=True)
for _ in range(3):
length = np.random.randint(1, 5)
pred = np.random.randint(0, 3, size=(length,))
target = np.random.randint(0, 3, size=(length,))
confusion_matrix.add_pred_target(pred, target)
print(confusion_matrix)


Loading…
Cancel
Save