Browse Source

update metric的实现

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
d758a36d4a
7 changed files with 83 additions and 170 deletions
  1. +0
    -2
      fastNLP/core/metrics/__init__.py
  2. +31
    -55
      fastNLP/core/metrics/classify_f1_pre_rec_metric.py
  3. +32
    -49
      fastNLP/core/metrics/span_f1_pre_rec_metric.py
  4. +15
    -27
      fastNLP/core/metrics/utils.py
  5. +3
    -4
      tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py
  6. +2
    -1
      tests/core/metrics/test_span_f1_rec_acc_torch.py
  7. +0
    -32
      tests/core/metrics/test_utils.py

+ 0
- 2
fastNLP/core/metrics/__init__.py View File

@@ -7,7 +7,6 @@ __all__ = [
'TorchBackend',
'SpanFPreRecMetric',
'ClassifyFPreRecMetric',
'func_post_proc'
]

from .metric import Metric
@@ -15,4 +14,3 @@ from .accuracy import Accuracy
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend
from .span_f1_pre_rec_metric import SpanFPreRecMetric
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric
from .utils import func_post_proc

+ 31
- 55
fastNLP/core/metrics/classify_f1_pre_rec_metric.py View File

@@ -3,40 +3,24 @@ __all__ = [
]

from typing import Union, List
from collections import defaultdict
from functools import partial
from collections import Counter
import warnings

from .metric import Metric
from .backend import Backend
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.utils.utils import seq_len_to_mask


def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""

:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)

return f, pre, rec
from .utils import _compute_f_pre_rec


