Browse Source

metric bug fix

tags/v0.2.0^2
yh 6 years ago
parent
commit
a90a62ab9b
2 changed files with 14 additions and 14 deletions
  1. +1
    -1
      fastNLP/core/losses.py
  2. +13
    -13
      fastNLP/core/metrics.py

+ 1
- 1
fastNLP/core/losses.py View File

@@ -112,7 +112,7 @@ class L1Loss(LossBase):




class BCELoss(LossBase): class BCELoss(LossBase):
def __init__(self):
def __init__(self, input=None, target=None):
super(BCELoss, self).__init__() super(BCELoss, self).__init__()
self.get_loss = F.binary_cross_entropy self.get_loss = F.binary_cross_entropy




+ 13
- 13
fastNLP/core/metrics.py View File

@@ -124,22 +124,22 @@ class AccuracyMetric(MetricBase):
self.total = 0 self.total = 0
self.acc_count = 0 self.acc_count = 0


def evaluate(self, predictions, targets, masks=None, seq_lens=None):
def evaluate(self, input, targets, masks=None, seq_lens=None):
""" """


:param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
:param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes])
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: :param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be:
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len])
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len])
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: :param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B, max_len], torch.Size([B, max_len]) None, None, torch.Size([B, max_len], torch.Size([B, max_len])
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: :param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be:
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided.
:return: dict({'acc': float}) :return: dict({'acc': float})
""" """
if not isinstance(predictions, torch.Tensor):
if not isinstance(input, torch.Tensor):
raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor," raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
f"got {type(predictions)}.")
f"got {type(input)}.")
if not isinstance(targets, torch.Tensor): if not isinstance(targets, torch.Tensor):
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor,"
f"got {type(targets)}.") f"got {type(targets)}.")
@@ -154,21 +154,21 @@ class AccuracyMetric(MetricBase):
if masks is None and seq_lens is not None: if masks is None and seq_lens is not None:
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) masks = seq_lens_to_masks(seq_lens=seq_lens, float=True)


if predictions.size()==targets.size():
if input.size()==targets.size():
pass pass
elif len(predictions.size())==len(targets.size())+1:
predictions = predictions.argmax(dim=-1)
elif len(input.size())==len(targets.size())+1:
predictions = input.argmax(dim=-1)
else: else:
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with " raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with "
f"size:{predictions.size()}, targets should with size: {predictions.size()} or "
f"{predictions.size()[:-1]}, got {targets.size()}.")
f"size:{input.size()}, targets should with size: {input.size()} or "
f"{input.size()[:-1]}, got {targets.size()}.")


if masks is not None: if masks is not None:
self.acc_count += torch.sum(torch.eq(predictions, targets).float() * masks.float()).item()
self.acc_count += torch.sum(torch.eq(input, targets).float() * masks.float()).item()
self.total += torch.sum(masks.float()).item() self.total += torch.sum(masks.float()).item()
else: else:
self.acc_count += torch.sum(torch.eq(predictions, targets).float()).item()
self.total += np.prod(list(torch.size(predictions)))
self.acc_count += torch.sum(torch.eq(input, targets).float()).item()
self.total += np.prod(list(input.size()))


def get_metric(self, reset=True): def get_metric(self, reset=True):
evaluate_result = {'acc': self.acc_count/self.total} evaluate_result = {'acc': self.acc_count/self.total}


Loading…
Cancel
Save