You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

iouEval.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import torch
  2. class iouEval:
  3. def __init__(self, nClasses, ignoreIndex=20):
  4. self.nClasses = nClasses
  5. self.ignoreIndex = ignoreIndex if nClasses > ignoreIndex else -1 # if ignoreIndex is larger than nClasses, consider no ignoreIndex
  6. self.reset()
  7. def reset(self):
  8. classes = self.nClasses if self.ignoreIndex == -1 else self.nClasses - 1
  9. self.tp = torch.zeros(classes).double()
  10. self.fp = torch.zeros(classes).double()
  11. self.fn = torch.zeros(classes).double()
  12. self.cdp_obstacle = torch.zeros(1).double()
  13. self.tp_obstacle = torch.zeros(1).double()
  14. self.idp_obstacle = torch.zeros(1).double()
  15. self.tp_nonobstacle = torch.zeros(1).double()
  16. # self.cdi = torch.zeros(1).double()
  17. def addBatch(self, x, y): # x=preds, y=targets
  18. # sizes should be "batch_size x nClasses x H x W"
  19. # cdi = 0
  20. # print ("X is cuda: ", x.is_cuda)
  21. # print ("Y is cuda: ", y.is_cuda)
  22. if (x.is_cuda or y.is_cuda):
  23. x = x.cuda()
  24. y = y.cuda()
  25. # if size is "batch_size x 1 x H x W" scatter to onehot
  26. if (x.size(1) == 1):
  27. x_onehot = torch.zeros(x.size(0), self.nClasses, x.size(2), x.size(3))
  28. if x.is_cuda:
  29. x_onehot = x_onehot.cuda()
  30. x_onehot.scatter_(1, x, 1).float() # dim index src 按照列用1替换0,索引为x
  31. else:
  32. x_onehot = x.float()
  33. if (y.size(1) == 1):
  34. y_onehot = torch.zeros(y.size(0), self.nClasses, y.size(2), y.size(3))
  35. if y.is_cuda:
  36. y_onehot = y_onehot.cuda()
  37. y_onehot.scatter_(1, y, 1).float()
  38. else:
  39. y_onehot = y.float()
  40. if (self.ignoreIndex != -1):
  41. ignores = y_onehot[:, self.ignoreIndex].unsqueeze(1) # 加一维
  42. x_onehot = x_onehot[:, :self.ignoreIndex] # ignoreIndex后的都不要
  43. y_onehot = y_onehot[:, :self.ignoreIndex]
  44. else:
  45. ignores = 0
  46. tpmult = x_onehot * y_onehot # times prediction and gt coincide is 1
  47. tp = torch.sum(torch.sum(torch.sum(tpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3,
  48. keepdim=True).squeeze()
  49. fpmult = x_onehot * (
  50. 1 - y_onehot - ignores) # times prediction says its that class and gt says its not (subtracting cases when its ignore label!)
  51. fp = torch.sum(torch.sum(torch.sum(fpmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3,
  52. keepdim=True).squeeze()
  53. fnmult = (1 - x_onehot) * (y_onehot) # times prediction says its not that class and gt says it is
  54. fn = torch.sum(torch.sum(torch.sum(fnmult, dim=0, keepdim=True), dim=2, keepdim=True), dim=3,
  55. keepdim=True).squeeze()
  56. self.tp += tp.double().cpu()
  57. self.fp += fp.double().cpu()
  58. self.fn += fn.double().cpu()
  59. cdp_obstacle = tpmult[:, 19].sum() # obstacle index 19
  60. tp_obstacle = y_onehot[:, 19].sum()
  61. idp_obstacle = (x_onehot[:, 19] - tpmult[:, 19]).sum()
  62. tp_nonobstacle = (-1*y_onehot+1).sum()
  63. # for i in range(0, x.size(0)):
  64. # if tpmult[i].sum()/(y_onehot[i].sum() + 1e-15) >= 0.5:
  65. # cdi += 1
  66. self.cdp_obstacle += cdp_obstacle.double().cpu()
  67. self.tp_obstacle += tp_obstacle.double().cpu()
  68. self.idp_obstacle += idp_obstacle.double().cpu()
  69. self.tp_nonobstacle += tp_nonobstacle.double().cpu()
  70. # self.cdi += cdi.double().cpu()
  71. def getIoU(self):
  72. num = self.tp
  73. den = self.tp + self.fp + self.fn + 1e-15
  74. iou = num / den
  75. iou_not_zero = list(filter(lambda x: x != 0, iou))
  76. # print(len(iou_not_zero))
  77. iou_mean = sum(iou_not_zero) / len(iou_not_zero)
  78. tfp = self.tp + self.fp + 1e-15
  79. acc = num / tfp
  80. acc_not_zero = list(filter(lambda x: x != 0, acc))
  81. acc_mean = sum(acc_not_zero) / len(acc_not_zero)
  82. return iou_mean, iou, acc_mean, acc # returns "iou mean", "iou per class"
  83. def getObstacleEval(self):
  84. pdr_obstacle = self.cdp_obstacle / (self.tp_obstacle+1e-15)
  85. pfp_obstacle = self.idp_obstacle / (self.tp_nonobstacle+1e-15)
  86. return pdr_obstacle, pfp_obstacle
  87. # Class for colors
  88. class colors:
  89. RED = '\033[31;1m'
  90. GREEN = '\033[32;1m'
  91. YELLOW = '\033[33;1m'
  92. BLUE = '\033[34;1m'
  93. MAGENTA = '\033[35;1m'
  94. CYAN = '\033[36;1m'
  95. BOLD = '\033[1m'
  96. UNDERLINE = '\033[4m'
  97. ENDC = '\033[0m'
  98. # Colored value output if colorized flag is activated.
  99. def getColorEntry(val):
  100. if not isinstance(val, float):
  101. return colors.ENDC
  102. if (val < .20):
  103. return colors.RED
  104. elif (val < .40):
  105. return colors.YELLOW
  106. elif (val < .60):
  107. return colors.BLUE
  108. elif (val < .80):
  109. return colors.CYAN
  110. else:
  111. return colors.GREEN