From a90a62ab9bad71670e6ac580d3be9336a44ce169 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 2 Dec 2018 14:28:44 +0800 Subject: [PATCH] metric bug fix --- fastNLP/core/losses.py | 2 +- fastNLP/core/metrics.py | 26 +++++++++++++------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 564eb7ce..b1628ec8 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -112,7 +112,7 @@ class L1Loss(LossBase): class BCELoss(LossBase): - def __init__(self): + def __init__(self, input=None, target=None): super(BCELoss, self).__init__() self.get_loss = F.binary_cross_entropy diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 5296b0bf..6b5fcb3c 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -124,22 +124,22 @@ class AccuracyMetric(MetricBase): self.total = 0 self.acc_count = 0 - def evaluate(self, predictions, targets, masks=None, seq_lens=None): + def evaluate(self, input, targets, masks=None, seq_lens=None): """ - :param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: + :param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) :param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: - torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len]) + torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) :param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: None, None, torch.Size([B, max_len], torch.Size([B, max_len]) :param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. :return: dict({'acc': float}) """ - if not isinstance(predictions, torch.Tensor): + if not isinstance(input, torch.Tensor): raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor," - f"got {type(predictions)}.") + f"got {type(input)}.") if not isinstance(targets, torch.Tensor): raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," f"got {type(targets)}.") @@ -154,21 +154,21 @@ class AccuracyMetric(MetricBase): if masks is None and seq_lens is not None: masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) - if predictions.size()==targets.size(): + if input.size()==targets.size(): pass - elif len(predictions.size())==len(targets.size())+1: - predictions = predictions.argmax(dim=-1) + elif len(input.size())==len(targets.size())+1: + predictions = input.argmax(dim=-1) else: raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with " - f"size:{predictions.size()}, targets should with size: {predictions.size()} or " - f"{predictions.size()[:-1]}, got {targets.size()}.") + f"size:{input.size()}, targets should with size: {input.size()} or " + f"{input.size()[:-1]}, got {targets.size()}.") if masks is not None: - self.acc_count += torch.sum(torch.eq(predictions, targets).float() * masks.float()).item() + self.acc_count += torch.sum(torch.eq(input, targets).float() * masks.float()).item() self.total += torch.sum(masks.float()).item() else: - self.acc_count += torch.sum(torch.eq(predictions, targets).float()).item() - self.total += np.prod(list(torch.size(predictions))) + self.acc_count += torch.sum(torch.eq(input, targets).float()).item() + self.total += np.prod(list(input.size())) def get_metric(self, reset=True): evaluate_result = {'acc': self.acc_count/self.total}