|
@@ -124,22 +124,22 @@ class AccuracyMetric(MetricBase): |
|
|
self.total = 0 |
|
|
self.total = 0 |
|
|
self.acc_count = 0 |
|
|
self.acc_count = 0 |
|
|
|
|
|
|
|
|
def evaluate(self, predictions, targets, masks=None, seq_lens=None): |
|
|
|
|
|
|
|
|
def evaluate(self, input, targets, masks=None, seq_lens=None): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: |
|
|
|
|
|
|
|
|
:param input: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: |
|
|
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) |
|
|
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) |
|
|
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len]) |
|
|
|
|
|
|
|
|
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) |
|
|
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
None, None, torch.Size([B, max_len], torch.Size([B, max_len]) |
|
|
None, None, torch.Size([B, max_len], torch.Size([B, max_len]) |
|
|
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
:return: dict({'acc': float}) |
|
|
:return: dict({'acc': float}) |
|
|
""" |
|
|
""" |
|
|
if not isinstance(predictions, torch.Tensor): |
|
|
|
|
|
|
|
|
if not isinstance(input, torch.Tensor): |
|
|
raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
raise NameError(f"`predictions` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
f"got {type(predictions)}.") |
|
|
|
|
|
|
|
|
f"got {type(input)}.") |
|
|
if not isinstance(targets, torch.Tensor): |
|
|
if not isinstance(targets, torch.Tensor): |
|
|
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
raise NameError(f"`targets` in {get_func_signature(self.evaluate())} expects torch.Tensor," |
|
|
f"got {type(targets)}.") |
|
|
f"got {type(targets)}.") |
|
@@ -154,21 +154,21 @@ class AccuracyMetric(MetricBase): |
|
|
if masks is None and seq_lens is not None: |
|
|
if masks is None and seq_lens is not None: |
|
|
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) |
|
|
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) |
|
|
|
|
|
|
|
|
if predictions.size()==targets.size(): |
|
|
|
|
|
|
|
|
if input.size()==targets.size(): |
|
|
pass |
|
|
pass |
|
|
elif len(predictions.size())==len(targets.size())+1: |
|
|
|
|
|
predictions = predictions.argmax(dim=-1) |
|
|
|
|
|
|
|
|
elif len(input.size())==len(targets.size())+1: |
|
|
|
|
|
predictions = input.argmax(dim=-1) |
|
|
else: |
|
|
else: |
|
|
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with " |
|
|
raise RuntimeError(f"In {get_func_signature(self.evaluate())}, when predictions with " |
|
|
f"size:{predictions.size()}, targets should with size: {predictions.size()} or " |
|
|
|
|
|
f"{predictions.size()[:-1]}, got {targets.size()}.") |
|
|
|
|
|
|
|
|
f"size:{input.size()}, targets should with size: {input.size()} or " |
|
|
|
|
|
f"{input.size()[:-1]}, got {targets.size()}.") |
|
|
|
|
|
|
|
|
if masks is not None: |
|
|
if masks is not None: |
|
|
self.acc_count += torch.sum(torch.eq(predictions, targets).float() * masks.float()).item() |
|
|
|
|
|
|
|
|
self.acc_count += torch.sum(torch.eq(input, targets).float() * masks.float()).item() |
|
|
self.total += torch.sum(masks.float()).item() |
|
|
self.total += torch.sum(masks.float()).item() |
|
|
else: |
|
|
else: |
|
|
self.acc_count += torch.sum(torch.eq(predictions, targets).float()).item() |
|
|
|
|
|
self.total += np.prod(list(torch.size(predictions))) |
|
|
|
|
|
|
|
|
self.acc_count += torch.sum(torch.eq(input, targets).float()).item() |
|
|
|
|
|
self.total += np.prod(list(input.size())) |
|
|
|
|
|
|
|
|
def get_metric(self, reset=True): |
|
|
def get_metric(self, reset=True): |
|
|
evaluate_result = {'acc': self.acc_count/self.total} |
|
|
evaluate_result = {'acc': self.acc_count/self.total} |
|
|