You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

accuracy.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. from tqdm import tqdm
  2. from sedna.common.class_factory import ClassType, ClassFactory
  3. from utils.args import EvaluationArguments
  4. from utils.metrics import Evaluator
  5. from dataloaders import make_data_loader
  6. __all__ = ('accuracy', 'robo_accuracy')
  7. @ClassFactory.register(ClassType.GENERAL)
  8. def accuracy(y_true, y_pred, **kwargs):
  9. args = EvaluationArguments()
  10. _, _, test_loader = make_data_loader(args, test_data=y_true)
  11. evaluator = Evaluator(args.num_class)
  12. tbar = tqdm(test_loader, desc='\r')
  13. for i, (sample, img_path) in enumerate(tbar):
  14. if args.depth:
  15. image, depth, target = sample['image'], sample['depth'], sample['label']
  16. else:
  17. image, target = sample['image'], sample['label']
  18. if args.cuda:
  19. image, target = image.cuda(), target.cuda()
  20. if args.depth:
  21. depth = depth.cuda()
  22. target[target > evaluator.num_class - 1] = 255
  23. target = target.cpu().numpy()
  24. # Add batch sample into evaluator
  25. evaluator.add_batch(target, y_pred[i])
  26. # Test during the training
  27. CPA = evaluator.Pixel_Accuracy_Class()
  28. mIoU = evaluator.Mean_Intersection_over_Union()
  29. FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
  30. print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU))
  31. return CPA
  32. @ClassFactory.register(ClassType.GENERAL)
  33. def robo_accuracy(y_true, y_pred, **kwargs):
  34. y_pred = y_pred[0]
  35. args = EvaluationArguments()
  36. _, _, test_loader = make_data_loader(args, test_data=y_true)
  37. evaluator = Evaluator(args.num_class)
  38. tbar = tqdm(test_loader, desc='\r')
  39. for i, (sample, img_path) in enumerate(tbar):
  40. if args.depth:
  41. image, depth, target = sample['image'], sample['depth'], sample['label']
  42. else:
  43. image, target = sample['image'], sample['label']
  44. if args.cuda:
  45. image, target = image.cuda(), target.cuda()
  46. if args.depth:
  47. depth = depth.cuda()
  48. target[target > evaluator.num_class - 1] = 255
  49. target = target.cpu().numpy()
  50. # Add batch sample into evaluator
  51. evaluator.add_batch(target, y_pred[i])
  52. # Test during the training
  53. CPA = evaluator.Pixel_Accuracy_Class()
  54. mIoU = evaluator.Mean_Intersection_over_Union()
  55. FWIoU = evaluator.Frequency_Weighted_Intersection_over_Union()
  56. print("CPA:{}, mIoU:{}, fwIoU: {}".format(CPA, mIoU, FWIoU))
  57. return CPA