class ClassifyFPreRecMetric(Metric):
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None, num_class: int = 0,
def __init__(self, tag_vocab: Vocabulary = None, ignore_labels: List[str] = None,
only_gross: bool = True, f_type='micro', beta=1, backend: Union[str, Backend, None] = 'auto',
aggregate_when_get_metric: bool = None) -> None:
"""

:param tag_vocab:
:param ignore_labels:
:param num_class:
:param only_gross:
:param f_type:
:param beta:
@@ -60,32 +44,15 @@ class ClassifyFPreRecMetric(Metric):

self.tag_vocab = tag_vocab

self._tp = {}
self._fp = {}
self._fn = {}
if tag_vocab:
for word, _ in tag_vocab:
word = word.lower()
if word != 'o':
word = word[2:]
if word in self._true_positives:
continue
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
backend=backend)
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
backend=backend)
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
backend=backend)
elif num_class > 0:
for word in range(num_class):
self._tp[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum',
backend=backend)
self._fn[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum',
backend=backend)
self._fp[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum',
backend=backend)
else:
raise ValueError()
self._tp = Counter()
self._fp = Counter()
self._fn = Counter()

def reset(self):
# 由于不是 element 了,需要自己手动清零一下
self._tp.clear()
self._fp.clear()
self._fn.clear()

def get_metric(self) -> dict:
r"""
@@ -94,10 +61,22 @@ class ClassifyFPreRecMetric(Metric):
:return dict evaluate_result: {"acc": float}
"""
evaluate_result = {}

# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp

if not self.only_gross or self.f_type == 'macro':
tags = set(self._fn.keys())
tags.update(set(self._fp.keys()))
tags.update(set(self._tp.keys()))
tags = set(_fn.keys())
tags.update(set(_fp.keys()))
tags.update(set(_tp.keys()))
f_sum = 0
pre_sum = 0
rec_sum = 0
@@ -106,9 +85,9 @@ class ClassifyFPreRecMetric(Metric):
tag_name = self.tag_vocab.to_word(tag)
else:
tag_name = int(tag)
tp = self._tp[tag].get_scalar()
fn = self._fn[tag].get_scalar()
fp = self._fp[tag].get_scalar()
tp = _tp[tag]
fn = _fn[tag]
fp = _fp[tag]
if tp == fn == fp == 0:
continue
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp)
@@ -129,10 +108,7 @@ class ClassifyFPreRecMetric(Metric):
evaluate_result['rec'] = rec_sum / len(tags)

if self.f_type == 'micro':
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(val.get_scalar() for val in self._tp.values()),
sum(val.get_scalar() for val in self._fn.values()),
sum(val.get_scalar() for val in self._fp.values()))
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec


+ 32
- 49
fastNLP/core/metrics/span_f1_pre_rec_metric.py View File

@@ -4,12 +4,12 @@ __all__ = [

from typing import Union, List, Optional
import warnings
from collections import defaultdict
from functools import partial
from collections import Counter

from fastNLP.core.metrics.backend import Backend
from fastNLP.core.metrics.metric import Metric
from fastNLP.core.vocabulary import Vocabulary
from .utils import _compute_f_pre_rec


def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str):
@@ -199,21 +199,6 @@ def _bio_tag_to_spans(tags, ignore_labels=None):
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels]


def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""

:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)

return f, pre, rec


class SpanFPreRecMetric(Metric):

def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None,
@@ -266,32 +251,40 @@ class SpanFPreRecMetric(Metric):
self.only_gross = only_gross
self.tag_vocab = tag_vocab

self._true_positives = {}
self._false_positives = {}
self._false_negatives = {}
for word, _ in tag_vocab:
word = word.lower()
if word != 'o':
word = word[2:]
if word in self._true_positives:
continue
self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend)
self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend)
self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend)
self._tp = Counter()
self._fp = Counter()
self._fn = Counter()

def reset(self):
self._tp.clear()
self._fp.clear()
self._fn.clear()

def get_metric(self) -> dict:
evaluate_result = {}

# 通过 all_gather_object 将各个卡上的结果收集过来,并加和。
if self.aggregate_when_get_metric:
ls = self.backend.all_gather_object([self._tp, self._fp, self._fn])
tps, fps, fns = zip(*ls)
_tp, _fp, _fn = Counter(), Counter(), Counter()
for c, cs in zip([_tp, _fp, _fn], [tps, fps, fns]):
for _c in cs:
c.update(_c)
else:
_tp, _fp, _fn = self._tp, self._fp, self._tp

if not self.only_gross or self.f_type == 'macro':
tags = set(self._false_negatives.keys())
tags.update(self._false_positives.keys())
tags.update(self._true_positives.keys())
tags = set(_fn.keys())
tags.update(_fp.keys())
tags.update(_tp.keys())
f_sum = 0
pre_sum = 0
rec_sum = 0
for tag in tags:
tp = self._true_positives[tag].get_scalar()
fn = self._false_negatives[tag].get_scalar()
fp = self._false_positives[tag].get_scalar()
tp = _tp[tag]
fn = _fn[tag]
fp = _fp[tag]
if tp == fn == fp == 0:
continue

@@ -313,17 +306,7 @@ class SpanFPreRecMetric(Metric):
evaluate_result['rec'] = rec_sum / len(tags)

if self.f_type == 'micro':
tp, fn, fp = [], [], []
for val in self._true_positives.values():
tp.append(val.get_scalar())
for val in self._false_negatives.values():
fn.append(val.get_scalar())
for val in self._false_positives.values():
fp.append(val.get_scalar())
f, pre, rec = _compute_f_pre_rec(self.beta_square,
sum(tp),
sum(fn),
sum(fp))
f, pre, rec = _compute_f_pre_rec(self.beta_square, sum(_tp.values()), sum(_fn.values()), sum(_fp.values()))
evaluate_result['f'] = f
evaluate_result['pre'] = pre
evaluate_result['rec'] = rec
@@ -372,9 +355,9 @@ class SpanFPreRecMetric(Metric):

for span in pred_spans:
if span in gold_spans:
self._true_positives[span[0]] += 1
self._tp[span[0]] += 1
gold_spans.remove(span)
else:
self._false_positives[span[0]] += 1
self._fp[span[0]] += 1
for span in gold_spans:
self._false_negatives[span[0]] += 1
self._fn[span[0]] += 1

+ 15
- 27
fastNLP/core/metrics/utils.py View File

@@ -1,5 +1,4 @@
__all__ = [
'func_post_proc'
]

from typing import Any
@@ -59,34 +58,23 @@ def _is_paddle_metric(metric: Any) -> bool:
return False


def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric':
"""
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。

