@@ -11,7 +11,7 @@ from fastNLP.core.utils import _build_args | |||
from fastNLP.core.utils import _check_arg_dict_list | |||
from fastNLP.core.utils import get_func_signature | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from fastNLP.core.utils import CheckRes | |||
class MetricBase(object): | |||
def __init__(self): | |||
@@ -72,6 +72,17 @@ class MetricBase(object): | |||
def get_metric(self, reset=True): | |||
raise NotImplemented | |||
def _fast_call_evaluate(self, pred_dict, target_dict): | |||
""" | |||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||
such as pred_dict has one element, target_dict has one element | |||
:param pred_dict: | |||
:param target_dict: | |||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||
""" | |||
return False | |||
def __call__(self, pred_dict, target_dict, check=False): | |||
""" | |||
@@ -79,7 +90,7 @@ class MetricBase(object): | |||
Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict | |||
(1) whether self.evaluate has varargs, which is not supported. | |||
(2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||
(3) whether params needed by self.evaluate duplicate in output_dict, target_dict | |||
(3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
@@ -92,6 +103,10 @@ class MetricBase(object): | |||
if not callable(self.evaluate): | |||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
if not check: | |||
if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict): | |||
return | |||
if not self._checked: | |||
# 1. check consistence between signature and param_map | |||
func_spect = inspect.getfullargspec(self.evaluate) | |||
@@ -110,28 +125,40 @@ class MetricBase(object): | |||
# need to wrap inputs in dict. | |||
mapped_pred_dict = {} | |||
mapped_target_dict = {} | |||
duplicated = [] | |||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||
not_duplicate_flag = 0 | |||
if input_arg in self._reverse_param_map: | |||
mapped_arg = self._reverse_param_map[input_arg] | |||
not_duplicate_flag += 1 | |||
else: | |||
mapped_arg = input_arg | |||
if input_arg in pred_dict: | |||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||
not_duplicate_flag += 1 | |||
if input_arg in target_dict: | |||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||
not_duplicate_flag += 1 | |||
if not_duplicate_flag == 3: | |||
duplicated.append(input_arg) | |||
# check duplicated, unused, missing | |||
# missing | |||
if check or not self._checked: | |||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | |||
for key in check_res._fields: | |||
value = getattr(check_res, key) | |||
new_value = list(value) | |||
# TODO 这里报错的逻辑应该是怎样的? | |||
for idx, func_arg in enumerate(value): | |||
if func_arg in self.param_map: | |||
new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})' | |||
else: | |||
new_value[idx] = func_arg | |||
# only check missing. | |||
missing = check_res.missing | |||
replaced_missing = list(missing) | |||
for idx, func_arg in enumerate(missing): | |||
replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \ | |||
f"in `{get_func_signature(self.evaluate)}`)" | |||
check_res = CheckRes(missing=replaced_missing, | |||
unused=check_res.unused, | |||
duplicated=duplicated, | |||
required=check_res.required, | |||
all_needed=check_res.all_needed, | |||
varargs=check_res.varargs) | |||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||
raise CheckError(check_res=check_res, | |||
func_signature=get_func_signature(self.evaluate)) | |||
@@ -140,6 +167,7 @@ class MetricBase(object): | |||
self.evaluate(**refined_args) | |||
self._checked = True | |||
return | |||
class AccuracyMetric(MetricBase): | |||
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | |||
@@ -151,6 +179,22 @@ class AccuracyMetric(MetricBase): | |||
self.total = 0 | |||
self.acc_count = 0 | |||
def _fast_call_evaluate(self, pred_dict, target_dict): | |||
""" | |||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||
such as pred_dict has one element, target_dict has one element | |||
:param pred_dict: | |||
:param target_dict: | |||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||
""" | |||
if len(pred_dict)==1 and len(target_dict)==1: | |||
pred = list(pred_dict.values())[0] | |||
target = list(target_dict.values())[0] | |||
self.evaluate(pred=pred, target=target) | |||
return True | |||
return False | |||
def evaluate(self, pred, target, masks=None, seq_lens=None): | |||
""" | |||
@@ -164,6 +208,7 @@ class AccuracyMetric(MetricBase): | |||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||
:return: dict({'acc': float}) | |||
""" | |||
#TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||
if not isinstance(pred, torch.Tensor): | |||
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(pred)}.") | |||
@@ -12,7 +12,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# target_dict = {'target': torch.zeros(4)} | |||
# metric = AccuracyMetric() | |||
# | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# print(metric.get_metric()) | |||
# | |||
# def test_AccuracyMetric2(self): | |||
@@ -22,7 +22,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# target_dict = {'target': torch.zeros(4)} | |||
# metric = AccuracyMetric() | |||
# | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# print(metric.get_metric()) | |||
# except Exception as e: | |||
# print(e) | |||
@@ -35,11 +35,11 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# metric = AccuracyMetric() | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# | |||
# print(metric.get_metric()) | |||
# except Exception as e: | |||
@@ -76,7 +76,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# | |||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
# target_dict = {'target': torch.zeros(4, 3)+1} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||
# | |||
# def test_AccuaryMetric6(self): | |||
@@ -85,7 +85,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# metric = AccuracyMetric() | |||
# pred_dict = {"pred": np.zeros((4, 3, 2))} | |||
# target_dict = {'target': np.zeros((4, 3))} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
# except Exception as e: | |||
# print(e) | |||
@@ -97,7 +97,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||
# pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
# target_dict = {'targets': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
# | |||
# def test_AccuaryMetric8(self): | |||
@@ -106,6 +106,19 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||
# pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||
# target_dict = {'targets': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
# except Exception as e: | |||
# print(e) | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
# def test_AccuaryMetric9(self): | |||
# # (9) check map, include unused | |||
# try: | |||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||
# pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||
# target_dict = {'targets': torch.zeros(4, 3)} | |||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
# except Exception as e: | |||
@@ -113,11 +126,11 @@ class TestAccuracyMetric(unittest.TestCase): | |||
# return | |||
# self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric9(self): | |||
# (9) check map, include unused | |||
def test_AccuaryMetric10(self): | |||
# (10) check _fast_metric | |||
try: | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||
metric = AccuracyMetric() | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
@@ -125,4 +138,3 @@ class TestAccuracyMetric(unittest.TestCase): | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||