From 228cca44e93e61947c8f39fa83d4aee0fc6923c4 Mon Sep 17 00:00:00 2001 From: ROGERDJQ Date: Fri, 10 Jul 2020 14:28:30 +0800 Subject: [PATCH] [new] add show-result-list for confusion_matrix (#309) add ConfusionMatrix and allow choose part of the column --- fastNLP/core/metrics.py | 2 ++ fastNLP/core/utils.py | 26 +++++++++++++++++++++----- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f893bd74..cf5b82b7 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -313,6 +313,7 @@ class ConfusionMatrixMetric(MetricBase): pred=None, target=None, seq_len=None, + show_result=None, print_ratio=False ): r""" @@ -326,6 +327,7 @@ class ConfusionMatrixMetric(MetricBase): super().__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) self.confusion_matrix = ConfusionMatrix( + show_result=show_result, vocab=vocab, print_ratio=print_ratio, ) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d324af72..212a31e6 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -36,8 +36,9 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require class ConfusionMatrix: r"""a dict can provide Confusion Matrix""" - def __init__(self, vocab=None, print_ratio=False): + def __init__(self, show_result=None,vocab=None, print_ratio=False): r""" + :param show_result: list type, 数据类型需要和target保持一致 :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 :param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列 """ @@ -48,6 +49,7 @@ class ConfusionMatrix: self.confusiondict = {} # key: pred index, value:target word ocunt self.predcount = {} # key:pred index, value:count self.targetcount = {} # key:target index, value:count + self.show_result = show_result self.vocab = vocab self.print_ratio = print_ratio @@ -153,16 +155,16 @@ class ConfusionMatrix: set(self.targetcount.keys()).union(set( self.predcount.keys())))) lenth = len(totallabel) - # namedict key :idx value:word/idx + # namedict key :label idx value: str label name/label idx namedict = dict([ (k, str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel ]) - for label, idx in zip(totallabel, range(lenth)): + for label, lineidx in zip(totallabel, range(lenth)): idx2row[ - label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... + label] = lineidx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... row2idx[ - idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... + lineidx] = label # 建立一个临时字典,key: 行列index 0,1,2...->1,3,5,...,value:vocab的index, # 这里打印东西 out = str() output = [] @@ -183,6 +185,7 @@ class ConfusionMatrix: for idx in range(len(col_lenths)) ] output.append(l) + tail = ["all"] + [[str(n) + "%", str(n)][flag == "result"] for n in data[-1]] col_lenths = [ @@ -190,6 +193,18 @@ class ConfusionMatrix: for idx in range(len(col_lenths)) ] output.append(tail) + + if self.show_result: + missing_item=[] + missing_item = [i for i in self.show_result if i not in idx2row] + self.show_result = [i for i in self.show_result if i in idx2row] + if missing_item: + print(f"Noticing label(s) which is/are not in target list appeared, final output string will not contain{str(missing_item)}") + if self.show_result: + show_col = [0] + [i + 1 for i in [idx2row[i] for i in self.show_result]] + show_row = [0]+[i+2 for i in [idx2row[i] for i in self.show_result]] + output = [[row[col] for col in show_col] for row in [output[row] for row in show_row]] + output.insert(1,["pred"]) for line in output: for colidx in range(len(line)): out += "%*s" % (col_lenths[colidx], line[colidx]) + "\t" @@ -217,6 +232,7 @@ class ConfusionMatrix: return out + class Option(dict): r"""a dict can treat keys as attributes"""