You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

yolo3_multiscale.py 29 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import cv2
  16. import numpy as np
  17. import os
  18. import tensorflow as tf
  19. from resnet18 import ResNet18
  20. LOG = logging.getLogger(__name__)
  21. flags = tf.flags.FLAGS
  22. class Yolo3:
  23. def __init__(self, sess, is_training, config):
  24. LOG.info('is_training: %s' % is_training)
  25. LOG.info('model dir: %s' % flags.train_url)
  26. LOG.info('input_shape: (%d, %d)' % (flags.input_shape[0], flags.input_shape[1]))
  27. LOG.info('learning rate: %f' % float(flags.learning_rate))
  28. self.is_training = is_training
  29. self.model_dir = flags.train_url
  30. self.norm_epsilon = config.norm_epsilon
  31. self.norm_decay = config.norm_decay
  32. self.obj_threshold = float(flags.obj_threshold)
  33. self.nms_threshold = float(flags.nms_threshold)
  34. self.anchors = np.array([float(x) for x in config.anchors]).reshape(-1, 2)
  35. self.class_names = flags.class_names
  36. self.num_classes = len(self.class_names)
  37. self.input_shape = flags.input_shape
  38. self.nas_sequence = flags.nas_sequence
  39. if not os.path.exists(self.model_dir):
  40. os.makedirs(self.model_dir)
  41. print("anchors : ", self.anchors)
  42. print("class_names : ", self.class_names)
  43. if is_training:
  44. self.images = tf.placeholder(shape=[None, None, None, 3], dtype=tf.float32, name='images')
  45. else:
  46. self.images = tf.placeholder(shape=[1, self.input_shape[0], self.input_shape[1], 3], dtype=tf.float32,
  47. name='images')
  48. self.image_shape = tf.placeholder(dtype=tf.int32, shape=(2,), name='shapes')
  49. self.bbox_true_13 = tf.placeholder(shape=[None, None, None, 3, self.num_classes + 5], dtype=tf.float32)
  50. self.bbox_true_26 = tf.placeholder(shape=[None, None, None, 3, self.num_classes + 5], dtype=tf.float32)
  51. self.bbox_true_52 = tf.placeholder(shape=[None, None, None, 3, self.num_classes + 5], dtype=tf.float32)
  52. bbox_true = [self.bbox_true_13, self.bbox_true_26, self.bbox_true_52]
  53. features_out, filters_yolo_block, conv_index = self._resnet18(self.images, self.is_training)
  54. self.output = self.yolo_inference(features_out, filters_yolo_block, conv_index, len(self.anchors) / 3,
  55. self.num_classes, self.is_training)
  56. self.loss = self.yolo_loss(self.output, bbox_true, self.anchors, self.num_classes, config.ignore_thresh)
  57. self.global_step = tf.Variable(0, trainable=False)
  58. if self.is_training:
  59. learning_rate = tf.train.exponential_decay(float(flags.learning_rate), self.global_step, 1000, 0.95,
  60. staircase=True)
  61. optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
  62. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  63. with tf.control_dependencies(update_ops):
  64. self.train_op = optimizer.minimize(loss=self.loss, global_step=self.global_step)
  65. else:
  66. self.boxes, self.scores, self.classes = self.yolo_eval(self.output, self.image_shape, config.max_boxes)
  67. self.saver = tf.train.Saver()
  68. ckpt = tf.train.get_checkpoint_state(flags.train_url)
  69. if ckpt and tf.train.checkpoint_exists(ckpt.model_checkpoint_path):
  70. if not flags.label_changed:
  71. print('restore model', ckpt.model_checkpoint_path)
  72. self.saver.restore(sess, ckpt.model_checkpoint_path)
  73. else:
  74. print('restore model', ckpt.model_checkpoint_path)
  75. sess.run(tf.global_variables_initializer())
  76. sess.run(tf.local_variables_initializer())
  77. variables = tf.global_variables()
  78. vars_restore = [var for var in variables if not ("Adam" in var.name
  79. or '25' in var.name
  80. or '33' in var.name
  81. or '41' in var.name)] # or ("yolo" in var.name))]
  82. saver_restore = tf.train.Saver(vars_restore)
  83. saver_restore.restore(sess, ckpt.model_checkpoint_path)
  84. else:
  85. print('initialize model with fresh weights...')
  86. sess.run(tf.global_variables_initializer())
  87. sess.run(tf.local_variables_initializer())
  88. def load_weights(self, sess, fpath):
  89. sess = tf.get_default_session()
  90. variables = sess.graph.get_collection("variables")
  91. data = np.load(fpath)
  92. for v in variables:
  93. vname = v.name.replace(':0', '')
  94. if vname not in data:
  95. print("----------skip %s----------" % vname)
  96. continue
  97. print("assigning %s" % vname)
  98. sess.run(v.assign(data[vname]))
  99. def step(self, sess, batch_data, is_training):
  100. """step, read one batch, generate gradients
  101. """
  102. # Input feed
  103. input_feed = {}
  104. input_feed[self.images] = batch_data['images']
  105. input_feed[self.bbox_true_13] = batch_data['bbox_true_13']
  106. input_feed[self.bbox_true_26] = batch_data['bbox_true_26']
  107. input_feed[self.bbox_true_52] = batch_data['bbox_true_52']
  108. # Output feed: depends on training or test
  109. output_feed = [self.loss] # Loss for this batch.
  110. if is_training:
  111. output_feed.append(self.train_op) # Gradient updates
  112. outputs = sess.run(output_feed, input_feed)
  113. return outputs[0] # loss
  114. def _batch_normalization_layer(self, input_layer, name=None, training=True, norm_decay=0.997, norm_epsilon=1e-5):
  115. """Batch normalization is used for feature map extracted from
  116. convolution layer
  117. :param input_layer: four dimensional tensor of input
  118. :param name: the name of batchnorm layer
  119. :param training: is training or not
  120. :param norm_decay: The decay rate of moving average is calculated
  121. during prediction
  122. :param norm_epsilon: Variance plus a minimal number to prevent
  123. division by 0
  124. :return bn_layer: batch normalization处理之后的feature map
  125. """
  126. bn_layer = tf.layers.batch_normalization(inputs=input_layer,
  127. momentum=norm_decay, epsilon=norm_epsilon, center=True,
  128. scale=True, training=training, name=name, fused=True)
  129. return tf.nn.relu(bn_layer)
  130. # return tf.nn.leaky_relu(bn_layer, alpha = 0.1)
  131. def _conv2d_layer(self, inputs, filters_num, kernel_size, name, use_bias=False, strides=1):
  132. """Use tf.layers.conv2d Reduce the weight and bias matrix
  133. initialization process, as well as convolution plus bias operation
  134. :param inputs: Input variables
  135. :param filters_num: Number of convolution kernels
  136. :param strides: Convolution step
  137. :param name: Convolution layer name
  138. :param training: is a training process or not
  139. :param use_bias: use bias or not
  140. :param kernel_size: the kernels size
  141. :return conv: Feature map after convolution
  142. """
  143. if strides > 1: # modified 0327
  144. inputs = tf.pad(inputs, paddings=[[0, 0], [1, 0], [1, 0], [0, 0]], mode='CONSTANT')
  145. conv = tf.layers.conv2d(inputs=inputs, filters=filters_num,
  146. kernel_size=kernel_size, strides=[strides, strides],
  147. padding=('SAME' if strides == 1 else 'VALID'), # padding = 'SAME', #
  148. use_bias=use_bias,
  149. name=name) # , kernel_initializer = tf.contrib.layers.xavier_initializer()
  150. return conv
  151. def _Residual_block(self, inputs, filters_num, blocks_num, conv_index, training=True, norm_decay=0.997,
  152. norm_epsilon=1e-5):
  153. layer = self._conv2d_layer(inputs, filters_num, kernel_size=3, strides=2, name="conv2d_" + str(conv_index))
  154. layer = self._batch_normalization_layer(layer, name="batch_normalization_" + str(conv_index), training=training,
  155. norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  156. conv_index += 1
  157. for _ in range(blocks_num):
  158. shortcut = layer
  159. layer = self._conv2d_layer(layer, filters_num // 2, kernel_size=1, strides=1,
  160. name="conv2d_" + str(conv_index))
  161. layer = self._batch_normalization_layer(layer, name="batch_normalization_" + str(conv_index),
  162. training=training, norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  163. conv_index += 1
  164. layer = self._conv2d_layer(layer, filters_num, kernel_size=3, strides=1, name="conv2d_" + str(conv_index))
  165. layer = self._batch_normalization_layer(layer, name="batch_normalization_" + str(conv_index),
  166. training=training, norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  167. conv_index += 1
  168. layer += shortcut
  169. return layer, conv_index
  170. def _resnet18(self, inputs, training=True):
  171. cnn_model = ResNet18(inputs, training)
  172. for k, v in cnn_model.end_points.items():
  173. print(k)
  174. print(v)
  175. features_out = [cnn_model.end_points['conv5_output'], cnn_model.end_points['conv4_output'],
  176. cnn_model.end_points['conv3_output']]
  177. filters_yolo_block = [256, 128, 64]
  178. conv_index = 19
  179. return features_out, filters_yolo_block, conv_index
  180. def _yolo_block(self, inputs, filters_num, out_filters, conv_index, training=True, norm_decay=0.997,
  181. norm_epsilon=1e-5):
  182. conv = self._conv2d_layer(inputs, filters_num=filters_num, kernel_size=1, strides=1,
  183. name="conv2d_" + str(conv_index))
  184. conv = self._batch_normalization_layer(conv, name="batch_normalization_" + str(conv_index), training=training,
  185. norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  186. conv_index += 1
  187. conv = self._conv2d_layer(conv, filters_num=filters_num * 2, kernel_size=3, strides=1,
  188. name="conv2d_" + str(conv_index))
  189. conv = self._batch_normalization_layer(conv, name="batch_normalization_" + str(conv_index), training=training,
  190. norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  191. conv_index += 1
  192. conv = self._conv2d_layer(conv, filters_num=filters_num, kernel_size=1, strides=1,
  193. name="conv2d_" + str(conv_index))
  194. conv = self._batch_normalization_layer(conv, name="batch_normalization_" + str(conv_index), training=training,
  195. norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  196. conv_index += 1
  197. conv = self._conv2d_layer(conv, filters_num=filters_num * 2, kernel_size=3, strides=1,
  198. name="conv2d_" + str(conv_index))
  199. conv = self._batch_normalization_layer(conv, name="batch_normalization_" + str(conv_index), training=training,
  200. norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  201. conv_index += 1
  202. conv = self._conv2d_layer(conv, filters_num=filters_num, kernel_size=1, strides=1,
  203. name="conv2d_" + str(conv_index))
  204. conv = self._batch_normalization_layer(conv, name="batch_normalization_" + str(conv_index), training=training,
  205. norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  206. conv_index += 1
  207. route = conv
  208. conv = self._conv2d_layer(conv, filters_num=filters_num * 2, kernel_size=3, strides=1,
  209. name="conv2d_" + str(conv_index))
  210. conv = self._batch_normalization_layer(conv, name="batch_normalization_" + str(conv_index), training=training,
  211. norm_decay=norm_decay, norm_epsilon=norm_epsilon)
  212. conv_index += 1
  213. conv = self._conv2d_layer(conv, filters_num=out_filters, kernel_size=1, strides=1,
  214. name="conv2d_" + str(conv_index), use_bias=True)
  215. conv_index += 1
  216. return route, conv, conv_index
  217. def yolo_inference(self, features_out, filters_yolo_block, conv_index, num_anchors, num_classes, training=True):
  218. conv = features_out[0]
  219. conv2d_45 = features_out[1]
  220. conv2d_26 = features_out[2]
  221. print('conv : ', conv)
  222. print('conv2d_45 : ', conv2d_45)
  223. print('conv2d_26 : ', conv2d_26)
  224. with tf.variable_scope('yolo'):
  225. conv2d_57, conv2d_59, conv_index = self._yolo_block(conv, filters_yolo_block[0],
  226. num_anchors * (num_classes + 5), conv_index=conv_index,
  227. training=training, norm_decay=self.norm_decay,
  228. norm_epsilon=self.norm_epsilon)
  229. print('conv2d_59 : ', conv2d_59)
  230. print('conv2d_57 : ', conv2d_57)
  231. conv2d_60 = self._conv2d_layer(conv2d_57, filters_num=filters_yolo_block[1], kernel_size=1, strides=1,
  232. name="conv2d_" + str(conv_index))
  233. conv2d_60 = self._batch_normalization_layer(conv2d_60, name="batch_normalization_" + str(conv_index),
  234. training=training, norm_decay=self.norm_decay,
  235. norm_epsilon=self.norm_epsilon)
  236. print('conv2d_60 : ', conv2d_60)
  237. conv_index += 1
  238. upSample_0 = tf.image.resize_nearest_neighbor(conv2d_60,
  239. [2 * tf.shape(conv2d_60)[1], 2 * tf.shape(conv2d_60)[2]],
  240. name='upSample_0')
  241. print('upSample_0 : ', upSample_0)
  242. route0 = tf.concat([upSample_0, conv2d_45], axis=-1, name='route_0')
  243. print('route0 : ', route0)
  244. conv2d_65, conv2d_67, conv_index = self._yolo_block(route0, filters_yolo_block[1],
  245. num_anchors * (num_classes + 5), conv_index=conv_index,
  246. training=training, norm_decay=self.norm_decay,
  247. norm_epsilon=self.norm_epsilon)
  248. print('conv2d_67 : ', conv2d_67)
  249. print('conv2d_65 : ', conv2d_65)
  250. conv2d_68 = self._conv2d_layer(conv2d_65, filters_num=filters_yolo_block[2], kernel_size=1, strides=1,
  251. name="conv2d_" + str(conv_index))
  252. conv2d_68 = self._batch_normalization_layer(conv2d_68, name="batch_normalization_" + str(conv_index),
  253. training=training, norm_decay=self.norm_decay,
  254. norm_epsilon=self.norm_epsilon)
  255. print('conv2d_68 : ', conv2d_68)
  256. conv_index += 1
  257. upSample_1 = tf.image.resize_nearest_neighbor(conv2d_68,
  258. [2 * tf.shape(conv2d_68)[1], 2 * tf.shape(conv2d_68)[2]],
  259. name='upSample_1')
  260. print('upSample_1 : ', upSample_1)
  261. route1 = tf.concat([upSample_1, conv2d_26], axis=-1, name='route_1')
  262. print('route1 : ', route1)
  263. _, conv2d_75, _ = self._yolo_block(route1, filters_yolo_block[2], num_anchors * (num_classes + 5),
  264. conv_index=conv_index, training=training, norm_decay=self.norm_decay,
  265. norm_epsilon=self.norm_epsilon)
  266. print('conv2d_75 : ', conv2d_75)
  267. return [conv2d_59, conv2d_67, conv2d_75]
  268. def yolo_head(self, feats, anchors, num_classes, input_shape, training=True):
  269. num_anchors = len(anchors)
  270. anchors_tensor = tf.reshape(tf.constant(anchors, dtype=tf.float32), [1, 1, 1, num_anchors, 2])
  271. grid_size = tf.shape(feats)[1:3]
  272. predictions = tf.reshape(feats, [-1, grid_size[0], grid_size[1], num_anchors, num_classes + 5])
  273. grid_y = tf.tile(tf.reshape(tf.range(grid_size[0]), [-1, 1, 1, 1]), [1, grid_size[1], 1, 1])
  274. grid_x = tf.tile(tf.reshape(tf.range(grid_size[1]), [1, -1, 1, 1]), [grid_size[0], 1, 1, 1])
  275. grid = tf.concat([grid_x, grid_y], axis=-1)
  276. grid = tf.cast(grid, tf.float32)
  277. box_xy = (tf.sigmoid(predictions[..., :2]) + grid) / tf.cast(grid_size[::-1], tf.float32)
  278. box_wh = tf.exp(predictions[..., 2:4]) * anchors_tensor / input_shape[::-1]
  279. box_confidence = tf.sigmoid(predictions[..., 4:5])
  280. box_class_probs = tf.sigmoid(predictions[..., 5:])
  281. if training == True:
  282. return grid, predictions, box_xy, box_wh
  283. return box_xy, box_wh, box_confidence, box_class_probs
  284. def yolo_boxes_scores(self, feats, anchors, num_classes, input_shape, image_shape):
  285. input_shape = tf.cast(input_shape, tf.float32)
  286. image_shape = tf.cast(image_shape, tf.float32)
  287. box_xy, box_wh, box_confidence, box_class_probs = self.yolo_head(feats, anchors, num_classes, input_shape,
  288. training=False)
  289. box_yx = box_xy[..., ::-1]
  290. box_hw = box_wh[..., ::-1]
  291. new_shape = tf.round(image_shape * tf.reduce_min(input_shape / image_shape))
  292. offset = (input_shape - new_shape) / 2. / input_shape
  293. scale = input_shape / new_shape
  294. box_yx = (box_yx - offset) * scale
  295. box_hw = box_hw * scale
  296. box_min = box_yx - box_hw / 2.
  297. box_max = box_yx + box_hw / 2.
  298. boxes = tf.concat(
  299. [box_min[..., 0:1],
  300. box_min[..., 1:2],
  301. box_max[..., 0:1],
  302. box_max[..., 1:2]],
  303. axis=-1
  304. )
  305. boxes *= tf.concat([image_shape, image_shape], axis=-1)
  306. boxes = tf.reshape(boxes, [-1, 4])
  307. boxes_scores = box_confidence * box_class_probs
  308. boxes_scores = tf.reshape(boxes_scores, [-1, num_classes])
  309. return boxes, boxes_scores
  310. def box_iou(self, box1, box2):
  311. box1 = tf.expand_dims(box1, -2)
  312. box1_xy = box1[..., :2]
  313. box1_wh = box1[..., 2:4]
  314. box1_mins = box1_xy - box1_wh / 2.
  315. box1_maxs = box1_xy + box1_wh / 2.
  316. box2 = tf.expand_dims(box2, 0)
  317. box2_xy = box2[..., :2]
  318. box2_wh = box2[..., 2:4]
  319. box2_mins = box2_xy - box2_wh / 2.
  320. box2_maxs = box2_xy + box2_wh / 2.
  321. intersect_mins = tf.maximum(box1_mins, box2_mins)
  322. intersect_maxs = tf.minimum(box1_maxs, box2_maxs)
  323. intersect_wh = tf.maximum(intersect_maxs - intersect_mins, 0.)
  324. intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
  325. box1_area = box1_wh[..., 0] * box1_wh[..., 1]
  326. box2_area = box2_wh[..., 0] * box2_wh[..., 1]
  327. iou = intersect_area / (box1_area + box2_area - intersect_area)
  328. return iou
  329. def yolo_loss(self, yolo_output, y_true, anchors, num_classes, ignore_thresh=.5):
  330. loss = 0.0
  331. anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  332. input_shape = tf.shape(yolo_output[0])[1: 3] * 32
  333. input_shape = tf.cast(input_shape, tf.float32)
  334. grid_shapes = [tf.cast(tf.shape(yolo_output[l])[1:3], tf.float32) for l in range(3)]
  335. for index in range(3):
  336. object_mask = y_true[index][..., 4:5]
  337. class_probs = y_true[index][..., 5:]
  338. grid, predictions, pred_xy, pred_wh = self.yolo_head(yolo_output[index], anchors[anchor_mask[index]],
  339. num_classes, input_shape, training=True)
  340. pred_box = tf.concat([pred_xy, pred_wh], axis=-1)
  341. raw_true_xy = y_true[index][..., :2] * grid_shapes[index][::-1] - grid
  342. object_mask_bool = tf.cast(object_mask, dtype=tf.bool)
  343. raw_true_wh = tf.log(
  344. tf.where(tf.equal(y_true[index][..., 2:4] / anchors[anchor_mask[index]] * input_shape[::-1], 0),
  345. tf.ones_like(y_true[index][..., 2:4]),
  346. y_true[index][..., 2:4] / anchors[anchor_mask[index]] * input_shape[::-1]))
  347. box_loss_scale = 2 - y_true[index][..., 2:3] * y_true[index][..., 3:4]
  348. ignore_mask = tf.TensorArray(dtype=tf.float32, size=1, dynamic_size=True)
  349. def loop_body(internal_index, ignore_mask):
  350. true_box = tf.boolean_mask(y_true[index][internal_index, ..., 0:4],
  351. object_mask_bool[internal_index, ..., 0])
  352. iou = self.box_iou(pred_box[internal_index], true_box)
  353. best_iou = tf.reduce_max(iou, axis=-1)
  354. ignore_mask = ignore_mask.write(internal_index, tf.cast(best_iou < ignore_thresh, tf.float32))
  355. return internal_index + 1, ignore_mask
  356. _, ignore_mask = tf.while_loop(
  357. lambda internal_index, ignore_mask: internal_index < tf.shape(yolo_output[0])[0], loop_body,
  358. [0, ignore_mask])
  359. ignore_mask = ignore_mask.stack()
  360. ignore_mask = tf.expand_dims(ignore_mask, axis=-1)
  361. xy_loss = object_mask * box_loss_scale * tf.nn.sigmoid_cross_entropy_with_logits(
  362. labels=raw_true_xy,
  363. logits=predictions[..., 0:2])
  364. wh_loss = object_mask * box_loss_scale * 0.5 * tf.square(raw_true_wh - predictions[..., 2:4])
  365. confidence_loss = object_mask * tf.nn.sigmoid_cross_entropy_with_logits(
  366. labels=object_mask,
  367. logits=predictions[..., 4:5]) + (1 - object_mask) * tf.nn.sigmoid_cross_entropy_with_logits(
  368. labels=object_mask,
  369. logits=predictions[..., 4:5]) * ignore_mask
  370. class_loss = object_mask * tf.nn.sigmoid_cross_entropy_with_logits(labels=class_probs,
  371. logits=predictions[..., 5:])
  372. xy_loss = tf.reduce_sum(xy_loss) / tf.cast(tf.shape(yolo_output[0])[0], tf.float32)
  373. wh_loss = tf.reduce_sum(wh_loss) / tf.cast(tf.shape(yolo_output[0])[0], tf.float32)
  374. confidence_loss = tf.reduce_sum(confidence_loss) / tf.cast(tf.shape(yolo_output[0])[0], tf.float32)
  375. class_loss = tf.reduce_sum(class_loss) / tf.cast(tf.shape(yolo_output[0])[0], tf.float32)
  376. loss += xy_loss + wh_loss + confidence_loss + class_loss
  377. return loss
  378. def yolo_eval(self, yolo_outputs, image_shape, max_boxes=20):
  379. with tf.variable_scope('boxes_scores'):
  380. anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  381. boxes = []
  382. box_scores = []
  383. input_shape = tf.shape(yolo_outputs[0])[1: 3] * 32
  384. for i in range(len(yolo_outputs)):
  385. _boxes, _box_scores = self.yolo_boxes_scores(yolo_outputs[i], self.anchors[anchor_mask[i]],
  386. len(self.class_names), input_shape, image_shape)
  387. boxes.append(_boxes)
  388. box_scores.append(_box_scores)
  389. boxes = tf.concat(boxes, axis=0)
  390. box_scores = tf.concat(box_scores, axis=0)
  391. with tf.variable_scope('nms'):
  392. mask = box_scores >= self.obj_threshold
  393. max_boxes_tensor = tf.constant(max_boxes, dtype=tf.int32)
  394. boxes_ = []
  395. scores_ = []
  396. classes_ = []
  397. for c in range(len(self.class_names)):
  398. class_boxes = tf.boolean_mask(boxes, mask[:, c])
  399. class_box_scores = tf.boolean_mask(box_scores[:, c], mask[:, c])
  400. nms_index = tf.image.non_max_suppression(class_boxes, class_box_scores, max_boxes_tensor,
  401. iou_threshold=self.nms_threshold)
  402. class_boxes = tf.gather(class_boxes, nms_index)
  403. class_box_scores = tf.gather(class_box_scores, nms_index)
  404. classes = tf.ones_like(class_box_scores, 'int32') * c
  405. boxes_.append(class_boxes)
  406. scores_.append(class_box_scores)
  407. classes_.append(classes)
  408. with tf.variable_scope('output'):
  409. boxes_ = tf.concat(boxes_, axis=0, name='boxes')
  410. scores_ = tf.concat(scores_, axis=0, name='scores')
  411. classes_ = tf.concat(classes_, axis=0, name='classes')
  412. return boxes_, scores_, classes_
  413. class YoloConfig:
  414. gpu_index = "3"
  415. net_type = 'resnet18'
  416. anchors = [10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, 198, 163, 326]
  417. max_boxes = 50
  418. jitter = 0.3
  419. hue = 0.1
  420. sat = 1.0
  421. cont = 0.8
  422. bri = 0.1
  423. norm_decay = 0.99
  424. norm_epsilon = 1e-5
  425. ignore_thresh = 0.5
  426. class YOLOInference(object):
  427. # pylint: disable=too-many-arguments, too-many-instance-attributes
  428. def __init__(self, sess, pb_model_path, input_shape):
  429. """
  430. initialization
  431. """
  432. self.load_model(sess, pb_model_path)
  433. self.input_shape = input_shape
  434. def load_model(self, sess, pb_model_path):
  435. """
  436. import model and load parameters from pb file
  437. """
  438. logging.info("Import yolo model from pb start .......")
  439. with sess.as_default():
  440. with sess.graph.as_default():
  441. with tf.gfile.FastGFile(pb_model_path, 'rb') as f_handle:
  442. logging.info("ParseFromString start .......")
  443. graph_def = tf.GraphDef()
  444. graph_def.ParseFromString(f_handle.read())
  445. logging.info("ParseFromString end .......")
  446. tf.import_graph_def(graph_def, name='')
  447. logging.info("Import_graph_def end .......")
  448. logging.info("Import yolo model from pb end .......")
  449. # pylint: disable=too-many-locals
  450. # pylint: disable=invalid-name
  451. def predict(self, sess, img_data):
  452. """
  453. prediction for image rectangle by input_feed and output_feed
  454. """
  455. with sess.as_default():
  456. new_image = self.preprocess(img_data, self.input_shape)
  457. input_feed = self.create_input_feed(sess, new_image, img_data)
  458. output_fetch = self.create_output_fetch(sess)
  459. all_classes, all_scores, all_bboxes = sess.run(output_fetch, input_feed)
  460. return all_classes, all_scores, all_bboxes
  461. def create_input_feed(self, sess, new_image, img_data):
  462. """
  463. create input feed data
  464. """
  465. input_feed = {}
  466. input_img_data = sess.graph.get_tensor_by_name('images:0')
  467. input_feed[input_img_data] = new_image
  468. input_img_shape = sess.graph.get_tensor_by_name('shapes:0')
  469. input_feed[input_img_shape] = [img_data.shape[0], img_data.shape[1]]
  470. return input_feed
  471. def create_output_fetch(self, sess):
  472. """
  473. create output fetch tensors
  474. """
  475. output_classes = sess.graph.get_tensor_by_name('output/classes:0')
  476. output_scores = sess.graph.get_tensor_by_name('output/scores:0')
  477. output_boxes = sess.graph.get_tensor_by_name('output/boxes:0')
  478. output_fetch = [output_classes, output_scores, output_boxes]
  479. return output_fetch
  480. def preprocess(self, image, input_shape):
  481. """
  482. resize image with unchanged aspect ratio using padding by opencv
  483. """
  484. # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  485. h, w, _ = image.shape
  486. input_h, input_w = input_shape
  487. scale = min(float(input_w) / float(w), float(input_h) / float(h))
  488. nw = int(w * scale)
  489. nh = int(h * scale)
  490. image = cv2.resize(image, (nw, nh))
  491. new_image = np.zeros((input_h, input_w, 3), np.float32)
  492. new_image.fill(128)
  493. bh, bw, _ = new_image.shape
  494. new_image[int((bh - nh) / 2):(nh + int((bh - nh) / 2)), int((bw - nw) / 2):(nw + int((bw - nw) / 2)), :] = image
  495. new_image /= 255.
  496. new_image = np.expand_dims(new_image, 0) # Add batch dimension.
  497. return new_image