diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6b8386c8..ee074feb 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -52,15 +52,16 @@ class MetricBase(object): value_counter[value].add(key) for value, key_set in value_counter.items(): if len(key_set)>1: - raise ValueError(f"Several params:{key_set} are provided with one output {value}.") + raise ValueError(f"Several parameters:{key_set} are provided with one output {value}.") # check consistence between signature and param_map func_spect = inspect.getfullargspec(self.evaluate) func_args = func_spect.args for func_param, input_param in self.param_map.items(): if func_param not in func_args: - raise NameError(f"`{func_param}` not in {get_func_signature(self.evaluate)}. Please check the " - f"initialization params, or change {get_func_signature(self.evaluate)} signature.") + raise NameError(f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " + f"initialization parameters, or change the signature of" + f" {get_func_signature(self.evaluate)}.") def get_metric(self, reset=True): raise NotImplemented @@ -134,19 +135,19 @@ class MetricBase(object): class AccuracyMetric(MetricBase): - def __init__(self, input=None, target=None, masks=None, seq_lens=None): + def __init__(self, pred=None, target=None, masks=None, seq_lens=None): super().__init__() - self._init_param_map(input=input, target=target, + self._init_param_map(pred=pred, target=target, masks=masks, seq_lens=seq_lens) self.total = 0 self.acc_count = 0 - def evaluate(self, input, target, masks=None, seq_lens=None): + def evaluate(self, pred, target, masks=None, seq_lens=None): """ - :param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: + :param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) :param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) @@ -156,41 +157,41 @@ class AccuracyMetric(MetricBase): None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. :return: dict({'acc': float}) """ - if not isinstance(input, torch.Tensor): - raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor," - f"got {type(input)}.") + if not isinstance(pred, torch.Tensor): + raise TypeError(f"`pred` in {get_func_signature(self.evaluate)} must be torch.Tensor," + f"got {type(pred)}.") if not isinstance(target, torch.Tensor): - raise NameError(f"`target` in {get_func_signature(self.evaluate())} expects torch.Tensor," + raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(target)}.") if masks is not None and not isinstance(masks, torch.Tensor): - raise NameError(f"`masks` in {get_func_signature(self.evaluate())} expects torch.Tensor," + raise TypeError(f"`masks` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(masks)}.") elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): - raise NameError(f"`seq_lens` in {get_func_signature(self.evaluate())} expects torch.Tensor," + raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," f"got {type(seq_lens)}.") if masks is None and seq_lens is not None: masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) - if input.size()==target.size(): + if pred.size()==target.size(): pass - elif len(input.size())==len(target.size())+1: - input = input.argmax(dim=-1) + elif len(pred.size())==len(target.size())+1: + pred = pred.argmax(dim=-1) else: - raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with " - f"size:{input.size()}, target should with size: {input.size()} or " - f"{input.size()[:-1]}, got {target.size()}.") + raise RuntimeError(f"In {get_func_signature(self.evaluate)}, when pred have " + f"size:{pred.size()}, target should have size: {pred.size()} or " + f"{pred.size()[:-1]}, got {target.size()}.") - input = input.float() + pred = pred.float() target = target.float() if masks is not None: - self.acc_count += torch.sum(torch.eq(input, target).float() * masks.float()).item() + self.acc_count += torch.sum(torch.eq(pred, target).float() * masks.float()).item() self.total += torch.sum(masks.float()).item() else: - self.acc_count += torch.sum(torch.eq(input, target).float()).item() - self.total += np.prod(list(input.size())) + self.acc_count += torch.sum(torch.eq(pred, target).float()).item() + self.total += np.prod(list(pred.size())) def get_metric(self, reset=True): evaluate_result = {'acc': self.acc_count/self.total} diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py index b279d7ca..bad3ebba 100644 --- a/test/core/test_metrics.py +++ b/test/core/test_metrics.py @@ -1,17 +1,72 @@ import unittest -class TestOptim(unittest.TestCase): - def test_AccuracyMetric(self): - from fastNLP.core.metrics import AccuracyMetric - import torch - import numpy as np +from fastNLP.core.metrics import AccuracyMetric +import torch +import numpy as np +class TestAccuracyMetric(unittest.TestCase): + def test_AccuracyMetric1(self): # (1) only input, targets passed - output_dict = {"input": torch.zeros(4, 3)} + output_dict = {"pred": torch.zeros(4, 3)} target_dict = {'target': torch.zeros(4)} metric = AccuracyMetric() metric(output_dict=output_dict, target_dict=target_dict) print(metric.get_metric()) + def test_AccuracyMetric2(self): + # (2) with corrupted size + output_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric = AccuracyMetric() + + metric(output_dict=output_dict, target_dict=target_dict) + print(metric.get_metric()) + + def test_AccuracyMetric3(self): + # (3) with check=False , the second batch is corrupted size + metric = AccuracyMetric() + output_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(output_dict=output_dict, target_dict=target_dict) + + output_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric(output_dict=output_dict, target_dict=target_dict) + + print(metric.get_metric()) + + def test_AccuracyMetric4(self): + # (4) with check=True , the second batch is corrupted size + metric = AccuracyMetric() + output_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(output_dict=output_dict, target_dict=target_dict) + + output_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4)} + metric(output_dict=output_dict, target_dict=target_dict, check=True) + + print(metric.get_metric()) + + def test_AccuaryMetric5(self): + # (5) check reset + metric = AccuracyMetric() + output_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)} + metric(output_dict=output_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc': 1}) + + output_dict = {"pred": torch.zeros(4, 3, 2)} + target_dict = {'target': torch.zeros(4, 3)+1} + metric(output_dict=output_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc':0}) + + def test_AccuaryMetric6(self): + # (6) check numpy array is not acceptable + metric = AccuracyMetric() + output_dict = {"pred": np.zeros((4, 3, 2))} + target_dict = {'target': np.zeros((4, 3))} + metric(output_dict=output_dict, target_dict=target_dict) + self.assertDictEqual(metric.get_metric(), {'acc': 1}) \ No newline at end of file