|
|
@@ -243,12 +243,11 @@ class AccuracyMetric(MetricBase): |
|
|
|
def evaluate(self, pred, target, seq_lens=None): |
|
|
|
""" |
|
|
|
|
|
|
|
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: |
|
|
|
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) |
|
|
|
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
|
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) |
|
|
|
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: |
|
|
|
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
|
:param pred: . Element's shape can be: torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), |
|
|
|
torch.Size([B, max_len, n_classes]) |
|
|
|
:param target: Element's can be: torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), |
|
|
|
torch.Size([B, max_len]) |
|
|
|
:param seq_lens: Element's can be: None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. |
|
|
|
|
|
|
|
""" |
|
|
|
# TODO 这里报错需要更改,因为pred是啥用户并不知道。需要告知用户真实的value |
|
|
@@ -264,7 +263,7 @@ class AccuracyMetric(MetricBase): |
|
|
|
f"got {type(seq_lens)}.") |
|
|
|
|
|
|
|
if seq_lens is not None: |
|
|
|
masks = seq_lens_to_masks(seq_lens=seq_lens).long() |
|
|
|
masks = seq_lens_to_masks(seq_lens=seq_lens) |
|
|
|
else: |
|
|
|
masks = None |
|
|
|
|
|
|
@@ -277,9 +276,9 @@ class AccuracyMetric(MetricBase): |
|
|
|
f"size:{pred.size()}, target should have size: {pred.size()} or " |
|
|
|
f"{pred.size()[:-1]}, got {target.size()}.") |
|
|
|
|
|
|
|
|
|
|
|
target = target.to(pred) |
|
|
|
if masks is not None: |
|
|
|
self.acc_count += torch.sum(torch.eq(pred, target) * masks).item() |
|
|
|
self.acc_count += torch.sum(torch.eq(pred, target).masked_fill(masks, 0)).item() |
|
|
|
self.total += torch.sum(masks).item() |
|
|
|
else: |
|
|
|
self.acc_count += torch.sum(torch.eq(pred, target)).item() |
|
|
|