Browse Source

bug fix in MetricAccuracy

tags/v0.2.0^2
yh 5 years ago
parent
commit
c2d2137500
2 changed files with 85 additions and 29 deletions
  1. +24
    -23
      fastNLP/core/metrics.py
  2. +61
    -6
      test/core/test_metrics.py

+ 24
- 23
fastNLP/core/metrics.py View File

@@ -52,15 +52,16 @@ class MetricBase(object):
value_counter[value].add(key) value_counter[value].add(key)
for value, key_set in value_counter.items(): for value, key_set in value_counter.items():
if len(key_set)>1: if len(key_set)>1:
raise ValueError(f"Several params:{key_set} are provided with one output {value}.")
raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.")


# check consistence between signature and param_map # check consistence between signature and param_map
func_spect = inspect.getfullargspec(self.evaluate) func_spect = inspect.getfullargspec(self.evaluate)
func_args = func_spect.args func_args = func_spect.args
for func_param, input_param in self.param_map.items(): for func_param, input_param in self.param_map.items():
if func_param not in func_args: if func_param not in func_args:
raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization params, or change {get_func_signature(self.evaluate)} signature.")
raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change the signature of"
f" {get_func_signature(self.evaluate)}.")


def get_metric(self, reset=True): def get_metric(self, reset=True):
raise NotImplemented raise NotImplemented
@@ -134,19 +135,19 @@ class MetricBase(object):




class AccuracyMetric(MetricBase): class AccuracyMetric(MetricBase):
def __init__(self, input=None, target=None, masks=None, seq_lens=None):
def __init__(self, pred=None, target=None, masks=None, seq_lens=None):
super().__init__() super().__init__()


self._init_param_map(input=input, target=target,
self._init_param_map(pred=pred, target=target,
masks=masks, seq_lens=seq_lens) masks=masks, seq_lens=seq_lens)


self.total = 0 self.total = 0
self.acc_count = 0 self.acc_count = 0


def evaluate(self, input, target, masks=None, seq_lens=None):
def evaluate(self, pred, target, masks=None, seq_lens=None):
""" """


:param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes])
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: :param target: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len])
@@ -156,41 +157,41 @@ class AccuracyMetric(MetricBase):
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float}) :return: dict({'acc': float})
""" """
if not isinstance(input, torch.Tensor):
raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
f"got {type(input)}.")
if not isinstance(pred, torch.Tensor):
raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(pred)}.")
if not isinstance(target, torch.Tensor): if not isinstance(target, torch.Tensor):
raise NameError(f"`target` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(target)}.") f"got {type(target)}.")


if masks is not None and not isinstance(masks, torch.Tensor): if masks is not None and not isinstance(masks, torch.Tensor):
raise NameError(f"`masks` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
raise TypeError(f"`masks` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(masks)}.") f"got {type(masks)}.")
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor):
raise NameError(f"`seq_lens` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor,"
f"got {type(seq_lens)}.") f"got {type(seq_lens)}.")


if masks is None and seq_lens is not None: if masks is None and seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)


if input.size()==target.size():
if pred.size()==target.size():
pass pass
elif len(input.size())==len(target.size())+1:
input = input.argmax(dim=-1)
elif len(pred.size())==len(target.size())+1:
pred = pred.argmax(dim=-1)
else: else:
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with "
f"size:{input.size()}, target should with size: {input.size()} or "
f"{input.size()[:-1]}, got {target.size()}.")
raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have "
f"size:{pred.size()}, target should have size: {pred.size()} or "
f"{pred.size()[:-1]}, got {target.size()}.")


input = input.float()
pred = pred.float()
target = target.float() target = target.float()


if masks is not None: if masks is not None:
self.acc_count += torch.sum(torch.eq(input, target).float() * masks.float()).item()
self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item() self.total += torch.sum(masks.float()).item()
else: else:
self.acc_count += torch.sum(torch.eq(input, target).float()).item()
self.total += np.prod(list(input.size()))
self.acc_count += torch.sum(torch.eq(pred, target).float()).item()
self.total += np.prod(list(pred.size()))


def get_metric(self, reset=True): def get_metric(self, reset=True):
evaluate_result = {'acc': self.acc_count/self.total} evaluate_result = {'acc': self.acc_count/self.total}


+ 61
- 6
test/core/test_metrics.py View File

@@ -1,17 +1,72 @@


import unittest import unittest


class TestOptim(unittest.TestCase):
def test_AccuracyMetric(self):
from fastNLP.core.metrics import AccuracyMetric
import torch
import numpy as np
from fastNLP.core.metrics import AccuracyMetric
import torch
import numpy as np


class TestAccuracyMetric(unittest.TestCase):
def test_AccuracyMetric1(self):
# (1) only input, targets passed # (1) only input, targets passed
output_dict = {"input": torch.zeros(4, 3)}
output_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4)} target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric() metric = AccuracyMetric()


metric(output_dict=output_dict, target_dict=target_dict) metric(output_dict=output_dict, target_dict=target_dict)
print(metric.get_metric()) print(metric.get_metric())


def test_AccuracyMetric2(self):
# (2) with corrupted size
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric()

metric(output_dict=output_dict, target_dict=target_dict)
print(metric.get_metric())

def test_AccuracyMetric3(self):
# (3) with check=False , the second batch is corrupted size
metric = AccuracyMetric()
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(output_dict=output_dict, target_dict=target_dict)

output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(output_dict=output_dict, target_dict=target_dict)

print(metric.get_metric())

def test_AccuracyMetric4(self):
# (4) with check=True , the second batch is corrupted size
metric = AccuracyMetric()
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(output_dict=output_dict, target_dict=target_dict)

output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)}
metric(output_dict=output_dict, target_dict=target_dict, check=True)

print(metric.get_metric())

def test_AccuaryMetric5(self):
# (5) check reset
metric = AccuracyMetric()
output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)}
metric(output_dict=output_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

output_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)+1}
metric(output_dict=output_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc':0})

def test_AccuaryMetric6(self):
# (6) check numpy array is not acceptable
metric = AccuracyMetric()
output_dict = {"pred": np.zeros((4, 3, 2))}
target_dict = {'target': np.zeros((4, 3))}
metric(output_dict=output_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1})

Loading…
Cancel
Save