Browse Source

* add _fast_call_evaluate mechanism in MetricBase

tags/v0.2.0^2
yh 5 years ago
parent
commit
d19850b397
2 changed files with 81 additions and 24 deletions
  1. +57
    -12
      fastNLP/core/metrics.py
  2. +24
    -12
      test/core/test_metrics.py

+ 57
- 12
fastNLP/core/metrics.py View File

@@ -11,7 +11,7 @@ from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import seq_lens_to_masks
from fastNLP.core.utils import CheckRes

class MetricBase(object):
def __init__(self):
@@ -72,6 +72,17 @@ class MetricBase(object):
def get_metric(self, reset=True):
raise NotImplemented

def _fast_call_evaluate(self, pred_dict, target_dict):
"""

Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
"""
return False

def __call__(self, pred_dict, target_dict, check=False):
"""

@@ -79,7 +90,7 @@ class MetricBase(object):
Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict
(1) whether self.evaluate has varargs, which is not supported.
(2) whether params needed by self.evaluate is not included in output_dict,target_dict.
(3) whether params needed by self.evaluate duplicate in output_dict, target_dict
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
@@ -92,6 +103,10 @@ class MetricBase(object):
if not callable(self.evaluate):
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.")

if not check:
if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict):
return

if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
@@ -110,28 +125,40 @@ class MetricBase(object):
# need to wrap inputs in dict.
mapped_pred_dict = {}
mapped_target_dict = {}
duplicated = []
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
not_duplicate_flag = 0
if input_arg in self._reverse_param_map:
mapped_arg = self._reverse_param_map[input_arg]
not_duplicate_flag += 1
else:
mapped_arg = input_arg
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
not_duplicate_flag += 1
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]
not_duplicate_flag += 1
if not_duplicate_flag == 3:
duplicated.append(input_arg)

# check duplicated, unused, missing
# missing
if check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
for key in check_res._fields:
value = getattr(check_res, key)
new_value = list(value)
# TODO 这里报错的逻辑应该是怎样的?
for idx, func_arg in enumerate(value):
if func_arg in self.param_map:
new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})'
else:
new_value[idx] = func_arg
# only check missing.
missing = check_res.missing
replaced_missing = list(missing)
for idx, func_arg in enumerate(missing):
replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \
f"in `{get_func_signature(self.evaluate)}`)"

check_res = CheckRes(missing=replaced_missing,
unused=check_res.unused,
duplicated=duplicated,
required=check_res.required,
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
@@ -140,6 +167,7 @@ class MetricBase(object):
self.evaluate(**refined_args)
self._checked = True

return

class AccuracyMetric(MetricBase):
def __init__(self, pred=None, target=None, masks=None, seq_lens=None):
@@ -151,6 +179,22 @@ class AccuracyMetric(MetricBase):
self.total = 0
self.acc_count = 0

def _fast_call_evaluate(self, pred_dict, target_dict):
"""

Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element
:param pred_dict:
:param target_dict:
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on.
"""
if len(pred_dict)==1 and len(target_dict)==1:
pred = list(pred_dict.values())[0]
target = list(target_dict.values())[0]
self.evaluate(pred=pred, target=target)
return True
return False

def evaluate(self, pred, target, masks=None, seq_lens=None):
"""

@@ -164,6 +208,7 @@ class AccuracyMetric(MetricBase):
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float})
"""
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")


+ 24
- 12
test/core/test_metrics.py View File

@@ -12,7 +12,7 @@ class TestAccuracyMetric(unittest.TestCase):
# target_dict = {'target': torch.zeros(4)}
# metric = AccuracyMetric()
#
# metric(pred_dict=pred_dict, target_dict=target_dict)
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# print(metric.get_metric())
#
# def test_AccuracyMetric2(self):
@@ -22,7 +22,7 @@ class TestAccuracyMetric(unittest.TestCase):
# target_dict = {'target': torch.zeros(4)}
# metric = AccuracyMetric()
#
# metric(pred_dict=pred_dict, target_dict=target_dict)
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# print(metric.get_metric())
# except Exception as e:
# print(e)
@@ -35,11 +35,11 @@ class TestAccuracyMetric(unittest.TestCase):
# metric = AccuracyMetric()
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
#
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
#
# print(metric.get_metric())
# except Exception as e:
@@ -76,7 +76,7 @@ class TestAccuracyMetric(unittest.TestCase):
#
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)+1}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc':0})
#
# def test_AccuaryMetric6(self):
@@ -85,7 +85,7 @@ class TestAccuracyMetric(unittest.TestCase):
# metric = AccuracyMetric()
# pred_dict = {"pred": np.zeros((4, 3, 2))}
# target_dict = {'target': np.zeros((4, 3))}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
# print(e)
@@ -97,7 +97,7 @@ class TestAccuracyMetric(unittest.TestCase):
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"predictions": torch.zeros(4, 3, 2)}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
#
# def test_AccuaryMetric8(self):
@@ -106,6 +106,19 @@ class TestAccuracyMetric(unittest.TestCase):
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"prediction": torch.zeros(4, 3, 2)}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."

# def test_AccuaryMetric9(self):
# # (9) check map, include unused
# try:
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
@@ -113,11 +126,11 @@ class TestAccuracyMetric(unittest.TestCase):
# return
# self.assertTrue(True, False), "No exception catches."

def test_AccuaryMetric9(self):
# (9) check map, include unused
def test_AccuaryMetric10(self):
# (10) check _fast_metric
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1}
metric = AccuracyMetric()
pred_dict = {"predictions": torch.zeros(4, 3, 2)}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})
@@ -125,4 +138,3 @@ class TestAccuracyMetric(unittest.TestCase):
print(e)
return
self.assertTrue(True, False), "No exception catches."


Loading…
Cancel
Save