You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate_utils.py 7.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import logging
  5. import sys
  6. import time
  7. import cv2
  8. import numpy as np
  9. import os
  10. import tensorflow as tf
  11. from PIL import Image
  12. from yolo3_multiscale import YOLOInference
  13. LOG = logging.getLogger(__name__)
  14. def add_path(path):
  15. if path not in sys.path:
  16. sys.path.insert(0, path)
  17. def init_yolo(model_path, input_shape):
  18. print('model_path : ', model_path)
  19. # initialize the session and bind the corresponding graph
  20. yolo_graph = tf.Graph()
  21. config = tf.ConfigProto(allow_soft_placement=True)
  22. config.gpu_options.allow_growth = True
  23. config.gpu_options.per_process_gpu_memory_fraction = 0.1
  24. yolo_session = tf.Session(graph=yolo_graph, config=config)
  25. # initialize yoloInference object
  26. yolo_infer = YOLOInference(yolo_session, model_path, input_shape)
  27. return yolo_infer, yolo_session
  28. def validate(model_path, test_dataset, class_names, input_shape=(352, 640)):
  29. yolo_infer, yolo_session = init_yolo(model_path, input_shape)
  30. folder_out = 'result'
  31. if not os.path.exists(folder_out):
  32. os.mkdir(folder_out)
  33. count_img = 0
  34. time_all = 0.0
  35. class_num = len(class_names)
  36. count_correct = [1e-6 for ix in range(class_num)]
  37. count_ground = [1e-6 for ix in range(class_num)]
  38. count_pred = [1e-6 for ix in range(class_num)]
  39. for line in test_dataset:
  40. line = line.strip()
  41. if not line:
  42. print("read line error")
  43. continue
  44. pos = line.find(' ')
  45. if pos == -1:
  46. print('error line : ', line)
  47. continue
  48. img_file = line[:pos]
  49. bbox_list_ground = line[pos + 1:].split(' ')
  50. time_predict, correct, pred, ground = validate_img_file(yolo_infer, yolo_session, img_file,
  51. bbox_list_ground,
  52. folder_out, class_names)
  53. count_correct = [count_correct[ix] + correct[ix] for ix in range(class_num)]
  54. count_pred = [count_pred[ix] + pred[ix] for ix in range(class_num)]
  55. count_ground = [count_ground[ix] + ground[ix] for ix in range(class_num)]
  56. count_img += 1
  57. time_all += time_predict
  58. print('count_correct', count_correct)
  59. print('count_pred', count_pred)
  60. print('count_ground', count_ground)
  61. precision = [float(count_correct[ix]) / float(count_pred[ix]) for ix in range(class_num)]
  62. recall = [float(count_correct[ix]) / float(count_ground[ix]) for ix in range(class_num)]
  63. all_precisions = sum(count_correct) / sum(count_pred)
  64. all_recalls = sum(count_correct) / sum(count_ground)
  65. print('precisions : ', precision)
  66. print('recalls : ', recall)
  67. print('all precisions : ', all_precisions)
  68. print('all recalls : ', all_recalls)
  69. print('average time = %.4f' % (time_all / float(count_img)))
  70. print('total time %.4f for %d images' % (time_all, count_img))
  71. return precision, recall, all_precisions, all_recalls
  72. def validate_img_file(yolo_infer, yolo_session, img_file, bbox_list_ground, folder_out, class_names):
  73. print('validate_img_file : ', img_file)
  74. class_num = len(class_names)
  75. img_data = Image.open(img_file)
  76. height, width = img_data.size
  77. img_data = np.array(img_data)
  78. if len(img_data.shape) == 3:
  79. input_image = img_data
  80. elif len(img_data.shape) == 2:
  81. input_image = Image.new('RGB', (height, width))
  82. input_image = np.array(input_image)
  83. input_image[:, :, 0] = img_data
  84. input_image[:, :, 1] = img_data
  85. input_image[:, :, 2] = img_data
  86. else:
  87. raise ValueError('validate image file should have three channels')
  88. channels = input_image.shape[-1]
  89. if channels != 3:
  90. time_start = time.time()
  91. count_correct = [0 for ix in range(class_num)]
  92. count_ground = [0 for ix in range(class_num)]
  93. count_pred = [0 for ix in range(class_num)]
  94. time_predict = time.time() - time_start
  95. return time_predict, count_correct, count_ground, count_pred
  96. time_start = time.time()
  97. labels, scores, bbox_list_pred = yolo_infer.predict(yolo_session, input_image)
  98. time_predict = time.time() - time_start
  99. colors = 'yellow,blue,green,red'
  100. if folder_out is not None:
  101. img = draw_boxes(img_data, labels, scores, bbox_list_pred, class_names, colors)
  102. img_file = img_file.split("/")[-1]
  103. cv2.imwrite(os.path.join(folder_out, img_file), img)
  104. count_correct = [0 for ix in range(class_num)]
  105. count_ground = [0 for ix in range(class_num)]
  106. count_pred = [0 for ix in range(class_num)]
  107. count_ground_all = len(bbox_list_ground)
  108. count_pred_all = bbox_list_pred.shape[0]
  109. for ix in range(count_ground_all):
  110. class_ground = int(bbox_list_ground[ix].split(',')[4])
  111. count_ground[class_ground] += 1
  112. for iy in range(count_pred_all):
  113. bbox_pred = [bbox_list_pred[iy][1], bbox_list_pred[iy][0], bbox_list_pred[iy][3], bbox_list_pred[iy][2]]
  114. LOG.debug(f'count_pred={count_pred}, labels[iy]={labels[iy]}')
  115. count_pred[labels[iy]] += 1
  116. for ix in range(count_ground_all):
  117. bbox_ground = [int(x) for x in bbox_list_ground[ix].split(',')]
  118. class_ground = bbox_ground[4]
  119. if labels[iy] == class_ground:
  120. iou = calc_iou(bbox_pred, bbox_ground)
  121. if iou >= 0.5:
  122. count_correct[class_ground] += 1
  123. break
  124. return time_predict, count_correct, count_pred, count_ground
  125. def rand(a=0, b=1):
  126. return np.random.rand() * (b - a) + a
  127. def draw_boxes(img, labels, scores, bboxes, class_names, colors):
  128. line_type = 2
  129. text_thickness = 1
  130. box_thickness = 1
  131. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  132. # get color code
  133. colors = colors.split(",")
  134. colors_code = []
  135. for color in colors:
  136. if color == 'green':
  137. colors_code.append((0, 255, 0))
  138. elif color == 'blue':
  139. colors_code.append((255, 0, 0))
  140. elif color == 'yellow':
  141. colors_code.append((0, 255, 255))
  142. else:
  143. colors_code.append((0, 0, 255))
  144. label_dict = {i: label for i, label in enumerate(class_names)}
  145. for i in range(bboxes.shape[0]):
  146. bbox = bboxes[i]
  147. if float("inf") in bbox or float("-inf") in bbox:
  148. continue
  149. label = int(labels[i])
  150. score = "%.2f" % round(scores[i], 2)
  151. text = label_dict.get(label) + ":" + score
  152. p1 = (int(bbox[0]), int(bbox[1]))
  153. p2 = (int(bbox[2]), int(bbox[3]))
  154. if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
  155. continue
  156. cv2.rectangle(img, p1[::-1], p2[::-1], colors_code[labels[i]], box_thickness)
  157. cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
  158. text_thickness, line_type)
  159. return img
  160. def calc_iou(bbox_pred, bbox_ground):
  161. """user-define function for calculating the IOU of two matrixes. The
  162. input parameters are rectangle diagonals
  163. """
  164. x1 = bbox_pred[0]
  165. y1 = bbox_pred[1]
  166. width1 = bbox_pred[2] - bbox_pred[0]
  167. height1 = bbox_pred[3] - bbox_pred[1]
  168. x2 = bbox_ground[0]
  169. y2 = bbox_ground[1]
  170. width2 = bbox_ground[2] - bbox_ground[0]
  171. height2 = bbox_ground[3] - bbox_ground[1]
  172. endx = max(x1 + width1, x2 + width2)
  173. startx = min(x1, x2)
  174. width = width1 + width2 - (endx - startx)
  175. endy = max(y1 + height1, y2 + height2)
  176. starty = min(y1, y2)
  177. height = height1 + height2 - (endy - starty)
  178. if width <= 0 or height <= 0:
  179. iou = 0
  180. else:
  181. area = width * height
  182. area1 = width1 * height1
  183. area2 = width2 * height2
  184. iou = area * 1. / (area1 + area2 - area)
  185. return iou