:param metric: metric对象
:param fn: 作用于 metric 的 accumulate 方法的返回值
:param method_name: 一般来说,对于
:return: metric
"""
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \
f"Parameter `metric` must have a {method_name} function."
assert callable(fn), "Parameter `fn` must be callable."

func = getattr(metric, method_name)

@wraps(func)
def wrap_method(*args, **kwargs):
res = func(*args, **kwargs)
return fn(res)

wrap_method.__wrapped_by_func_post_proc__ = True
setattr(metric, method_name, wrap_method)
return metric


class AggregateMethodError(BaseException):
def __init__(self, should_have_aggregate_method, only_warn=False):
super(AggregateMethodError, self).__init__(self)
self.should_have_aggregate_method = should_have_aggregate_method
self.only_warn = only_warn


def _compute_f_pre_rec(beta_square, tp, fn, fp):
r"""

:param tp: int, true positive
:param fn: int, false negative
:param fp: int, false positive
:return: (f, pre, rec)
"""
pre = tp / (fp + tp + 1e-13)
rec = tp / (fn + tp + 1e-13)
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13)

return f, pre, rec

+ 3
- 4
tests/core/metrics/test_classify_f1_pre_rec_metric_torch.py View File

@@ -67,7 +67,7 @@ class TestClassfiyFPreRecMetric:
target = torch.tensor([0, 2, 4, 1, 4, 0, 1, 3, 3, 3, 1, 3, 4, 4, 3, 4, 0, 2, 4, 4, 3, 4, 4, 3,
0, 3, 0, 0, 0, 1, 3, 1])

metric = ClassifyFPreRecMetric(f_type='macro', num_class=5)
metric = ClassifyFPreRecMetric(f_type='macro')
metric.update(pred, target)
result_dict = metric.get_metric()
f1_score = 0.1882051282051282
@@ -78,7 +78,7 @@ class TestClassfiyFPreRecMetric:
for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)

metric = ClassifyFPreRecMetric(f_type='micro', num_class=5)
metric = ClassifyFPreRecMetric(f_type='micro')
metric.update(pred, target)
result_dict = metric.get_metric()
f1_score = 0.21875
@@ -89,7 +89,7 @@ class TestClassfiyFPreRecMetric:
for keys in ['f', 'pre', 'rec']:
np.allclose(result_dict[keys], ground_truth[keys], atol=0.000001)

metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro', num_class=5)
metric = ClassifyFPreRecMetric(only_gross=False, f_type='macro')
metric.update(pred, target)
result_dict = metric.get_metric()
ground_truth = {
@@ -157,7 +157,6 @@ class TestClassfiyFPreRecMetric:
})
metric_kwargs = {
'f_type': f_type,
'num_class': 5,
'only_gross': False,
'aggregate_when_get_metric': True
}


+ 2
- 1
tests/core/metrics/test_span_f1_rec_acc_torch.py View File

@@ -102,7 +102,8 @@ class TestSpanFPreRecMetric:
# bio tag
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels))
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False,
aggregate_when_get_metric=True)
bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
-0.3782, 0.8240],
[-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563,


+ 0
- 32
tests/core/metrics/test_utils.py View File

@@ -1,32 +0,0 @@
import unittest
from fastNLP.core.metrics.utils import func_post_proc


class Metric:
def accumulate(self, x, y):
return x, y

def compute(self, x, y):
return x, y


class TestMetricUtil(unittest.TestCase):
def test_func_post_proc(self):
metric = Metric()
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='accumulate')
self.assertDictEqual({'x': 1, 'y': 2}, metric.accumulate(x=1, y=2))

func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='accumulate')
self.assertDictEqual({'1': 1, '2': 2}, metric.accumulate(x=1, y=2))

metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
self.assertDictEqual({'x': 1, 'y': 2}, metric.update(x=1, y=2))

func_post_proc(metric, lambda o: {'1': o['x'], '2': o['y']}, method_name='update')
self.assertDictEqual({'1': 1, '2': 2}, metric.update(x=1, y=2))

def test_check_accumulate_post_special_local_variable(self):
metric = Metric()
self.assertFalse(hasattr(metric, '__wrapped_by_fn__'))
metric = func_post_proc(metric, lambda o: {'x': o[0], 'y': o[1]}, method_name='update')
self.assertTrue(hasattr(metric, '__wrapped_by_fn__'))

Loading…
Cancel
Save