You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_gen.py 12 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import random
  15. import logging
  16. import numpy as np
  17. from PIL import Image
  18. import tensorflow as tf
  19. from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
  20. LOG = logging.getLogger(__name__)
  21. flags = tf.flags.FLAGS
  22. class DataGen(object):
  23. def __init__(self, config, train_data):
  24. LOG.info("DataGen build start .......")
  25. self.input_shape = flags.input_shape
  26. self.batch_size = flags.batch_size
  27. self.anchors = np.array([float(x)
  28. for x in config.anchors]).reshape(-1, 2)
  29. self.class_names = flags.class_names
  30. self.num_classes = len(self.class_names)
  31. self.max_boxes = config.max_boxes
  32. self.train_curr_index = 0
  33. self.train_data = train_data
  34. self.train_data_size = len(self.train_data)
  35. LOG.info('size of train data is : %d' % self.train_data_size)
  36. self.batch_index = 0
  37. self.cur_shape = flags.input_shape
  38. LOG.info("DataGen build end .......")
  39. def next_batch_train(self):
  40. multi_scales = [self.input_shape]
  41. for i in range(1, 3):
  42. multi_scales.append(
  43. (self.input_shape[0] - 32 * i,
  44. self.input_shape[1] - 32 * i))
  45. multi_scales.append(
  46. (self.input_shape[0] + 32 * i,
  47. self.input_shape[1] + 32 * i))
  48. if self.batch_index % 25 == 0:
  49. self.cur_shape = random.choice(multi_scales)
  50. self.batch_index += 1
  51. count, batch_data = self.next_batch(
  52. self.train_curr_index,
  53. self.train_data,
  54. self.train_data_size,
  55. self.cur_shape, True
  56. )
  57. if not count:
  58. self.train_curr_index = 0
  59. random.shuffle(self.train_data)
  60. return None
  61. else:
  62. self.train_curr_index += count
  63. batch_data['input_shape'] = self.cur_shape
  64. return batch_data
  65. def next_batch(
  66. self,
  67. curr_index,
  68. dataset,
  69. data_size,
  70. input_shape,
  71. is_training):
  72. count = 0
  73. img_data_list = []
  74. box_data_list = []
  75. while curr_index < data_size:
  76. if curr_index % 10000 == 0:
  77. LOG.info("processing label line %d" % curr_index)
  78. curr_line = dataset[curr_index]
  79. count += 1
  80. curr_index += 1
  81. if len(curr_line.strip()) <= 0:
  82. LOG.info("current line length less than 0......")
  83. continue
  84. image_data, box_data = self.read_data(
  85. curr_line, input_shape, is_training, self.max_boxes)
  86. if image_data is None or box_data is None:
  87. continue
  88. img_data_list.append(image_data)
  89. box_data_list.append(box_data)
  90. if len(img_data_list) >= self.batch_size:
  91. batch_data = dict()
  92. batch_data['images'] = np.array(img_data_list)
  93. bbox_true_13, bbox_true_26, bbox_true_52 = (
  94. self.preprocess_true_boxes(
  95. np.array(box_data_list), input_shape
  96. )
  97. )
  98. # np.array(bbox_13_list)
  99. batch_data['bbox_true_13'] = bbox_true_13
  100. # np.array(bbox_26_list)
  101. batch_data['bbox_true_26'] = bbox_true_26
  102. # np.array(bbox_52_list)
  103. batch_data['bbox_true_52'] = bbox_true_52
  104. return count, batch_data
  105. LOG.info('reaching the last line of data ~~~')
  106. return None, None
  107. def rand(self, a=0., b=1.):
  108. return np.random.rand() * (b - a) + a
  109. def read_data(
  110. self,
  111. annotation_line,
  112. input_shape=416,
  113. random=True,
  114. max_boxes=50,
  115. jitter=.3,
  116. hue=.1,
  117. sat=1.5,
  118. val=1.5,
  119. proc_img=True):
  120. """
  121. random preprocessing for real-time data augmentation
  122. """
  123. line = annotation_line.split()
  124. image = Image.open(line[0])
  125. iw, ih = image.size
  126. h, w = input_shape
  127. box = np.array([np.array(list(map(int, box.split(','))))
  128. for box in line[1:]])
  129. if not random:
  130. # resize image
  131. scale = min(float(w) / float(iw), float(h) / float(ih))
  132. nw = int(iw * scale)
  133. nh = int(ih * scale)
  134. dx = (w - nw) // 2
  135. dy = (h - nh) // 2
  136. image_data = 0
  137. if proc_img:
  138. image = image.resize((nw, nh), Image.BICUBIC)
  139. new_image = Image.new('RGB', (w, h), (128, 128, 128))
  140. new_image.paste(image, (dx, dy))
  141. image_data = np.array(new_image) / 255.
  142. # correct boxes
  143. box_data = np.zeros((max_boxes, 5))
  144. if len(box) > 0:
  145. np.random.shuffle(box)
  146. if len(box) > max_boxes:
  147. box = box[:max_boxes]
  148. box[:, [0, 2]] = box[:, [0, 2]] * scale + dx
  149. box[:, [1, 3]] = box[:, [1, 3]] * scale + dy
  150. box_data[:len(box)] = box
  151. return image_data, box_data
  152. else:
  153. return None, None
  154. # resize image
  155. new_ar = (float(w) / float(h)
  156. * self.rand(1 - jitter, 1 + jitter)
  157. / self.rand(1 - jitter, 1 + jitter))
  158. scale = self.rand(.25, 2)
  159. if new_ar < 1:
  160. nh = int(scale * h)
  161. nw = int(nh * new_ar)
  162. else:
  163. nw = int(scale * w)
  164. nh = int(nw / new_ar)
  165. image = image.resize((nw, nh), Image.BICUBIC)
  166. # place image
  167. dx = int(self.rand(0, w - nw))
  168. dy = int(self.rand(0, h - nh))
  169. new_image = Image.new('RGB', (w, h), (128, 128, 128))
  170. new_image.paste(image, (dx, dy))
  171. image = new_image
  172. # flip image or not
  173. flip = self.rand() < .5
  174. if flip:
  175. image = image.transpose(Image.FLIP_LEFT_RIGHT)
  176. # convert image to gray or not
  177. gray = self.rand() < .25
  178. if gray:
  179. image = image.convert('L').convert('RGB')
  180. # distort image
  181. hue = self.rand(-hue, hue)
  182. sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat)
  183. val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val)
  184. x = rgb_to_hsv(np.array(image) / 255.)
  185. x[..., 0] += hue
  186. x[..., 0][x[..., 0] > 1] -= 1
  187. x[..., 0][x[..., 0] < 0] += 1
  188. x[..., 1] *= sat
  189. x[..., 2] *= val
  190. x[x > 1] = 1
  191. x[x < 0] = 0
  192. image_data = hsv_to_rgb(x) # numpy array, 0 to 1
  193. # correct boxes
  194. box_data = np.zeros((max_boxes, 5))
  195. if len(box) > 0:
  196. np.random.shuffle(box)
  197. box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx
  198. box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
  199. if flip:
  200. box[:, [0, 2]] = w - box[:, [2, 0]]
  201. box[:, 0:2][box[:, 0:2] < 0] = 0
  202. box[:, 2][box[:, 2] > w] = w
  203. box[:, 3][box[:, 3] > h] = h
  204. box_w = box[:, 2] - box[:, 0]
  205. box_h = box[:, 3] - box[:, 1]
  206. # discard invalid box
  207. box = box[np.logical_and(box_w > 1, box_h > 1)]
  208. if len(box) > max_boxes:
  209. box = box[:max_boxes]
  210. if len(box) == 0:
  211. return None, None
  212. box_data[:len(box)] = box
  213. return image_data, box_data
  214. def preprocess_true_boxes(self, true_boxes, in_shape=416):
  215. """Preprocesses the ground truth box of the training data
  216. :param true_boxes: ground truth box shape is [boxes, 5], x_min, y_min,
  217. x_max, y_max, class_id
  218. """
  219. num_layers = self.anchors.shape[0] // 3
  220. anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  221. true_boxes = np.array(true_boxes, dtype='float32')
  222. # input_shape = np.array([in_shape, in_shape], dtype='int32')
  223. input_shape = np.array(in_shape, dtype='int32')
  224. boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
  225. boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
  226. true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
  227. true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
  228. m = true_boxes.shape[0]
  229. grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
  230. y_true = [np.zeros(
  231. (m,
  232. grid_shapes[layer][0],
  233. grid_shapes[layer][1],
  234. len(anchor_mask[layer]),
  235. 5 + self.num_classes),
  236. dtype='float32') for layer in range(num_layers)
  237. ]
  238. # The dimension is expanded to calculate the IOU between the
  239. # anchors of all boxes in each graph by broadcasting
  240. anchors = np.expand_dims(self.anchors, 0)
  241. anchors_max = anchors / 2.
  242. anchors_min = -anchors_max
  243. # Because we padded the box before, we need to remove all 0 lines
  244. valid_mask = boxes_wh[..., 0] > 0
  245. for b in range(m):
  246. wh = boxes_wh[b, valid_mask[b]]
  247. if len(wh) == 0:
  248. continue
  249. # Expanding dimensions for broadcasting applications
  250. wh = np.expand_dims(wh, -2)
  251. # wh shape is [box_num, 1, 2]
  252. boxes_max = wh / 2.
  253. boxes_min = -boxes_max
  254. intersect_min = np.maximum(boxes_min, anchors_min)
  255. intersect_max = np.minimum(boxes_max, anchors_max)
  256. intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
  257. intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
  258. box_area = wh[..., 0] * wh[..., 1]
  259. anchor_area = anchors[..., 0] * anchors[..., 1]
  260. iou = intersect_area / (box_area + anchor_area - intersect_area)
  261. # Find out the largest anchor box with the IOU of the ground truth
  262. # box, and then set the corresponding positions of different
  263. # proportions responsible for the ground turn box as the
  264. # coordinates of the ground truth box
  265. best_anchor = np.argmax(iou, axis=-1)
  266. for t, n in enumerate(best_anchor):
  267. for layer in range(num_layers):
  268. if n in anchor_mask[layer]:
  269. i = np.floor(
  270. true_boxes[b, t, 0] *
  271. grid_shapes[layer][1]).astype('int32')
  272. j = np.floor(
  273. true_boxes[b, t, 1] *
  274. grid_shapes[layer][0]).astype('int32')
  275. k = anchor_mask[layer].index(n)
  276. c = true_boxes[b, t, 4].astype('int32')
  277. y_true[layer][b, j, i, k, 0:4] = true_boxes[b, t, 0:4]
  278. y_true[layer][b, j, i, k, 4] = 1.
  279. y_true[layer][b, j, i, k, 5 + c] = 1.
  280. return y_true[0], y_true[1], y_true[2]