@@ -7,7 +7,6 @@ __all__ = [ | |||
'TorchBackend', | |||
'SpanFPreRecMetric', | |||
'ClassifyFPreRecMetric', | |||
'func_post_proc' | |||
] | |||
from .metric import Metric | |||
@@ -15,4 +14,3 @@ from .accuracy import Accuracy | |||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | |||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | |||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | |||
from .utils import func_post_proc |
@@ -3,40 +3,24 @@ __all__ = [ | |||
] | |||
from typing import Union, List | |||
from collections import defaultdict | |||
from functools import partial | |||
from collections import Counter | |||
import warnings | |||
from .metric import Metric | |||
from .backend import Backend | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.utils.utils import seq_len_to_mask | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||
from .utils import _compute_f_pre_rec | |||
class ClassifyFPreRecMetric(Metric): | |||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0, | |||
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, | |||
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = None) -> None: | |||
""" | |||
:param tag_vocab: | |||
:param ignore_labels: | |||
:param num_class: | |||
:param only_gross: | |||
:param f_type: | |||
:param beta: | |||
@@ -60,32 +44,15 @@ class ClassifyFPreRecMetric(Metric): | |||
self.tag_vocab = tag_vocab | |||
self._tp = {} | |||
self._fp = {} | |||
self._fn = {} | |||
if tag_vocab: | |||
for word, _ in tag_vocab: | |||
word = word.lower() | |||
if word != 'o': | |||
word = word[2:] | |||
if word in self._true_positives: | |||
continue | |||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
elif num_class > 0: | |||
for word in range(num_class): | |||
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', | |||
backend=backend) | |||
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', | |||
backend=backend) | |||
else: | |||
raise ValueError() | |||
self._tp = Counter() | |||
self._fp = Counter() | |||
self._fn = Counter() | |||
def reset(self): | |||
# 由于不是 element 了,需要自己手动清零一下 | |||
self._tp.clear() | |||
self._fp.clear() | |||
self._fn.clear() | |||
def get_metric(self) -> dict: | |||
r""" | |||
@@ -94,10 +61,22 @@ class ClassifyFPreRecMetric(Metric): | |||
:return dict evaluate_result: {"acc": float} | |||
""" | |||
evaluate_result = {} | |||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||
if self.aggregate_when_get_metric: | |||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||
tps, fps, fns = zip(*ls) | |||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
for _c in cs: | |||
c.update(_c) | |||
else: | |||
_tp, _fp, _fn = self._tp, self._fp, self._tp | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._fn.keys()) | |||
tags.update(set(self._fp.keys())) | |||
tags.update(set(self._tp.keys())) | |||
tags = set(_fn.keys()) | |||
tags.update(set(_fp.keys())) | |||
tags.update(set(_tp.keys())) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
@@ -106,9 +85,9 @@ class ClassifyFPreRecMetric(Metric): | |||
tag_name = self.tag_vocab.to_word(tag) | |||
else: | |||
tag_name = int(tag) | |||
tp = self._tp[tag].get_scalar() | |||
fn = self._fn[tag].get_scalar() | |||
fp = self._fp[tag].get_scalar() | |||
tp = _tp[tag] | |||
fn = _fn[tag] | |||
fp = _fp[tag] | |||
if tp == fn == fp == 0: | |||
continue | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
@@ -129,10 +108,7 @@ class ClassifyFPreRecMetric(Metric): | |||
evaluate_result['rec'] = rec_sum / len(tags) | |||
if self.f_type == 'micro': | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(val.get_scalar() for val in self._tp.values()), | |||
sum(val.get_scalar() for val in self._fn.values()), | |||
sum(val.get_scalar() for val in self._fp.values())) | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
@@ -4,12 +4,12 @@ __all__ = [ | |||
from typing import Union, List, Optional | |||
import warnings | |||
from collections import defaultdict | |||
from functools import partial | |||
from collections import Counter | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.metrics.metric import Metric | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from .utils import _compute_f_pre_rec | |||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | |||
@@ -199,21 +199,6 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||
class SpanFPreRecMetric(Metric): | |||
def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, | |||
@@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric): | |||
self.only_gross = only_gross | |||
self.tag_vocab = tag_vocab | |||
self._true_positives = {} | |||
self._false_positives = {} | |||
self._false_negatives = {} | |||
for word, _ in tag_vocab: | |||
word = word.lower() | |||
if word != 'o': | |||
word = word[2:] | |||
if word in self._true_positives: | |||
continue | |||
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) | |||
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) | |||
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) | |||
self._tp = Counter() | |||
self._fp = Counter() | |||
self._fn = Counter() | |||
def reset(self): | |||
self._tp.clear() | |||
self._fp.clear() | |||
self._fn.clear() | |||
def get_metric(self) -> dict: | |||
evaluate_result = {} | |||
# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。 | |||
if self.aggregate_when_get_metric: | |||
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn]) | |||
tps, fps, fns = zip(*ls) | |||
_tp, _fp, _fn = Counter(), Counter(), Counter() | |||
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]): | |||
for _c in cs: | |||
c.update(_c) | |||
else: | |||
_tp, _fp, _fn = self._tp, self._fp, self._tp | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._false_negatives.keys()) | |||
tags.update(self._false_positives.keys()) | |||
tags.update(self._true_positives.keys()) | |||
tags = set(_fn.keys()) | |||
tags.update(_fp.keys()) | |||
tags.update(_tp.keys()) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
for tag in tags: | |||
tp = self._true_positives[tag].get_scalar() | |||
fn = self._false_negatives[tag].get_scalar() | |||
fp = self._false_positives[tag].get_scalar() | |||
tp = _tp[tag] | |||
fn = _fn[tag] | |||
fp = _fp[tag] | |||
if tp == fn == fp == 0: | |||
continue | |||
@@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric): | |||
evaluate_result['rec'] = rec_sum / len(tags) | |||
if self.f_type == 'micro': | |||
tp, fn, fp = [], [], [] | |||
for val in self._true_positives.values(): | |||
tp.append(val.get_scalar()) | |||
for val in self._false_negatives.values(): | |||
fn.append(val.get_scalar()) | |||
for val in self._false_positives.values(): | |||
fp.append(val.get_scalar()) | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(tp), | |||
sum(fn), | |||
sum(fp)) | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
@@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric): | |||
for span in pred_spans: | |||
if span in gold_spans: | |||
self._true_positives[span[0]] += 1 | |||
self._tp[span[0]] += 1 | |||
gold_spans.remove(span) | |||
else: | |||
self._false_positives[span[0]] += 1 | |||
self._fp[span[0]] += 1 | |||
for span in gold_spans: | |||
self._false_negatives[span[0]] += 1 | |||
self._fn[span[0]] += 1 |
@@ -1,5 +1,4 @@ | |||
__all__ = [ | |||
'func_post_proc' | |||
] | |||
from typing import Any | |||
@@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool: | |||
return False | |||
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': | |||
""" | |||
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 | |||
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 | |||
:param metric: metric对象 | |||
:param fn: 作用于 metric 的 accumulate 方法的返回值 | |||
:param method_name: 一般来说,对于 | |||
:return: metric | |||
""" | |||
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ | |||
f"Parameter `metric` must have a {method_name} function." | |||
assert callable(fn), "Parameter `fn` must be callable." | |||
func = getattr(metric, method_name) | |||
@wraps(func) | |||
def wrap_method(*args, **kwargs): | |||
res = func(*args, **kwargs) | |||
return fn(res) | |||
wrap_method.__wrapped_by_func_post_proc__ = True | |||
setattr(metric, method_name, wrap_method) | |||
return metric | |||
class AggregateMethodError(BaseException): | |||
def __init__(self, should_have_aggregate_method, only_warn=False): | |||
super(AggregateMethodError, self).__init__(self) | |||
self.should_have_aggregate_method = should_have_aggregate_method | |||
self.only_warn = only_warn | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
return f, pre, rec |
@@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric: | |||
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3, | |||
0, 3, 0, 0, 0, 1, 3, 1]) | |||
metric = ClassifyFPreRecMetric(f_type='macro', num_class=5) | |||
metric = ClassifyFPreRecMetric(f_type='macro') | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
f1_score = 0.1882051282051282 | |||
@@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric: | |||
for keys in ['f', 'pre', 'rec']: | |||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
metric = ClassifyFPreRecMetric(f_type='micro', num_class=5) | |||
metric = ClassifyFPreRecMetric(f_type='micro') | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
f1_score = 0.21875 | |||
@@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric: | |||
for keys in ['f', 'pre', 'rec']: | |||
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001) | |||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5) | |||
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro') | |||
metric.update(pred, target) | |||
result_dict = metric.get_metric() | |||
ground_truth = { | |||
@@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric: | |||
}) | |||
metric_kwargs = { | |||
'f_type': f_type, | |||
'num_class': 5, | |||
'only_gross': False, | |||
'aggregate_when_get_metric': True | |||
} | |||
@@ -102,7 +102,8 @@ class TestSpanFPreRecMetric: | |||
# bio tag | |||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) | |||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False, | |||
aggregate_when_get_metric=True) | |||
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, | |||
-0.3782, 0.8240], | |||
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, | |||
@@ -1,32 +0,0 @@ | |||
import unittest | |||
from fastNLP.core.metrics.utils import func_post_proc | |||
class Metric: | |||
def accumulate(self, x, y): | |||
return x, y | |||
def compute(self, x, y): | |||
return x, y | |||
class TestMetricUtil(unittest.TestCase): | |||
def test_func_post_proc(self): | |||
metric = Metric() | |||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate') | |||
self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2)) | |||
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate') | |||
self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2)) | |||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||
self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2)) | |||
func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update') | |||
self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2)) | |||
def test_check_accumulate_post_special_local_variable(self): | |||
metric = Metric() | |||
self.assertFalse(hasattr(metric, '__wrapped_by_fn__')) | |||
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update') | |||
self.assertTrue(hasattr(metric, '__wrapped_by_fn__')) |