You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

visualize.py 25 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import imageio
  5. import numpy as np
  6. import tensorlayer as tl
  7. from tensorlayer.lazy_imports import LazyImport
  8. import colorsys, random
  9. cv2 = LazyImport("cv2")
  10. # Uncomment the following line if you got: _tkinter.TclError: no display name and no $DISPLAY environment variable
  11. # import matplotlib
  12. # matplotlib.use('Agg')
  13. __all__ = [
  14. 'read_image', 'read_images', 'save_image', 'save_images', 'draw_boxes_and_labels_to_image',
  15. 'draw_mpii_people_to_image', 'frame', 'CNN2d', 'images2d', 'tsne_embedding', 'draw_weights', 'W',
  16. 'draw_boxes_and_labels_to_image_with_json'
  17. ]
  18. def read_image(image, path=''):
  19. """Read one image.
  20. Parameters
  21. -----------
  22. image : str
  23. The image file name.
  24. path : str
  25. The image folder path.
  26. Returns
  27. -------
  28. numpy.array
  29. The image.
  30. """
  31. return imageio.imread(os.path.join(path, image))
  32. def read_images(img_list, path='', n_threads=10, printable=True):
  33. """Returns all images in list by given path and name of each image file.
  34. Parameters
  35. -------------
  36. img_list : list of str
  37. The image file names.
  38. path : str
  39. The image folder path.
  40. n_threads : int
  41. The number of threads to read image.
  42. printable : boolean
  43. Whether to print information when reading images.
  44. Returns
  45. -------
  46. list of numpy.array
  47. The images.
  48. """
  49. imgs = []
  50. for idx in range(0, len(img_list), n_threads):
  51. b_imgs_list = img_list[idx:idx + n_threads]
  52. b_imgs = tl.prepro.threading_data(b_imgs_list, fn=read_image, path=path)
  53. # tl.logging.info(b_imgs.shape)
  54. imgs.extend(b_imgs)
  55. if printable:
  56. tl.logging.info('read %d from %s' % (len(imgs), path))
  57. return imgs
  58. def save_image(image, image_path='_temp.png'):
  59. """Save a image.
  60. Parameters
  61. -----------
  62. image : numpy array
  63. [w, h, c]
  64. image_path : str
  65. path
  66. """
  67. try: # RGB
  68. imageio.imwrite(image_path, image)
  69. except Exception: # Greyscale
  70. imageio.imwrite(image_path, image[:, :, 0])
  71. def save_images(images, size, image_path='_temp.png'):
  72. """Save multiple images into one single image.
  73. Parameters
  74. -----------
  75. images : numpy array
  76. (batch, w, h, c)
  77. size : list of 2 ints
  78. row and column number.
  79. number of images should be equal or less than size[0] * size[1]
  80. image_path : str
  81. save path
  82. Examples
  83. ---------
  84. >>> import numpy as np
  85. >>> import tensorlayer as tl
  86. >>> images = np.random.rand(64, 100, 100, 3)
  87. >>> tl.visualize.save_images(images, [8, 8], 'temp.png')
  88. """
  89. if len(images.shape) == 3: # Greyscale [batch, h, w] --> [batch, h, w, 1]
  90. images = images[:, :, :, np.newaxis]
  91. def merge(images, size):
  92. h, w = images.shape[1], images.shape[2]
  93. img = np.zeros((h * size[0], w * size[1], 3), dtype=images.dtype)
  94. for idx, image in enumerate(images):
  95. i = idx % size[1]
  96. j = idx // size[1]
  97. img[j * h:j * h + h, i * w:i * w + w, :] = image
  98. return img
  99. def imsave(images, size, path):
  100. if np.max(images) <= 1 and (-1 <= np.min(images) < 0):
  101. images = ((images + 1) * 127.5).astype(np.uint8)
  102. elif np.max(images) <= 1 and np.min(images) >= 0:
  103. images = (images * 255).astype(np.uint8)
  104. return imageio.imwrite(path, merge(images, size))
  105. if len(images) > size[0] * size[1]:
  106. raise AssertionError("number of images should be equal or less than size[0] * size[1] {}".format(len(images)))
  107. return imsave(images, size, image_path)
  108. def draw_boxes_and_labels_to_image(
  109. image, classes, coords, scores, classes_list, is_center=True, is_rescale=True, save_name=None
  110. ):
  111. """Draw bboxes and class labels on image. Return or save the image with bboxes, example in the docs of ``tl.prepro``.
  112. Parameters
  113. -----------
  114. image : numpy.array
  115. The RGB image [height, width, channel].
  116. classes : list of int
  117. A list of class ID (int).
  118. coords : list of int
  119. A list of list for coordinates.
  120. - Should be [x, y, x2, y2] (up-left and botton-right format)
  121. - If [x_center, y_center, w, h] (set is_center to True).
  122. scores : list of float
  123. A list of score (float). (Optional)
  124. classes_list : list of str
  125. for converting ID to string on image.
  126. is_center : boolean
  127. Whether the coordinates is [x_center, y_center, w, h]
  128. - If coordinates are [x_center, y_center, w, h], set it to True for converting it to [x, y, x2, y2] (up-left and botton-right) internally.
  129. - If coordinates are [x1, x2, y1, y2], set it to False.
  130. is_rescale : boolean
  131. Whether to rescale the coordinates from pixel-unit format to ratio format.
  132. - If True, the input coordinates are the portion of width and high, this API will scale the coordinates to pixel unit internally.
  133. - If False, feed the coordinates with pixel unit format.
  134. save_name : None or str
  135. The name of image file (i.e. image.png), if None, not to save image.
  136. Returns
  137. -------
  138. numpy.array
  139. The saved image.
  140. References
  141. -----------
  142. - OpenCV rectangle and putText.
  143. - `scikit-image <http://scikit-image.org/docs/dev/api/skimage.draw.html#skimage.draw.rectangle>`__.
  144. """
  145. if len(coords) != len(classes):
  146. raise AssertionError("number of coordinates and classes are equal")
  147. if len(scores) > 0 and len(scores) != len(classes):
  148. raise AssertionError("number of scores and classes are equal")
  149. # don't change the original image, and avoid error https://stackoverflow.com/questions/30249053/python-opencv-drawing-errors-after-manipulating-array-with-numpy
  150. image = image.copy()
  151. imh, imw = image.shape[0:2]
  152. thick = int((imh + imw) // 430)
  153. for i, _v in enumerate(coords):
  154. if is_center:
  155. x, y, x2, y2 = tl.prepro.obj_box_coord_centroid_to_upleft_butright(coords[i])
  156. else:
  157. x, y, x2, y2 = coords[i]
  158. if is_rescale: # scale back to pixel unit if the coords are the portion of width and high
  159. x, y, x2, y2 = tl.prepro.obj_box_coord_scale_to_pixelunit([x, y, x2, y2], (imh, imw))
  160. cv2.rectangle(
  161. image,
  162. (int(x), int(y)),
  163. (int(x2), int(y2)), # up-left and botton-right
  164. [0, 255, 0],
  165. thick
  166. )
  167. cv2.putText(
  168. image,
  169. classes_list[classes[i]] + ((" %.2f" % (scores[i])) if (len(scores) != 0) else " "),
  170. (int(x), int(y)), # button left
  171. 0,
  172. 1.5e-3 * imh, # bigger = larger font
  173. [0, 0, 256], # self.meta['colors'][max_indx],
  174. int(thick / 2) + 1
  175. ) # bold
  176. if save_name is not None:
  177. # cv2.imwrite('_my.png', image)
  178. save_image(image, save_name)
  179. # if len(coords) == 0:
  180. # tl.logging.info("draw_boxes_and_labels_to_image: no bboxes exist, cannot draw !")
  181. return image
  182. def draw_mpii_pose_to_image(image, poses, save_name='image.png'):
  183. """Draw people(s) into image using MPII dataset format as input, return or save the result image.
  184. This is an experimental API, can be changed in the future.
  185. Parameters
  186. -----------
  187. image : numpy.array
  188. The RGB image [height, width, channel].
  189. poses : list of dict
  190. The people(s) annotation in MPII format, see ``tl.files.load_mpii_pose_dataset``.
  191. save_name : None or str
  192. The name of image file (i.e. image.png), if None, not to save image.
  193. Returns
  194. --------
  195. numpy.array
  196. The saved image.
  197. Examples
  198. --------
  199. >>> import pprint
  200. >>> import tensorlayer as tl
  201. >>> img_train_list, ann_train_list, img_test_list, ann_test_list = tl.files.load_mpii_pose_dataset()
  202. >>> image = tl.vis.read_image(img_train_list[0])
  203. >>> tl.vis.draw_mpii_pose_to_image(image, ann_train_list[0], 'image.png')
  204. >>> pprint.pprint(ann_train_list[0])
  205. References
  206. -----------
  207. - `MPII Keyponts and ID <http://human-pose.mpi-inf.mpg.de/#download>`__
  208. """
  209. # import skimage
  210. # don't change the original image, and avoid error https://stackoverflow.com/questions/30249053/python-opencv-drawing-errors-after-manipulating-array-with-numpy
  211. image = image.copy()
  212. imh, imw = image.shape[0:2]
  213. thick = int((imh + imw) // 430)
  214. # radius = int(image.shape[1] / 500) + 1
  215. radius = int(thick * 1.5)
  216. if image.max() < 1:
  217. image = image * 255
  218. for people in poses:
  219. # Pose Keyponts
  220. joint_pos = people['joint_pos']
  221. # draw sketch
  222. # joint id (0 - r ankle, 1 - r knee, 2 - r hip, 3 - l hip, 4 - l knee,
  223. # 5 - l ankle, 6 - pelvis, 7 - thorax, 8 - upper neck,
  224. # 9 - head top, 10 - r wrist, 11 - r elbow, 12 - r shoulder,
  225. # 13 - l shoulder, 14 - l elbow, 15 - l wrist)
  226. #
  227. # 9
  228. # 8
  229. # 12 ** 7 ** 13
  230. # * * *
  231. # 11 * 14
  232. # * * *
  233. # 10 2 * 6 * 3 15
  234. # * *
  235. # 1 4
  236. # * *
  237. # 0 5
  238. lines = [
  239. [(0, 1), [100, 255, 100]],
  240. [(1, 2), [50, 255, 50]],
  241. [(2, 6), [0, 255, 0]], # right leg
  242. [(3, 4), [100, 100, 255]],
  243. [(4, 5), [50, 50, 255]],
  244. [(6, 3), [0, 0, 255]], # left leg
  245. [(6, 7), [255, 255, 100]],
  246. [(7, 8), [255, 150, 50]], # body
  247. [(8, 9), [255, 200, 100]], # head
  248. [(10, 11), [255, 100, 255]],
  249. [(11, 12), [255, 50, 255]],
  250. [(12, 8), [255, 0, 255]], # right hand
  251. [(8, 13), [0, 255, 255]],
  252. [(13, 14), [100, 255, 255]],
  253. [(14, 15), [200, 255, 255]] # left hand
  254. ]
  255. for line in lines:
  256. start, end = line[0]
  257. if (start in joint_pos) and (end in joint_pos):
  258. cv2.line(
  259. image,
  260. (int(joint_pos[start][0]), int(joint_pos[start][1])),
  261. (int(joint_pos[end][0]), int(joint_pos[end][1])), # up-left and botton-right
  262. line[1],
  263. thick
  264. )
  265. # rr, cc, val = skimage.draw.line_aa(int(joint_pos[start][1]), int(joint_pos[start][0]), int(joint_pos[end][1]), int(joint_pos[end][0]))
  266. # image[rr, cc] = line[1]
  267. # draw circles
  268. for pos in joint_pos.items():
  269. _, pos_loc = pos # pos_id, pos_loc
  270. pos_loc = (int(pos_loc[0]), int(pos_loc[1]))
  271. cv2.circle(image, center=pos_loc, radius=radius, color=(200, 200, 200), thickness=-1)
  272. # rr, cc = skimage.draw.circle(int(pos_loc[1]), int(pos_loc[0]), radius)
  273. # image[rr, cc] = [0, 255, 0]
  274. # Head
  275. head_rect = people['head_rect']
  276. if head_rect: # if head exists
  277. cv2.rectangle(
  278. image,
  279. (int(head_rect[0]), int(head_rect[1])),
  280. (int(head_rect[2]), int(head_rect[3])), # up-left and botton-right
  281. [0, 180, 0],
  282. thick
  283. )
  284. if save_name is not None:
  285. # cv2.imwrite(save_name, image)
  286. save_image(image, save_name)
  287. return image
  288. draw_mpii_people_to_image = draw_mpii_pose_to_image
  289. def frame(I=None, second=5, saveable=True, name='frame', cmap=None, fig_idx=12836):
  290. """Display a frame. Make sure OpenAI Gym render() is disable before using it.
  291. Parameters
  292. ----------
  293. I : numpy.array
  294. The image.
  295. second : int
  296. The display second(s) for the image(s), if saveable is False.
  297. saveable : boolean
  298. Save or plot the figure.
  299. name : str
  300. A name to save the image, if saveable is True.
  301. cmap : None or str
  302. 'gray' for greyscale, None for default, etc.
  303. fig_idx : int
  304. matplotlib figure index.
  305. Examples
  306. --------
  307. >>> env = gym.make("Pong-v0")
  308. >>> observation = env.reset()
  309. >>> tl.visualize.frame(observation)
  310. """
  311. import matplotlib.pyplot as plt
  312. if saveable is False:
  313. plt.ion()
  314. plt.figure(fig_idx) # show all feature images
  315. if len(I.shape) and I.shape[-1] == 1: # (10,10,1) --> (10,10)
  316. I = I[:, :, 0]
  317. plt.imshow(I, cmap)
  318. plt.title(name)
  319. # plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  320. # plt.gca().yaxis.set_major_locator(plt.NullLocator())
  321. if saveable:
  322. plt.savefig(name + '.pdf', format='pdf')
  323. else:
  324. plt.draw()
  325. plt.pause(second)
  326. def CNN2d(CNN=None, second=10, saveable=True, name='cnn', fig_idx=3119362):
  327. """Display a group of RGB or Greyscale CNN masks.
  328. Parameters
  329. ----------
  330. CNN : numpy.array
  331. The image. e.g: 64 5x5 RGB images can be (5, 5, 3, 64).
  332. second : int
  333. The display second(s) for the image(s), if saveable is False.
  334. saveable : boolean
  335. Save or plot the figure.
  336. name : str
  337. A name to save the image, if saveable is True.
  338. fig_idx : int
  339. The matplotlib figure index.
  340. Examples
  341. --------
  342. >>> tl.visualize.CNN2d(network.all_params[0].eval(), second=10, saveable=True, name='cnn1_mnist', fig_idx=2012)
  343. """
  344. import matplotlib.pyplot as plt
  345. # tl.logging.info(CNN.shape) # (5, 5, 3, 64)
  346. # exit()
  347. n_mask = CNN.shape[3]
  348. n_row = CNN.shape[0]
  349. n_col = CNN.shape[1]
  350. n_color = CNN.shape[2]
  351. row = int(np.sqrt(n_mask))
  352. col = int(np.ceil(n_mask / row))
  353. plt.ion() # active mode
  354. fig = plt.figure(fig_idx)
  355. count = 1
  356. for _ir in range(1, row + 1):
  357. for _ic in range(1, col + 1):
  358. if count > n_mask:
  359. break
  360. fig.add_subplot(col, row, count)
  361. # tl.logging.info(CNN[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5
  362. # exit()
  363. # plt.imshow(
  364. # np.reshape(CNN[count-1,:,:,:], (n_row, n_col)),
  365. # cmap='gray', interpolation="nearest") # theano
  366. if n_color == 1:
  367. plt.imshow(np.reshape(CNN[:, :, :, count - 1], (n_row, n_col)), cmap='gray', interpolation="nearest")
  368. elif n_color == 3:
  369. plt.imshow(
  370. np.reshape(CNN[:, :, :, count - 1], (n_row, n_col, n_color)), cmap='gray', interpolation="nearest"
  371. )
  372. else:
  373. raise Exception("Unknown n_color")
  374. plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  375. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  376. count = count + 1
  377. if saveable:
  378. plt.savefig(name + '.pdf', format='pdf')
  379. else:
  380. plt.draw()
  381. plt.pause(second)
  382. def images2d(images=None, second=10, saveable=True, name='images', dtype=None, fig_idx=3119362):
  383. """Display a group of RGB or Greyscale images.
  384. Parameters
  385. ----------
  386. images : numpy.array
  387. The images.
  388. second : int
  389. The display second(s) for the image(s), if saveable is False.
  390. saveable : boolean
  391. Save or plot the figure.
  392. name : str
  393. A name to save the image, if saveable is True.
  394. dtype : None or numpy data type
  395. The data type for displaying the images.
  396. fig_idx : int
  397. matplotlib figure index.
  398. Examples
  399. --------
  400. >>> X_train, y_train, X_test, y_test = tl.files.load_cifar10_dataset(shape=(-1, 32, 32, 3), plotable=False)
  401. >>> tl.visualize.images2d(X_train[0:100,:,:,:], second=10, saveable=False, name='cifar10', dtype=np.uint8, fig_idx=20212)
  402. """
  403. import matplotlib.pyplot as plt
  404. # tl.logging.info(images.shape) # (50000, 32, 32, 3)
  405. # exit()
  406. if dtype:
  407. images = np.asarray(images, dtype=dtype)
  408. n_mask = images.shape[0]
  409. n_row = images.shape[1]
  410. n_col = images.shape[2]
  411. n_color = images.shape[3]
  412. row = int(np.sqrt(n_mask))
  413. col = int(np.ceil(n_mask / row))
  414. plt.ion() # active mode
  415. fig = plt.figure(fig_idx)
  416. count = 1
  417. for _ir in range(1, row + 1):
  418. for _ic in range(1, col + 1):
  419. if count > n_mask:
  420. break
  421. fig.add_subplot(col, row, count)
  422. # tl.logging.info(images[:,:,:,count-1].shape, n_row, n_col) # (5, 1, 32) 5 5
  423. # plt.imshow(
  424. # np.reshape(images[count-1,:,:,:], (n_row, n_col)),
  425. # cmap='gray', interpolation="nearest") # theano
  426. if n_color == 1:
  427. plt.imshow(np.reshape(images[count - 1, :, :], (n_row, n_col)), cmap='gray', interpolation="nearest")
  428. # plt.title(name)
  429. elif n_color == 3:
  430. plt.imshow(images[count - 1, :, :], cmap='gray', interpolation="nearest")
  431. # plt.title(name)
  432. else:
  433. raise Exception("Unknown n_color")
  434. plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  435. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  436. count = count + 1
  437. if saveable:
  438. plt.savefig(name + '.pdf', format='pdf')
  439. else:
  440. plt.draw()
  441. plt.pause(second)
  442. def tsne_embedding(embeddings, reverse_dictionary, plot_only=500, second=5, saveable=False, name='tsne', fig_idx=9862):
  443. """Visualize the embeddings by using t-SNE.
  444. Parameters
  445. ----------
  446. embeddings : numpy.array
  447. The embedding matrix.
  448. reverse_dictionary : dictionary
  449. id_to_word, mapping id to unique word.
  450. plot_only : int
  451. The number of examples to plot, choice the most common words.
  452. second : int
  453. The display second(s) for the image(s), if saveable is False.
  454. saveable : boolean
  455. Save or plot the figure.
  456. name : str
  457. A name to save the image, if saveable is True.
  458. fig_idx : int
  459. matplotlib figure index.
  460. Examples
  461. --------
  462. >>> see 'tutorial_word2vec_basic.py'
  463. >>> final_embeddings = normalized_embeddings.eval()
  464. >>> tl.visualize.tsne_embedding(final_embeddings, labels, reverse_dictionary,
  465. ... plot_only=500, second=5, saveable=False, name='tsne')
  466. """
  467. import matplotlib.pyplot as plt
  468. def plot_with_labels(low_dim_embs, labels, figsize=(18, 18), second=5, saveable=True, name='tsne', fig_idx=9862):
  469. if low_dim_embs.shape[0] < len(labels):
  470. raise AssertionError("More labels than embeddings")
  471. if saveable is False:
  472. plt.ion()
  473. plt.figure(fig_idx)
  474. plt.figure(figsize=figsize) # in inches
  475. for i, label in enumerate(labels):
  476. x, y = low_dim_embs[i, :]
  477. plt.scatter(x, y)
  478. plt.annotate(label, xy=(x, y), xytext=(5, 2), textcoords='offset points', ha='right', va='bottom')
  479. if saveable:
  480. plt.savefig(name + '.pdf', format='pdf')
  481. else:
  482. plt.draw()
  483. plt.pause(second)
  484. try:
  485. from sklearn.manifold import TSNE
  486. from six.moves import xrange
  487. tsne = TSNE(perplexity=30, n_components=2, init='pca', n_iter=5000)
  488. # plot_only = 500
  489. low_dim_embs = tsne.fit_transform(embeddings[:plot_only, :])
  490. labels = [reverse_dictionary[i] for i in xrange(plot_only)]
  491. plot_with_labels(low_dim_embs, labels, second=second, saveable=saveable, name=name, fig_idx=fig_idx)
  492. except ImportError:
  493. _err = "Please install sklearn and matplotlib to visualize embeddings."
  494. tl.logging.error(_err)
  495. raise ImportError(_err)
  496. def draw_weights(W=None, second=10, saveable=True, shape=None, name='mnist', fig_idx=2396512):
  497. """Visualize every columns of the weight matrix to a group of Greyscale img.
  498. Parameters
  499. ----------
  500. W : numpy.array
  501. The weight matrix
  502. second : int
  503. The display second(s) for the image(s), if saveable is False.
  504. saveable : boolean
  505. Save or plot the figure.
  506. shape : a list with 2 int or None
  507. The shape of feature image, MNIST is [28, 80].
  508. name : a string
  509. A name to save the image, if saveable is True.
  510. fig_idx : int
  511. matplotlib figure index.
  512. Examples
  513. --------
  514. >>> tl.visualize.draw_weights(network.all_params[0].eval(), second=10, saveable=True, name='weight_of_1st_layer', fig_idx=2012)
  515. """
  516. if shape is None:
  517. shape = [28, 28]
  518. import matplotlib.pyplot as plt
  519. if saveable is False:
  520. plt.ion()
  521. fig = plt.figure(fig_idx) # show all feature images
  522. n_units = W.shape[1]
  523. num_r = int(np.sqrt(n_units)) # 每行显示的个数 若25个hidden unit -> 每行显示5个
  524. num_c = int(np.ceil(n_units / num_r))
  525. count = int(1)
  526. for _row in range(1, num_r + 1):
  527. for _col in range(1, num_c + 1):
  528. if count > n_units:
  529. break
  530. fig.add_subplot(num_r, num_c, count)
  531. # ------------------------------------------------------------
  532. # plt.imshow(np.reshape(W[:,count-1],(28,28)), cmap='gray')
  533. # ------------------------------------------------------------
  534. feature = W[:, count - 1] / np.sqrt((W[:, count - 1]**2).sum())
  535. # feature[feature<0.0001] = 0 # value threshold
  536. # if count == 1 or count == 2:
  537. # print(np.mean(feature))
  538. # if np.std(feature) < 0.03: # condition threshold
  539. # feature = np.zeros_like(feature)
  540. # if np.mean(feature) < -0.015: # condition threshold
  541. # feature = np.zeros_like(feature)
  542. plt.imshow(
  543. np.reshape(feature, (shape[0], shape[1])), cmap='gray', interpolation="nearest"
  544. ) # , vmin=np.min(feature), vmax=np.max(feature))
  545. # plt.title(name)
  546. # ------------------------------------------------------------
  547. # plt.imshow(np.reshape(W[:,count-1] ,(np.sqrt(size),np.sqrt(size))), cmap='gray', interpolation="nearest")
  548. plt.gca().xaxis.set_major_locator(plt.NullLocator()) # distable tick
  549. plt.gca().yaxis.set_major_locator(plt.NullLocator())
  550. count = count + 1
  551. if saveable:
  552. plt.savefig(name + '.pdf', format='pdf')
  553. else:
  554. plt.draw()
  555. plt.pause(second)
  556. W = draw_weights
  557. def draw_boxes_and_labels_to_image_with_json(image, json_result, class_list, save_name=None):
  558. """Draw bboxes and class labels on image. Return the image with bboxes.
  559. Parameters
  560. -----------
  561. image : numpy.array
  562. The RGB image [height, width, channel].
  563. json_result : list of dict
  564. The object detection result with json format.
  565. classes_list : list of str
  566. For converting ID to string on image.
  567. save_name : None or str
  568. The name of image file (i.e. image.png), if None, not to save image.
  569. Returns
  570. -------
  571. numpy.array
  572. The saved image.
  573. References
  574. -----------
  575. - OpenCV rectangle and putText.
  576. - `scikit-image <http://scikit-image.org/docs/dev/api/skimage.draw.html#skimage.draw.rectangle>`__.
  577. """
  578. image_h, image_w, _ = image.shape
  579. num_classes = len(class_list)
  580. hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)]
  581. colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
  582. colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
  583. random.seed(0)
  584. random.shuffle(colors)
  585. random.seed(None)
  586. bbox_thick = int(0.6 * (image_h + image_w) / 600)
  587. fontScale = 0.5
  588. for bbox_info in json_result:
  589. image_name = bbox_info['image']
  590. category_id = bbox_info['category_id']
  591. if category_id < 0 or category_id > num_classes: continue
  592. bbox = bbox_info['bbox'] # the order of coordinates is [x1, y2, x2, y2]
  593. score = bbox_info['score']
  594. bbox_color = colors[category_id]
  595. c1, c2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
  596. cv2.rectangle(image, c1, c2, bbox_color, bbox_thick)
  597. bbox_mess = '%s: %.2f' % (class_list[category_id], score)
  598. t_size = cv2.getTextSize(bbox_mess, 0, fontScale, thickness=bbox_thick // 2)[0]
  599. c3 = (c1[0] + t_size[0], c1[1] - t_size[1] - 3)
  600. cv2.rectangle(image, c1, (np.float32(c3[0]), np.float32(c3[1])), bbox_color, -1)
  601. cv2.putText(
  602. image, bbox_mess, (c1[0], np.float32(c1[1] - 2)), cv2.FONT_HERSHEY_SIMPLEX, fontScale, (0, 0, 0),
  603. bbox_thick // 2, lineType=cv2.LINE_AA
  604. )
  605. if save_name is not None:
  606. save_image(image, save_name)
  607. return image

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.