You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

evaluate.py 1.5 kB

1234567891011121314151617181920212223242526272829303132333435363738
  1. """
  2. 评估指定文件夹下的预测结果, 评价结果均不计算背景类
  3. """
  4. import numpy as np
  5. from PIL import Image
  6. from os.path import join
  7. from network_module.iou import IOU
  8. from network_module.pix_acc import calculate_acc
  9. def evalute(n_classes, dataset_path, verbose=False):
  10. iou = IOU(n_classes)
  11. test_list = open(join(dataset_path, 'test_dataset_list.txt').replace('\\', '/')).readlines()
  12. for ind in range(len(test_list)):
  13. pred = np.array(Image.open(join(dataset_path, 'prediction', test_list[ind].strip('\n')).replace('\\', '/')))
  14. label = np.array(Image.open(join(dataset_path, 'labels', test_list[ind].strip('\n')).replace('\\', '/')))
  15. if len(label.flatten()) != len(pred.flatten()):
  16. print('跳过{:s}: pred len {:d} != label len {:d},'.format(
  17. test_list[ind].strip('\n'), len(label.flatten()), len(pred.flatten())))
  18. continue
  19. iou.add_data(pred, label)
  20. # 必须置于iou_loss.forward前,因为forward会清除hist
  21. overall_acc, acc = calculate_acc(iou.hist)
  22. mIoU, IoUs = iou.get_miou()
  23. if verbose:
  24. for ind_class in range(n_classes):
  25. print('===>' + str(ind_class) + ':\t' + str(IoUs[ind_class].float()))
  26. print('===> mIoU: ' + str(mIoU))
  27. print('===> overall accuracy:', overall_acc)
  28. print('===> accuracy of each class:', acc)
  29. if __name__ == "__main__":
  30. evalute(9, './dataset/MFNet(RGB-T)-mini', verbose=True)

基于pytorch lightning的机器学习模板, 用于对机器学习算法进行训练, 验证, 测试等, 目前实现了神经网路, 深度学习, k折交叉, 自动保存训练信息等.

Contributors (1)