You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 8.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. import torch
  4. def decode_seg_map_sequence(label_masks, dataset='pascal'):
  5. rgb_masks = []
  6. for label_mask in label_masks:
  7. rgb_mask = decode_segmap(label_mask, dataset)
  8. rgb_masks.append(rgb_mask)
  9. rgb_masks = torch.from_numpy(np.array(rgb_masks).transpose([0, 3, 1, 2])) # change for val
  10. return rgb_masks
  11. def decode_segmap(label_mask, dataset, plot=False):
  12. """Decode segmentation class labels into a color image
  13. Args:
  14. label_mask (np.ndarray): an (M,N) array of integer values denoting
  15. the class label at each spatial location.
  16. plot (bool, optional): whether to show the resulting color image
  17. in a figure.
  18. Returns:
  19. (np.ndarray, optional): the resulting decoded color image.
  20. """
  21. if dataset == 'pascal' or dataset == 'coco':
  22. n_classes = 21
  23. label_colours = get_pascal_labels()
  24. elif dataset == 'cityscapes':
  25. n_classes = 24
  26. label_colours = get_cityscapes_labels()
  27. elif dataset == 'target':
  28. n_classes = 24
  29. label_colours = get_cityscapes_labels()
  30. elif dataset == 'cityrand':
  31. n_classes = 19
  32. label_colours = get_cityscapes_labels()
  33. elif dataset == 'citylostfound':
  34. n_classes = 20
  35. label_colours = get_citylostfound_labels()
  36. elif dataset == 'xrlab':
  37. n_classes = 25
  38. label_colours = get_cityscapes_labels()
  39. elif dataset == 'e1':
  40. n_classes = 24
  41. label_colours = get_cityscapes_labels()
  42. elif dataset == 'mapillary':
  43. n_classes = 24
  44. label_colours = get_cityscapes_labels()
  45. else:
  46. raise NotImplementedError
  47. r = label_mask.copy()
  48. g = label_mask.copy()
  49. b = label_mask.copy()
  50. for ll in range(0, n_classes):
  51. r[label_mask == ll] = label_colours[ll, 0]
  52. g[label_mask == ll] = label_colours[ll, 1]
  53. b[label_mask == ll] = label_colours[ll, 2]
  54. rgb = np.zeros((label_mask.shape[0], label_mask.shape[1], 3)) # change for val
  55. # rgb = torch.ByteTensor(3, label_mask.shape[0], label_mask.shape[1]).fill_(0)
  56. rgb[:, :, 0] = r / 255.0
  57. rgb[:, :, 1] = g / 255.0
  58. rgb[:, :, 2] = b / 255.0
  59. # r = torch.from_numpy(r)
  60. # g = torch.from_numpy(g)
  61. # b = torch.from_numpy(b)
  62. rgb[:, :, 0] = r / 255.0
  63. rgb[:, :, 1] = g / 255.0
  64. rgb[:, :, 2] = b / 255.0
  65. if plot:
  66. plt.imshow(rgb)
  67. plt.show()
  68. else:
  69. return rgb
  70. def encode_segmap(mask):
  71. """Encode segmentation label images as pascal classes
  72. Args:
  73. mask (np.ndarray): raw segmentation label image of dimension
  74. (M, N, 3), in which the Pascal classes are encoded as colours.
  75. Returns:
  76. (np.ndarray): class map with dimensions (M,N), where the value at
  77. a given location is the integer denoting the class index.
  78. """
  79. mask = mask.astype(int)
  80. label_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=np.int16)
  81. for ii, label in enumerate(get_pascal_labels()):
  82. label_mask[np.where(np.all(mask == label, axis=-1))[:2]] = ii
  83. label_mask = label_mask.astype(int)
  84. return label_mask
  85. def get_cityscapes_labels():
  86. return np.array([
  87. [128, 64, 128],
  88. [244, 35, 232],
  89. [70, 70, 70],
  90. [102, 102, 156],
  91. [190, 153, 153],
  92. [153, 153, 153],
  93. [250, 170, 30],
  94. [220, 220, 0],
  95. [107, 142, 35],
  96. [152, 251, 152],
  97. [0, 130, 180],
  98. [220, 20, 60],
  99. [255, 0, 0],
  100. [0, 0, 142],
  101. [0, 0, 70],
  102. [0, 60, 100],
  103. [0, 80, 100],
  104. [0, 0, 230],
  105. [119, 11, 32],
  106. [119, 11, 119],
  107. [128, 64, 64],
  108. [102, 10, 156],
  109. [102, 102, 15],
  110. [10, 102, 156],
  111. [10, 102, 156],
  112. [10, 102, 156],
  113. [10, 102, 156]])
  114. def get_citylostfound_labels():
  115. return np.array([
  116. [128, 64, 128],
  117. [244, 35, 232],
  118. [70, 70, 70],
  119. [102, 102, 156],
  120. [190, 153, 153],
  121. [153, 153, 153],
  122. [250, 170, 30],
  123. [220, 220, 0],
  124. [107, 142, 35],
  125. [152, 251, 152],
  126. [0, 130, 180],
  127. [220, 20, 60],
  128. [255, 0, 0],
  129. [0, 0, 142],
  130. [0, 0, 70],
  131. [0, 60, 100],
  132. [0, 80, 100],
  133. [0, 0, 230],
  134. [119, 11, 32],
  135. [111, 74, 0]])
  136. def get_pascal_labels():
  137. """Load the mapping that associates pascal classes with label colors
  138. Returns:
  139. np.ndarray with dimensions (21, 3)
  140. """
  141. return np.asarray([[0, 0, 0], [128, 0, 0], [0, 128, 0], [128, 128, 0],
  142. [0, 0, 128], [128, 0, 128], [0, 128, 128], [128, 128, 128],
  143. [64, 0, 0], [192, 0, 0], [64, 128, 0], [192, 128, 0],
  144. [64, 0, 128], [192, 0, 128], [64, 128, 128], [192, 128, 128],
  145. [0, 64, 0], [128, 64, 0], [0, 192, 0], [128, 192, 0],
  146. [0, 64, 128]])
  147. def colormap_bdd(n):
  148. cmap=np.zeros([n, 3]).astype(np.uint8)
  149. cmap[0,:] = np.array([128, 64, 128])
  150. cmap[1,:] = np.array([244, 35, 232])
  151. cmap[2,:] = np.array([ 70, 70, 70])
  152. cmap[3,:] = np.array([102, 102, 156])
  153. cmap[4,:] = np.array([190, 153, 153])
  154. cmap[5,:] = np.array([153, 153, 153])
  155. cmap[6,:] = np.array([250, 170, 30])
  156. cmap[7,:] = np.array([220, 220, 0])
  157. cmap[8,:] = np.array([107, 142, 35])
  158. cmap[9,:] = np.array([152, 251, 152])
  159. cmap[10,:]= np.array([70, 130, 180])
  160. cmap[11,:]= np.array([220, 20, 60])
  161. cmap[12,:]= np.array([255, 0, 0])
  162. cmap[13,:]= np.array([0, 0, 142])
  163. cmap[14,:]= np.array([0, 0, 70])
  164. cmap[15,:]= np.array([0, 60, 100])
  165. cmap[16,:]= np.array([0, 80, 100])
  166. cmap[17,:]= np.array([0, 0, 230])
  167. cmap[18,:]= np.array([119, 11, 32])
  168. cmap[19,:]= np.array([111, 74, 0]) #多加了一类small obstacle
  169. # supplemented class rgb
  170. cmap[20, :] = np.array([0, 10, 128])
  171. cmap[21, :] = np.array([0, 100, 232])
  172. cmap[22, :] = np.array([250, 70, 70])
  173. cmap[23, :] = np.array([0, 102, 156])
  174. cmap[24, :] = np.array([190, 153, 0])
  175. cmap[25, :] = np.array([153, 153, 255])
  176. cmap[26, :] = np.array([250, 170, 0])
  177. cmap[27, :] = np.array([150, 220, 0])
  178. cmap[28, :] = np.array([107, 0, 35])
  179. cmap[29, :] = np.array([0, 251, 152])
  180. cmap[30, :] = np.array([70, 255, 180])
  181. return cmap
  182. def colormap_bdd0(n):
  183. cmap=np.zeros([n, 3]).astype(np.uint8)
  184. cmap[0,:] = np.array([0, 0, 0])
  185. cmap[1,:] = np.array([70, 130, 180])
  186. cmap[2,:] = np.array([70, 70, 70])
  187. cmap[3,:] = np.array([128, 64, 128])
  188. cmap[4,:] = np.array([244, 35, 232])
  189. cmap[5,:] = np.array([64, 64, 128])
  190. cmap[6,:] = np.array([107, 142, 35])
  191. cmap[7,:] = np.array([153, 153, 153])
  192. cmap[8,:] = np.array([0, 0, 142])
  193. cmap[9,:] = np.array([220, 220, 0])
  194. cmap[10,:]= np.array([220, 20, 60])
  195. cmap[11,:]= np.array([119, 11, 32])
  196. cmap[12,:]= np.array([0, 0, 230])
  197. cmap[13,:]= np.array([250, 170, 160])
  198. cmap[14,:]= np.array([128, 64, 64])
  199. cmap[15,:]= np.array([250, 170, 30])
  200. cmap[16,:]= np.array([152, 251, 152])
  201. cmap[17,:]= np.array([255, 0, 0])
  202. cmap[18,:]= np.array([0, 0, 70])
  203. cmap[19,:]= np.array([0, 60, 100]) #small obstacle
  204. cmap[20,:]= np.array([0, 80, 100])
  205. cmap[21,:]= np.array([102, 102, 156])
  206. cmap[22,:]= np.array([102, 102, 156])
  207. return cmap
  208. class Colorize:
  209. def __init__(self, n=31): # n = nClasses
  210. # self.cmap = colormap(256)
  211. self.cmap = colormap_bdd(256)
  212. self.cmap[n] = self.cmap[-1]
  213. self.cmap = torch.from_numpy(self.cmap[:n])
  214. def __call__(self, gray_image):
  215. size = gray_image.size()
  216. # print(size)
  217. color_images = torch.ByteTensor(size[0], 3, size[1], size[2]).fill_(0)
  218. # color_image = torch.ByteTensor(3, size[0], size[1]).fill_(0)
  219. # for label in range(1, len(self.cmap)):
  220. for i in range(color_images.shape[0]):
  221. for label in range(0, len(self.cmap)):
  222. mask = gray_image[0] == label
  223. # mask = gray_image == label
  224. color_images[i][0][mask] = self.cmap[label][0]
  225. color_images[i][1][mask] = self.cmap[label][1]
  226. color_images[i][2][mask] = self.cmap[label][2]
  227. return color_images