diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index f823cc52..6401d731 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -11,7 +11,7 @@ from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import seq_lens_to_masks - +from fastNLP.core.utils import CheckRes class MetricBase(object): def __init__(self): @@ -72,6 +72,17 @@ class MetricBase(object): def get_metric(self, reset=True): raise NotImplemented + def _fast_call_evaluate(self, pred_dict, target_dict): + """ + + Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. + such as pred_dict has one element, target_dict has one element + :param pred_dict: + :param target_dict: + :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. + """ + return False + def __call__(self, pred_dict, target_dict, check=False): """ @@ -79,7 +90,7 @@ class MetricBase(object): Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict (1) whether self.evaluate has varargs, which is not supported. (2) whether params needed by self.evaluate is not included in output_dict,target_dict. - (3) whether params needed by self.evaluate duplicate in output_dict, target_dict + (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) Besides, before passing params into self.evaluate, this function will filter out params from output_dict and target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering @@ -92,6 +103,10 @@ class MetricBase(object): if not callable(self.evaluate): raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") + if not check: + if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict): + return + if not self._checked: # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) @@ -110,28 +125,40 @@ class MetricBase(object): # need to wrap inputs in dict. mapped_pred_dict = {} mapped_target_dict = {} + duplicated = [] for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): + not_duplicate_flag = 0 if input_arg in self._reverse_param_map: mapped_arg = self._reverse_param_map[input_arg] + not_duplicate_flag += 1 else: mapped_arg = input_arg if input_arg in pred_dict: mapped_pred_dict[mapped_arg] = pred_dict[input_arg] + not_duplicate_flag += 1 if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] + not_duplicate_flag += 1 + if not_duplicate_flag == 3: + duplicated.append(input_arg) - # check duplicated, unused, missing + # missing if check or not self._checked: check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) - for key in check_res._fields: - value = getattr(check_res, key) - new_value = list(value) - # TODO 这里报错的逻辑应该是怎样的? - for idx, func_arg in enumerate(value): - if func_arg in self.param_map: - new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})' - else: - new_value[idx] = func_arg + # only check missing. + missing = check_res.missing + replaced_missing = list(missing) + for idx, func_arg in enumerate(missing): + replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \ + f"in `{get_func_signature(self.evaluate)}`)" + + check_res = CheckRes(missing=replaced_missing, + unused=check_res.unused, + duplicated=duplicated, + required=check_res.required, + all_needed=check_res.all_needed, + varargs=check_res.varargs) + if check_res.missing or check_res.duplicated or check_res.varargs: raise CheckError(check_res=check_res, func_signature=get_func_signature(self.evaluate)) @@ -140,6 +167,7 @@ class MetricBase(object): self.evaluate(**refined_args) self._checked = True + return class AccuracyMetric(MetricBase): def __init__(self, pred=None, target=None, masks=None, seq_lens=None): @@ -151,6 +179,22 @@ class AccuracyMetric(MetricBase): self.total = 0 self.acc_count = 0 + def _fast_call_evaluate(self, pred_dict, target_dict): + """ + + Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. + such as pred_dict has one element, target_dict has one element + :param pred_dict: + :param target_dict: + :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. + """ + if len(pred_dict)==1 and len(target_dict)==1: + pred = list(pred_dict.values())[0] + target = list(target_dict.values())[0] + self.evaluate(pred=pred, target=target) + return True + return False + def evaluate(self, pred, target, masks=None, seq_lens=None): """ @@ -164,6 +208,7 @@ class AccuracyMetric(MetricBase): None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. :return: dict({'acc': float}) """ + #TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value if not isinstance(pred, torch.Tensor): raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(pred)}.") diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index c6a8523e..ffc11401 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -12,7 +12,7 @@ class TestAccuracyMetric(unittest.TestCase): # target_dict = {'target': torch.zeros(4)} # metric = AccuracyMetric() # - # metric(pred_dict=pred_dict, target_dict=target_dict) + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) # print(metric.get_metric()) # # def test_AccuracyMetric2(self): @@ -22,7 +22,7 @@ class TestAccuracyMetric(unittest.TestCase): # target_dict = {'target': torch.zeros(4)} # metric = AccuracyMetric() # - # metric(pred_dict=pred_dict, target_dict=target_dict) + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) # print(metric.get_metric()) # except Exception as e: # print(e) @@ -35,11 +35,11 @@ class TestAccuracyMetric(unittest.TestCase): # metric = AccuracyMetric() # pred_dict = {"pred": torch.zeros(4, 3, 2)} # target_dict = {'target': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict) + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) # # pred_dict = {"pred": torch.zeros(4, 3, 2)} # target_dict = {'target': torch.zeros(4)} - # metric(pred_dict=pred_dict, target_dict=target_dict) + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) # # print(metric.get_metric()) # except Exception as e: @@ -76,7 +76,7 @@ class TestAccuracyMetric(unittest.TestCase): # # pred_dict = {"pred": torch.zeros(4, 3, 2)} # target_dict = {'target': torch.zeros(4, 3)+1} - # metric(pred_dict=pred_dict, target_dict=target_dict) + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) # self.assertDictEqual(metric.get_metric(), {'acc':0}) # # def test_AccuaryMetric6(self): @@ -85,7 +85,7 @@ class TestAccuracyMetric(unittest.TestCase): # metric = AccuracyMetric() # pred_dict = {"pred": np.zeros((4, 3, 2))} # target_dict = {'target': np.zeros((4, 3))} - # metric(pred_dict=pred_dict, target_dict=target_dict) + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) # self.assertDictEqual(metric.get_metric(), {'acc': 1}) # except Exception as e: # print(e) @@ -97,7 +97,7 @@ class TestAccuracyMetric(unittest.TestCase): # metric = AccuracyMetric(pred='predictions', target='targets') # pred_dict = {"predictions": torch.zeros(4, 3, 2)} # target_dict = {'targets': torch.zeros(4, 3)} - # metric(pred_dict=pred_dict, target_dict=target_dict) + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) # self.assertDictEqual(metric.get_metric(), {'acc': 1}) # # def test_AccuaryMetric8(self): @@ -106,6 +106,19 @@ class TestAccuracyMetric(unittest.TestCase): # metric = AccuracyMetric(pred='predictions', target='targets') # pred_dict = {"prediction": torch.zeros(4, 3, 2)} # target_dict = {'targets': torch.zeros(4, 3)} + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) + # self.assertDictEqual(metric.get_metric(), {'acc': 1}) + # except Exception as e: + # print(e) + # return + # self.assertTrue(True, False), "No exception catches." + + # def test_AccuaryMetric9(self): + # # (9) check map, include unused + # try: + # metric = AccuracyMetric(pred='predictions', target='targets') + # pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} + # target_dict = {'targets': torch.zeros(4, 3)} # metric(pred_dict=pred_dict, target_dict=target_dict) # self.assertDictEqual(metric.get_metric(), {'acc': 1}) # except Exception as e: @@ -113,11 +126,11 @@ class TestAccuracyMetric(unittest.TestCase): # return # self.assertTrue(True, False), "No exception catches." - def test_AccuaryMetric9(self): - # (9) check map, include unused + def test_AccuaryMetric10(self): + # (10) check _fast_metric try: - metric = AccuracyMetric(pred='predictions', target='targets') - pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} + metric = AccuracyMetric() + pred_dict = {"predictions": torch.zeros(4, 3, 2)} target_dict = {'targets': torch.zeros(4, 3)} metric(pred_dict=pred_dict, target_dict=target_dict) self.assertDictEqual(metric.get_metric(), {'acc': 1}) @@ -125,4 +138,3 @@ class TestAccuracyMetric(unittest.TestCase): print(e) return self.assertTrue(True, False), "No exception catches." -