| @@ -7,7 +7,6 @@ __all__ = [ | |||
| 'TorchBackend', | |||
| 'SpanFPreRecMetric', | |||
| 'ClassifyFPreRecMetric', | |||
| 'func_post_proc' | |||
| ] | |||
| from .metric import Metric | |||
| @@ -15,4 +14,3 @@ from .accuracy import Accuracy | |||
| from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | |||
| from .span_f1_pre_rec_metric import SpanFPreRecMetric | |||
| from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | |||
| from .utils import func_post_proc | |||
| @@ -3,40 +3,24 @@ __all__ = [ | |||
| ] | |||
| from typing import Union, List | |||
| from collections import defaultdict | |||
| from functools import partial | |||
| from collections import Counter | |||
| import warnings | |||
| from .metric import Metric | |||
| from .backend import Backend | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from fastNLP.core.utils.utils import seq_len_to_mask | |||
| def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
| r""" | |||
| :param tp: int, true positive | |||
| :param fn: int, false negative | |||
| :param fp: int, false positive | |||
| :return: (f, pre, rec) | |||
| """ | |||
| pre = tp / (fp + tp + 1e-13) | |||
| rec = tp / (fn + tp + 1e-13) | |||
| f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
| return f, pre, rec | |||
| from .utils import _compute_f_pre_rec | |||
| class ClassifyFPreRecMetric(Metric): | |||
| def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||
| def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, | |||
| only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | |||
| aggregate_when_get_metric: bool = None) -> None: | |||
| """ | |||
| :param tag_vocab: | |||
| :param ignore_labels: | |||
| :param num_class: | |||
| :param only_gross: | |||
| :param f_type: | |||
| :param beta: | |||
| @@ -60,32 +44,15 @@ class ClassifyFPreRecMetric(Metric): | |||
| self.tag_vocab = tag_vocab | |||
| self._tp = {} | |||
| self._fp = {} | |||
| self._fn = {} | |||
| if tag_vocab: | |||
| for word, _ in tag_vocab: | |||
| word = word.lower() | |||
| if word != 'o': | |||
| word = word[2:] | |||
| if word in self._true_positives: | |||
| continue | |||
| self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
| backend=backend) | |||
| self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
| backend=backend) | |||
| self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
| backend=backend) | |||
| elif num_class > 0: | |||
| for word in range(num_class): | |||
| self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
| backend=backend) | |||
| self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
| backend=backend) | |||
| self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
| backend=backend) | |||
| else: | |||
| raise ValueError() | |||
| self._tp = Counter() | |||
| self._fp = Counter() | |||
| self._fn = Counter() | |||
| def reset(self): | |||
| # 由于不是 element 了,需要自己手动清零一下 | |||
| self._tp.clear() | |||
| self._fp.clear() | |||
| self._fn.clear() | |||
| def get_metric(self) -> dict: | |||
| r""" | |||
| @@ -94,10 +61,22 @@ class ClassifyFPreRecMetric(Metric): | |||
| :return dict evaluate_result: {"acc": float} | |||
| """ | |||
| evaluate_result = {} | |||
| # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||
| if self.aggregate_when_get_metric: | |||
| ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||
| tps, fps, fns = zip(*ls) | |||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
| for _c in cs: | |||
| c.update(_c) | |||
| else: | |||
| _tp, _fp, _fn = self._tp, self._fp, self._tp | |||
| if not self.only_gross or self.f_type == 'macro': | |||
| tags = set(self._fn.keys()) | |||
| tags.update(set(self._fp.keys())) | |||
| tags.update(set(self._tp.keys())) | |||
| tags = set(_fn.keys()) | |||
| tags.update(set(_fp.keys())) | |||
| tags.update(set(_tp.keys())) | |||
| f_sum = 0 | |||
| pre_sum = 0 | |||
| rec_sum = 0 | |||
| @@ -106,9 +85,9 @@ class ClassifyFPreRecMetric(Metric): | |||
| tag_name = self.tag_vocab.to_word(tag) | |||
| else: | |||
| tag_name = int(tag) | |||
| tp = self._tp[tag].get_scalar() | |||
| fn = self._fn[tag].get_scalar() | |||
| fp = self._fp[tag].get_scalar() | |||
| tp = _tp[tag] | |||
| fn = _fn[tag] | |||
| fp = _fp[tag] | |||
| if tp == fn == fp == 0: | |||
| continue | |||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
| @@ -129,10 +108,7 @@ class ClassifyFPreRecMetric(Metric): | |||
| evaluate_result['rec'] = rec_sum / len(tags) | |||
| if self.f_type == 'micro': | |||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
| sum(val.get_scalar() for val in self._tp.values()), | |||
| sum(val.get_scalar() for val in self._fn.values()), | |||
| sum(val.get_scalar() for val in self._fp.values())) | |||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||
| evaluate_result['f'] = f | |||
| evaluate_result['pre'] = pre | |||
| evaluate_result['rec'] = rec | |||
| @@ -4,12 +4,12 @@ __all__ = [ | |||
| from typing import Union, List, Optional | |||
| import warnings | |||
| from collections import defaultdict | |||
| from functools import partial | |||
| from collections import Counter | |||
| from fastNLP.core.metrics.backend import Backend | |||
| from fastNLP.core.metrics.metric import Metric | |||
| from fastNLP.core.vocabulary import Vocabulary | |||
| from .utils import _compute_f_pre_rec | |||
| def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | |||
| @@ -199,21 +199,6 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
| return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
| def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
| r""" | |||
| :param tp: int, true positive | |||
| :param fn: int, false negative | |||
| :param fp: int, false positive | |||
| :return: (f, pre, rec) | |||
| """ | |||
| pre = tp / (fp + tp + 1e-13) | |||
| rec = tp / (fn + tp + 1e-13) | |||
| f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
| return f, pre, rec | |||
| class SpanFPreRecMetric(Metric): | |||
| def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||
| @@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric): | |||
| self.only_gross = only_gross | |||
| self.tag_vocab = tag_vocab | |||
| self._true_positives = {} | |||
| self._false_positives = {} | |||
| self._false_negatives = {} | |||
| for word, _ in tag_vocab: | |||
| word = word.lower() | |||
| if word != 'o': | |||
| word = word[2:] | |||
| if word in self._true_positives: | |||
| continue | |||
| self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||
| self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) | |||
| self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) | |||
| self._tp = Counter() | |||
| self._fp = Counter() | |||
| self._fn = Counter() | |||
| def reset(self): | |||
| self._tp.clear() | |||
| self._fp.clear() | |||
| self._fn.clear() | |||
| def get_metric(self) -> dict: | |||
| evaluate_result = {} | |||
| # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||
| if self.aggregate_when_get_metric: | |||
| ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||
| tps, fps, fns = zip(*ls) | |||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
| for _c in cs: | |||
| c.update(_c) | |||
| else: | |||
| _tp, _fp, _fn = self._tp, self._fp, self._tp | |||
| if not self.only_gross or self.f_type == 'macro': | |||
| tags = set(self._false_negatives.keys()) | |||
| tags.update(self._false_positives.keys()) | |||
| tags.update(self._true_positives.keys()) | |||
| tags = set(_fn.keys()) | |||
| tags.update(_fp.keys()) | |||
| tags.update(_tp.keys()) | |||
| f_sum = 0 | |||
| pre_sum = 0 | |||
| rec_sum = 0 | |||
| for tag in tags: | |||
| tp = self._true_positives[tag].get_scalar() | |||
| fn = self._false_negatives[tag].get_scalar() | |||
| fp = self._false_positives[tag].get_scalar() | |||
| tp = _tp[tag] | |||
| fn = _fn[tag] | |||
| fp = _fp[tag] | |||
| if tp == fn == fp == 0: | |||
| continue | |||
| @@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric): | |||
| evaluate_result['rec'] = rec_sum / len(tags) | |||
| if self.f_type == 'micro': | |||
| tp, fn, fp = [], [], [] | |||
| for val in self._true_positives.values(): | |||
| tp.append(val.get_scalar()) | |||
| for val in self._false_negatives.values(): | |||
| fn.append(val.get_scalar()) | |||
| for val in self._false_positives.values(): | |||
| fp.append(val.get_scalar()) | |||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
| sum(tp), | |||
| sum(fn), | |||
| sum(fp)) | |||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||
| evaluate_result['f'] = f | |||
| evaluate_result['pre'] = pre | |||
| evaluate_result['rec'] = rec | |||
| @@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric): | |||
| for span in pred_spans: | |||
| if span in gold_spans: | |||
| self._true_positives[span[0]] += 1 | |||
| self._tp[span[0]] += 1 | |||
| gold_spans.remove(span) | |||
| else: | |||
| self._false_positives[span[0]] += 1 | |||
| self._fp[span[0]] += 1 | |||
| for span in gold_spans: | |||
| self._false_negatives[span[0]] += 1 | |||
| self._fn[span[0]] += 1 | |||
| @@ -1,5 +1,4 @@ | |||
| __all__ = [ | |||
| 'func_post_proc' | |||
| ] | |||
| from typing import Any | |||
| @@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool: | |||
| return False | |||
| def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': | |||
| """ | |||
| 将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 | |||
| 后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 | |||
| :param metric: metric对象 | |||
| :param fn: 作用于 metric 的 accumulate 方法的返回值 | |||
| :param method_name: 一般来说,对于 | |||
| :return: metric | |||
| """ | |||
| assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ | |||
| f"Parameter `metric` must have a {method_name} function." | |||
| assert callable(fn), "Parameter `fn` must be callable." | |||
| func = getattr(metric, method_name) | |||
| @wraps(func) | |||
| def wrap_method(*args, **kwargs): | |||
| res = func(*args, **kwargs) | |||
| return fn(res) | |||
| wrap_method.__wrapped_by_func_post_proc__ = True | |||
| setattr(metric, method_name, wrap_method) | |||
| return metric | |||
| class AggregateMethodError(BaseException): | |||
| def __init__(self, should_have_aggregate_method, only_warn=False): | |||
| super(AggregateMethodError, self).__init__(self) | |||
| self.should_have_aggregate_method = should_have_aggregate_method | |||
| self.only_warn = only_warn | |||
| def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
| r""" | |||
| :param tp: int, true positive | |||
| :param fn: int, false negative | |||
| :param fp: int, false positive | |||
| :return: (f, pre, rec) | |||
| """ | |||
| pre = tp / (fp + tp + 1e-13) | |||
| rec = tp / (fn + tp + 1e-13) | |||
| f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
| return f, pre, rec | |||
| @@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric: | |||
| target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | |||
| 0, 3, 0, 0, 0, 1, 3, 1]) | |||
| metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) | |||
| metric = ClassifyFPreRecMetric(f_type='macro') | |||
| metric.update(pred, target) | |||
| result_dict = metric.get_metric() | |||
| f1_score = 0.1882051282051282 | |||
| @@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric: | |||
| for keys in ['f', 'pre', 'rec']: | |||
| np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
| metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) | |||
| metric = ClassifyFPreRecMetric(f_type='micro') | |||
| metric.update(pred, target) | |||
| result_dict = metric.get_metric() | |||
| f1_score = 0.21875 | |||
| @@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric: | |||
| for keys in ['f', 'pre', 'rec']: | |||
| np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
| metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) | |||
| metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') | |||
| metric.update(pred, target) | |||
| result_dict = metric.get_metric() | |||
| ground_truth = { | |||
| @@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric: | |||
| }) | |||
| metric_kwargs = { | |||
| 'f_type': f_type, | |||
| 'num_class': 5, | |||
| 'only_gross': False, | |||
| 'aggregate_when_get_metric': True | |||
| } | |||
| @@ -102,7 +102,8 @@ class TestSpanFPreRecMetric: | |||
| # bio tag | |||
| fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
| fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | |||
| fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
| fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False, | |||
| aggregate_when_get_metric=True) | |||
| bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
| -0.3782, 0.8240], | |||
| [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
| @@ -1,32 +0,0 @@ | |||
| import unittest | |||
| from fastNLP.core.metrics.utils import func_post_proc | |||
| class Metric: | |||
| def accumulate(self, x, y): | |||
| return x, y | |||
| def compute(self, x, y): | |||
| return x, y | |||
| class TestMetricUtil(unittest.TestCase): | |||
| def test_func_post_proc(self): | |||
| metric = Metric() | |||
| metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate') | |||
| self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2)) | |||
| func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate') | |||
| self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2)) | |||
| metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||
| self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2)) | |||
| func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update') | |||
| self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2)) | |||
| def test_check_accumulate_post_special_local_variable(self): | |||
| metric = Metric() | |||
| self.assertFalse(hasattr(metric, '__wrapped_by_fn__')) | |||
| metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||
| self.assertTrue(hasattr(metric, '__wrapped_by_fn__')) | |||