@@ -7,7 +7,6 @@ __all__ = [ | |||||
'TorchBackend', | 'TorchBackend', | ||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
'func_post_proc' | |||||
] | ] | ||||
from .metric import Metric | from .metric import Metric | ||||
@@ -15,4 +14,3 @@ from .accuracy import Accuracy | |||||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | ||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | from .span_f1_pre_rec_metric import SpanFPreRecMetric | ||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | ||||
from .utils import func_post_proc |
@@ -3,40 +3,24 @@ __all__ = [ | |||||
] | ] | ||||
from typing import Union, List | from typing import Union, List | ||||
from collections import defaultdict | |||||
from functools import partial | |||||
from collections import Counter | |||||
import warnings | import warnings | ||||
from .metric import Metric | from .metric import Metric | ||||
from .backend import Backend | from .backend import Backend | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.utils.utils import seq_len_to_mask | from fastNLP.core.utils.utils import seq_len_to_mask | ||||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
r""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec | |||||
from .utils import _compute_f_pre_rec | |||||
class ClassifyFPreRecMetric(Metric): | class ClassifyFPreRecMetric(Metric): | ||||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, | |||||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | ||||
aggregate_when_get_metric: bool = None) -> None: | aggregate_when_get_metric: bool = None) -> None: | ||||
""" | """ | ||||
:param tag_vocab: | :param tag_vocab: | ||||
:param ignore_labels: | :param ignore_labels: | ||||
:param num_class: | |||||
:param only_gross: | :param only_gross: | ||||
:param f_type: | :param f_type: | ||||
:param beta: | :param beta: | ||||
@@ -60,32 +44,15 @@ class ClassifyFPreRecMetric(Metric): | |||||
self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
self._tp = {} | |||||
self._fp = {} | |||||
self._fn = {} | |||||
if tag_vocab: | |||||
for word, _ in tag_vocab: | |||||
word = word.lower() | |||||
if word != 'o': | |||||
word = word[2:] | |||||
if word in self._true_positives: | |||||
continue | |||||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||||
backend=backend) | |||||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||||
backend=backend) | |||||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||||
backend=backend) | |||||
elif num_class > 0: | |||||
for word in range(num_class): | |||||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||||
backend=backend) | |||||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||||
backend=backend) | |||||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||||
backend=backend) | |||||
else: | |||||
raise ValueError() | |||||
self._tp = Counter() | |||||
self._fp = Counter() | |||||
self._fn = Counter() | |||||
def reset(self): | |||||
# 由于不是 element 了,需要自己手动清零一下 | |||||
self._tp.clear() | |||||
self._fp.clear() | |||||
self._fn.clear() | |||||
def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
r""" | r""" | ||||
@@ -94,10 +61,22 @@ class ClassifyFPreRecMetric(Metric): | |||||
:return dict evaluate_result: {"acc": float} | :return dict evaluate_result: {"acc": float} | ||||
""" | """ | ||||
evaluate_result = {} | evaluate_result = {} | ||||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||||
if self.aggregate_when_get_metric: | |||||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||||
tps, fps, fns = zip(*ls) | |||||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
for _c in cs: | |||||
c.update(_c) | |||||
else: | |||||
_tp, _fp, _fn = self._tp, self._fp, self._tp | |||||
if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
tags = set(self._fn.keys()) | |||||
tags.update(set(self._fp.keys())) | |||||
tags.update(set(self._tp.keys())) | |||||
tags = set(_fn.keys()) | |||||
tags.update(set(_fp.keys())) | |||||
tags.update(set(_tp.keys())) | |||||
f_sum = 0 | f_sum = 0 | ||||
pre_sum = 0 | pre_sum = 0 | ||||
rec_sum = 0 | rec_sum = 0 | ||||
@@ -106,9 +85,9 @@ class ClassifyFPreRecMetric(Metric): | |||||
tag_name = self.tag_vocab.to_word(tag) | tag_name = self.tag_vocab.to_word(tag) | ||||
else: | else: | ||||
tag_name = int(tag) | tag_name = int(tag) | ||||
tp = self._tp[tag].get_scalar() | |||||
fn = self._fn[tag].get_scalar() | |||||
fp = self._fp[tag].get_scalar() | |||||
tp = _tp[tag] | |||||
fn = _fn[tag] | |||||
fp = _fp[tag] | |||||
if tp == fn == fp == 0: | if tp == fn == fp == 0: | ||||
continue | continue | ||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | ||||
@@ -129,10 +108,7 @@ class ClassifyFPreRecMetric(Metric): | |||||
evaluate_result['rec'] = rec_sum / len(tags) | evaluate_result['rec'] = rec_sum / len(tags) | ||||
if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||||
sum(val.get_scalar() for val in self._tp.values()), | |||||
sum(val.get_scalar() for val in self._fn.values()), | |||||
sum(val.get_scalar() for val in self._fp.values())) | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||||
evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
@@ -4,12 +4,12 @@ __all__ = [ | |||||
from typing import Union, List, Optional | from typing import Union, List, Optional | ||||
import warnings | import warnings | ||||
from collections import defaultdict | |||||
from functools import partial | |||||
from collections import Counter | |||||
from fastNLP.core.metrics.backend import Backend | from fastNLP.core.metrics.backend import Backend | ||||
from fastNLP.core.metrics.metric import Metric | from fastNLP.core.metrics.metric import Metric | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from .utils import _compute_f_pre_rec | |||||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | ||||
@@ -199,21 +199,6 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | ||||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
r""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec | |||||
class SpanFPreRecMetric(Metric): | class SpanFPreRecMetric(Metric): | ||||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | ||||
@@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric): | |||||
self.only_gross = only_gross | self.only_gross = only_gross | ||||
self.tag_vocab = tag_vocab | self.tag_vocab = tag_vocab | ||||
self._true_positives = {} | |||||
self._false_positives = {} | |||||
self._false_negatives = {} | |||||
for word, _ in tag_vocab: | |||||
word = word.lower() | |||||
if word != 'o': | |||||
word = word[2:] | |||||
if word in self._true_positives: | |||||
continue | |||||
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||||
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) | |||||
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) | |||||
self._tp = Counter() | |||||
self._fp = Counter() | |||||
self._fn = Counter() | |||||
def reset(self): | |||||
self._tp.clear() | |||||
self._fp.clear() | |||||
self._fn.clear() | |||||
def get_metric(self) -> dict: | def get_metric(self) -> dict: | ||||
evaluate_result = {} | evaluate_result = {} | ||||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||||
if self.aggregate_when_get_metric: | |||||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||||
tps, fps, fns = zip(*ls) | |||||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||||
for _c in cs: | |||||
c.update(_c) | |||||
else: | |||||
_tp, _fp, _fn = self._tp, self._fp, self._tp | |||||
if not self.only_gross or self.f_type == 'macro': | if not self.only_gross or self.f_type == 'macro': | ||||
tags = set(self._false_negatives.keys()) | |||||
tags.update(self._false_positives.keys()) | |||||
tags.update(self._true_positives.keys()) | |||||
tags = set(_fn.keys()) | |||||
tags.update(_fp.keys()) | |||||
tags.update(_tp.keys()) | |||||
f_sum = 0 | f_sum = 0 | ||||
pre_sum = 0 | pre_sum = 0 | ||||
rec_sum = 0 | rec_sum = 0 | ||||
for tag in tags: | for tag in tags: | ||||
tp = self._true_positives[tag].get_scalar() | |||||
fn = self._false_negatives[tag].get_scalar() | |||||
fp = self._false_positives[tag].get_scalar() | |||||
tp = _tp[tag] | |||||
fn = _fn[tag] | |||||
fp = _fp[tag] | |||||
if tp == fn == fp == 0: | if tp == fn == fp == 0: | ||||
continue | continue | ||||
@@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric): | |||||
evaluate_result['rec'] = rec_sum / len(tags) | evaluate_result['rec'] = rec_sum / len(tags) | ||||
if self.f_type == 'micro': | if self.f_type == 'micro': | ||||
tp, fn, fp = [], [], [] | |||||
for val in self._true_positives.values(): | |||||
tp.append(val.get_scalar()) | |||||
for val in self._false_negatives.values(): | |||||
fn.append(val.get_scalar()) | |||||
for val in self._false_positives.values(): | |||||
fp.append(val.get_scalar()) | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||||
sum(tp), | |||||
sum(fn), | |||||
sum(fp)) | |||||
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||||
evaluate_result['f'] = f | evaluate_result['f'] = f | ||||
evaluate_result['pre'] = pre | evaluate_result['pre'] = pre | ||||
evaluate_result['rec'] = rec | evaluate_result['rec'] = rec | ||||
@@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric): | |||||
for span in pred_spans: | for span in pred_spans: | ||||
if span in gold_spans: | if span in gold_spans: | ||||
self._true_positives[span[0]] += 1 | |||||
self._tp[span[0]] += 1 | |||||
gold_spans.remove(span) | gold_spans.remove(span) | ||||
else: | else: | ||||
self._false_positives[span[0]] += 1 | |||||
self._fp[span[0]] += 1 | |||||
for span in gold_spans: | for span in gold_spans: | ||||
self._false_negatives[span[0]] += 1 | |||||
self._fn[span[0]] += 1 |
@@ -1,5 +1,4 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'func_post_proc' | |||||
] | ] | ||||
from typing import Any | from typing import Any | ||||
@@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool: | |||||
return False | return False | ||||
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': | |||||
""" | |||||
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 | |||||
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 | |||||
:param metric: metric对象 | |||||
:param fn: 作用于 metric 的 accumulate 方法的返回值 | |||||
:param method_name: 一般来说,对于 | |||||
:return: metric | |||||
""" | |||||
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ | |||||
f"Parameter `metric` must have a {method_name} function." | |||||
assert callable(fn), "Parameter `fn` must be callable." | |||||
func = getattr(metric, method_name) | |||||
@wraps(func) | |||||
def wrap_method(*args, **kwargs): | |||||
res = func(*args, **kwargs) | |||||
return fn(res) | |||||
wrap_method.__wrapped_by_func_post_proc__ = True | |||||
setattr(metric, method_name, wrap_method) | |||||
return metric | |||||
class AggregateMethodError(BaseException): | class AggregateMethodError(BaseException): | ||||
def __init__(self, should_have_aggregate_method, only_warn=False): | def __init__(self, should_have_aggregate_method, only_warn=False): | ||||
super(AggregateMethodError, self).__init__(self) | super(AggregateMethodError, self).__init__(self) | ||||
self.should_have_aggregate_method = should_have_aggregate_method | self.should_have_aggregate_method = should_have_aggregate_method | ||||
self.only_warn = only_warn | self.only_warn = only_warn | ||||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||||
r""" | |||||
:param tp: int, true positive | |||||
:param fn: int, false negative | |||||
:param fp: int, false positive | |||||
:return: (f, pre, rec) | |||||
""" | |||||
pre = tp / (fp + tp + 1e-13) | |||||
rec = tp / (fn + tp + 1e-13) | |||||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||||
return f, pre, rec |
@@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric: | |||||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | ||||
0, 3, 0, 0, 0, 1, 3, 1]) | 0, 3, 0, 0, 0, 1, 3, 1]) | ||||
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) | |||||
metric = ClassifyFPreRecMetric(f_type='macro') | |||||
metric.update(pred, target) | metric.update(pred, target) | ||||
result_dict = metric.get_metric() | result_dict = metric.get_metric() | ||||
f1_score = 0.1882051282051282 | f1_score = 0.1882051282051282 | ||||
@@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric: | |||||
for keys in ['f', 'pre', 'rec']: | for keys in ['f', 'pre', 'rec']: | ||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | ||||
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) | |||||
metric = ClassifyFPreRecMetric(f_type='micro') | |||||
metric.update(pred, target) | metric.update(pred, target) | ||||
result_dict = metric.get_metric() | result_dict = metric.get_metric() | ||||
f1_score = 0.21875 | f1_score = 0.21875 | ||||
@@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric: | |||||
for keys in ['f', 'pre', 'rec']: | for keys in ['f', 'pre', 'rec']: | ||||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | ||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) | |||||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') | |||||
metric.update(pred, target) | metric.update(pred, target) | ||||
result_dict = metric.get_metric() | result_dict = metric.get_metric() | ||||
ground_truth = { | ground_truth = { | ||||
@@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric: | |||||
}) | }) | ||||
metric_kwargs = { | metric_kwargs = { | ||||
'f_type': f_type, | 'f_type': f_type, | ||||
'num_class': 5, | |||||
'only_gross': False, | 'only_gross': False, | ||||
'aggregate_when_get_metric': True | 'aggregate_when_get_metric': True | ||||
} | } | ||||
@@ -102,7 +102,8 @@ class TestSpanFPreRecMetric: | |||||
# bio tag | # bio tag | ||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | ||||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | ||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False, | |||||
aggregate_when_get_metric=True) | |||||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | ||||
-0.3782, 0.8240], | -0.3782, 0.8240], | ||||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | ||||
@@ -1,32 +0,0 @@ | |||||
import unittest | |||||
from fastNLP.core.metrics.utils import func_post_proc | |||||
class Metric: | |||||
def accumulate(self, x, y): | |||||
return x, y | |||||
def compute(self, x, y): | |||||
return x, y | |||||
class TestMetricUtil(unittest.TestCase): | |||||
def test_func_post_proc(self): | |||||
metric = Metric() | |||||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate') | |||||
self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2)) | |||||
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate') | |||||
self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2)) | |||||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||||
self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2)) | |||||
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update') | |||||
self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2)) | |||||
def test_check_accumulate_post_special_local_variable(self): | |||||
metric = Metric() | |||||
self.assertFalse(hasattr(metric, '__wrapped_by_fn__')) | |||||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||||
self.assertTrue(hasattr(metric, '__wrapped_by_fn__')) |