|
@@ -115,10 +115,10 @@ class MetricBase(object): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AccuracyMetric(MetricBase): |
|
|
class AccuracyMetric(MetricBase): |
|
|
def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None): |
|
|
|
|
|
|
|
|
def __init__(self, input=None, targets=None, masks=None, seq_lens=None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self._init_param_map(predictions=predictions, targets=targets, |
|
|
|
|
|
|
|
|
self._init_param_map(input=input, targets=targets, |
|
|
masks=masks, seq_lens=seq_lens) |
|
|
masks=masks, seq_lens=seq_lens) |
|
|
|
|
|
|
|
|
self.total = 0 |
|
|
self.total = 0 |
|
@@ -138,7 +138,7 @@ class AccuracyMetric(MetricBase): |
|
|
:return: dict({'acc': float}) |
|
|
:return: dict({'acc': float}) |
|
|
""" |
|
|
""" |
|
|
if not isinstance(input, torch.Tensor): |
|
|
if not isinstance(input, torch.Tensor): |
|
|
raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
|
|
|
|
|
|
raise NameError(f"`input` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
f"got {type(input)}.") |
|
|
f"got {type(input)}.") |
|
|
if not isinstance(targets, torch.Tensor): |
|
|
if not isinstance(targets, torch.Tensor): |
|
|
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
@@ -157,9 +157,9 @@ class AccuracyMetric(MetricBase): |
|
|
if input.size()==targets.size(): |
|
|
if input.size()==targets.size(): |
|
|
pass |
|
|
pass |
|
|
elif len(input.size())==len(targets.size())+1: |
|
|
elif len(input.size())==len(targets.size())+1: |
|
|
predictions = input.argmax(dim=-1) |
|
|
|
|
|
|
|
|
input = input.argmax(dim=-1) |
|
|
else: |
|
|
else: |
|
|
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with " |
|
|
|
|
|
|
|
|
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when input with " |
|
|
f"size:{input.size()}, targets should with size: {input.size()} or " |
|
|
f"size:{input.size()}, targets should with size: {input.size()} or " |
|
|
f"{input.size()[:-1]}, got {targets.size()}.") |
|
|
f"{input.size()[:-1]}, got {targets.size()}.") |
|
|
|
|
|
|
|
|