Browse Source

initial test for AccuracyMetric

tags/v0.2.0^2
yh 7 years ago
parent
commit
8d7d2b428c
3 changed files with 59 additions and 20 deletions
  1. +41
    -19
      fastNLP/core/metrics.py
  2. +1
    -1
      fastNLP/core/utils.py
  3. +17
    -0
      test/core/test_metrics.py

+ 41
- 19
fastNLP/core/metrics.py View File

@@ -54,14 +54,32 @@ class MetricBase(object):
if len(key_set)>1:
raise ValueError(f"Several params:{key_set} are provided with one output {value}.")

# check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization params, or change {get_func_signature(self.evaluate)} signature.")

def get_metric(self, reset=True):
raise NotImplemented

def __call__(self, output_dict, target_dict, check=False):
"""
:param output_dict:
:param target_dict:
:param check: boolean,

This method will call self.evaluate method.
Before calling self.evaluate, it will first check the validity ofoutput_dict, target_dict
(1) whether self.evaluate has varargs, which is not supported.
(2) whether params needed by self.evaluate is not included in output_dict,target_dict.
(3) whether params needed by self.evaluate duplicate in output_dict, target_dict
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning)
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering
will be conducted)
:param output_dict: usually the output of forward or prediction function
:param target_dict: usually features set as target..
:param check: boolean, if check is True, it will force check `varargs, missing, unsed, duplicated`.
:return:
"""
if not callable(self.evaluate):
@@ -73,7 +91,7 @@ class MetricBase(object):
func_args = func_spect.args
for func_param, input_param in self.param_map.items():
if func_param not in func_args:
raise NameError(f"{func_param} not in {get_func_signature(self.evaluate)}.")
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}.")
# 2. only part of the param_map are passed, left are not
for arg in func_args:
if arg not in self.param_map:
@@ -97,8 +115,9 @@ class MetricBase(object):

# check duplicated, unused, missing
if check or not self._checked:
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict])
for key, value in check_res.items():
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_target_dict])
for key in check_res._fields:
value = getattr(check_res, key)
new_value = list(value)
for idx, func_param in enumerate(value):
if func_param in self._reverse_param_map:
@@ -115,21 +134,21 @@ class MetricBase(object):


class AccuracyMetric(MetricBase):
def __init__(self, input=None, targets=None, masks=None, seq_lens=None):
def __init__(self, input=None, target=None, masks=None, seq_lens=None):
super().__init__()

self._init_param_map(input=input, targets=targets,
self._init_param_map(input=input, target=target,
masks=masks, seq_lens=seq_lens)

self.total = 0
self.acc_count = 0

def evaluate(self, input, targets, masks=None, seq_lens=None):
def evaluate(self, input, target, masks=None, seq_lens=None):
"""

:param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes])
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be:
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len])
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B, max_len], torch.Size([B, max_len])
@@ -140,9 +159,9 @@ class AccuracyMetric(MetricBase):
if not isinstance(input, torch.Tensor):
raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
f"got {type(input)}.")
if not isinstance(targets, torch.Tensor):
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
f"got {type(targets)}.")
if not isinstance(target, torch.Tensor):
raise NameError(f"`target` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
f"got {type(target)}.")

if masks is not None and not isinstance(masks, torch.Tensor):
raise NameError(f"`masks` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
@@ -154,20 +173,23 @@ class AccuracyMetric(MetricBase):
if masks is None and seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)

if input.size()==targets.size():
if input.size()==target.size():
pass
elif len(input.size())==len(targets.size())+1:
elif len(input.size())==len(target.size())+1:
input = input.argmax(dim=-1)
else:
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with "
f"size:{input.size()}, targets should with size: {input.size()} or "
f"{input.size()[:-1]}, got {targets.size()}.")
f"size:{input.size()}, target should with size: {input.size()} or "
f"{input.size()[:-1]}, got {target.size()}.")

input = input.float()
target = target.float()

if masks is not None:
self.acc_count += torch.sum(torch.eq(input, targets).float() * masks.float()).item()
self.acc_count += torch.sum(torch.eq(input, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item()
else:
self.acc_count += torch.sum(torch.eq(input, targets).float()).item()
self.acc_count += torch.sum(torch.eq(input, target).float()).item()
self.total += np.prod(list(input.size()))

def get_metric(self, reset=True):


+ 1
- 1
fastNLP/core/utils.py View File

@@ -123,7 +123,7 @@ def _check_arg_dict_list(func, args):
input_args = set(input_arg_count.keys())
missing = list(require_args - input_args)
unused = list(input_args - all_args)
varargs = [] if spect.varargs else [arg for arg in spect.varargs]
varargs = [] if not spect.varargs else [arg for arg in spect.varargs]
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,


+ 17
- 0
test/core/test_metrics.py View File

@@ -0,0 +1,17 @@

import unittest

class TestOptim(unittest.TestCase):
def test_AccuracyMetric(self):
from fastNLP.core.metrics import AccuracyMetric
import torch
import numpy as np

# (1) only input, targets passed
output_dict = {"input": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(output_dict=output_dict, target_dict=target_dict)
print(metric.get_metric())


Loading…
Cancel
Save