| @@ -29,14 +29,16 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
| class ClassifyFPreRecMetric(Metric): | class ClassifyFPreRecMetric(Metric): | ||||
| def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, | |||||
| tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, | |||||
| only_gross: bool = True, f_type='micro', beta=1) -> None: | |||||
| def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||||
| only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | |||||
| aggregate_when_get_metric: bool = False) -> None: | |||||
| super(ClassifyFPreRecMetric, self).__init__(backend=backend, | super(ClassifyFPreRecMetric, self).__init__(backend=backend, | ||||
| aggregate_when_get_metric=aggregate_when_get_metric) | aggregate_when_get_metric=aggregate_when_get_metric) | ||||
| if f_type not in ('micro', 'macro'): | if f_type not in ('micro', 'macro'): | ||||
| raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | ||||
| if tag_vocab: | |||||
| if not isinstance(tag_vocab, Vocabulary): | |||||
| raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||||
| self.ignore_labels = ignore_labels | self.ignore_labels = ignore_labels | ||||
| self.f_type = f_type | self.f_type = f_type | ||||
| self.beta = beta | self.beta = beta | ||||
| @@ -45,9 +47,32 @@ class ClassifyFPreRecMetric(Metric): | |||||
| self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
| self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||||
| defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||||
| defaultdict(partial(self.register_element, aggregate_method='sum')) | |||||
| self._tp = {} | |||||
| self._fp = {} | |||||
| self._fn = {} | |||||
| if tag_vocab: | |||||
| for word, _ in tag_vocab: | |||||
| word = word.lower() | |||||
| if word != 'o': | |||||
| word = word[2:] | |||||
| if word in self._true_positives: | |||||
| continue | |||||
| self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| elif num_class > 0: | |||||
| for word in range(num_class): | |||||
| self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| else: | |||||
| raise ValueError() | |||||
| def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
| r""" | r""" | ||||
| @@ -68,9 +93,11 @@ class ClassifyFPreRecMetric(Metric): | |||||
| tag_name = self.tag_vocab.to_word(tag) | tag_name = self.tag_vocab.to_word(tag) | ||||
| else: | else: | ||||
| tag_name = int(tag) | tag_name = int(tag) | ||||
| tp = self._tp[tag] | |||||
| fn = self._fn[tag] | |||||
| fp = self._fp[tag] | |||||
| tp = self._tp[tag].get_scalar() | |||||
| fn = self._fn[tag].get_scalar() | |||||
| fp = self._fp[tag].get_scalar() | |||||
| if tp == fn == fp == 0: | |||||
| continue | |||||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | ||||
| f_sum += f | f_sum += f | ||||
| pre_sum += pre | pre_sum += pre | ||||
| @@ -90,20 +117,29 @@ class ClassifyFPreRecMetric(Metric): | |||||
| if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, | f, pre, rec = _compute_f_pre_rec(self.beta_square, | ||||
| sum(self._tp.values()), | |||||
| sum(self._fn.values()), | |||||
| sum(self._fp.values())) | |||||
| sum(val.get_scalar() for val in self._tp.values()), | |||||
| sum(val.get_scalar() for val in self._fn.values()), | |||||
| sum(val.get_scalar() for val in self._fp.values())) | |||||
| evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
| evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
| evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
| for key, value in evaluate_result.items(): | for key, value in evaluate_result.items(): | ||||
| evaluate_result[key] = round(value, 6) | evaluate_result[key] = round(value, 6) | ||||
| return evaluate_result | return evaluate_result | ||||
| def update(self, pred, target, seq_len=None): | def update(self, pred, target, seq_len=None): | ||||
| r""" | |||||
| evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||||
| :param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
| torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||||
| :param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
| torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||||
| :param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
| 如果mask也被传进来的话seq_len会被忽略. | |||||
| """ | |||||
| pred = self.tensor2numpy(pred) | pred = self.tensor2numpy(pred) | ||||
| target = self.tensor2numpy(target) | target = self.tensor2numpy(target) | ||||
| if seq_len is not None: | if seq_len is not None: | ||||
| @@ -122,14 +158,14 @@ class ClassifyFPreRecMetric(Metric): | |||||
| f"pred have element numbers: {len(target.flatten())}") | f"pred have element numbers: {len(target.flatten())}") | ||||
| pass | pass | ||||
| elif len(pred.ndim) == len(target.ndim) + 1: | |||||
| elif pred.ndim == target.ndim + 1: | |||||
| pred = pred.argmax(axis=-1) | pred = pred.argmax(axis=-1) | ||||
| if seq_len is None and len(target.ndim) > 1: | |||||
| if seq_len is None and target.ndim > 1: | |||||
| warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | ||||
| else: | else: | ||||
| raise RuntimeError(f"when pred have " | raise RuntimeError(f"when pred have " | ||||
| f"size:{pred.ndim}, target should have size: {pred.ndim} or " | |||||
| f"{pred.ndim[:-1]}, got {target.ndim}.") | |||||
| f"size:{pred.shape}, target should have size: {pred.shape} or " | |||||
| f"{pred.shape[:-1]}, got {target.shape}.") | |||||
| if masks is not None: | if masks is not None: | ||||
| target = target * masks | target = target * masks | ||||
| pred = pred * masks | pred = pred * masks | ||||
| @@ -138,5 +174,3 @@ class ClassifyFPreRecMetric(Metric): | |||||
| self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() | self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() | ||||
| self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() | self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() | ||||
| self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() | self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() | ||||
| @@ -0,0 +1,88 @@ | |||||
| import pytest | |||||
| import torch | |||||
| import numpy as np | |||||
| from fastNLP.core.metrics import ClassifyFPreRecMetric | |||||
| class TestClassfiyFPreRecMetric: | |||||
| def test_case_1(self): | |||||
| pred = torch.tensor([[-0.4375, -0.1779, -1.0985, -1.1592, 0.4910], | |||||
| [1.3410, 0.2889, -0.8667, -1.8580, 0.3029], | |||||
| [0.7459, -1.1957, 0.3231, 0.0308, -0.1847], | |||||
| [1.1439, -0.0057, 0.8203, 0.0312, -1.0051], | |||||
| [-0.4870, 0.3215, -0.8290, 0.9221, 0.4683], | |||||
| [0.9078, 1.0674, -0.5629, 0.3895, 0.8917], | |||||
| [-0.7743, -0.4041, -0.9026, 0.2112, 1.0892], | |||||
| [1.8232, -1.4188, -2.5615, -2.4187, 0.5907], | |||||
| [-1.0592, 0.4164, -0.1192, 1.4238, -0.9258], | |||||
| [-1.1137, 0.5773, 2.5778, 0.5398, -0.3323], | |||||
| [-0.3868, -0.5165, 0.2286, -1.3876, 0.5561], | |||||
| [-0.3304, 1.3619, -1.5744, 0.4902, -0.7661], | |||||
| [1.8387, 0.5234, 0.4269, 1.3748, -1.2793], | |||||
| [0.6692, 0.2571, 1.2425, -0.5894, -0.0184], | |||||
| [0.4165, 0.4084, -0.1280, 1.4489, -2.3058], | |||||
| [-0.5826, -0.5469, 1.5898, -0.2786, -0.9882], | |||||
| [-1.5548, -2.2891, 0.2983, -1.2145, -0.1947], | |||||
| [-0.7222, 2.3543, -0.5801, -0.0640, -1.5614], | |||||
| [-1.4978, 1.9297, -1.3652, -0.2358, 2.5566], | |||||
| [0.1561, -0.0316, 0.9331, 1.0363, 2.3949], | |||||
| [0.2650, -0.8459, 1.3221, 0.1321, -1.1900], | |||||
| [0.0664, -1.2353, -0.5242, -1.4491, 1.3300], | |||||
| [-0.2744, 0.0941, 0.7157, 0.1404, 1.2046], | |||||
| [0.9341, -0.6652, 1.4512, 0.9608, -0.3623], | |||||
| [-1.1641, 0.0873, 0.1163, -0.2068, -0.7002], | |||||
| [1.4775, -2.0025, -0.5634, -0.1589, 0.0247], | |||||
| [1.0151, 1.0304, -0.1042, -0.6955, -0.0629], | |||||
| [-0.3119, -0.4558, 0.7757, 0.0758, -1.6297], | |||||
| [1.0654, 0.0313, -0.7716, 0.1194, 0.6913], | |||||
| [-0.8088, -0.6648, -0.5018, -0.0230, -0.8207], | |||||
| [-0.7753, -0.3508, 1.6163, 0.7158, 1.5207], | |||||
| [0.8692, 0.7718, -0.6734, 0.6515, 0.0641]]) | |||||
| arg_max_pred = torch.argmax(pred, dim=-1) | |||||
| target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | |||||
| 0, 3, 0, 0, 0, 1, 3, 1]) | |||||
| metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) | |||||
| metric.update(pred, target) | |||||
| result_dict = metric.get_metric() | |||||
| f1_score = 0.1882051282051282 | |||||
| recall = 0.1619047619047619 | |||||
| pre = 0.23928571428571427 | |||||
| ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||||
| for keys in ['f', 'pre', 'rec']: | |||||
| np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||||
| metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) | |||||
| metric.update(pred, target) | |||||
| result_dict = metric.get_metric() | |||||
| f1_score = 0.21875 | |||||
| recall = 0.21875 | |||||
| pre = 0.21875 | |||||
| ground_truth = {'f': f1_score, 'pre': pre, 'rec': recall} | |||||
| for keys in ['f', 'pre', 'rec']: | |||||
| np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||||
| metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) | |||||
| metric.update(pred, target) | |||||
| result_dict = metric.get_metric() | |||||
| ground_truth = { | |||||
| '0': {'f1-score': 0.13333333333333333, 'precision': 0.125, 'recall': 0.14285714285714285, 'support': 7}, | |||||
| '1': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 5}, | |||||
| '2': {'f1-score': 0.0, 'precision': 0.0, 'recall': 0.0, 'support': 2}, | |||||
| '3': {'f1-score': 0.30769230769230765, 'precision': 0.5, 'recall': 0.2222222222222222, 'support': 9}, | |||||
| '4': {'f1-score': 0.5, 'precision': 0.5714285714285714, 'recall': 0.4444444444444444, 'support': 9}, | |||||
| 'macro avg': {'f1-score': 0.1882051282051282, 'precision': 0.23928571428571427, | |||||
| 'recall': 0.1619047619047619, 'support': 32}, | |||||
| 'micro avg': {'f1-score': 0.21875, 'precision': 0.21875, 'recall': 0.21875, 'support': 32}, | |||||
| 'weighted avg': {'f1-score': 0.2563301282051282, 'precision': 0.3286830357142857, 'recall': 0.21875, | |||||
| 'support': 32}} | |||||
| for keys in result_dict.keys(): | |||||
| if keys == "f" or "pre" or "rec": | |||||
| continue | |||||
| gl = str(keys[-1]) | |||||
| tmp_d = {"p": "precision", "r": "recall", "f": "f1-score"} | |||||
| gk = tmp_d[keys[0]] | |||||
| np.allclose(result_dict[keys], ground_truth[gl][gk], atol=0.000001) | |||||