diff --git a/fastNLP/core/metrics/__init__.py b/fastNLP/core/metrics/__init__.py index 4ab0ed36..82bca331 100644 --- a/fastNLP/core/metrics/__init__.py +++ b/fastNLP/core/metrics/__init__.py @@ -7,7 +7,6 @@ __all__ = [ 'TorchBackend', 'SpanFPreRecMetric', 'ClassifyFPreRecMetric', - 'func_post_proc' ] from .metric import Metric @@ -15,4 +14,3 @@ from .accuracy import Accuracy from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend from .span_f1_pre_rec_metric import SpanFPreRecMetric from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric -from .utils import func_post_proc diff --git a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py index c030d257..87b022c9 100644 --- a/fastNLP/core/metrics/classify_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/classify_f1_pre_rec_metric.py @@ -3,40 +3,24 @@ __all__ = [ ] from typing import Union, List -from collections import defaultdict -from functools import partial +from collections import Counter import warnings from .metric import Metric from .backend import Backend from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.utils.utils import seq_len_to_mask - - -def _compute_f_pre_rec(beta_square, tp, fn, fp): - r""" - - :param tp: int, true positive - :param fn: int, false negative - :param fp: int, false positive - :return: (f, pre, rec) - """ - pre = tp / (fp + tp + 1e-13) - rec = tp / (fn + tp + 1e-13) - f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) - - return f, pre, rec +from .utils import _compute_f_pre_rec class ClassifyFPreRecMetric(Metric): - def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, + def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = None) -> None: """ :param tag_vocab: :param ignore_labels: - :param num_class: :param only_gross: :param f_type: :param beta: @@ -60,32 +44,15 @@ class ClassifyFPreRecMetric(Metric): self.tag_vocab = tag_vocab - self._tp = {} - self._fp = {} - self._fn = {} - if tag_vocab: - for word, _ in tag_vocab: - word = word.lower() - if word != 'o': - word = word[2:] - if word in self._true_positives: - continue - self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', - backend=backend) - self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', - backend=backend) - self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', - backend=backend) - elif num_class > 0: - for word in range(num_class): - self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', - backend=backend) - self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', - backend=backend) - self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', - backend=backend) - else: - raise ValueError() + self._tp = Counter() + self._fp = Counter() + self._fn = Counter() + + def reset(self): + # 由于不是 element 了,需要自己手动清零一下 + self._tp.clear() + self._fp.clear() + self._fn.clear() def get_metric(self) -> dict: r""" @@ -94,10 +61,22 @@ class ClassifyFPreRecMetric(Metric): :return dict evaluate_result: {"acc": float} """ evaluate_result = {} + + # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 + if self.aggregate_when_get_metric: + ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) + tps, fps, fns = zip(*ls) + _tp, _fp, _fn = Counter(), Counter(), Counter() + for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): + for _c in cs: + c.update(_c) + else: + _tp, _fp, _fn = self._tp, self._fp, self._tp + if not self.only_gross or self.f_type == 'macro': - tags = set(self._fn.keys()) - tags.update(set(self._fp.keys())) - tags.update(set(self._tp.keys())) + tags = set(_fn.keys()) + tags.update(set(_fp.keys())) + tags.update(set(_tp.keys())) f_sum = 0 pre_sum = 0 rec_sum = 0 @@ -106,9 +85,9 @@ class ClassifyFPreRecMetric(Metric): tag_name = self.tag_vocab.to_word(tag) else: tag_name = int(tag) - tp = self._tp[tag].get_scalar() - fn = self._fn[tag].get_scalar() - fp = self._fp[tag].get_scalar() + tp = _tp[tag] + fn = _fn[tag] + fp = _fp[tag] if tp == fn == fp == 0: continue f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) @@ -129,10 +108,7 @@ class ClassifyFPreRecMetric(Metric): evaluate_result['rec'] = rec_sum / len(tags) if self.f_type == 'micro': - f, pre, rec = _compute_f_pre_rec(self.beta_square, - sum(val.get_scalar() for val in self._tp.values()), - sum(val.get_scalar() for val in self._fn.values()), - sum(val.get_scalar() for val in self._fp.values())) + f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py index a49914a5..d847da41 100644 --- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -4,12 +4,12 @@ __all__ = [ from typing import Union, List, Optional import warnings -from collections import defaultdict -from functools import partial +from collections import Counter from fastNLP.core.metrics.backend import Backend from fastNLP.core.metrics.metric import Metric from fastNLP.core.vocabulary import Vocabulary +from .utils import _compute_f_pre_rec def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): @@ -199,21 +199,6 @@ def _bio_tag_to_spans(tags, ignore_labels=None): return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] -def _compute_f_pre_rec(beta_square, tp, fn, fp): - r""" - - :param tp: int, true positive - :param fn: int, false negative - :param fp: int, false positive - :return: (f, pre, rec) - """ - pre = tp / (fp + tp + 1e-13) - rec = tp / (fn + tp + 1e-13) - f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) - - return f, pre, rec - - class SpanFPreRecMetric(Metric): def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, @@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric): self.only_gross = only_gross self.tag_vocab = tag_vocab - self._true_positives = {} - self._false_positives = {} - self._false_negatives = {} - for word, _ in tag_vocab: - word = word.lower() - if word != 'o': - word = word[2:] - if word in self._true_positives: - continue - self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) - self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) - self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) + self._tp = Counter() + self._fp = Counter() + self._fn = Counter() + + def reset(self): + self._tp.clear() + self._fp.clear() + self._fn.clear() def get_metric(self) -> dict: evaluate_result = {} + + # 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 + if self.aggregate_when_get_metric: + ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) + tps, fps, fns = zip(*ls) + _tp, _fp, _fn = Counter(), Counter(), Counter() + for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): + for _c in cs: + c.update(_c) + else: + _tp, _fp, _fn = self._tp, self._fp, self._tp + if not self.only_gross or self.f_type == 'macro': - tags = set(self._false_negatives.keys()) - tags.update(self._false_positives.keys()) - tags.update(self._true_positives.keys()) + tags = set(_fn.keys()) + tags.update(_fp.keys()) + tags.update(_tp.keys()) f_sum = 0 pre_sum = 0 rec_sum = 0 for tag in tags: - tp = self._true_positives[tag].get_scalar() - fn = self._false_negatives[tag].get_scalar() - fp = self._false_positives[tag].get_scalar() + tp = _tp[tag] + fn = _fn[tag] + fp = _fp[tag] if tp == fn == fp == 0: continue @@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric): evaluate_result['rec'] = rec_sum / len(tags) if self.f_type == 'micro': - tp, fn, fp = [], [], [] - for val in self._true_positives.values(): - tp.append(val.get_scalar()) - for val in self._false_negatives.values(): - fn.append(val.get_scalar()) - for val in self._false_positives.values(): - fp.append(val.get_scalar()) - f, pre, rec = _compute_f_pre_rec(self.beta_square, - sum(tp), - sum(fn), - sum(fp)) + f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec @@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric): for span in pred_spans: if span in gold_spans: - self._true_positives[span[0]] += 1 + self._tp[span[0]] += 1 gold_spans.remove(span) else: - self._false_positives[span[0]] += 1 + self._fp[span[0]] += 1 for span in gold_spans: - self._false_negatives[span[0]] += 1 + self._fn[span[0]] += 1 diff --git a/fastNLP/core/metrics/utils.py b/fastNLP/core/metrics/utils.py index beafd6f4..ce6f618b 100644 --- a/fastNLP/core/metrics/utils.py +++ b/fastNLP/core/metrics/utils.py @@ -1,5 +1,4 @@ __all__ = [ - 'func_post_proc' ] from typing import Any @@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool: return False -def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': - """ - 将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 - 后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 - - :param metric: metric对象 - :param fn: 作用于 metric 的 accumulate 方法的返回值 - :param method_name: 一般来说,对于 - :return: metric - """ - assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ - f"Parameter `metric` must have a {method_name} function." - assert callable(fn), "Parameter `fn` must be callable." - - func = getattr(metric, method_name) - - @wraps(func) - def wrap_method(*args, **kwargs): - res = func(*args, **kwargs) - return fn(res) - - wrap_method.__wrapped_by_func_post_proc__ = True - setattr(metric, method_name, wrap_method) - return metric - - class AggregateMethodError(BaseException): def __init__(self, should_have_aggregate_method, only_warn=False): super(AggregateMethodError, self).__init__(self) self.should_have_aggregate_method = should_have_aggregate_method self.only_warn = only_warn + + +def _compute_f_pre_rec(beta_square, tp, fn, fp): + r""" + + :param tp: int, true positive + :param fn: int, false negative + :param fp: int, false positive + :return: (f, pre, rec) + """ + pre = tp / (fp + tp + 1e-13) + rec = tp / (fn + tp + 1e-13) + f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) + + return f, pre, rec diff --git a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py index 268adbd3..6dd7596e 100644 --- a/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py +++ b/tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py @@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric: target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, 0, 3, 0, 0, 0, 1, 3, 1]) - metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) + metric = ClassifyFPreRecMetric(f_type='macro') metric.update(pred, target) result_dict = metric.get_metric() f1_score = 0.1882051282051282 @@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric: for keys in ['f', 'pre', 'rec']: np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) - metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) + metric = ClassifyFPreRecMetric(f_type='micro') metric.update(pred, target) result_dict = metric.get_metric() f1_score = 0.21875 @@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric: for keys in ['f', 'pre', 'rec']: np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) - metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) + metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') metric.update(pred, target) result_dict = metric.get_metric() ground_truth = { @@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric: }) metric_kwargs = { 'f_type': f_type, - 'num_class': 5, 'only_gross': False, 'aggregate_when_get_metric': True } diff --git a/tests/core/metrics/test_span_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py index f0a420d9..105afea9 100644 --- a/tests/core/metrics/test_span_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -102,7 +102,8 @@ class TestSpanFPreRecMetric: # bio tag fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) - fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) + fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False, + aggregate_when_get_metric=True) bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, -0.3782, 0.8240], [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, diff --git a/tests/core/metrics/test_utils.py b/tests/core/metrics/test_utils.py deleted file mode 100644 index 6a443df0..00000000 --- a/tests/core/metrics/test_utils.py +++ /dev/null @@ -1,32 +0,0 @@ -import unittest -from fastNLP.core.metrics.utils import func_post_proc - - -class Metric: - def accumulate(self, x, y): - return x, y - - def compute(self, x, y): - return x, y - - -class TestMetricUtil(unittest.TestCase): - def test_func_post_proc(self): - metric = Metric() - metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate') - self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2)) - - func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate') - self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2)) - - metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') - self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2)) - - func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update') - self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2)) - - def test_check_accumulate_post_special_local_variable(self): - metric = Metric() - self.assertFalse(hasattr(metric, '__wrapped_by_fn__')) - metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') - self.assertTrue(hasattr(metric, '__wrapped_by_fn__'))