Browse Source

Merge branch 'master' of github.com:fastnlp/fastNLP

tags/v0.5.5
yh_cc 5 years ago
parent
commit
b3ace23d11
2 changed files with 200 additions and 85 deletions
  1. +44
    -29
      fastNLP/core/metrics.py
  2. +156
    -56
      fastNLP/core/utils.py

+ 44
- 29
fastNLP/core/metrics.py View File

@@ -282,39 +282,46 @@ class MetricBase(object):
class ConfusionMatrixMetric(MetricBase): class ConfusionMatrixMetric(MetricBase):
r""" r"""
分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) 分类问题计算混淆矩阵的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )

最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例} 最后返回结果为dict,{'confusion_matrix': ConfusionMatrix实例}
ConfusionMatrix实例的print()函数将输出矩阵字符串。 ConfusionMatrix实例的print()函数将输出矩阵字符串。

pred_dict = {"pred": torch.Tensor([2,1,3])} pred_dict = {"pred": torch.Tensor([2,1,3])}
target_dict = {'target': torch.Tensor([2,2,1])} target_dict = {'target': torch.Tensor([2,2,1])}
metric = ConfusionMatrixMetric() metric = ConfusionMatrixMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, ) metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric()) print(metric.get_metric())

{'confusion_matrix': {'confusion_matrix':
target 1.0 2.0 3.0 all
pred
1.0 0 1 0 1
2.0 0 1 0 1
3.0 1 0 0 1
all 1 2 0 3}
target 1.0 2.0 3.0 all
pred
1.0 0 1 0 1
2.0 0 1 0 1
3.0 1 0 0 1
all 1 2 0 3
}
""" """
def __init__(self, vocab=None, pred=None, target=None, seq_len=None):
def __init__(self,
vocab=None,
pred=None,
target=None,
seq_len=None,
print_ratio=False
):
""" """
:param vocab: vocab词表类,要求有to_word()方法。 :param vocab: vocab词表类,要求有to_word()方法。
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` :param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len`
:param print_ratio: 限制print的输出,false only for result, true for result, percent(dim=0), percent(dim = 1)
""" """
super().__init__() super().__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len) self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.confusion_matrix = ConfusionMatrix(vocab=vocab)
self.confusion_matrix = ConfusionMatrix(
vocab=vocab,
print_ratio=print_ratio,
)


def evaluate(self, pred, target, seq_len=None): def evaluate(self, pred, target, seq_len=None):
""" """
evaluate函数将针对一个批次的预测结果做评价指标的累计 evaluate函数将针对一个批次的预测结果做评价指标的累计

:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]),
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes])
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]),
@@ -323,15 +330,18 @@ class ConfusionMatrixMetric(MetricBase):
""" """
if not isinstance(pred, torch.Tensor): if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
raise TypeError(
f"`pred` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise TypeError(f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")
raise TypeError(
f"`target` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.")


if seq_len is not None and not isinstance(seq_len, torch.Tensor): if seq_len is not None and not isinstance(seq_len, torch.Tensor):
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")
raise TypeError(
f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_len)}.")


if pred.dim() == target.dim(): if pred.dim() == target.dim():
pass pass
@@ -340,25 +350,27 @@ class ConfusionMatrixMetric(MetricBase):
if seq_len is None and target.dim() > 1: if seq_len is None and target.dim() > 1:
warnings.warn("You are not passing `seq_len` to exclude pad.") warnings.warn("You are not passing `seq_len` to exclude pad.")
else: else:
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")
raise RuntimeError(
f"In {_get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")


target = target.to(pred) target = target.to(pred)
if seq_len is not None and target.dim() > 1:
for p, t, l in zip(pred.tolist(), target.tolist(), seq_len.tolist()):
l=int(l)
if seq_len is not None and target.dim() > 1:
for p, t, l in zip(pred.tolist(), target.tolist(),
seq_len.tolist()):
l = int(l)
self.confusion_matrix.add_pred_target(p[:l], t[:l]) self.confusion_matrix.add_pred_target(p[:l], t[:l])
elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出
for p, t in zip(pred.tolist(), target.tolist()):
elif target.dim() > 1: #对于没有传入seq_len,但是又是高维的target,按全长输出
for p, t in zip(pred.tolist(), target.tolist()):
self.confusion_matrix.add_pred_target(p, t) self.confusion_matrix.add_pred_target(p, t)
else: else:
self.confusion_matrix.add_pred_target(pred.tolist(), target.tolist())
self.confusion_matrix.add_pred_target(pred.tolist(),
target.tolist())


def get_metric(self,reset=True):
def get_metric(self, reset=True):
""" """
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果.

