From 234ceb6fa3c6eb12372c58c5b8b79530332b4119 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 2 Dec 2018 16:39:28 +0800 Subject: [PATCH] fix bug in MetricBase --- fastNLP/core/metrics.py | 48 +++++----- test/core/test_metrics.py | 178 +++++++++++++++++++++++++------------- 2 files changed, 144 insertions(+), 82 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index ee074feb..595783f7 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -46,7 +46,7 @@ class MetricBase(object): if value is None: self.param_map[key] = key continue - if isinstance(value, str): + if not isinstance(value, str): raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") self.param_map[key] = value value_counter[value].add(key) @@ -56,17 +56,22 @@ class MetricBase(object): # check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) - func_args = func_spect.args + func_args = [arg for arg in func_spect.args if arg!='self'] for func_param, input_param in self.param_map.items(): if func_param not in func_args: raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " f"initialization parameters, or change the signature of" f" {get_func_signature(self.evaluate)}.") + # evaluate should not have varargs. + if func_spect.varargs: + raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " + f"positional argument.).") + def get_metric(self, reset=True): raise NotImplemented - def __call__(self, output_dict, target_dict, check=False): + def __call__(self, pred_dict, target_dict, check=False): """ This method will call self.evaluate method. @@ -78,7 +83,7 @@ class MetricBase(object): Besides, before passing params into self.evaluate, this function will filter out params from output_dict and target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering will be conducted) - :param output_dict: usually the output of forward or prediction function + :param pred_dict: usually the output of forward or prediction function :param target_dict: usually features set as target.. :param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`. :return: @@ -89,46 +94,47 @@ class MetricBase(object): if not self._checked: # 1. check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) - func_args = func_spect.args - for func_param, input_param in self.param_map.items(): - if func_param not in func_args: - raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}.") + func_args = set([arg for arg in func_spect.args if arg!='self']) + for func_arg, input_arg in self.param_map.items(): + if func_arg not in func_args: + raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") + # 2. only part of the param_map are passed, left are not for arg in func_args: if arg not in self.param_map: self.param_map[arg] = arg #This param does not need mapping. self._evaluate_args = func_args - self._reverse_param_map = {value: key for key, value in self.param_map.items()} + self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} # need to wrap inputs in dict. - mapped_output_dict = {} + mapped_pred_dict = {} mapped_target_dict = {} - for func_arg in self._evaluate_args: - input_arg = self.param_map[func_arg] + for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): if input_arg in self._reverse_param_map: - mapped_arg = func_arg + mapped_arg = self._reverse_param_map[input_arg] else: mapped_arg = input_arg - if input_arg in output_dict: - mapped_output_dict[mapped_arg] = output_dict[input_arg] + if input_arg in pred_dict: + mapped_pred_dict[mapped_arg] = pred_dict[input_arg] if input_arg in target_dict: mapped_target_dict[mapped_arg] = target_dict[input_arg] # check duplicated, unused, missing if check or not self._checked: - check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_target_dict]) + check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) for key in check_res._fields: value = getattr(check_res, key) new_value = list(value) - for idx, func_param in enumerate(value): - if func_param in self._reverse_param_map: - new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})' + # TODO 这里报错的逻辑应该是怎样的? + for idx, func_arg in enumerate(value): + if func_arg in self.param_map: + new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})' else: - new_value[idx] = func_param + new_value[idx] = func_arg if check_res.missing or check_res.duplicated or check_res.varargs: raise CheckError(check_res=check_res, func_signature=get_func_signature(self.evaluate)) - refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) + refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) self.evaluate(**refined_args) self._checked = True diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index bad3ebba..c6a8523e 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -6,67 +6,123 @@ import torch import numpy as np class TestAccuracyMetric(unittest.TestCase): - def test_AccuracyMetric1(self): - # (1) only input, targets passed - output_dict = {"pred": torch.zeros(4, 3)} - target_dict = {'target': torch.zeros(4)} - metric = AccuracyMetric() + # def test_AccuracyMetric1(self): + # # (1) only input, targets passed + # pred_dict = {"pred": torch.zeros(4, 3)} + # target_dict = {'target': torch.zeros(4)} + # metric = AccuracyMetric() + # + # metric(pred_dict=pred_dict, target_dict=target_dict) + # print(metric.get_metric()) + # + # def test_AccuracyMetric2(self): + # # (2) with corrupted size + # try: + # pred_dict = {"pred": torch.zeros(4, 3, 2)} + # target_dict = {'target': torch.zeros(4)} + # metric = AccuracyMetric() + # + # metric(pred_dict=pred_dict, target_dict=target_dict) + # print(metric.get_metric()) + # except Exception as e: + # print(e) + # return + # self.assertTrue(True, False), "No exception catches." + # + # def test_AccuracyMetric3(self): + # # (3) with check=False , the second batch is corrupted size + # try: + # metric = AccuracyMetric() + # pred_dict = {"pred": torch.zeros(4, 3, 2)} + # target_dict = {'target': torch.zeros(4, 3)} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # + # pred_dict = {"pred": torch.zeros(4, 3, 2)} + # target_dict = {'target': torch.zeros(4)} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # + # print(metric.get_metric()) + # except Exception as e: + # print(e) + # return + # self.assertTrue(True, False), "No exception catches." + # + # def test_AccuracyMetric4(self): + # # (4) with check=True , the second batch is corrupted size + # try: + # metric = AccuracyMetric() + # pred_dict = {"pred": torch.zeros(4, 3, 2)} + # target_dict = {'target': torch.zeros(4, 3)} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # + # pred_dict = {"pred": torch.zeros(4, 3, 2)} + # target_dict = {'target': torch.zeros(4)} + # metric(pred_dict=pred_dict, target_dict=target_dict, check=True) + # + # print(metric.get_metric()) + # + # except Exception as e: + # print(e) + # return + # self.assertTrue(True, False), "No exception catches." + # + # def test_AccuaryMetric5(self): + # # (5) check reset + # metric = AccuracyMetric() + # pred_dict = {"pred": torch.zeros(4, 3, 2)} + # target_dict = {'target': torch.zeros(4, 3)} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # self.assertDictEqual(metric.get_metric(), {'acc': 1}) + # + # pred_dict = {"pred": torch.zeros(4, 3, 2)} + # target_dict = {'target': torch.zeros(4, 3)+1} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # self.assertDictEqual(metric.get_metric(), {'acc':0}) + # + # def test_AccuaryMetric6(self): + # # (6) check numpy array is not acceptable + # try: + # metric = AccuracyMetric() + # pred_dict = {"pred": np.zeros((4, 3, 2))} + # target_dict = {'target': np.zeros((4, 3))} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # self.assertDictEqual(metric.get_metric(), {'acc': 1}) + # except Exception as e: + # print(e) + # return + # self.assertTrue(True, False), "No exception catches." - metric(output_dict=output_dict, target_dict=target_dict) - print(metric.get_metric()) + # def test_AccuaryMetric7(self): + # # (7) check map, match + # metric = AccuracyMetric(pred='predictions', target='targets') + # pred_dict = {"predictions": torch.zeros(4, 3, 2)} + # target_dict = {'targets': torch.zeros(4, 3)} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # self.assertDictEqual(metric.get_metric(), {'acc': 1}) + # + # def test_AccuaryMetric8(self): + # # (8) check map, does not match + # try: + # metric = AccuracyMetric(pred='predictions', target='targets') + # pred_dict = {"prediction": torch.zeros(4, 3, 2)} + # target_dict = {'targets': torch.zeros(4, 3)} + # metric(pred_dict=pred_dict, target_dict=target_dict) + # self.assertDictEqual(metric.get_metric(), {'acc': 1}) + # except Exception as e: + # print(e) + # return + # self.assertTrue(True, False), "No exception catches." - def test_AccuracyMetric2(self): - # (2) with corrupted size - output_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4)} - metric = AccuracyMetric() + def test_AccuaryMetric9(self): + # (9) check map, include unused + try: + metric = AccuracyMetric(pred='predictions', target='targets') + pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} + target_dict = {'targets': torch.zeros(4, 3)} + metric(pred_dict=pred_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc': 1}) + except Exception as e: + print(e) + return + self.assertTrue(True, False), "No exception catches." - metric(output_dict=output_dict, target_dict=target_dict) - print(metric.get_metric()) - - def test_AccuracyMetric3(self): - # (3) with check=False , the second batch is corrupted size - metric = AccuracyMetric() - output_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3)} - metric(output_dict=output_dict, target_dict=target_dict) - - output_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4)} - metric(output_dict=output_dict, target_dict=target_dict) - - print(metric.get_metric()) - - def test_AccuracyMetric4(self): - # (4) with check=True , the second batch is corrupted size - metric = AccuracyMetric() - output_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3)} - metric(output_dict=output_dict, target_dict=target_dict) - - output_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4)} - metric(output_dict=output_dict, target_dict=target_dict, check=True) - - print(metric.get_metric()) - - def test_AccuaryMetric5(self): - # (5) check reset - metric = AccuracyMetric() - output_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3)} - metric(output_dict=output_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc': 1}) - - output_dict = {"pred": torch.zeros(4, 3, 2)} - target_dict = {'target': torch.zeros(4, 3)+1} - metric(output_dict=output_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc':0}) - - def test_AccuaryMetric6(self): - # (6) check numpy array is not acceptable - metric = AccuracyMetric() - output_dict = {"pred": np.zeros((4, 3, 2))} - target_dict = {'target': np.zeros((4, 3))} - metric(output_dict=output_dict, target_dict=target_dict) - self.assertDictEqual(metric.get_metric(), {'acc': 1}) \ No newline at end of file