diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 146b532f..97d181f9 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -282,39 +282,46 @@ class MetricBase(object): class ConfusionMatrixMetric(MetricBase): r""" 分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) - 最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例} ConfusionMatrix实例的print()函数将输出矩阵字符串。 - pred_dict = {"pred": torch.Tensor([2,1,3])} target_dict = {'target': torch.Tensor([2,2,1])} metric = ConfusionMatrixMetric() metric(pred_dict=pred_dict, target_dict=target_dict, ) print(metric.get_metric()) - {'confusion_matrix': - target 1.0 2.0 3.0 all - pred - 1.0 0 1 0 1 - 2.0 0 1 0 1 - 3.0 1 0 0 1 - all 1 2 0 3} + target 1.0 2.0 3.0 all + pred + 1.0 0 1 0 1 + 2.0 0 1 0 1 + 3.0 1 0 0 1 + all 1 2 0 3 +} """ - def __init__(self, vocab=None, pred=None, target=None, seq_len=None): + def __init__(self, + vocab=None, + pred=None, + target=None, + seq_len=None, + print_ratio=False + ): """ :param vocab: vocab词表类,要求有to_word()方法。 :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` + :param print_ratio: 限制print的输出,false only for result, true for result, percent(dim=0), percent(dim = 1) """ super().__init__() self._init_param_map(pred=pred, target=target, seq_len=seq_len) - self.confusion_matrix = ConfusionMatrix(vocab=vocab) + self.confusion_matrix = ConfusionMatrix( + vocab=vocab, + print_ratio=print_ratio, + ) def evaluate(self, pred, target, seq_len=None): """ evaluate函数将针对一个批次的预测结果做评价指标的累计 - :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), @@ -323,15 +330,18 @@ class ConfusionMatrixMetric(MetricBase): """ if not isinstance(pred, torch.Tensor): - raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(pred)}.") + raise TypeError( + f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") if not isinstance(target, torch.Tensor): - raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(target)}.") + raise TypeError( + f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(target)}.") if seq_len is not None and not isinstance(seq_len, torch.Tensor): - raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," - f"got {type(seq_len)}.") + raise TypeError( + f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(seq_len)}.") if pred.dim() == target.dim(): pass @@ -340,25 +350,27 @@ class ConfusionMatrixMetric(MetricBase): if seq_len is None and target.dim() > 1: warnings.warn("You are not passing `seq_len` to exclude pad.") else: - raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " - f"size:{pred.size()}, target should have size: {pred.size()} or " - f"{pred.size()[:-1]}, got {target.size()}.") + raise RuntimeError( + f"In {_get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") target = target.to(pred) - if seq_len is not None and target.dim() > 1: - for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()): - l=int(l) + if seq_len is not None and target.dim() > 1: + for p, t, l in zip(pred.tolist(), target.tolist(), + seq_len.tolist()): + l = int(l) self.confusion_matrix.add_pred_target(p[:l], t[:l]) - elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出 - for p, t in zip(pred.tolist(), target.tolist()): + elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出 + for p, t in zip(pred.tolist(), target.tolist()): self.confusion_matrix.add_pred_target(p, t) else: - self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist()) + self.confusion_matrix.add_pred_target(pred.tolist(), + target.tolist()) - def get_metric(self,reset=True): + def get_metric(self, reset=True): """ get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. - :param bool reset: 在调用完get_metric后是否清空评价指标统计量. :return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} """ @@ -368,6 +380,9 @@ class ConfusionMatrixMetric(MetricBase): return confusion + + + class AccuracyMetric(MetricBase): """ 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 05722c48..daf0b050 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -38,47 +38,47 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require class ConfusionMatrix: """a dict can provide Confusion Matrix""" - def __init__(self, vocab=None): + def __init__(self, vocab=None, print_ratio=False): """ :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 + :param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列 """ - if vocab and not hasattr(vocab, 'to_word'): - raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary," - f"got {type(vocab)}.") - self.confusiondict={} #key: pred index, value:target word ocunt - self.predcount={} #key:pred index, value:count - self.targetcount={} #key:target index, value:count - self.vocab=vocab - - def add_pred_target(self, pred, target): #一组结果 + if vocab and not hasattr(vocab, "to_word"): + raise TypeError( + f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary," + f"got {type(vocab)}.") + self.confusiondict = {} # key: pred index, value:target word ocunt + self.predcount = {} # key:pred index, value:count + self.targetcount = {} # key:target index, value:count + self.vocab = vocab + self.print_ratio = print_ratio + + def add_pred_target(self, pred, target): # 一组结果 """ 通过这个函数向ConfusionMatrix加入一组预测结果 - :param list pred: 预测的标签列表 :param list target: 真实值的标签列表 :return ConfusionMatrix - confusion=ConfusionMatrix() pred = [2,1,3] target = [2,2,1] confusion.add_pred_target(pred, target) print(confusion) - target 1 2 3 all - pred - 1 0 1 0 1 - 2 0 1 0 1 - 3 1 0 0 1 - all 1 2 0 3 + pred + 1 0 1 0 1 + 2 0 1 0 1 + 3 1 0 0 1 + all 1 2 0 3 """ - for p,t in zip(pred,target): # - self.predcount[p]=self.predcount.get(p,0)+ 1 - self.targetcount[t]=self.targetcount.get(t,0)+1 + for p, t in zip(pred, target): # + self.predcount[p] = self.predcount.get(p, 0) + 1 + self.targetcount[t] = self.targetcount.get(t, 0) + 1 if p in self.confusiondict: - self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1 + self.confusiondict[p][t] = self.confusiondict[p].get(t, 0) + 1 else: - self.confusiondict[p]={} - self.confusiondict[p][t]= 1 + self.confusiondict[p] = {} + self.confusiondict[p][t] = 1 return self.confusiondict def clear(self): @@ -86,44 +86,144 @@ class ConfusionMatrix: 清除一些值,等待再次新加入 :return: """ - self.confusiondict={} - self.targetcount={} - self.predcount={} - - def __repr__(self): + self.confusiondict = {} + self.targetcount = {} + self.predcount = {} + + def get_result(self): """ - :return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 + :return list output: ConfusionMatrix content,具体值与汇总统计 """ - row2idx={} - idx2row={} + row2idx = {} + idx2row = {} # 已知的所有键/label - totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys())))) - lenth=len(totallabel) + totallabel = sorted( + list( + set(self.targetcount.keys()).union(set( + self.predcount.keys())))) + lenth = len(totallabel) # namedict key :idx value:word/idx - namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel]) + namedict = dict([ + (k, str(k if self.vocab == None else self.vocab.to_word(k))) + for k in totallabel + ]) + + for label, idx in zip(totallabel, range(lenth)): + idx2row[ + label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... + row2idx[ + idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... + output = [] + for i in row2idx.keys(): # 第i行 + p = row2idx[i] + h = namedict[p] + l = [0 for _ in range(lenth)] + if self.confusiondict.get(p, None): + for t, c in self.confusiondict[p].items(): + l[idx2row[t]] = c # 完成一行 + l = [n for n in l] + [sum(l)] + output.append(l) + tail = [self.targetcount.get(row2idx[k], 0) for k in row2idx.keys()] + tail += [sum(tail)] + output.append(tail) + return output - for label,idx in zip(totallabel,range(lenth)): - idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... - row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... + def get_percent(self, dim=0): + """ + :param dim int: 0/1, 0 for row,1 for column + :return list output: ConfusionMatrix content,具体值与汇总统计 + """ + result = self.get_result() + if dim == 0: + tmp = np.array(result) + tmp = tmp / (tmp[:, -1].reshape([len(result), -1])) + tmp[np.isnan(tmp)] = 0 + tmp = tmp * 100 + elif dim == 1: + tmp = np.array(result).T + mp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12) + tmp = tmp.T * 100 + tmp = np.around(tmp, decimals=2) + return tmp.tolist() + + def get_aligned_table(self, data, flag="result"): + """ + :param data: highly recommend use get_percent/ get_result return as dataset here, or make sure data is a n*n list type data + :param flag: only difference between result and other words is whether "%" is in output string + :return: an aligned_table ready to print out + """ + row2idx = {} + idx2row = {} + # 已知的所有键/label + totallabel = sorted( + list( + set(self.targetcount.keys()).union(set( + self.predcount.keys())))) + lenth = len(totallabel) + # namedict key :idx value:word/idx + namedict = dict([ + (k, str(k if self.vocab == None else self.vocab.to_word(k))) + for k in totallabel + ]) + + for label, idx in zip(totallabel, range(lenth)): + idx2row[ + label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,... + row2idx[ + idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,... # 这里打印东西 - #表头 - head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"] - output="\t".join(head) + "\n" + "pred" + "\n" - #内容 - for i in row2idx.keys(): #第i行 - p=row2idx[i] - h=namedict[p] - l=[0 for _ in range(lenth)] - if self.confusiondict.get(p,None): - for t,c in self.confusiondict[p].items(): - l[idx2row[t]] = c #完成一行 - l=[h]+[str(n) for n in l]+[str(sum(l))] - output+="\t".join(l) +"\n" - #表尾 - tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()] - tail=["all"]+[str(n) for n in tail]+[str(sum(tail))] - output+="\t".join(tail) - return output + col_lenths = [] + out = str() + output = [] + # 表头 + head = (["target"] + + [str(namedict[row2idx[k]]) for k in row2idx.keys()] + ["all"]) + col_lenths = [len(h) for h in head] + output.append(head) + output.append(["pred"]) + # 内容 + for i in row2idx.keys(): # 第i行 + p = row2idx[i] + h = namedict[p] + l = [h] + [[str(n) + "%", str(n)][flag == "result"] + for n in data[i]] + col_lenths = [ + max(col_lenths[idx], [len(i) for i in l][idx]) + for idx in range(len(col_lenths)) + ] + output.append(l) + tail = ["all"] + [[str(n) + "%", str(n)][flag == "result"] + for n in data[-1]] + col_lenths = [ + max(col_lenths[idx], [len(i) for i in tail][idx]) + for idx in range(len(col_lenths)) + ] + output.append(tail) + for line in output: + for colidx in range(len(line)): + out += "%*s" % (col_lenths[colidx], line[colidx]) + "\t" + out += "\n" + return "\n" + out + + def __repr__(self): + """ + :return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。 + """ + result = self.get_result() + o0 = self.get_aligned_table(result, flag="result") + + out = str() + if self.print_ratio: + p1 = self.get_percent() + o1 = "\nNotice the row direction\n" + self.get_aligned_table( + p1, flag="percent") + p2 = self.get_percent(dim=1) + o2 = "\nNotice the column direction\n" + self.get_aligned_table( + p2, flag="percent") + out = out + o0 + o1 + o2 + else: + out = o0 + return out class Option(dict):