@@ -29,14 +29,16 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
class ClassifyFPreRecMetric(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, | |||
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, | |||
only_gross: bool = True, f_type='micro', beta=1) -> None: | |||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = False) -> None: | |||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | |||
aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
if tag_vocab: | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
@@ -45,9 +47,32 @@ class ClassifyFPreRecMetric(Metric): | |||
self.tag_vocab = tag_vocab | |||
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')) | |||
self._tp = {} | |||
self._fp = {} | |||
self._fn = {} | |||
if tag_vocab: | |||
for word, _ in tag_vocab: | |||
word = word.lower() | |||
if word != 'o': | |||
word = word[2:] | |||
if word in self._true_positives: | |||
continue | |||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
elif num_class > 0: | |||
for word in range(num_class): | |||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
else: | |||
raise ValueError() | |||
def get_metric(self) -> dict: | |||
r""" | |||
@@ -68,9 +93,11 @@ class ClassifyFPreRecMetric(Metric): | |||
tag_name = self.tag_vocab.to_word(tag) | |||
else: | |||
tag_name = int(tag) | |||
tp = self._tp[tag] | |||
fn = self._fn[tag] | |||
fp = self._fp[tag] | |||
tp = self._tp[tag].get_scalar() | |||
fn = self._fn[tag].get_scalar() | |||
fp = self._fp[tag].get_scalar() | |||
if tp == fn == fp == 0: | |||
continue | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
@@ -90,20 +117,29 @@ class ClassifyFPreRecMetric(Metric): | |||
if self.f_type == 'micro': | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(self._tp.values()), | |||
sum(self._fn.values()), | |||
sum(self._fp.values())) | |||
sum(val.get_scalar() for val in self._tp.values()), | |||
sum(val.get_scalar() for val in self._fn.values()), | |||
sum(val.get_scalar() for val in self._fp.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
for key, value in evaluate_result.items(): | |||
evaluate_result[key] = round(value, 6) | |||
return evaluate_result | |||
def update(self, pred, target, seq_len=None): | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||
如果mask也被传进来的话seq_len会被忽略. | |||
""" | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if seq_len is not None: | |||
@@ -122,14 +158,14 @@ class ClassifyFPreRecMetric(Metric): | |||
f"pred have element numbers: {len(target.flatten())}") | |||
pass | |||
elif len(pred.ndim) == len(target.ndim) + 1: | |||
elif pred.ndim == target.ndim + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and len(target.ndim) > 1: | |||
if seq_len is None and target.ndim > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred have " | |||
f"size:{pred.ndim}, target should have size: {pred.ndim} or " | |||
f"{pred.ndim[:-1]}, got {target.ndim}.") | |||
f"size:{pred.shape}, target should have size: {pred.shape} or " | |||
f"{pred.shape[:-1]}, got {target.shape}.") | |||
if masks is not None: | |||
target = target * masks | |||
pred = pred * masks | |||
@@ -138,5 +174,3 @@ class ClassifyFPreRecMetric(Metric): | |||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() | |||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() | |||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() | |||
@@ -0,0 +1,88 @@ | |||
import pytest | |||
import torch | |||
import numpy as np | |||
from fastNLP.core.metrics import ClassifyFPreRecMetric | |||
class TestClassfiyFPreRecMetric: | |||
def test_case_1(self): | |||
pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910], | |||
[1.3410, 0.2889, -0.8667, -1.8580, 0.3029], | |||
[0.7459, -1.1957, 0.3231, 0.0308, -0.1847], | |||
[1.1439, -0.0057, 0.8203, 0.0312, -1.0051], | |||
[-0.4870, 0.3215, -0.8290, 0.9221, 0.4683], | |||
[0.9078, 1.0674, -0.5629, 0.3895, 0.8917], | |||
[-0.7743, -0.4041, -0.9026, 0.2112, 1.0892], | |||
[1.8232, -1.4188, -2.5615, -2.4187, 0.5907], | |||
[-1.0592, 0.4164, -0.1192, 1.4238, -0.9258], | |||
[-1.1137, 0.5773, 2.5778, 0.5398, -0.3323], | |||
[-0.3868, -0.5165, 0.2286, -1.3876, 0.5561], | |||
[-0.3304, 1.3619, -1.5744, 0.4902, -0.7661], | |||
[1.8387, 0.5234, 0.4269, 1.3748, -1.2793], | |||
[0.6692, 0.2571, 1.2425, -0.5894, -0.0184], | |||
[0.4165, 0.4084, -0.1280, 1.4489, -2.3058], | |||
[-0.5826, -0.5469, 1.5898, -0.2786, -0.9882], | |||
[-1.5548, -2.2891, 0.2983, -1.2145, -0.1947], | |||
[-0.7222, 2.3543, -0.5801, -0.0640, -1.5614], | |||
[-1.4978, 1.9297, -1.3652, -0.2358, 2.5566], | |||
[0.1561, -0.0316, 0.9331, 1.0363, 2.3949], | |||
[0.2650, -0.8459, 1.3221, 0.1321, -1.1900], | |||
[0.0664, -1.2353, -0.5242, -1.4491, 1.3300], | |||
[-0.2744, 0.0941, 0.7157, 0.1404, 1.2046], | |||
[0.9341, -0.6652, 1.4512, 0.9608, -0.3623], | |||
[-1.1641, 0.0873, 0.1163, -0.2068, -0.7002], | |||
[1.4775, -2.0025, -0.5634, -0.1589, 0.0247], | |||
[1.0151, 1.0304, -0.1042, -0.6955, -0.0629], | |||
[-0.3119, -0.4558, 0.7757, 0.0758, -1.6297], | |||
[1.0654, 0.0313, -0.7716, 0.1194, 0.6913], | |||
[-0.8088, -0.6648, -0.5018, -0.0230, -0.8207], | |||
[-0.7753, -0.3508, 1.6163, 0.7158, 1.5207], | |||
[0.8692, 0.7718, -0.6734, 0.6515, 0.0641]]) | |||
arg_max_pred = torch.argmax(pred, dim=-1) | |||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | |||
0, 3, 0, 0, 0, 1, 3, 1]) | |||
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
f1_score = 0.1882051282051282 | |||
recall = 0.1619047619047619 | |||
pre = 0.23928571428571427 | |||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||
for keys in ['f', 'pre', 'rec']: | |||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
f1_score = 0.21875 | |||
recall = 0.21875 | |||
pre = 0.21875 | |||
ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||
for keys in ['f', 'pre', 'rec']: | |||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
ground_truth = { | |||
'0': {'f1-score': 0.13333333333333333, 'precision': 0.125, 'recall': 0.14285714285714285, 'support': 7}, | |||
'1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5}, | |||
'2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2}, | |||
'3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9}, | |||
'4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}, | |||
'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427, | |||
'recall': 0.1619047619047619, 'support': 32}, | |||
'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32}, | |||
'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875, | |||
'support': 32}} | |||
for keys in result_dict.keys(): | |||
if keys == "f" or "pre" or "rec": | |||
continue | |||
gl = str(keys[-1]) | |||
tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"} | |||
gk = tmp_d[keys[0]] | |||
np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001) |