You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate_utils.py 8.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import logging
  18. import sys
  19. import time
  20. import cv2
  21. import numpy as np
  22. import os
  23. import tensorflow as tf
  24. from PIL import Image
  25. from yolo3_multiscale import YOLOInference
  26. LOG = logging.getLogger(__name__)
  27. def add_path(path):
  28. if path not in sys.path:
  29. sys.path.insert(0, path)
  30. def init_yolo(model_path, input_shape):
  31. print('model_path : ', model_path)
  32. # initialize the session and bind the corresponding graph
  33. yolo_graph = tf.Graph()
  34. config = tf.ConfigProto(allow_soft_placement=True)
  35. config.gpu_options.allow_growth = True
  36. config.gpu_options.per_process_gpu_memory_fraction = 0.1
  37. yolo_session = tf.Session(graph=yolo_graph, config=config)
  38. # initialize yoloInference object
  39. yolo_infer = YOLOInference(yolo_session, model_path, input_shape)
  40. return yolo_infer, yolo_session
  41. def validate(model_path, test_dataset, class_names, input_shape=(352, 640)):
  42. yolo_infer, yolo_session = init_yolo(model_path, input_shape)
  43. folder_out = 'result'
  44. if not os.path.exists(folder_out):
  45. os.mkdir(folder_out)
  46. count_img = 0
  47. time_all = 0.0
  48. class_num = len(class_names)
  49. count_correct = [1e-6 for ix in range(class_num)]
  50. count_ground = [1e-6 for ix in range(class_num)]
  51. count_pred = [1e-6 for ix in range(class_num)]
  52. for line in test_dataset:
  53. line = line.strip()
  54. if not line:
  55. print("read line error")
  56. continue
  57. pos = line.find(' ')
  58. if pos == -1:
  59. print('error line : ', line)
  60. continue
  61. img_file = line[:pos]
  62. bbox_list_ground = line[pos + 1:].split(' ')
  63. time_predict, correct, pred, ground = validate_img_file(yolo_infer, yolo_session, img_file,
  64. bbox_list_ground,
  65. folder_out, class_names)
  66. count_correct = [count_correct[ix] + correct[ix] for ix in range(class_num)]
  67. count_pred = [count_pred[ix] + pred[ix] for ix in range(class_num)]
  68. count_ground = [count_ground[ix] + ground[ix] for ix in range(class_num)]
  69. count_img += 1
  70. time_all += time_predict
  71. print('count_correct', count_correct)
  72. print('count_pred', count_pred)
  73. print('count_ground', count_ground)
  74. precision = [float(count_correct[ix]) / float(count_pred[ix]) for ix in range(class_num)]
  75. recall = [float(count_correct[ix]) / float(count_ground[ix]) for ix in range(class_num)]
  76. all_precisions = sum(count_correct) / sum(count_pred)
  77. all_recalls = sum(count_correct) / sum(count_ground)
  78. print('precisions : ', precision)
  79. print('recalls : ', recall)
  80. print('all precisions : ', all_precisions)
  81. print('all recalls : ', all_recalls)
  82. print('average time = %.4f' % (time_all / float(count_img)))
  83. print('total time %.4f for %d images' % (time_all, count_img))
  84. return precision, recall, all_precisions, all_recalls
  85. def validate_img_file(yolo_infer, yolo_session, img_file, bbox_list_ground, folder_out, class_names):
  86. print('validate_img_file : ', img_file)
  87. class_num = len(class_names)
  88. img_data = Image.open(img_file)
  89. height, width = img_data.size
  90. img_data = np.array(img_data)
  91. if len(img_data.shape) == 3:
  92. input_image = img_data
  93. elif len(img_data.shape) == 2:
  94. input_image = Image.new('RGB', (height, width))
  95. input_image = np.array(input_image)
  96. input_image[:, :, 0] = img_data
  97. input_image[:, :, 1] = img_data
  98. input_image[:, :, 2] = img_data
  99. else:
  100. raise ValueError('validate image file should have three channels')
  101. channels = input_image.shape[-1]
  102. if channels != 3:
  103. time_start = time.time()
  104. count_correct = [0 for ix in range(class_num)]
  105. count_ground = [0 for ix in range(class_num)]
  106. count_pred = [0 for ix in range(class_num)]
  107. time_predict = time.time() - time_start
  108. return time_predict, count_correct, count_ground, count_pred
  109. time_start = time.time()
  110. labels, scores, bbox_list_pred = yolo_infer.predict(yolo_session, input_image)
  111. time_predict = time.time() - time_start
  112. colors = 'yellow,blue,green,red'
  113. if folder_out is not None:
  114. img = draw_boxes(img_data, labels, scores, bbox_list_pred, class_names, colors)
  115. img_file = img_file.split("/")[-1]
  116. cv2.imwrite(os.path.join(folder_out, img_file), img)
  117. count_correct = [0 for ix in range(class_num)]
  118. count_ground = [0 for ix in range(class_num)]
  119. count_pred = [0 for ix in range(class_num)]
  120. count_ground_all = len(bbox_list_ground)
  121. count_pred_all = bbox_list_pred.shape[0]
  122. for ix in range(count_ground_all):
  123. class_ground = int(bbox_list_ground[ix].split(',')[4])
  124. count_ground[class_ground] += 1
  125. for iy in range(count_pred_all):
  126. bbox_pred = [bbox_list_pred[iy][1], bbox_list_pred[iy][0], bbox_list_pred[iy][3], bbox_list_pred[iy][2]]
  127. LOG.debug(f'count_pred={count_pred}, labels[iy]={labels[iy]}')
  128. count_pred[labels[iy]] += 1
  129. for ix in range(count_ground_all):
  130. bbox_ground = [int(x) for x in bbox_list_ground[ix].split(',')]
  131. class_ground = bbox_ground[4]
  132. if labels[iy] == class_ground:
  133. iou = calc_iou(bbox_pred, bbox_ground)
  134. if iou >= 0.5:
  135. count_correct[class_ground] += 1
  136. break
  137. return time_predict, count_correct, count_pred, count_ground
  138. def rand(a=0, b=1):
  139. return np.random.rand() * (b - a) + a
  140. def draw_boxes(img, labels, scores, bboxes, class_names, colors):
  141. line_type = 2
  142. text_thickness = 1
  143. box_thickness = 1
  144. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  145. # get color code
  146. colors = colors.split(",")
  147. colors_code = []
  148. for color in colors:
  149. if color == 'green':
  150. colors_code.append((0, 255, 0))
  151. elif color == 'blue':
  152. colors_code.append((255, 0, 0))
  153. elif color == 'yellow':
  154. colors_code.append((0, 255, 255))
  155. else:
  156. colors_code.append((0, 0, 255))
  157. label_dict = {i: label for i, label in enumerate(class_names)}
  158. for i in range(bboxes.shape[0]):
  159. bbox = bboxes[i]
  160. if float("inf") in bbox or float("-inf") in bbox:
  161. continue
  162. label = int(labels[i])
  163. score = "%.2f" % round(scores[i], 2)
  164. text = label_dict.get(label) + ":" + score
  165. p1 = (int(bbox[0]), int(bbox[1]))
  166. p2 = (int(bbox[2]), int(bbox[3]))
  167. if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
  168. continue
  169. cv2.rectangle(img, p1[::-1], p2[::-1], colors_code[labels[i]], box_thickness)
  170. cv2.putText(img, text, (p1[1], p1[0] + 20 * (label + 1)), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 0, 0),
  171. text_thickness, line_type)
  172. return img
  173. def calc_iou(bbox_pred, bbox_ground):
  174. """user-define function for calculating the IOU of two matrixes. The
  175. input parameters are rectangle diagonals
  176. """
  177. x1 = bbox_pred[0]
  178. y1 = bbox_pred[1]
  179. width1 = bbox_pred[2] - bbox_pred[0]
  180. height1 = bbox_pred[3] - bbox_pred[1]
  181. x2 = bbox_ground[0]
  182. y2 = bbox_ground[1]
  183. width2 = bbox_ground[2] - bbox_ground[0]
  184. height2 = bbox_ground[3] - bbox_ground[1]
  185. endx = max(x1 + width1, x2 + width2)
  186. startx = min(x1, x2)
  187. width = width1 + width2 - (endx - startx)
  188. endy = max(y1 + height1, y2 + height2)
  189. starty = min(y1, y2)
  190. height = height1 + height2 - (endy - starty)
  191. if width <= 0 or height <= 0:
  192. iou = 0
  193. else:
  194. area = width * height
  195. area1 = width1 * height1
  196. area2 = width2 * height2
  197. iou = area * 1. / (area1 + area2 - area)
  198. return iou