You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

data_gen.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. import logging
  2. import numpy as np
  3. import random
  4. import tensorflow as tf
  5. from PIL import Image
  6. from matplotlib.colors import rgb_to_hsv, hsv_to_rgb
  7. LOG = logging.getLogger(__name__)
  8. flags = tf.flags.FLAGS
  9. class DataGen(object):
  10. def __init__(self, config, train_data, valid_data):
  11. LOG.info("DataGen build start .......")
  12. self.input_shape = flags.input_shape
  13. self.batch_size = flags.batch_size
  14. self.anchors = np.array([float(x) for x in config.anchors]).reshape(-1, 2)
  15. self.class_names = flags.class_names
  16. self.num_classes = len(self.class_names)
  17. self.max_boxes = config.max_boxes
  18. self.train_curr_index = 0
  19. self.train_data = train_data
  20. self.train_data_size = len(self.train_data)
  21. LOG.info('size of train data is : %d' % self.train_data_size)
  22. self.val_curr_index = 0
  23. self.val_data = valid_data
  24. self.val_data_size = len(self.val_data)
  25. LOG.info('size of validation data is : %d' % self.val_data_size)
  26. self.batch_index = 0
  27. self.cur_shape = flags.input_shape
  28. LOG.info("DataGen build end .......")
  29. def next_batch_train(self):
  30. multi_scales = [self.input_shape]
  31. for i in range(1, 3):
  32. multi_scales.append((self.input_shape[0] - 32 * i, self.input_shape[1] - 32 * i))
  33. multi_scales.append((self.input_shape[0] + 32 * i, self.input_shape[1] + 32 * i))
  34. if self.batch_index % 25 == 0:
  35. self.cur_shape = random.choice(multi_scales)
  36. self.batch_index += 1
  37. count, batch_data = self.next_batch(self.train_curr_index, self.train_data, self.train_data_size,
  38. self.cur_shape, True)
  39. if not count:
  40. self.train_curr_index = 0
  41. random.shuffle(self.train_data)
  42. return None
  43. else:
  44. self.train_curr_index += count
  45. batch_data['input_shape'] = self.cur_shape
  46. return batch_data
  47. def next_batch_validate(self):
  48. count, batch_data = self.next_batch(self.val_curr_index, self.val_data, self.val_data_size, self.input_shape,
  49. False)
  50. if not count:
  51. self.val_curr_index = 0
  52. return None
  53. else:
  54. self.val_curr_index += count
  55. return batch_data
  56. def next_batch(self, curr_index, dataset, data_size, input_shape, is_training):
  57. count = 0
  58. img_data_list = []
  59. box_data_list = []
  60. while curr_index < data_size:
  61. if curr_index % 10000 == 0:
  62. LOG.info("processing label line %d" % curr_index)
  63. curr_line = dataset[curr_index]
  64. count += 1
  65. curr_index += 1
  66. if len(curr_line.strip()) <= 0:
  67. LOG.info("current line length less than 0......")
  68. continue
  69. image_data, box_data = self.read_data(curr_line, input_shape, is_training, self.max_boxes)
  70. if image_data is None or box_data is None:
  71. continue
  72. img_data_list.append(image_data)
  73. box_data_list.append(box_data)
  74. if len(img_data_list) >= self.batch_size:
  75. batch_data = dict()
  76. batch_data['images'] = np.array(img_data_list)
  77. bbox_true_13, bbox_true_26, bbox_true_52 = self.preprocess_true_boxes(np.array(box_data_list),
  78. input_shape)
  79. batch_data['bbox_true_13'] = bbox_true_13 # np.array(bbox_13_list)
  80. batch_data['bbox_true_26'] = bbox_true_26 # np.array(bbox_26_list)
  81. batch_data['bbox_true_52'] = bbox_true_52 # np.array(bbox_52_list)
  82. return count, batch_data
  83. LOG.info('reaching the last line of data ~~~')
  84. return None, None
  85. def rand(self, a=0., b=1.):
  86. return np.random.rand() * (b - a) + a
  87. def read_data(self, annotation_line, input_shape=416, random=True, max_boxes=50, jitter=.3, hue=.1, sat=1.5,
  88. val=1.5, proc_img=True):
  89. """
  90. random preprocessing for real-time data augmentation
  91. """
  92. line = annotation_line.split()
  93. image = Image.open(line[0])
  94. iw, ih = image.size
  95. h, w = input_shape
  96. box = np.array([np.array(list(map(int, box.split(',')))) for box in line[1:]])
  97. if not random:
  98. # resize image
  99. scale = min(float(w) / float(iw), float(h) / float(ih))
  100. nw = int(iw * scale)
  101. nh = int(ih * scale)
  102. dx = (w - nw) // 2
  103. dy = (h - nh) // 2
  104. image_data = 0
  105. if proc_img:
  106. image = image.resize((nw, nh), Image.BICUBIC)
  107. new_image = Image.new('RGB', (w, h), (128, 128, 128))
  108. new_image.paste(image, (dx, dy))
  109. image_data = np.array(new_image) / 255.
  110. # correct boxes
  111. box_data = np.zeros((max_boxes, 5))
  112. if len(box) > 0:
  113. np.random.shuffle(box)
  114. if len(box) > max_boxes: box = box[:max_boxes]
  115. box[:, [0, 2]] = box[:, [0, 2]] * scale + dx
  116. box[:, [1, 3]] = box[:, [1, 3]] * scale + dy
  117. box_data[:len(box)] = box
  118. return image_data, box_data
  119. else:
  120. return None, None
  121. # resize image
  122. new_ar = float(w) / float(h) * self.rand(1 - jitter, 1 + jitter) / self.rand(1 - jitter, 1 + jitter)
  123. scale = self.rand(.25, 2)
  124. if new_ar < 1:
  125. nh = int(scale * h)
  126. nw = int(nh * new_ar)
  127. else:
  128. nw = int(scale * w)
  129. nh = int(nw / new_ar)
  130. image = image.resize((nw, nh), Image.BICUBIC)
  131. # place image
  132. dx = int(self.rand(0, w - nw))
  133. dy = int(self.rand(0, h - nh))
  134. new_image = Image.new('RGB', (w, h), (128, 128, 128))
  135. new_image.paste(image, (dx, dy))
  136. image = new_image
  137. # flip image or not
  138. flip = self.rand() < .5
  139. if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
  140. # convert image to gray or not
  141. gray = self.rand() < .25
  142. if gray: image = image.convert('L').convert('RGB')
  143. # distort image
  144. hue = self.rand(-hue, hue)
  145. sat = self.rand(1, sat) if self.rand() < .5 else 1 / self.rand(1, sat)
  146. val = self.rand(1, val) if self.rand() < .5 else 1 / self.rand(1, val)
  147. x = rgb_to_hsv(np.array(image) / 255.)
  148. x[..., 0] += hue
  149. x[..., 0][x[..., 0] > 1] -= 1
  150. x[..., 0][x[..., 0] < 0] += 1
  151. x[..., 1] *= sat
  152. x[..., 2] *= val
  153. x[x > 1] = 1
  154. x[x < 0] = 0
  155. image_data = hsv_to_rgb(x) # numpy array, 0 to 1
  156. # correct boxes
  157. box_data = np.zeros((max_boxes, 5))
  158. if len(box) > 0:
  159. np.random.shuffle(box)
  160. box[:, [0, 2]] = box[:, [0, 2]] * nw / iw + dx
  161. box[:, [1, 3]] = box[:, [1, 3]] * nh / ih + dy
  162. if flip: box[:, [0, 2]] = w - box[:, [2, 0]]
  163. box[:, 0:2][box[:, 0:2] < 0] = 0
  164. box[:, 2][box[:, 2] > w] = w
  165. box[:, 3][box[:, 3] > h] = h
  166. box_w = box[:, 2] - box[:, 0]
  167. box_h = box[:, 3] - box[:, 1]
  168. box = box[np.logical_and(box_w > 1, box_h > 1)] # discard invalid box
  169. if len(box) > max_boxes:
  170. box = box[:max_boxes]
  171. if len(box) == 0:
  172. return None, None
  173. box_data[:len(box)] = box
  174. return image_data, box_data
  175. def preprocess_true_boxes(self, true_boxes, in_shape=416):
  176. """Preprocesses the ground truth box of the training data
  177. :param true_boxes: ground truth box shape is [boxes, 5], x_min, y_min,
  178. x_max, y_max, class_id
  179. """
  180. num_layers = self.anchors.shape[0] // 3
  181. anchor_mask = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  182. true_boxes = np.array(true_boxes, dtype='float32')
  183. # input_shape = np.array([in_shape, in_shape], dtype='int32')
  184. input_shape = np.array(in_shape, dtype='int32')
  185. boxes_xy = (true_boxes[..., 0:2] + true_boxes[..., 2:4]) // 2.
  186. boxes_wh = true_boxes[..., 2:4] - true_boxes[..., 0:2]
  187. true_boxes[..., 0:2] = boxes_xy / input_shape[::-1]
  188. true_boxes[..., 2:4] = boxes_wh / input_shape[::-1]
  189. m = true_boxes.shape[0]
  190. grid_shapes = [input_shape // 32, input_shape // 16, input_shape // 8]
  191. y_true = [np.zeros((m, grid_shapes[l][0], grid_shapes[l][1], len(anchor_mask[l]), 5 + self.num_classes),
  192. dtype='float32') for l in range(num_layers)]
  193. # The dimension is expanded to calculate the IOU between the
  194. # anchors of all boxes in each graph by broadcasting
  195. anchors = np.expand_dims(self.anchors, 0)
  196. anchors_max = anchors / 2.
  197. anchors_min = -anchors_max
  198. # Because we padded the box before, we need to remove all 0 lines
  199. valid_mask = boxes_wh[..., 0] > 0
  200. for b in range(m):
  201. wh = boxes_wh[b, valid_mask[b]]
  202. if len(wh) == 0: continue
  203. # Expanding dimensions for broadcasting applications
  204. wh = np.expand_dims(wh, -2)
  205. # wh shape is [box_num, 1, 2]
  206. boxes_max = wh / 2.
  207. boxes_min = -boxes_max
  208. intersect_min = np.maximum(boxes_min, anchors_min)
  209. intersect_max = np.minimum(boxes_max, anchors_max)
  210. intersect_wh = np.maximum(intersect_max - intersect_min, 0.)
  211. intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
  212. box_area = wh[..., 0] * wh[..., 1]
  213. anchor_area = anchors[..., 0] * anchors[..., 1]
  214. iou = intersect_area / (box_area + anchor_area - intersect_area)
  215. # Find out the largest anchor box with the IOU of the ground truth
  216. # box, and then set the corresponding positions of different
  217. # proportions responsible for the ground turn box as the
  218. # coordinates of the ground truth box
  219. best_anchor = np.argmax(iou, axis=-1)
  220. for t, n in enumerate(best_anchor):
  221. for l in range(num_layers):
  222. if n in anchor_mask[l]:
  223. i = np.floor(true_boxes[b, t, 0] * grid_shapes[l][1]).astype('int32')
  224. j = np.floor(true_boxes[b, t, 1] * grid_shapes[l][0]).astype('int32')
  225. k = anchor_mask[l].index(n)
  226. c = true_boxes[b, t, 4].astype('int32')
  227. y_true[l][b, j, i, k, 0:4] = true_boxes[b, t, 0:4]
  228. y_true[l][b, j, i, k, 4] = 1.
  229. y_true[l][b, j, i, k, 5 + c] = 1.
  230. return y_true[0], y_true[1], y_true[2]