|
|
@@ -36,8 +36,9 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require |
|
|
|
|
|
|
|
class ConfusionMatrix: |
|
|
|
r"""a dict can provide Confusion Matrix""" |
|
|
|
def __init__(self, vocab=None, print_ratio=False): |
|
|
|
def __init__(self, show_result=None,vocab=None, print_ratio=False): |
|
|
|
r""" |
|
|
|
:param show_result: list type, 数据类型需要和target保持一致 |
|
|
|
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 |
|
|
|
:param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列 |
|
|
|
""" |
|
|
@@ -48,6 +49,7 @@ class ConfusionMatrix: |
|
|
|
self.confusiondict = {} # key: pred index, value:target word ocunt |
|
|
|
self.predcount = {} # key:pred index, value:count |
|
|
|
self.targetcount = {} # key:target index, value:count |
|
|
|
self.show_result = show_result |
|
|
|
self.vocab = vocab |
|
|
|
self.print_ratio = print_ratio |
|
|
|
|
|
|
@@ -153,16 +155,16 @@ class ConfusionMatrix: |
|
|
|
set(self.targetcount.keys()).union(set( |
|
|
|
self.predcount.keys())))) |
|
|
|
lenth = len(totallabel) |
|
|
|
# namedict key :idx value:word/idx |
|
|
|
# namedict key :label idx value: str label name/label idx |
|
|
|
namedict = dict([ |
|
|
|
(k, str(k if self.vocab == None else self.vocab.to_word(k))) |
|
|
|
for k in totallabel |
|
|
|
]) |
|
|
|
for label, idx in zip(totallabel, range(lenth)): |
|
|
|
for label, lineidx in zip(totallabel, range(lenth)): |
|
|
|
idx2row[ |
|
|
|
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... |
|
|
|
label] = lineidx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... |
|
|
|
row2idx[ |
|
|
|
idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... |
|
|
|
lineidx] = label # 建立一个临时字典,key: 行列index 0,1,2...->1,3,5,...,value:vocab的index, |
|
|
|
# 这里打印东西 |
|
|
|
out = str() |
|
|
|
output = [] |
|
|
@@ -183,6 +185,7 @@ class ConfusionMatrix: |
|
|
|
for idx in range(len(col_lenths)) |
|
|
|
] |
|
|
|
output.append(l) |
|
|
|
|
|
|
|
tail = ["all"] + [[str(n) + "%", str(n)][flag == "result"] |
|
|
|
for n in data[-1]] |
|
|
|
col_lenths = [ |
|
|
@@ -190,6 +193,18 @@ class ConfusionMatrix: |
|
|
|
for idx in range(len(col_lenths)) |
|
|
|
] |
|
|
|
output.append(tail) |
|
|
|
|
|
|
|
if self.show_result: |
|
|
|
missing_item=[] |
|
|
|
missing_item = [i for i in self.show_result if i not in idx2row] |
|
|
|
self.show_result = [i for i in self.show_result if i in idx2row] |
|
|
|
if missing_item: |
|
|
|
print(f"Noticing label(s) which is/are not in target list appeared, final output string will not contain{str(missing_item)}") |
|
|
|
if self.show_result: |
|
|
|
show_col = [0] + [i + 1 for i in [idx2row[i] for i in self.show_result]] |
|
|
|
show_row = [0]+[i+2 for i in [idx2row[i] for i in self.show_result]] |
|
|
|
output = [[row[col] for col in show_col] for row in [output[row] for row in show_row]] |
|
|
|
output.insert(1,["pred"]) |
|
|
|
for line in output: |
|
|
|
for colidx in range(len(line)): |
|
|
|
out += "%*s" % (col_lenths[colidx], line[colidx]) + "\t" |
|
|
@@ -217,6 +232,7 @@ class ConfusionMatrix: |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Option(dict): |
|
|
|
r"""a dict can treat keys as attributes""" |
|
|
|
|
|
|
|