Browse Source

fix bug in MetricBase

tags/v0.2.0^2
yh 5 years ago
parent
commit
234ceb6fa3
2 changed files with 144 additions and 82 deletions
  1. +27
    -21
      fastNLP/core/metrics.py
  2. +117
    -61
      test/core/test_metrics.py

+ 27
- 21
fastNLP/core/metrics.py View File

@@ -46,7 +46,7 @@ class MetricBase(object):
if value is None:
self.param_map[key] = key
continue
if isinstance(value, str):
if not isinstance(value, str):
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.")
self.param_map[key] = value
value_counter[value].add(key)
@@ -56,17 +56,22 @@ class MetricBase(object):

# check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
func_args = [arg for arg in func_spect.args if arg!='self']
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")

# evaluate should not have varargs.
if func_spect.varargs:
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use "
f"positional argument.).")

def get_metric(self, reset=True):
raise NotImplemented

def __call__(self, output_dict, target_dict, check=False):
def __call__(self, pred_dict, target_dict, check=False):
"""

This method will call self.evaluate method.
@@ -78,7 +83,7 @@ class MetricBase(object):
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted)
:param output_dict: usually the output of forward or prediction function
:param pred_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target..
:param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`.
:return:
@@ -89,46 +94,47 @@ class MetricBase(object):
if not self._checked:
# 1. check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}.")
func_args = set([arg for arg in func_spect.args if arg!='self'])
for func_arg, input_arg in self.param_map.items():
if func_arg not in func_args:
raise NameError(f"`{func_arg}` not in {get_func_signature(self.evaluate)}.")

# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
self.param_map[arg] = arg #This param does not need mapping.
self._evaluate_args = func_args
self._reverse_param_map = {value: key for key, value in self.param_map.items()}
self._reverse_param_map = {input_arg: func_arg for func_arg, input_arg in self.param_map.items()}

# need to wrap inputs in dict.
mapped_output_dict = {}
mapped_pred_dict = {}
mapped_target_dict = {}
for func_arg in self._evaluate_args:
input_arg = self.param_map[func_arg]
for input_arg in set(list(pred_dict.keys()) + list(target_dict.keys())):
if input_arg in self._reverse_param_map:
mapped_arg = func_arg
mapped_arg = self._reverse_param_map[input_arg]
else:
mapped_arg = input_arg
if input_arg in output_dict:
mapped_output_dict[mapped_arg] = output_dict[input_arg]
if input_arg in pred_dict:
mapped_pred_dict[mapped_arg] = pred_dict[input_arg]
if input_arg in target_dict:
mapped_target_dict[mapped_arg] = target_dict[input_arg]

# check duplicated, unused, missing
if check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_target_dict])
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict])
for key in check_res._fields:
value = getattr(check_res, key)
new_value = list(value)
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})'
# TODO 这里报错的逻辑应该是怎样的?
for idx, func_arg in enumerate(value):
if func_arg in self.param_map:
new_value[idx] = self.param_map[func_arg] + f'(try to get value from {self.param_map[func_arg]})'
else:
new_value[idx] = func_param
new_value[idx] = func_arg
if check_res.missing or check_res.duplicated or check_res.varargs:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_output_dict, **mapped_target_dict)
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)

self.evaluate(**refined_args)
self._checked = True


+ 117
- 61
test/core/test_metrics.py View File

@@ -6,67 +6,123 @@ import torch
import numpy as np

class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
# (1) only input, targets passed
output_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()
# def test_AccuracyMetric1(self):
# # (1) only input, targets passed
# pred_dict = {"pred": torch.zeros(4, 3)}
# target_dict = {'target': torch.zeros(4)}
# metric = AccuracyMetric()
#
# metric(pred_dict=pred_dict, target_dict=target_dict)
# print(metric.get_metric())
#
# def test_AccuracyMetric2(self):
# # (2) with corrupted size
# try:
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4)}
# metric = AccuracyMetric()
#
# metric(pred_dict=pred_dict, target_dict=target_dict)
# print(metric.get_metric())
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."
#
# def test_AccuracyMetric3(self):
# # (3) with check=False , the second batch is corrupted size
# try:
# metric = AccuracyMetric()
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
#
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
#
# print(metric.get_metric())
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."
#
# def test_AccuracyMetric4(self):
# # (4) with check=True , the second batch is corrupted size
# try:
# metric = AccuracyMetric()
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
#
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4)}
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True)
#
# print(metric.get_metric())
#
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."
#
# def test_AccuaryMetric5(self):
# # (5) check reset
# metric = AccuracyMetric()
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
#
# pred_dict = {"pred": torch.zeros(4, 3, 2)}
# target_dict = {'target': torch.zeros(4, 3)+1}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc':0})
#
# def test_AccuaryMetric6(self):
# # (6) check numpy array is not acceptable
# try:
# metric = AccuracyMetric()
# pred_dict = {"pred": np.zeros((4, 3, 2))}
# target_dict = {'target': np.zeros((4, 3))}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."

metric(output_dict=output_dict, target_dict=target_dict)
print(metric.get_metric())
# def test_AccuaryMetric7(self):
# # (7) check map, match
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"predictions": torch.zeros(4, 3, 2)}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
#
# def test_AccuaryMetric8(self):
# # (8) check map, does not match
# try:
# metric = AccuracyMetric(pred='predictions', target='targets')
# pred_dict = {"prediction": torch.zeros(4, 3, 2)}
# target_dict = {'targets': torch.zeros(4, 3)}
# metric(pred_dict=pred_dict, target_dict=target_dict)
# self.assertDictEqual(metric.get_metric(), {'acc': 1})
# except Exception as e:
# print(e)
# return
# self.assertTrue(True, False), "No exception catches."

def test_AccuracyMetric2(self):
# (2) with corrupted size
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()
def test_AccuaryMetric9(self):
# (9) check map, include unused
try:
metric = AccuracyMetric(pred='predictions', target='targets')
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1}
target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})
except Exception as e:
print(e)
return
self.assertTrue(True, False), "No exception catches."

metric(output_dict=output_dict, target_dict=target_dict)
print(metric.get_metric())

def test_AccuracyMetric3(self):
# (3) with check=False , the second batch is corrupted size
metric = AccuracyMetric()
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(output_dict=output_dict, target_dict=target_dict)

output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(output_dict=output_dict, target_dict=target_dict)

print(metric.get_metric())

def test_AccuracyMetric4(self):
# (4) with check=True , the second batch is corrupted size
metric = AccuracyMetric()
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(output_dict=output_dict, target_dict=target_dict)

output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(output_dict=output_dict, target_dict=target_dict, check=True)

print(metric.get_metric())

def test_AccuaryMetric5(self):
# (5) check reset
metric = AccuracyMetric()
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(output_dict=output_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)+1}
metric(output_dict=output_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc':0})

def test_AccuaryMetric6(self):
# (6) check numpy array is not acceptable
metric = AccuracyMetric()
output_dict = {"pred": np.zeros((4, 3, 2))}
target_dict = {'target': np.zeros((4, 3))}
metric(output_dict=output_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

Loading…
Cancel
Save