@@ -46,7 +46,7 @@ class MetricBase(object): | |||||
if value is None: | if value is None: | ||||
self.param_map[key] = key | self.param_map[key] = key | ||||
continue | continue | ||||
if isinstance(value, str): | |||||
if not isinstance(value, str): | |||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | self.param_map[key] = value | ||||
value_counter[value].add(key) | value_counter[value].add(key) | ||||
@@ -56,17 +56,22 @@ class MetricBase(object): | |||||
# check consistence between signature and param_map | # check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = func_spect.args | |||||
func_args = [arg for arg in func_spect.args if arg!='self'] | |||||
for func_param, input_param in self.param_map.items(): | for func_param, input_param in self.param_map.items(): | ||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | ||||
f"initialization parameters, or change the signature of" | f"initialization parameters, or change the signature of" | ||||
f" {get_func_signature(self.evaluate)}.") | f" {get_func_signature(self.evaluate)}.") | ||||
# evaluate should not have varargs. | |||||
if func_spect.varargs: | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||||
f"positional argument.).") | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
def __call__(self, output_dict, target_dict, check=False): | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
""" | """ | ||||
This method will call self.evaluate method. | This method will call self.evaluate method. | ||||
@@ -78,7 +83,7 @@ class MetricBase(object): | |||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted) | will be conducted) | ||||
:param output_dict: usually the output of forward or prediction function | |||||
:param pred_dict: usually the output of forward or prediction function | |||||
:param target_dict: usually features set as target.. | :param target_dict: usually features set as target.. | ||||
:param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`. | :param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`. | ||||
:return: | :return: | ||||
@@ -89,46 +94,47 @@ class MetricBase(object): | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = func_spect.args | |||||
for func_param, input_param in self.param_map.items(): | |||||
if func_param not in func_args: | |||||
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}.") | |||||
func_args = set([arg for arg in func_spect.args if arg!='self']) | |||||
for func_arg, input_arg in self.param_map.items(): | |||||
if func_arg not in func_args: | |||||
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.") | |||||
# 2. only part of the param_map are passed, left are not | # 2. only part of the param_map are passed, left are not | ||||
for arg in func_args: | for arg in func_args: | ||||
if arg not in self.param_map: | if arg not in self.param_map: | ||||
self.param_map[arg] = arg #This param does not need mapping. | self.param_map[arg] = arg #This param does not need mapping. | ||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {value: key for key, value in self.param_map.items()} | |||||
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()} | |||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_output_dict = {} | |||||
mapped_pred_dict = {} | |||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
for func_arg in self._evaluate_args: | |||||
input_arg = self.param_map[func_arg] | |||||
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())): | |||||
if input_arg in self._reverse_param_map: | if input_arg in self._reverse_param_map: | ||||
mapped_arg = func_arg | |||||
mapped_arg = self._reverse_param_map[input_arg] | |||||
else: | else: | ||||
mapped_arg = input_arg | mapped_arg = input_arg | ||||
if input_arg in output_dict: | |||||
mapped_output_dict[mapped_arg] = output_dict[input_arg] | |||||
if input_arg in pred_dict: | |||||
mapped_pred_dict[mapped_arg] = pred_dict[input_arg] | |||||
if input_arg in target_dict: | if input_arg in target_dict: | ||||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | mapped_target_dict[mapped_arg] = target_dict[input_arg] | ||||
# check duplicated, unused, missing | # check duplicated, unused, missing | ||||
if check or not self._checked: | if check or not self._checked: | ||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_target_dict]) | |||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | |||||
for key in check_res._fields: | for key in check_res._fields: | ||||
value = getattr(check_res, key) | value = getattr(check_res, key) | ||||
new_value = list(value) | new_value = list(value) | ||||
for idx, func_param in enumerate(value): | |||||
if func_param in self._reverse_param_map: | |||||
new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})' | |||||
# TODO 这里报错的逻辑应该是怎样的? | |||||
for idx, func_arg in enumerate(value): | |||||
if func_arg in self.param_map: | |||||
new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})' | |||||
else: | else: | ||||
new_value[idx] = func_param | |||||
new_value[idx] = func_arg | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | if check_res.missing or check_res.duplicated or check_res.varargs: | ||||
raise CheckError(check_res=check_res, | raise CheckError(check_res=check_res, | ||||
func_signature=get_func_signature(self.evaluate)) | func_signature=get_func_signature(self.evaluate)) | ||||
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict) | |||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | |||||
self.evaluate(**refined_args) | self.evaluate(**refined_args) | ||||
self._checked = True | self._checked = True | ||||
@@ -6,67 +6,123 @@ import torch | |||||
import numpy as np | import numpy as np | ||||
class TestAccuracyMetric(unittest.TestCase): | class TestAccuracyMetric(unittest.TestCase): | ||||
def test_AccuracyMetric1(self): | |||||
# (1) only input, targets passed | |||||
output_dict = {"pred": torch.zeros(4, 3)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
# def test_AccuracyMetric1(self): | |||||
# # (1) only input, targets passed | |||||
# pred_dict = {"pred": torch.zeros(4, 3)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric = AccuracyMetric() | |||||
# | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# print(metric.get_metric()) | |||||
# | |||||
# def test_AccuracyMetric2(self): | |||||
# # (2) with corrupted size | |||||
# try: | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric = AccuracyMetric() | |||||
# | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# print(metric.get_metric()) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
# | |||||
# def test_AccuracyMetric3(self): | |||||
# # (3) with check=False , the second batch is corrupted size | |||||
# try: | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# | |||||
# print(metric.get_metric()) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
# | |||||
# def test_AccuracyMetric4(self): | |||||
# # (4) with check=True , the second batch is corrupted size | |||||
# try: | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# | |||||
# print(metric.get_metric()) | |||||
# | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
# | |||||
# def test_AccuaryMetric5(self): | |||||
# # (5) check reset | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
# | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)+1} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||||
# | |||||
# def test_AccuaryMetric6(self): | |||||
# # (6) check numpy array is not acceptable | |||||
# try: | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": np.zeros((4, 3, 2))} | |||||
# target_dict = {'target': np.zeros((4, 3))} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
# def test_AccuaryMetric7(self): | |||||
# # (7) check map, match | |||||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||||
# pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'targets': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
# | |||||
# def test_AccuaryMetric8(self): | |||||
# # (8) check map, does not match | |||||
# try: | |||||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||||
# pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'targets': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuracyMetric2(self): | |||||
# (2) with corrupted size | |||||
output_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
def test_AccuaryMetric9(self): | |||||
# (9) check map, include unused | |||||
try: | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
def test_AccuracyMetric3(self): | |||||
# (3) with check=False , the second batch is corrupted size | |||||
metric = AccuracyMetric() | |||||
output_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
output_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
def test_AccuracyMetric4(self): | |||||
# (4) with check=True , the second batch is corrupted size | |||||
metric = AccuracyMetric() | |||||
output_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
output_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric(output_dict=output_dict, target_dict=target_dict, check=True) | |||||
print(metric.get_metric()) | |||||
def test_AccuaryMetric5(self): | |||||
# (5) check reset | |||||
metric = AccuracyMetric() | |||||
output_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
output_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)+1} | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||||
def test_AccuaryMetric6(self): | |||||
# (6) check numpy array is not acceptable | |||||
metric = AccuracyMetric() | |||||
output_dict = {"pred": np.zeros((4, 3, 2))} | |||||
target_dict = {'target': np.zeros((4, 3))} | |||||
metric(output_dict=output_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) |