@@ -313,11 +313,9 @@ class ConfusionMatrixMetric(MetricBase): | |||||
pred=None, | pred=None, | ||||
target=None, | target=None, | ||||
seq_len=None, | seq_len=None, | ||||
show_result=None, | |||||
print_ratio=False | print_ratio=False | ||||
): | ): | ||||
r""" | r""" | ||||
:param vocab: vocab词表类,要求有to_word()方法。 | :param vocab: vocab词表类,要求有to_word()方法。 | ||||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | ||||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | ||||
@@ -327,7 +325,6 @@ class ConfusionMatrixMetric(MetricBase): | |||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | self._init_param_map(pred=pred, target=target, seq_len=seq_len) | ||||
self.confusion_matrix = ConfusionMatrix( | self.confusion_matrix = ConfusionMatrix( | ||||
show_result=show_result, | |||||
vocab=vocab, | vocab=vocab, | ||||
print_ratio=print_ratio, | print_ratio=print_ratio, | ||||
) | ) | ||||
@@ -335,6 +332,7 @@ class ConfusionMatrixMetric(MetricBase): | |||||
def evaluate(self, pred, target, seq_len=None): | def evaluate(self, pred, target, seq_len=None): | ||||
r""" | r""" | ||||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | evaluate函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | ||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | ||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | ||||
@@ -356,6 +354,10 @@ class ConfusionMatrixMetric(MetricBase): | |||||
f"got {type(seq_len)}.") | f"got {type(seq_len)}.") | ||||
if pred.dim() == target.dim(): | if pred.dim() == target.dim(): | ||||
if torch.numel(pred) !=torch.numel(target): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have " | |||||
f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}") | |||||
pass | pass | ||||
elif pred.dim() == target.dim() + 1: | elif pred.dim() == target.dim() + 1: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
@@ -446,6 +448,10 @@ class AccuracyMetric(MetricBase): | |||||
masks = None | masks = None | ||||
if pred.dim() == target.dim(): | if pred.dim() == target.dim(): | ||||
if torch.numel(pred) !=torch.numel(target): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have " | |||||
f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}") | |||||
pass | pass | ||||
elif pred.dim() == target.dim() + 1: | elif pred.dim() == target.dim() + 1: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
@@ -477,7 +483,6 @@ class AccuracyMetric(MetricBase): | |||||
self.total = 0 | self.total = 0 | ||||
return evaluate_result | return evaluate_result | ||||
class ClassifyFPreRecMetric(MetricBase): | class ClassifyFPreRecMetric(MetricBase): | ||||
r""" | r""" | ||||
分类问题计算FPR值的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | 分类问题计算FPR值的Metric(其它的Metric参见 :mod:`fastNLP.core.metrics` ) | ||||
@@ -567,9 +572,14 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | ||||
else: | else: | ||||
masks = torch.ones_like(target).long().to(target.device) | masks = torch.ones_like(target).long().to(target.device) | ||||
masks = masks.eq(False) | |||||
masks = masks.eq(1) | |||||
if pred.dim() == target.dim(): | if pred.dim() == target.dim(): | ||||
if torch.numel(pred) !=torch.numel(target): | |||||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have same dimensions with target, they should have same element numbers. while target have " | |||||
f"element numbers:{torch.numel(target)}, pred have element numbers: {torch.numel(pred)}") | |||||
pass | pass | ||||
elif pred.dim() == target.dim() + 1: | elif pred.dim() == target.dim() + 1: | ||||
pred = pred.argmax(dim=-1) | pred = pred.argmax(dim=-1) | ||||
@@ -580,12 +590,14 @@ class ClassifyFPreRecMetric(MetricBase): | |||||
f"size:{pred.size()}, target should have size: {pred.size()} or " | f"size:{pred.size()}, target should have size: {pred.size()} or " | ||||
f"{pred.size()[:-1]}, got {target.size()}.") | f"{pred.size()[:-1]}, got {target.size()}.") | ||||
target_idxes = set(target.reshape(-1).tolist()) | |||||
target = target.to(pred) | target = target.to(pred) | ||||
target = target.masked_select(masks) | |||||
pred = pred.masked_select(masks) | |||||
target_idxes = set(target.reshape(-1).tolist()) | |||||
for target_idx in target_idxes: | for target_idx in target_idxes: | ||||
self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() | |||||
self._fp[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0).masked_fill(masks, 0)).item() | |||||
self._fn[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0).masked_fill(masks, 0)).item() | |||||
self._tp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target != target_idx, 0)).item() | |||||
self._fp[target_idx] += torch.sum((pred == target_idx).long().masked_fill(target == target_idx, 0)).item() | |||||
self._fn[target_idx] += torch.sum((pred != target_idx).long().masked_fill(target != target_idx, 0)).item() | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
r""" | r""" | ||||
@@ -1,13 +1,14 @@ | |||||
import unittest | import unittest | ||||
from collections import Counter | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP import AccuracyMetric | from fastNLP import AccuracyMetric | ||||
from fastNLP.core.metrics import _pred_topk, _accuracy_topk | |||||
from fastNLP.core.metrics import (ClassifyFPreRecMetric, CMRC2018Metric, | |||||
ConfusionMatrixMetric, SpanFPreRecMetric, | |||||
_accuracy_topk, _pred_topk) | |||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from collections import Counter | |||||
from fastNLP.core.metrics import SpanFPreRecMetric, CMRC2018Metric, ClassifyFPreRecMetric,ConfusionMatrixMetric | |||||
from sklearn import metrics as m | |||||
def _generate_tags(encoding_type, number_labels=4): | def _generate_tags(encoding_type, number_labels=4): | ||||
@@ -563,69 +564,45 @@ class TestUsefulFunctions(unittest.TestCase): | |||||
# 跑通即可 | # 跑通即可 | ||||
class TestClassfiyFPreRecMetric(unittest.TestCase): | class TestClassfiyFPreRecMetric(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
pred = torch.FloatTensor([[-0.1603, -1.3247, 0.2010, 0.9240, -0.6396], | |||||
[-0.7316, -1.6028, 0.2281, 0.3558, 1.2500], | |||||
[-1.2943, -1.7350, -0.7085, 1.1269, 1.0782], | |||||
[ 0.1314, -0.2578, 0.7200, 1.0920, -1.0819], | |||||
[-0.6787, -0.9081, -0.2752, -1.5818, 0.5538], | |||||
[-0.2925, 1.1320, 2.8709, -0.6225, -0.6279], | |||||
[-0.3320, -0.9009, -1.5762, 0.3810, -0.1220], | |||||
[ 0.4601, -1.0509, 1.4242, 0.3427, 2.7014], | |||||
[-0.5558, 1.0899, -1.9045, 0.3377, 1.3192], | |||||
[-0.8251, -0.1558, -0.0871, -0.6755, -0.5905], | |||||
[ 0.1019, 1.2504, -1.1627, -0.7062, 1.8654], | |||||
[ 0.9016, -0.1984, -0.0831, -0.7646, 1.5309], | |||||
[ 0.2073, 0.2250, -0.0879, 0.1608, -0.8915], | |||||
[ 0.3624, 0.3806, 0.3159, -0.3603, -0.6672], | |||||
[ 0.2714, 2.5086, -0.1053, -0.5188, 0.9229], | |||||
[ 0.3258, -0.0303, 1.1439, -0.9123, 1.5180], | |||||
[ 1.2496, -1.0298, -0.4463, 0.1186, -1.7089], | |||||
[ 0.0788, 0.6300, -1.3336, -0.7122, 1.0164], | |||||
[-1.1900, -0.9620, -0.3839, 0.1159, -1.2045], | |||||
[-0.9037, -0.1447, 1.1834, -0.2617, 2.6112], | |||||
[ 0.1507, 0.1686, -0.1535, -0.3669, -0.8425], | |||||
[ 1.0537, 1.1958, -1.2309, 1.0405, 1.3018], | |||||
[-0.9823, -0.9712, 1.1560, -0.6473, 1.0361], | |||||
[ 0.8659, -0.2166, -0.8335, -0.3557, -0.5660], | |||||
[-1.4742, -0.8773, -2.5237, 0.7410, 0.1506], | |||||
[-1.3032, -1.7157, 0.7479, 1.0755, 1.0817], | |||||
[-0.2988, 2.3745, 1.2072, 0.0054, 1.1877], | |||||
[-0.0123, 1.6513, 0.2741, -0.7791, 0.6161], | |||||
[ 1.6339, -1.0365, 0.3961, -0.9683, 0.2684], | |||||
[-0.0278, -2.0856, -0.5376, 0.5129, -0.3169], | |||||
[ 0.9386, 0.8317, 0.9518, -0.5050, -0.2808], | |||||
[-0.6907, 0.5020, -0.9039, -1.1061, 0.1656]]) | |||||
arg_max_pred = torch.Tensor([3, 2, 3, 3, 4, 2, 3, 4, 4, 2, 4, 4, 1, 1, | |||||
1, 4, 0, 4, 3, 4, 1, 4, 2, 0, | |||||
3, 4, 1, 1, 0, 3, 2, 1]) | |||||
target = torch.Tensor([3, 3, 3, 3, 4, 1, 0, 2, 1, 2, 4, 4, 1, 1, | |||||
1, 4, 0, 4, 3, 4, 1, 4, 2, 0, | |||||
3, 4, 1, 1, 0, 3, 2, 1]) | |||||
pred= torch.randn(32,5) | |||||
arg_max_pred = torch.argmax(pred,dim=-1) | |||||
target=np.random.randint(0, high=5, size=(32,1)) | |||||
target = torch.from_numpy(target) | |||||
metric = ClassifyFPreRecMetric(f_type='macro') | metric = ClassifyFPreRecMetric(f_type='macro') | ||||
metric.evaluate(pred, target) | metric.evaluate(pred, target) | ||||
result_dict = metric.get_metric(reset=True) | |||||
ground_truth = {'f': 0.8362782, 'pre': 0.8269841, 'rec': 0.8668831} | |||||
result_dict = metric.get_metric() | |||||
f1_score = m.f1_score(target.tolist(), arg_max_pred.tolist(), average="macro") | |||||
recall = m.recall_score(target.tolist(), arg_max_pred.tolist(), average="macro") | |||||
pre = m.precision_score(target.tolist(), arg_max_pred.tolist(), average="macro") | |||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||||
for keys in ['f', 'pre', 'rec']: | for keys in ['f', 'pre', 'rec']: | ||||
self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | ||||
metric = ClassifyFPreRecMetric(f_type='micro') | metric = ClassifyFPreRecMetric(f_type='micro') | ||||
metric.evaluate(pred, target) | metric.evaluate(pred, target) | ||||
result_dict = metric.get_metric(reset=True) | |||||
ground_truth = {'f': 0.84375, 'pre': 0.84375, 'rec': 0.84375} | |||||
result_dict = metric.get_metric() | |||||
f1_score = m.f1_score(target.tolist(), arg_max_pred.tolist(), average="micro") | |||||
recall = m.recall_score(target.tolist(), arg_max_pred.tolist(), average="micro") | |||||
pre = m.precision_score(target.tolist(), arg_max_pred.tolist(), average="micro") | |||||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||||
for keys in ['f', 'pre', 'rec']: | for keys in ['f', 'pre', 'rec']: | ||||
self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | ||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='micro') | |||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') | |||||
metric.evaluate(pred, target) | metric.evaluate(pred, target) | ||||
result_dict = metric.get_metric(reset=True) | result_dict = metric.get_metric(reset=True) | ||||
ground_truth = {'f-0': 0.857143, 'pre-0': 0.75, 'rec-0': 1.0, 'f-1': 0.875, 'pre-1': 0.777778, 'rec-1': 1.0, | |||||
'f-2': 0.75, 'pre-2': 0.75, 'rec-2': 0.75, 'f-3': 0.857143, 'pre-3': 0.857143, | |||||
'rec-3': 0.857143, 'f-4': 0.842105, 'pre-4': 1.0, 'rec-4': 0.727273, 'f': 0.84375, | |||||
'pre': 0.84375, 'rec': 0.84375} | |||||
for keys in ground_truth.keys(): | |||||
self.assertAlmostEqual(result_dict[keys], ground_truth[keys], delta=0.0001) | |||||
ground_truth = m.classification_report(target.tolist(), arg_max_pred.tolist(),output_dict=True) | |||||
for keys in result_dict.keys(): | |||||
if keys=="f" or "pre" or "rec": | |||||
continue | |||||
gl=str(keys[-1]) | |||||
tmp_d={"p":"precision","r":"recall","f":"f1-score"} | |||||
gk=tmp_d[keys[0]] | |||||
self.assertAlmostEqual(result_dict[keys], ground_truth[gl][gk], delta=0.0001) |