From 8d7d2b428cce4f7b8c8be12ca74810544c56e048 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 2 Dec 2018 14:57:11 +0800 Subject: [PATCH] initial test for AccuracyMetric --- fastNLP/core/metrics.py | 60 ++++++++++++++++++++++++++------------- fastNLP/core/utils.py | 2 +- test/core/test_metrics.py | 17 +++++++++++ 3 files changed, 59 insertions(+), 20 deletions(-) create mode 100644 test/core/test_metrics.py diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 0d83fe44..6b8386c8 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -54,14 +54,32 @@ class MetricBase(object): if len(key_set)>1: raise ValueError(f"Several params:{key_set} are provided with one output {value}.") + # check consistence between signature and param_map + func_spect = inspect.getfullargspec(self.evaluate) + func_args = func_spect.args + for func_param, input_param in self.param_map.items(): + if func_param not in func_args: + raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}. Please check the " + f"initialization params, or change {get_func_signature(self.evaluate)} signature.") + def get_metric(self, reset=True): raise NotImplemented def __call__(self, output_dict, target_dict, check=False): """ - :param output_dict: - :param target_dict: - :param check: boolean, + + This method will call self.evaluate method. + Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict + (1) whether self.evaluate has varargs, which is not supported. + (2) whether params needed by self.evaluate is not included in output_dict,target_dict. + (3) whether params needed by self.evaluate duplicate in output_dict, target_dict + (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) + Besides, before passing params into self.evaluate, this function will filter out params from output_dict and + target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering + will be conducted) + :param output_dict: usually the output of forward or prediction function + :param target_dict: usually features set as target.. + :param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`. :return: """ if not callable(self.evaluate): @@ -73,7 +91,7 @@ class MetricBase(object): func_args = func_spect.args for func_param, input_param in self.param_map.items(): if func_param not in func_args: - raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.") + raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}.") # 2. only part of the param_map are passed, left are not for arg in func_args: if arg not in self.param_map: @@ -97,8 +115,9 @@ class MetricBase(object): # check duplicated, unused, missing if check or not self._checked: - check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) - for key, value in check_res.items(): + check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_target_dict]) + for key in check_res._fields: + value = getattr(check_res, key) new_value = list(value) for idx, func_param in enumerate(value): if func_param in self._reverse_param_map: @@ -115,21 +134,21 @@ class MetricBase(object): class AccuracyMetric(MetricBase): - def __init__(self, input=None, targets=None, masks=None, seq_lens=None): + def __init__(self, input=None, target=None, masks=None, seq_lens=None): super().__init__() - self._init_param_map(input=input, targets=targets, + self._init_param_map(input=input, target=target, masks=masks, seq_lens=seq_lens) self.total = 0 self.acc_count = 0 - def evaluate(self, input, targets, masks=None, seq_lens=None): + def evaluate(self, input, target, masks=None, seq_lens=None): """ :param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) - :param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: + :param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) :param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: None, None, torch.Size([B, max_len], torch.Size([B, max_len]) @@ -140,9 +159,9 @@ class AccuracyMetric(MetricBase): if not isinstance(input, torch.Tensor): raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor," f"got {type(input)}.") - if not isinstance(targets, torch.Tensor): - raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," - f"got {type(targets)}.") + if not isinstance(target, torch.Tensor): + raise NameError(f"`target` in {get_func_signature(self.evaluate())} expects torch.Tensor," + f"got {type(target)}.") if masks is not None and not isinstance(masks, torch.Tensor): raise NameError(f"`masks` in {get_func_signature(self.evaluate())} expects torch.Tensor," @@ -154,20 +173,23 @@ class AccuracyMetric(MetricBase): if masks is None and seq_lens is not None: masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) - if input.size()==targets.size(): + if input.size()==target.size(): pass - elif len(input.size())==len(targets.size())+1: + elif len(input.size())==len(target.size())+1: input = input.argmax(dim=-1) else: raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with " - f"size:{input.size()}, targets should with size: {input.size()} or " - f"{input.size()[:-1]}, got {targets.size()}.") + f"size:{input.size()}, target should with size: {input.size()} or " + f"{input.size()[:-1]}, got {target.size()}.") + + input = input.float() + target = target.float() if masks is not None: - self.acc_count += torch.sum(torch.eq(input, targets).float() * masks.float()).item() + self.acc_count += torch.sum(torch.eq(input, target).float() * masks.float()).item() self.total += torch.sum(masks.float()).item() else: - self.acc_count += torch.sum(torch.eq(input, targets).float()).item() + self.acc_count += torch.sum(torch.eq(input, target).float()).item() self.total += np.prod(list(input.size())) def get_metric(self, reset=True): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 08640d0f..62f60cf7 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -123,7 +123,7 @@ def _check_arg_dict_list(func, args): input_args = set(input_arg_count.keys()) missing = list(require_args - input_args) unused = list(input_args - all_args) - varargs = [] if spect.varargs else [arg for arg in spect.varargs] + varargs = [] if not spect.varargs else [arg for arg in spect.varargs] return CheckRes(missing=missing, unused=unused, duplicated=duplicated, diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py new file mode 100644 index 00000000..b279d7ca --- /dev/null +++ b/test/core/test_metrics.py @@ -0,0 +1,17 @@ + +import unittest + +class TestOptim(unittest.TestCase): + def test_AccuracyMetric(self): + from fastNLP.core.metrics import AccuracyMetric + import torch + import numpy as np + + # (1) only input, targets passed + output_dict = {"input": torch.zeros(4, 3)} + target_dict = {'target': torch.zeros(4)} + metric = AccuracyMetric() + + metric(output_dict=output_dict, target_dict=target_dict) + print(metric.get_metric()) +