|
12345678910111213141516171819202122232425 |
- import torch
-
-
- def calculate_acc(hist):
- """
- 计算准确率, 而不是iou
-
- :param hist:
- :return:
- """
- n_class = hist.size()[0]
- conf = torch.zeros((n_class, n_class))
- for cid in range(n_class):
- if torch.sum(hist[:, cid]) > 0:
- conf[:, cid] = hist[:, cid] / torch.sum(hist[:, cid])
-
- # 可以看作对于除了背景外的像素点的判断accuracy, 但是比较偏向于判断为某些类的正确率.
- # nan表示均判断为背景, 如果存在除背景外的类别, 则正确率为0; 如果不存在, 则表示nan(无结果,若不去除背景,则正确率为1)
- overall_acc = torch.sum(torch.diag(hist[1:, 1:])) / torch.sum(hist[1:, :])
-
- # acc为某类预测结果是正确的概率
- # nan表示无像素判断为该类, 若存在该类, 则表示正确率为0; 若不存在, 则表示nan(无法判断为该类结果的正确率)
- acc = torch.diag(conf)
-
- return overall_acc, acc
|