| @@ -11,7 +11,7 @@ from fastNLP.core.utils import _build_args | |||
| from fastNLP.core.utils import _check_arg_dict_list | |||
| from fastNLP.core.utils import get_func_signature | |||
| from fastNLP.core.utils import seq_lens_to_masks | |||
| from fastNLP.core.utils import CheckRes | |||
| class MetricBase(object): | |||
| def __init__(self): | |||
| @@ -72,6 +72,17 @@ class MetricBase(object): | |||
| def get_metric(self, reset=True): | |||
| raise NotImplemented | |||
| def _fast_call_evaluate(self, pred_dict, target_dict): | |||
| """ | |||
| Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||
| such as pred_dict has one element, target_dict has one element | |||
| :param pred_dict: | |||
| :param target_dict: | |||
| :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||
| """ | |||
| return False | |||
| def __call__(self, pred_dict, target_dict, check=False): | |||
| """ | |||
| @@ -79,7 +90,7 @@ class MetricBase(object): | |||
| Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict | |||
| (1) whether self.evaluate has varargs, which is not supported. | |||
| (2) whether params needed by self.evaluate is not included in output_dict,target_dict. | |||
| (3) whether params needed by self.evaluate duplicate in output_dict, target_dict | |||
| (3) whether params needed by self.evaluate duplicate in pred_dict, target_dict | |||
| (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | |||
| Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | |||
| target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | |||
| @@ -92,6 +103,10 @@ class MetricBase(object): | |||
| if not callable(self.evaluate): | |||
| raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | |||
| if not check: | |||
| if self._fast_call_evaluate(pred_dict=pred_dict, target_dict=target_dict): | |||
| return | |||
| if not self._checked: | |||
| # 1. check consistence between signature and param_map | |||
| func_spect = inspect.getfullargspec(self.evaluate) | |||
| @@ -110,28 +125,40 @@ class MetricBase(object): | |||
| # need to wrap inputs in dict. | |||
| mapped_pred_dict = {} | |||
| mapped_target_dict = {} | |||
| duplicated = [] | |||
| for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||
| not_duplicate_flag = 0 | |||
| if input_arg in self._reverse_param_map: | |||
| mapped_arg = self._reverse_param_map[input_arg] | |||
| not_duplicate_flag += 1 | |||
| else: | |||
| mapped_arg = input_arg | |||
| if input_arg in pred_dict: | |||
| mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||
| not_duplicate_flag += 1 | |||
| if input_arg in target_dict: | |||
| mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||
| not_duplicate_flag += 1 | |||
| if not_duplicate_flag == 3: | |||
| duplicated.append(input_arg) | |||
| # check duplicated, unused, missing | |||
| # missing | |||
| if check or not self._checked: | |||
| check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | |||
| for key in check_res._fields: | |||
| value = getattr(check_res, key) | |||
| new_value = list(value) | |||
| # TODO 这里报错的逻辑应该是怎样的? | |||
| for idx, func_arg in enumerate(value): | |||
| if func_arg in self.param_map: | |||
| new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})' | |||
| else: | |||
| new_value[idx] = func_arg | |||
| # only check missing. | |||
| missing = check_res.missing | |||
| replaced_missing = list(missing) | |||
| for idx, func_arg in enumerate(missing): | |||
| replaced_missing[idx] = f"`{self.param_map[func_arg]}`" + f"(assign to `{func_arg}` " \ | |||
| f"in `{get_func_signature(self.evaluate)}`)" | |||
| check_res = CheckRes(missing=replaced_missing, | |||
| unused=check_res.unused, | |||
| duplicated=duplicated, | |||
| required=check_res.required, | |||
| all_needed=check_res.all_needed, | |||
| varargs=check_res.varargs) | |||
| if check_res.missing or check_res.duplicated or check_res.varargs: | |||
| raise CheckError(check_res=check_res, | |||
| func_signature=get_func_signature(self.evaluate)) | |||
| @@ -140,6 +167,7 @@ class MetricBase(object): | |||
| self.evaluate(**refined_args) | |||
| self._checked = True | |||
| return | |||
| class AccuracyMetric(MetricBase): | |||
| def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | |||
| @@ -151,6 +179,22 @@ class AccuracyMetric(MetricBase): | |||
| self.total = 0 | |||
| self.acc_count = 0 | |||
| def _fast_call_evaluate(self, pred_dict, target_dict): | |||
| """ | |||
| Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | |||
| such as pred_dict has one element, target_dict has one element | |||
| :param pred_dict: | |||
| :param target_dict: | |||
| :return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||
| """ | |||
| if len(pred_dict)==1 and len(target_dict)==1: | |||
| pred = list(pred_dict.values())[0] | |||
| target = list(target_dict.values())[0] | |||
| self.evaluate(pred=pred, target=target) | |||
| return True | |||
| return False | |||
| def evaluate(self, pred, target, masks=None, seq_lens=None): | |||
| """ | |||
| @@ -164,6 +208,7 @@ class AccuracyMetric(MetricBase): | |||
| None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | |||
| :return: dict({'acc': float}) | |||
| """ | |||
| #TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value | |||
| if not isinstance(pred, torch.Tensor): | |||
| raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||
| f"got {type(pred)}.") | |||
| @@ -12,7 +12,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # target_dict = {'target': torch.zeros(4)} | |||
| # metric = AccuracyMetric() | |||
| # | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # print(metric.get_metric()) | |||
| # | |||
| # def test_AccuracyMetric2(self): | |||
| @@ -22,7 +22,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # target_dict = {'target': torch.zeros(4)} | |||
| # metric = AccuracyMetric() | |||
| # | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # print(metric.get_metric()) | |||
| # except Exception as e: | |||
| # print(e) | |||
| @@ -35,11 +35,11 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # metric = AccuracyMetric() | |||
| # pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| # target_dict = {'target': torch.zeros(4, 3)} | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # | |||
| # pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| # target_dict = {'target': torch.zeros(4)} | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # | |||
| # print(metric.get_metric()) | |||
| # except Exception as e: | |||
| @@ -76,7 +76,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # | |||
| # pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
| # target_dict = {'target': torch.zeros(4, 3)+1} | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||
| # | |||
| # def test_AccuaryMetric6(self): | |||
| @@ -85,7 +85,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # metric = AccuracyMetric() | |||
| # pred_dict = {"pred": np.zeros((4, 3, 2))} | |||
| # target_dict = {'target': np.zeros((4, 3))} | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| # except Exception as e: | |||
| # print(e) | |||
| @@ -97,7 +97,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # metric = AccuracyMetric(pred='predictions', target='targets') | |||
| # pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
| # target_dict = {'targets': torch.zeros(4, 3)} | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| # | |||
| # def test_AccuaryMetric8(self): | |||
| @@ -106,6 +106,19 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # metric = AccuracyMetric(pred='predictions', target='targets') | |||
| # pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||
| # target_dict = {'targets': torch.zeros(4, 3)} | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||
| # self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| # except Exception as e: | |||
| # print(e) | |||
| # return | |||
| # self.assertTrue(True, False), "No exception catches." | |||
| # def test_AccuaryMetric9(self): | |||
| # # (9) check map, include unused | |||
| # try: | |||
| # metric = AccuracyMetric(pred='predictions', target='targets') | |||
| # pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||
| # target_dict = {'targets': torch.zeros(4, 3)} | |||
| # metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| # self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| # except Exception as e: | |||
| @@ -113,11 +126,11 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| # return | |||
| # self.assertTrue(True, False), "No exception catches." | |||
| def test_AccuaryMetric9(self): | |||
| # (9) check map, include unused | |||
| def test_AccuaryMetric10(self): | |||
| # (10) check _fast_metric | |||
| try: | |||
| metric = AccuracyMetric(pred='predictions', target='targets') | |||
| pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||
| metric = AccuracyMetric() | |||
| pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||
| target_dict = {'targets': torch.zeros(4, 3)} | |||
| metric(pred_dict=pred_dict, target_dict=target_dict) | |||
| self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
| @@ -125,4 +138,3 @@ class TestAccuracyMetric(unittest.TestCase): | |||
| print(e) | |||
| return | |||
| self.assertTrue(True, False), "No exception catches." | |||