|
- # Copyright (c) Microsoft Corporation.
- # Licensed under the MIT license.
-
-
- import os
- import torch
- import torch.distributed as dist
-
-
- def accuracy(output, target, topk=(1,)):
- """ Computes the precision@k for the specified values of k """
- maxk = max(topk)
- batch_size = target.size(0)
-
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- # one-hot case
- if target.ndimension() > 1:
- target = target.max(1)[1]
-
- correct = pred.eq(target.reshape(1, -1).expand_as(pred))
-
- res = []
- for k in topk:
- correct_k = correct[:k].reshape(-1).float().sum(0)
- res.append(correct_k.mul_(1.0 / batch_size))
- return res
-
-
- def reduce_metrics(metrics):
- return {k: reduce_tensor(v).item() for k, v in metrics.items()}
-
-
- def reduce_tensor(tensor):
- rt = torch.sum(tensor)
- # rt = tensor.clone()
- # dist.all_reduce(rt, op=dist.ReduceOp.SUM)
- # rt /= float(os.environ["WORLD_SIZE"])
- return rt
|