|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
- def accuracy(output, target, topk=(1,)):
- """ Computes the precision@k for the specified values of k """
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- # one-hot case
- if target.ndimension() > 1:
- target = target.max(1)[1]
-
- correct = pred.eq(target.view(1, -1).expand_as(pred))
-
- res = dict()
- for k in topk:
- correct_k = correct[:k].reshape(-1).float().sum(0)
- res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
- return res
|