:param bool reset: 在调用完get_metric后是否清空评价指标统计量. :param bool reset: 在调用完get_metric后是否清空评价指标统计量.
:return dict evaluate_result: {"confusion_matrix": ConfusionMatrix} :return dict evaluate_result: {"confusion_matrix": ConfusionMatrix}
""" """
@@ -368,6 +380,9 @@ class ConfusionMatrixMetric(MetricBase):
return confusion return confusion







class AccuracyMetric(MetricBase): class AccuracyMetric(MetricBase):
""" """
准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) 准确率Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` )


+ 156
- 56
fastNLP/core/utils.py View File

@@ -38,47 +38,47 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require


class ConfusionMatrix: class ConfusionMatrix:
"""a dict can provide Confusion Matrix""" """a dict can provide Confusion Matrix"""
def __init__(self, vocab=None):
def __init__(self, vocab=None, print_ratio=False):
""" """
:param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。 :param vocab: 需要有to_word方法,建议直接使用Fastnlp.core.Vocabulary。
:param print_ratio: 限制print的输出,False只输出数量Confusion Matrix, True还会输出百分比Confusion Matrix, 分别为行/列
""" """
if vocab and not hasattr(vocab, 'to_word'):
raise TypeError(f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary,"
f"got {type(vocab)}.")
self.confusiondict={} #key: pred index, value:target word ocunt
self.predcount={} #key:pred index, value:count
self.targetcount={} #key:target index, value:count
self.vocab=vocab
def add_pred_target(self, pred, target): #一组结果
if vocab and not hasattr(vocab, "to_word"):
raise TypeError(
f"`vocab` in {_get_func_signature(self.__init__)} must be Fastnlp.core.Vocabulary,"
f"got {type(vocab)}.")
self.confusiondict = {} # key: pred index, value:target word ocunt
self.predcount = {} # key:pred index, value:count
self.targetcount = {} # key:target index, value:count
self.vocab = vocab
self.print_ratio = print_ratio

def add_pred_target(self, pred, target): # 一组结果
""" """
通过这个函数向ConfusionMatrix加入一组预测结果 通过这个函数向ConfusionMatrix加入一组预测结果

:param list pred: 预测的标签列表 :param list pred: 预测的标签列表
:param list target: 真实值的标签列表 :param list target: 真实值的标签列表
:return ConfusionMatrix :return ConfusionMatrix

confusion=ConfusionMatrix() confusion=ConfusionMatrix()
pred = [2,1,3] pred = [2,1,3]
target = [2,2,1] target = [2,2,1]
confusion.add_pred_target(pred, target) confusion.add_pred_target(pred, target)
print(confusion) print(confusion)

target 1 2 3 all target 1 2 3 all
pred
1 0 1 0 1
2 0 1 0 1
3 1 0 0 1
all 1 2 0 3
pred
1 0 1 0 1
2 0 1 0 1
3 1 0 0 1
all 1 2 0 3
""" """
for p,t in zip(pred,target): #<int, int>
self.predcount[p]=self.predcount.get(p,0)+ 1
self.targetcount[t]=self.targetcount.get(t,0)+1
for p, t in zip(pred, target): # <int, int>
self.predcount[p] = self.predcount.get(p, 0) + 1
self.targetcount[t] = self.targetcount.get(t, 0) + 1
if p in self.confusiondict: if p in self.confusiondict:
self.confusiondict[p][t]=self.confusiondict[p].get(t,0) + 1
self.confusiondict[p][t] = self.confusiondict[p].get(t, 0) + 1
else: else:
self.confusiondict[p]={}
self.confusiondict[p][t]= 1
self.confusiondict[p] = {}
self.confusiondict[p][t] = 1
return self.confusiondict return self.confusiondict


def clear(self): def clear(self):
@@ -86,44 +86,144 @@ class ConfusionMatrix:
清除一些值,等待再次新加入 清除一些值,等待再次新加入
:return: :return:
""" """
self.confusiondict={}
self.targetcount={}
self.predcount={}
def __repr__(self):
self.confusiondict = {}
self.targetcount = {}
self.predcount = {}
def get_result(self):
""" """
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。
:return list output: ConfusionMatrix content,具体值与汇总统计
""" """
row2idx={}
idx2row={}
row2idx = {}
idx2row = {}
# 已知的所有键/label # 已知的所有键/label
totallabel=sorted(list(set(self.targetcount.keys()).union(set(self.predcount.keys()))))
lenth=len(totallabel)
totallabel = sorted(
list(
set(self.targetcount.keys()).union(set(
self.predcount.keys()))))
lenth = len(totallabel)
# namedict key :idx value:word/idx # namedict key :idx value:word/idx
namedict=dict([(k,str(k if self.vocab == None else self.vocab.to_word(k))) for k in totallabel])
namedict = dict([
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel
])

for label, idx in zip(totallabel, range(lenth)):
idx2row[
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[
idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
output = []
for i in row2idx.keys(): # 第i行
p = row2idx[i]
h = namedict[p]
l = [0 for _ in range(lenth)]
if self.confusiondict.get(p, None):
for t, c in self.confusiondict[p].items():
l[idx2row[t]] = c # 完成一行
l = [n for n in l] + [sum(l)]
output.append(l)
tail = [self.targetcount.get(row2idx[k], 0) for k in row2idx.keys()]
tail += [sum(tail)]
output.append(tail)
return output


for label,idx in zip(totallabel,range(lenth)):
idx2row[label]=idx #建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[idx]=label #建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
def get_percent(self, dim=0):
"""
:param dim int: 0/1, 0 for row,1 for column
:return list output: ConfusionMatrix content,具体值与汇总统计
"""
result = self.get_result()
if dim == 0:
tmp = np.array(result)
tmp = tmp / (tmp[:, -1].reshape([len(result), -1]))
tmp[np.isnan(tmp)] = 0
tmp = tmp * 100
elif dim == 1:
tmp = np.array(result).T
mp = tmp / (tmp[:, -1].reshape([len(result), -1]) + 1e-12)
tmp = tmp.T * 100
tmp = np.around(tmp, decimals=2)
return tmp.tolist()

def get_aligned_table(self, data, flag="result"):
"""
:param data: highly recommend use get_percent/ get_result return as dataset here, or make sure data is a n*n list type data
:param flag: only difference between result and other words is whether "%" is in output string
:return: an aligned_table ready to print out
"""
row2idx = {}
idx2row = {}
# 已知的所有键/label
totallabel = sorted(
list(
set(self.targetcount.keys()).union(set(
self.predcount.keys()))))
lenth = len(totallabel)
# namedict key :idx value:word/idx
namedict = dict([
(k, str(k if self.vocab == None else self.vocab.to_word(k)))
for k in totallabel
])

for label, idx in zip(totallabel, range(lenth)):
idx2row[
label] = idx # 建立一个临时字典,key:vocab的index, value: 行列index 1,3,5...->0,1,2,...
row2idx[
idx] = label # 建立一个临时字典,value:vocab的index, key: 行列index 0,1,2...->1,3,5,...
# 这里打印东西 # 这里打印东西
#表头
head=["\ntarget"]+[str(namedict[row2idx[k]]) for k in row2idx.keys()]+["all"]
output="\t".join(head) + "\n" + "pred" + "\n"
#内容
for i in row2idx.keys(): #第i行
p=row2idx[i]
h=namedict[p]
l=[0 for _ in range(lenth)]
if self.confusiondict.get(p,None):
for t,c in self.confusiondict[p].items():
l[idx2row[t]] = c #完成一行
l=[h]+[str(n) for n in l]+[str(sum(l))]
output+="\t".join(l) +"\n"
#表尾
tail=[self.targetcount.get(row2idx[k],0) for k in row2idx.keys()]
tail=["all"]+[str(n) for n in tail]+[str(sum(tail))]
output+="\t".join(tail)
return output
col_lenths = []
out = str()
output = []
# 表头
head = (["target"] +
[str(namedict[row2idx[k]]) for k in row2idx.keys()] + ["all"])
col_lenths = [len(h) for h in head]
output.append(head)
output.append(["pred"])
# 内容
for i in row2idx.keys(): # 第i行
p = row2idx[i]
h = namedict[p]
l = [h] + [[str(n) + "%", str(n)][flag == "result"]
for n in data[i]]
col_lenths = [
max(col_lenths[idx], [len(i) for i in l][idx])
for idx in range(len(col_lenths))
]
output.append(l)
tail = ["all"] + [[str(n) + "%", str(n)][flag == "result"]
for n in data[-1]]
col_lenths = [
max(col_lenths[idx], [len(i) for i in tail][idx])
for idx in range(len(col_lenths))
]
output.append(tail)
for line in output:
for colidx in range(len(line)):
out += "%*s" % (col_lenths[colidx], line[colidx]) + "\t"
out += "\n"
return "\n" + out

def __repr__(self):
"""
:return string output: ConfusionMatrix的格式化输出,包括表头各标签字段,具体值与汇总统计。
"""
result = self.get_result()
o0 = self.get_aligned_table(result, flag="result")

out = str()
if self.print_ratio:
p1 = self.get_percent()
o1 = "\nNotice the row direction\n" + self.get_aligned_table(
p1, flag="percent")
p2 = self.get_percent(dim=1)
o2 = "\nNotice the column direction\n" + self.get_aligned_table(
p2, flag="percent")
out = out + o0 + o1 + o2
else:
out = o0
return out




class Option(dict): class Option(dict):


Loading…
Cancel
Save