You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pix_acc.py 954 B

12345678910111213141516171819202122232425
  1. import torch
  2. def calculate_acc(hist):
  3. """
  4. 计算准确率, 而不是iou
  5. :param hist:
  6. :return:
  7. """
  8. n_class = hist.size()[0]
  9. conf = torch.zeros((n_class, n_class))
  10. for cid in range(n_class):
  11. if torch.sum(hist[:, cid]) > 0:
  12. conf[:, cid] = hist[:, cid] / torch.sum(hist[:, cid])
  13. # 可以看作对于除了背景外的像素点的判断accuracy, 但是比较偏向于判断为某些类的正确率.
  14. # nan表示均判断为背景, 如果存在除背景外的类别, 则正确率为0; 如果不存在, 则表示nan(无结果,若不去除背景,则正确率为1)
  15. overall_acc = torch.sum(torch.diag(hist[1:, 1:])) / torch.sum(hist[1:, :])
  16. # acc为某类预测结果是正确的概率
  17. # nan表示无像素判断为该类, 若存在该类, 则表示正确率为0; 若不存在, 则表示nan(无法判断为该类结果的正确率)
  18. acc = torch.diag(conf)
  19. return overall_acc, acc

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)