Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.6.0
yh_cc 4 years ago
parent
commit
14ea125687
2 changed files with 23 additions and 5 deletions
  1. +2
    -0
      fastNLP/core/metrics.py
  2. +21
    -5
      fastNLP/core/utils.py

+ 2
- 0
fastNLP/core/metrics.py View File

@@ -313,6 +313,7 @@ class ConfusionMatrixMetric(MetricBase):
pred=None,
target=None,
seq_len=None,
show_result=None,
print_ratio=False
):
r"""
@@ -326,6 +327,7 @@ class ConfusionMatrixMetric(MetricBase):
super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.confusion_matrix = ConfusionMatrix(
show_result=show_result,
vocab=vocab,
print_ratio=print_ratio,
)


+ 21
- 5
fastNLP/core/utils.py View File

@@ -36,8 +36,9 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require

class ConfusionMatrix:
r"""a dict can provide Confusion Matrix"""
def __init__(self, vocab=None, print_ratio=False):
def __init__(self, show_result=None,vocab=None, print_ratio=False):
r"""
:param show_result: list type, 数据类型需要和target保持一致
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。
:param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列
"""
@@ -48,6 +49,7 @@ class ConfusionMatrix:
self.confusiondict = {} # key: pred index, value:target word ocunt
self.predcount = {} # key:pred index, value:count
self.targetcount = {} # key:target index, value:count
self.show_result = show_result
self.vocab = vocab
self.print_ratio = print_ratio

@@ -153,16 +155,16 @@ class ConfusionMatrix:
set(self.targetcount.keys()).union(set(
self.predcount.keys()))))
lenth = len(totallabel)
# namedict key :idx value:word/idx
# namedict key :label idx value: str label name/label idx
namedict = dict([
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel
])
for label, idx in zip(totallabel, range(lenth)):
for label, lineidx in zip(totallabel, range(lenth)):
idx2row[
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
label] = lineidx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[
idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
lineidx] = label # 建立一个临时字典,key: 行列index 0,1,2...->1,3,5,...,value:vocab的index,
# 这里打印东西
out = str()
output = []
@@ -183,6 +185,7 @@ class ConfusionMatrix:
for idx in range(len(col_lenths))
]
output.append(l)

tail = ["all"] + [[str(n) + "%", str(n)][flag == "result"]
for n in data[-1]]
col_lenths = [
@@ -190,6 +193,18 @@ class ConfusionMatrix:
for idx in range(len(col_lenths))
]
output.append(tail)

if self.show_result:
missing_item=[]
missing_item = [i for i in self.show_result if i not in idx2row]
self.show_result = [i for i in self.show_result if i in idx2row]
if missing_item:
print(f"Noticing label(s) which is/are not in target list appeared, final output string will not contain{str(missing_item)}")
if self.show_result:
show_col = [0] + [i + 1 for i in [idx2row[i] for i in self.show_result]]
show_row = [0]+[i+2 for i in [idx2row[i] for i in self.show_result]]
output = [[row[col] for col in show_col] for row in [output[row] for row in show_row]]
output.insert(1,["pred"])
for line in output:
for colidx in range(len(line)):
out += "%*s" % (col_lenths[colidx], line[colidx]) + "\t"
@@ -217,6 +232,7 @@ class ConfusionMatrix:
return out



class Option(dict):
r"""a dict can treat keys as attributes"""



Loading…
Cancel
Save