You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

custom_transforms.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. import torch
  2. import random
  3. import numpy as np
  4. from PIL import Image, ImageOps, ImageFilter
  5. class Normalize(object):
  6. """Normalize a tensor image with mean and standard deviation.
  7. Args:
  8. mean (tuple): means for each channel.
  9. std (tuple): standard deviations for each channel.
  10. """
  11. def __init__(self, mean=(0., 0., 0.), std=(1., 1., 1.)):
  12. self.mean = mean
  13. self.std = std
  14. def __call__(self, sample):
  15. img = sample['image']
  16. depth = sample['depth']
  17. mask = sample['label']
  18. img = np.array(img).astype(np.float32)
  19. depth = np.array(depth).astype(np.float32)
  20. mask = np.array(mask).astype(np.float32)
  21. img /= 255.0
  22. img -= self.mean
  23. img /= self.std
  24. # mean and std for original depth images
  25. mean_depth = 0.12176
  26. std_depth = 0.09752
  27. depth /= 255.0
  28. depth -= mean_depth
  29. depth /= std_depth
  30. return {'image': img,
  31. 'depth': depth,
  32. 'label': mask}
  33. class ToTensor(object):
  34. """Convert Image object in sample to Tensors."""
  35. def __call__(self, sample):
  36. # swap color axis because
  37. # numpy image: H x W x C
  38. # torch image: C X H X W
  39. img = sample['image']
  40. depth = sample['depth']
  41. mask = sample['label']
  42. img = np.array(img).astype(np.float32).transpose((2, 0, 1))
  43. depth = np.array(depth).astype(np.float32)
  44. mask = np.array(mask).astype(np.float32)
  45. img = torch.from_numpy(img).float()
  46. depth = torch.from_numpy(depth).float()
  47. mask = torch.from_numpy(mask).float()
  48. return {'image': img,
  49. 'depth': depth,
  50. 'label': mask}
  51. class CropBlackArea(object):
  52. """
  53. crop black area for depth image
  54. """
  55. def __call__(self, sample):
  56. img = sample['image']
  57. depth = sample['depth']
  58. mask = sample['label']
  59. width, height = img.size
  60. left = 140
  61. top = 30
  62. right = 2030
  63. bottom = 900
  64. # crop
  65. img = img.crop((left, top, right, bottom))
  66. depth = depth.crop((left, top, right, bottom))
  67. mask = mask.crop((left, top, right, bottom))
  68. # resize
  69. img = img.resize((width, height), Image.BILINEAR)
  70. depth = depth.resize((width, height), Image.BILINEAR)
  71. mask = mask.resize((width, height), Image.NEAREST)
  72. return {'image': img,
  73. 'depth': depth,
  74. 'label': mask}
  75. class RandomHorizontalFlip(object):
  76. def __call__(self, sample):
  77. img = sample['image']
  78. depth = sample['depth']
  79. mask = sample['label']
  80. if random.random() < 0.5:
  81. img = img.transpose(Image.FLIP_LEFT_RIGHT)
  82. depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
  83. mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
  84. return {'image': img,
  85. 'depth': depth,
  86. 'label': mask}
  87. class RandomRotate(object):
  88. def __init__(self, degree):
  89. self.degree = degree
  90. def __call__(self, sample):
  91. img = sample['image']
  92. depth = sample['depth']
  93. mask = sample['label']
  94. rotate_degree = random.uniform(-1*self.degree, self.degree)
  95. img = img.rotate(rotate_degree, Image.BILINEAR)
  96. depth = depth.rotate(rotate_degree, Image.BILINEAR)
  97. mask = mask.rotate(rotate_degree, Image.NEAREST)
  98. return {'image': img,
  99. 'depth': depth,
  100. 'label': mask}
  101. class RandomGaussianBlur(object):
  102. def __call__(self, sample):
  103. img = sample['image']
  104. depth = sample['depth']
  105. mask = sample['label']
  106. if random.random() < 0.5:
  107. img = img.filter(ImageFilter.GaussianBlur(
  108. radius=random.random()))
  109. return {'image': img,
  110. 'depth': depth,
  111. 'label': mask}
  112. class RandomScaleCrop(object):
  113. def __init__(self, base_size, crop_size, fill=0):
  114. self.base_size = base_size
  115. self.crop_size = crop_size
  116. self.fill = fill
  117. def __call__(self, sample):
  118. img = sample['image']
  119. depth = sample['depth']
  120. mask = sample['label']
  121. # random scale (short edge)
  122. short_size = random.randint(
  123. int(self.base_size * 0.5), int(self.base_size * 2.0))
  124. w, h = img.size
  125. if h > w:
  126. ow = short_size
  127. oh = int(1.0 * h * ow / w)
  128. else:
  129. oh = short_size
  130. ow = int(1.0 * w * oh / h)
  131. img = img.resize((ow, oh), Image.BILINEAR)
  132. depth = depth.resize((ow, oh), Image.BILINEAR)
  133. mask = mask.resize((ow, oh), Image.NEAREST)
  134. # pad crop
  135. if short_size < self.crop_size:
  136. padh = self.crop_size - oh if oh < self.crop_size else 0
  137. padw = self.crop_size - ow if ow < self.crop_size else 0
  138. img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
  139. depth = ImageOps.expand(depth, border=(
  140. 0, 0, padw, padh), fill=0)
  141. mask = ImageOps.expand(mask, border=(
  142. 0, 0, padw, padh), fill=self.fill)
  143. # random crop crop_size
  144. w, h = img.size
  145. x1 = random.randint(0, w - self.crop_size)
  146. y1 = random.randint(0, h - self.crop_size)
  147. img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
  148. depth = depth.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
  149. mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
  150. return {'image': img,
  151. 'depth': depth,
  152. 'label': mask}
  153. class FixScaleCrop(object):
  154. def __init__(self, crop_size):
  155. self.crop_size = crop_size
  156. def __call__(self, sample):
  157. img = sample['image']
  158. depth = sample['depth']
  159. mask = sample['label']
  160. w, h = img.size
  161. if w > h:
  162. oh = self.crop_size
  163. ow = int(1.0 * w * oh / h)
  164. else:
  165. ow = self.crop_size
  166. oh = int(1.0 * h * ow / w)
  167. img = img.resize((ow, oh), Image.BILINEAR)
  168. depth = depth.resize((ow, oh), Image.BILINEAR)
  169. mask = mask.resize((ow, oh), Image.NEAREST)
  170. # center crop
  171. w, h = img.size
  172. x1 = int(round((w - self.crop_size) / 2.))
  173. y1 = int(round((h - self.crop_size) / 2.))
  174. img = img.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
  175. depth = depth.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
  176. mask = mask.crop((x1, y1, x1 + self.crop_size, y1 + self.crop_size))
  177. return {'image': img,
  178. 'depth': depth,
  179. 'label': mask}
  180. class FixedResize(object):
  181. def __init__(self, size):
  182. # size: (h, w)
  183. self.size = (size, size)
  184. def __call__(self, sample):
  185. img = sample['image']
  186. depth = sample['depth']
  187. mask = sample['label']
  188. assert img.size == depth.size == mask.size
  189. img = img.resize(self.size, Image.BILINEAR)
  190. depth = depth.resize(self.size, Image.BILINEAR)
  191. mask = mask.resize(self.size, Image.NEAREST)
  192. return {'image': img,
  193. 'depth': depth,
  194. 'label': mask}
  195. class Relabel(object):
  196. def __init__(self, olabel, nlabel):
  197. # change trainid label from olabel to nlabel
  198. self.olabel = olabel
  199. self.nlabel = nlabel
  200. def __call__(self, tensor):
  201. tensor[tensor == self.olabel] = self.nlabel
  202. return tensor