| @@ -7,7 +7,6 @@ __all__ = [ | |||||
| 'TorchBackend', | 'TorchBackend', | ||||
| 'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
| 'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
| 'func_post_proc' | |||||
| ] | ] | ||||
| from .metric import Metric | from .metric import Metric | ||||
| @@ -15,4 +14,3 @@ from .accuracy import Accuracy | |||||
| from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | ||||
| from .span_f1_pre_rec_metric import SpanFPreRecMetric | from .span_f1_pre_rec_metric import SpanFPreRecMetric | ||||
| from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | ||||
| from .utils import func_post_proc | |||||
| @@ -3,40 +3,24 @@ __all__ = [ | |||||
| ] | ] | ||||
| from typing import Union, List | from typing import Union, List | ||||
| from collections import defaultdict | |||||
| from functools import partial | |||||
| from collections import Counter | |||||
| import warnings | import warnings | ||||
| from .metric import Metric | from .metric import Metric | ||||
| from .backend import Backend | from .backend import Backend | ||||
| from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
| from fastNLP.core.utils.utils import seq_len_to_mask | from fastNLP.core.utils.utils import seq_len_to_mask | ||||
| def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
| r""" | |||||
| :param tp: int, true positive | |||||
| :param fn: int, false negative | |||||
| :param fp: int, false positive | |||||
| :return: (f, pre, rec) | |||||
| """ | |||||
| pre = tp / (fp + tp + 1e-13) | |||||
| rec = tp / (fn + tp + 1e-13) | |||||
| f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
| return f, pre, rec | |||||
| from .utils import _compute_f_pre_rec | |||||
| class ClassifyFPreRecMetric(Metric): | class ClassifyFPreRecMetric(Metric): | ||||
| def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||||
| def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, | |||||
| only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | ||||
| aggregate_when_get_metric: bool = None) -> None: | aggregate_when_get_metric: bool = None) -> None: | ||||
| """ | """ | ||||
| :param tag_vocab: | :param tag_vocab: | ||||
| :param ignore_labels: | :param ignore_labels: | ||||
| :param num_class: | |||||
| :param only_gross: | :param only_gross: | ||||
| :param f_type: | :param f_type: | ||||
| :param beta: | :param beta: | ||||
| @@ -60,32 +44,15 @@ class ClassifyFPreRecMetric(Metric): | |||||
| self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
| self._tp = {} | |||||
| self._fp = {} | |||||
| self._fn = {} | |||||
| if tag_vocab: | |||||
| for word, _ in tag_vocab: | |||||
| word = word.lower() | |||||
| if word != 'o': | |||||
| word = word[2:] | |||||
| if word in self._true_positives: | |||||
| continue | |||||
| self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| elif num_class > 0: | |||||
| for word in range(num_class): | |||||
| self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||||
| backend=backend) | |||||
| else: | |||||
| raise ValueError() | |||||
| self._tp = Counter() | |||||
| self._fp = Counter() | |||||
| self._fn = Counter() | |||||
| def reset(self): | |||||
| # 由于不是 element 了,需要自己手动清零一下 | |||||
| self._tp.clear() | |||||
| self._fp.clear() | |||||
| self._fn.clear() | |||||
| def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
| r""" | r""" | ||||
| @@ -94,10 +61,22 @@ class ClassifyFPreRecMetric(Metric): | |||||
| :return dict evaluate_result: {"acc": float} | :return dict evaluate_result: {"acc": float} | ||||
| """ | """ | ||||
| evaluate_result = {} | evaluate_result = {} | ||||
| # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||||
| if self.aggregate_when_get_metric: | |||||
| ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||||
| tps, fps, fns = zip(*ls) | |||||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
| for _c in cs: | |||||
| c.update(_c) | |||||
| else: | |||||
| _tp, _fp, _fn = self._tp, self._fp, self._tp | |||||
| if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
| tags = set(self._fn.keys()) | |||||
| tags.update(set(self._fp.keys())) | |||||
| tags.update(set(self._tp.keys())) | |||||
| tags = set(_fn.keys()) | |||||
| tags.update(set(_fp.keys())) | |||||
| tags.update(set(_tp.keys())) | |||||
| f_sum = 0 | f_sum = 0 | ||||
| pre_sum = 0 | pre_sum = 0 | ||||
| rec_sum = 0 | rec_sum = 0 | ||||
| @@ -106,9 +85,9 @@ class ClassifyFPreRecMetric(Metric): | |||||
| tag_name = self.tag_vocab.to_word(tag) | tag_name = self.tag_vocab.to_word(tag) | ||||
| else: | else: | ||||
| tag_name = int(tag) | tag_name = int(tag) | ||||
| tp = self._tp[tag].get_scalar() | |||||
| fn = self._fn[tag].get_scalar() | |||||
| fp = self._fp[tag].get_scalar() | |||||
| tp = _tp[tag] | |||||
| fn = _fn[tag] | |||||
| fp = _fp[tag] | |||||
| if tp == fn == fp == 0: | if tp == fn == fp == 0: | ||||
| continue | continue | ||||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | ||||
| @@ -129,10 +108,7 @@ class ClassifyFPreRecMetric(Metric): | |||||
| evaluate_result['rec'] = rec_sum / len(tags) | evaluate_result['rec'] = rec_sum / len(tags) | ||||
| if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||||
| sum(val.get_scalar() for val in self._tp.values()), | |||||
| sum(val.get_scalar() for val in self._fn.values()), | |||||
| sum(val.get_scalar() for val in self._fp.values())) | |||||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||||
| evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
| evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
| evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
| @@ -4,12 +4,12 @@ __all__ = [ | |||||
| from typing import Union, List, Optional | from typing import Union, List, Optional | ||||
| import warnings | import warnings | ||||
| from collections import defaultdict | |||||
| from functools import partial | |||||
| from collections import Counter | |||||
| from fastNLP.core.metrics.backend import Backend | from fastNLP.core.metrics.backend import Backend | ||||
| from fastNLP.core.metrics.metric import Metric | from fastNLP.core.metrics.metric import Metric | ||||
| from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
| from .utils import _compute_f_pre_rec | |||||
| def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | ||||
| @@ -199,21 +199,6 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
| return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | ||||
| def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
| r""" | |||||
| :param tp: int, true positive | |||||
| :param fn: int, false negative | |||||
| :param fp: int, false positive | |||||
| :return: (f, pre, rec) | |||||
| """ | |||||
| pre = tp / (fp + tp + 1e-13) | |||||
| rec = tp / (fn + tp + 1e-13) | |||||
| f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
| return f, pre, rec | |||||
| class SpanFPreRecMetric(Metric): | class SpanFPreRecMetric(Metric): | ||||
| def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | ||||
| @@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric): | |||||
| self.only_gross = only_gross | self.only_gross = only_gross | ||||
| self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
| self._true_positives = {} | |||||
| self._false_positives = {} | |||||
| self._false_negatives = {} | |||||
| for word, _ in tag_vocab: | |||||
| word = word.lower() | |||||
| if word != 'o': | |||||
| word = word[2:] | |||||
| if word in self._true_positives: | |||||
| continue | |||||
| self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||||
| self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) | |||||
| self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) | |||||
| self._tp = Counter() | |||||
| self._fp = Counter() | |||||
| self._fn = Counter() | |||||
| def reset(self): | |||||
| self._tp.clear() | |||||
| self._fp.clear() | |||||
| self._fn.clear() | |||||
| def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
| evaluate_result = {} | evaluate_result = {} | ||||
| # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||||
| if self.aggregate_when_get_metric: | |||||
| ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||||
| tps, fps, fns = zip(*ls) | |||||
| _tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
| for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
| for _c in cs: | |||||
| c.update(_c) | |||||
| else: | |||||
| _tp, _fp, _fn = self._tp, self._fp, self._tp | |||||
| if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
| tags = set(self._false_negatives.keys()) | |||||
| tags.update(self._false_positives.keys()) | |||||
| tags.update(self._true_positives.keys()) | |||||
| tags = set(_fn.keys()) | |||||
| tags.update(_fp.keys()) | |||||
| tags.update(_tp.keys()) | |||||
| f_sum = 0 | f_sum = 0 | ||||
| pre_sum = 0 | pre_sum = 0 | ||||
| rec_sum = 0 | rec_sum = 0 | ||||
| for tag in tags: | for tag in tags: | ||||
| tp = self._true_positives[tag].get_scalar() | |||||
| fn = self._false_negatives[tag].get_scalar() | |||||
| fp = self._false_positives[tag].get_scalar() | |||||
| tp = _tp[tag] | |||||
| fn = _fn[tag] | |||||
| fp = _fp[tag] | |||||
| if tp == fn == fp == 0: | if tp == fn == fp == 0: | ||||
| continue | continue | ||||
| @@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric): | |||||
| evaluate_result['rec'] = rec_sum / len(tags) | evaluate_result['rec'] = rec_sum / len(tags) | ||||
| if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
| tp, fn, fp = [], [], [] | |||||
| for val in self._true_positives.values(): | |||||
| tp.append(val.get_scalar()) | |||||
| for val in self._false_negatives.values(): | |||||
| fn.append(val.get_scalar()) | |||||
| for val in self._false_positives.values(): | |||||
| fp.append(val.get_scalar()) | |||||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||||
| sum(tp), | |||||
| sum(fn), | |||||
| sum(fp)) | |||||
| f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||||
| evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
| evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
| evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
| @@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric): | |||||
| for span in pred_spans: | for span in pred_spans: | ||||
| if span in gold_spans: | if span in gold_spans: | ||||
| self._true_positives[span[0]] += 1 | |||||
| self._tp[span[0]] += 1 | |||||
| gold_spans.remove(span) | gold_spans.remove(span) | ||||
| else: | else: | ||||
| self._false_positives[span[0]] += 1 | |||||
| self._fp[span[0]] += 1 | |||||
| for span in gold_spans: | for span in gold_spans: | ||||
| self._false_negatives[span[0]] += 1 | |||||
| self._fn[span[0]] += 1 | |||||
| @@ -1,5 +1,4 @@ | |||||
| __all__ = [ | __all__ = [ | ||||
| 'func_post_proc' | |||||
| ] | ] | ||||
| from typing import Any | from typing import Any | ||||
| @@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool: | |||||
| return False | return False | ||||
| def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': | |||||
| """ | |||||
| 将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 | |||||
| 后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 | |||||
| :param metric: metric对象 | |||||
| :param fn: 作用于 metric 的 accumulate 方法的返回值 | |||||
| :param method_name: 一般来说,对于 | |||||
| :return: metric | |||||
| """ | |||||
| assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ | |||||
| f"Parameter `metric` must have a {method_name} function." | |||||
| assert callable(fn), "Parameter `fn` must be callable." | |||||
| func = getattr(metric, method_name) | |||||
| @wraps(func) | |||||
| def wrap_method(*args, **kwargs): | |||||
| res = func(*args, **kwargs) | |||||
| return fn(res) | |||||
| wrap_method.__wrapped_by_func_post_proc__ = True | |||||
| setattr(metric, method_name, wrap_method) | |||||
| return metric | |||||
| class AggregateMethodError(BaseException): | class AggregateMethodError(BaseException): | ||||
| def __init__(self, should_have_aggregate_method, only_warn=False): | def __init__(self, should_have_aggregate_method, only_warn=False): | ||||
| super(AggregateMethodError, self).__init__(self) | super(AggregateMethodError, self).__init__(self) | ||||
| self.should_have_aggregate_method = should_have_aggregate_method | self.should_have_aggregate_method = should_have_aggregate_method | ||||
| self.only_warn = only_warn | self.only_warn = only_warn | ||||
| def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
| r""" | |||||
| :param tp: int, true positive | |||||
| :param fn: int, false negative | |||||
| :param fp: int, false positive | |||||
| :return: (f, pre, rec) | |||||
| """ | |||||
| pre = tp / (fp + tp + 1e-13) | |||||
| rec = tp / (fn + tp + 1e-13) | |||||
| f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
| return f, pre, rec | |||||
| @@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric: | |||||
| target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | ||||
| 0, 3, 0, 0, 0, 1, 3, 1]) | 0, 3, 0, 0, 0, 1, 3, 1]) | ||||
| metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) | |||||
| metric = ClassifyFPreRecMetric(f_type='macro') | |||||
| metric.update(pred, target) | metric.update(pred, target) | ||||
| result_dict = metric.get_metric() | result_dict = metric.get_metric() | ||||
| f1_score = 0.1882051282051282 | f1_score = 0.1882051282051282 | ||||
| @@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric: | |||||
| for keys in ['f', 'pre', 'rec']: | for keys in ['f', 'pre', 'rec']: | ||||
| np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | ||||
| metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) | |||||
| metric = ClassifyFPreRecMetric(f_type='micro') | |||||
| metric.update(pred, target) | metric.update(pred, target) | ||||
| result_dict = metric.get_metric() | result_dict = metric.get_metric() | ||||
| f1_score = 0.21875 | f1_score = 0.21875 | ||||
| @@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric: | |||||
| for keys in ['f', 'pre', 'rec']: | for keys in ['f', 'pre', 'rec']: | ||||
| np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | ||||
| metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) | |||||
| metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') | |||||
| metric.update(pred, target) | metric.update(pred, target) | ||||
| result_dict = metric.get_metric() | result_dict = metric.get_metric() | ||||
| ground_truth = { | ground_truth = { | ||||
| @@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric: | |||||
| }) | }) | ||||
| metric_kwargs = { | metric_kwargs = { | ||||
| 'f_type': f_type, | 'f_type': f_type, | ||||
| 'num_class': 5, | |||||
| 'only_gross': False, | 'only_gross': False, | ||||
| 'aggregate_when_get_metric': True | 'aggregate_when_get_metric': True | ||||
| } | } | ||||
| @@ -102,7 +102,8 @@ class TestSpanFPreRecMetric: | |||||
| # bio tag | # bio tag | ||||
| fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | ||||
| fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | ||||
| fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||||
| fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False, | |||||
| aggregate_when_get_metric=True) | |||||
| bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | ||||
| -0.3782, 0.8240], | -0.3782, 0.8240], | ||||
| [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | ||||
| @@ -1,32 +0,0 @@ | |||||
| import unittest | |||||
| from fastNLP.core.metrics.utils import func_post_proc | |||||
| class Metric: | |||||
| def accumulate(self, x, y): | |||||
| return x, y | |||||
| def compute(self, x, y): | |||||
| return x, y | |||||
| class TestMetricUtil(unittest.TestCase): | |||||
| def test_func_post_proc(self): | |||||
| metric = Metric() | |||||
| metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate') | |||||
| self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2)) | |||||
| func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate') | |||||
| self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2)) | |||||
| metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||||
| self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2)) | |||||
| func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update') | |||||
| self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2)) | |||||
| def test_check_accumulate_post_special_local_variable(self): | |||||
| metric = Metric() | |||||
| self.assertFalse(hasattr(metric, '__wrapped_by_fn__')) | |||||
| metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||||
| self.assertTrue(hasattr(metric, '__wrapped_by_fn__')) | |||||