From 50f1c28b74c0cbd1595bdd3580ae7ec40afef007 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 2 Dec 2018 14:29:11 +0800 Subject: [PATCH] metric bug fix --- fastNLP/core/metrics.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6b5fcb3c..0d83fe44 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -115,10 +115,10 @@ class MetricBase(object): class AccuracyMetric(MetricBase): - def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None): + def __init__(self, input=None, targets=None, masks=None, seq_lens=None): super().__init__() - self._init_param_map(predictions=predictions, targets=targets, + self._init_param_map(input=input, targets=targets, masks=masks, seq_lens=seq_lens) self.total = 0 @@ -138,7 +138,7 @@ class AccuracyMetric(MetricBase): :return: dict({'acc': float}) """ if not isinstance(input, torch.Tensor): - raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor," + raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor," f"got {type(input)}.") if not isinstance(targets, torch.Tensor): raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," @@ -157,9 +157,9 @@ class AccuracyMetric(MetricBase): if input.size()==targets.size(): pass elif len(input.size())==len(targets.size())+1: - predictions = input.argmax(dim=-1) + input = input.argmax(dim=-1) else: - raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with " + raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with " f"size:{input.size()}, targets should with size: {input.size()} or " f"{input.size()[:-1]}, got {targets.size()}.")