You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

validate_utils.py 8.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import time
  20. import logging
  21. import cv2
  22. import numpy as np
  23. import tensorflow as tf
  24. from PIL import Image
  25. from yolo3_multiscale import YOLOInference
  26. LOG = logging.getLogger(__name__)
  27. def add_path(path):
  28. if path not in sys.path:
  29. sys.path.insert(0, path)
  30. def init_yolo(model_path, input_shape):
  31. print('model_path : ', model_path)
  32. # initialize the session and bind the corresponding graph
  33. yolo_graph = tf.Graph()
  34. config = tf.ConfigProto(allow_soft_placement=True)
  35. config.gpu_options.allow_growth = True
  36. config.gpu_options.per_process_gpu_memory_fraction = 0.1
  37. yolo_session = tf.Session(graph=yolo_graph, config=config)
  38. # initialize yoloInference object
  39. yolo_infer = YOLOInference(yolo_session, model_path, input_shape)
  40. return yolo_infer, yolo_session
  41. def validate(model_path, test_dataset, class_names, input_shape=(352, 640)):
  42. yolo_infer, yolo_session = init_yolo(model_path, input_shape)
  43. folder_out = 'result'
  44. if not os.path.exists(folder_out):
  45. os.mkdir(folder_out)
  46. count_img = 0
  47. time_all = 0.0
  48. class_num = len(class_names)
  49. count_correct = [1e-6 for ix in range(class_num)]
  50. count_ground = [1e-6 for ix in range(class_num)]
  51. count_pred = [1e-6 for ix in range(class_num)]
  52. for line in test_dataset:
  53. line = line.strip()
  54. if not line:
  55. print("read line error")
  56. continue
  57. pos = line.find(' ')
  58. if pos == -1:
  59. print('error line : ', line)
  60. continue
  61. img_file = line[:pos]
  62. bbox_list_ground = line[pos + 1:].split(' ')
  63. time_predict, correct, pred, ground = validate_img_file(
  64. yolo_infer, yolo_session, img_file,
  65. bbox_list_ground, folder_out, class_names
  66. )
  67. count_correct = [count_correct[ix] + correct[ix]
  68. for ix in range(class_num)]
  69. count_pred = [count_pred[ix] + pred[ix] for ix in range(class_num)]
  70. count_ground = [count_ground[ix] + ground[ix]
  71. for ix in range(class_num)]
  72. count_img += 1
  73. time_all += time_predict
  74. print('count_correct', count_correct)
  75. print('count_pred', count_pred)
  76. print('count_ground', count_ground)
  77. precision = [float(count_correct[ix]) / float(count_pred[ix])
  78. for ix in range(class_num)]
  79. recall = [float(count_correct[ix]) / float(count_ground[ix])
  80. for ix in range(class_num)]
  81. all_precisions = sum(count_correct) / sum(count_pred)
  82. all_recalls = sum(count_correct) / sum(count_ground)
  83. print('precisions : ', precision)
  84. print('recalls : ', recall)
  85. print('all precisions : ', all_precisions)
  86. print('all recalls : ', all_recalls)
  87. print('average time = %.4f' % (time_all / float(count_img)))
  88. print('total time %.4f for %d images' % (time_all, count_img))
  89. return precision, recall, all_precisions, all_recalls
  90. def validate_img_file(yolo_infer, yolo_session, img_file,
  91. bbox_list_ground, folder_out, class_names):
  92. print('validate_img_file : ', img_file)
  93. class_num = len(class_names)
  94. img_data = Image.open(img_file)
  95. height, width = img_data.size
  96. img_data = np.array(img_data)
  97. if len(img_data.shape) == 3:
  98. input_image = img_data
  99. elif len(img_data.shape) == 2:
  100. input_image = Image.new('RGB', (height, width))
  101. input_image = np.array(input_image)
  102. input_image[:, :, 0] = img_data
  103. input_image[:, :, 1] = img_data
  104. input_image[:, :, 2] = img_data
  105. else:
  106. raise ValueError('validate image file should have three channels')
  107. channels = input_image.shape[-1]
  108. if channels != 3:
  109. time_start = time.time()
  110. count_correct = [0 for ix in range(class_num)]
  111. count_ground = [0 for ix in range(class_num)]
  112. count_pred = [0 for ix in range(class_num)]
  113. time_predict = time.time() - time_start
  114. return time_predict, count_correct, count_ground, count_pred
  115. time_start = time.time()
  116. labels, scores, bbox_list_pred = yolo_infer.predict(
  117. yolo_session, input_image)
  118. time_predict = time.time() - time_start
  119. colors = 'yellow,blue,green,red'
  120. if folder_out is not None:
  121. img = draw_boxes(
  122. img_data,
  123. labels,
  124. scores,
  125. bbox_list_pred,
  126. class_names,
  127. colors)
  128. img_file = img_file.split("/")[-1]
  129. cv2.imwrite(os.path.join(folder_out, img_file), img)
  130. count_correct = [0 for ix in range(class_num)]
  131. count_ground = [0 for ix in range(class_num)]
  132. count_pred = [0 for ix in range(class_num)]
  133. count_ground_all = len(bbox_list_ground)
  134. count_pred_all = bbox_list_pred.shape[0]
  135. for ix in range(count_ground_all):
  136. class_ground = int(bbox_list_ground[ix].split(',')[4])
  137. count_ground[class_ground] += 1
  138. for iy in range(count_pred_all):
  139. bbox_pred = [
  140. bbox_list_pred[iy][1],
  141. bbox_list_pred[iy][0],
  142. bbox_list_pred[iy][3],
  143. bbox_list_pred[iy][2]]
  144. LOG.debug(f'count_pred={count_pred}, labels[iy]={labels[iy]}')
  145. count_pred[labels[iy]] += 1
  146. for ix in range(count_ground_all):
  147. bbox_ground = [int(x) for x in bbox_list_ground[ix].split(',')]
  148. class_ground = bbox_ground[4]
  149. if labels[iy] == class_ground:
  150. iou = calc_iou(bbox_pred, bbox_ground)
  151. if iou >= 0.5:
  152. count_correct[class_ground] += 1
  153. break
  154. return time_predict, count_correct, count_pred, count_ground
  155. def rand(a=0, b=1):
  156. return np.random.rand() * (b - a) + a
  157. def draw_boxes(img, labels, scores, bboxes, class_names, colors):
  158. line_type = 2
  159. text_thickness = 1
  160. box_thickness = 1
  161. img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
  162. # get color code
  163. colors = colors.split(",")
  164. colors_code = []
  165. for color in colors:
  166. if color == 'green':
  167. colors_code.append((0, 255, 0))
  168. elif color == 'blue':
  169. colors_code.append((255, 0, 0))
  170. elif color == 'yellow':
  171. colors_code.append((0, 255, 255))
  172. else:
  173. colors_code.append((0, 0, 255))
  174. label_dict = {i: label for i, label in enumerate(class_names)}
  175. for i in range(bboxes.shape[0]):
  176. bbox = bboxes[i]
  177. if float("inf") in bbox or float("-inf") in bbox:
  178. continue
  179. label = int(labels[i])
  180. score = "%.2f" % round(scores[i], 2)
  181. text = label_dict.get(label) + ":" + score
  182. p1 = (int(bbox[0]), int(bbox[1]))
  183. p2 = (int(bbox[2]), int(bbox[3]))
  184. if (p2[0] - p1[0] < 1) or (p2[1] - p1[1] < 1):
  185. continue
  186. if max(p1) > 2 ** 31 or max(p2) > 2 ** 31:
  187. continue
  188. if min(p1) < - (2 ** 31) or max(p2) < - (2 ** 31):
  189. continue
  190. cv2.rectangle(img, p1[::-1], p2[::-1],
  191. colors_code[labels[i]], box_thickness)
  192. cv2.putText(img,
  193. text,
  194. (p1[1],
  195. p1[0] + 20 * (label + 1)),
  196. cv2.FONT_HERSHEY_SIMPLEX,
  197. 0.6, (255, 0, 0),
  198. text_thickness,
  199. line_type)
  200. return img
  201. def calc_iou(bbox_pred, bbox_ground):
  202. """user-define function for calculating the IOU of two matrixes. The
  203. input parameters are rectangle diagonals
  204. """
  205. x1 = bbox_pred[0]
  206. y1 = bbox_pred[1]
  207. width1 = bbox_pred[2] - bbox_pred[0]
  208. height1 = bbox_pred[3] - bbox_pred[1]
  209. x2 = bbox_ground[0]
  210. y2 = bbox_ground[1]
  211. width2 = bbox_ground[2] - bbox_ground[0]
  212. height2 = bbox_ground[3] - bbox_ground[1]
  213. endx = max(x1 + width1, x2 + width2)
  214. startx = min(x1, x2)
  215. width = width1 + width2 - (endx - startx)
  216. endy = max(y1 + height1, y2 + height2)
  217. starty = min(y1, y2)
  218. height = height1 + height2 - (endy - starty)
  219. if width <= 0 or height <= 0:
  220. iou = 0
  221. else:
  222. area = width * height
  223. area1 = width1 * height1
  224. area2 = width2 * height2
  225. iou = area * 1. / (area1 + area2 - area)
  226. return iou