You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

paddle_vision.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import paddle
  4. from . import functional_cv2 as F_cv2
  5. from . import functional_pil as F_pil
  6. import sys
  7. import math
  8. import numbers
  9. import warnings
  10. import collections
  11. import numpy as np
  12. from PIL import Image
  13. from numpy import sin, cos, tan
  14. import paddle
  15. import random
  16. __all__ = [
  17. 'central_crop',
  18. 'to_tensor',
  19. 'crop',
  20. 'pad',
  21. 'resize',
  22. 'transpose',
  23. 'hwc_to_chw',
  24. 'chw_to_hwc',
  25. 'rgb_to_hsv',
  26. 'hsv_to_rgb',
  27. 'rgb_to_gray',
  28. 'adjust_brightness',
  29. 'adjust_contrast',
  30. 'adjust_hue',
  31. 'adjust_saturation',
  32. 'normalize',
  33. 'hflip',
  34. 'vflip',
  35. 'padtoboundingbox',
  36. 'standardize',
  37. 'random_brightness',
  38. 'random_contrast',
  39. 'random_saturation',
  40. 'random_hue',
  41. 'random_crop',
  42. 'random_resized_crop',
  43. 'random_vflip',
  44. 'random_hflip',
  45. 'random_rotation',
  46. 'random_shear',
  47. 'random_shift',
  48. 'random_zoom',
  49. 'random_affine',
  50. ]
  51. def _is_pil_image(img):
  52. return isinstance(img, Image.Image)
  53. def _is_tensor_image(img):
  54. return isinstance(img, paddle.Tensor)
  55. def _is_numpy_image(img):
  56. return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
  57. def to_tensor(img, data_format='HWC'):
  58. return paddle.vision.functional.to_tensor(img, data_format=data_format)
  59. def _get_image_size(img):
  60. if _is_pil_image(img):
  61. return img.size[::-1]
  62. elif _is_numpy_image(img):
  63. return img.shape[:2]
  64. else:
  65. raise TypeError("Unexpected type {}".format(type(img)))
  66. def random_factor(factor, name, center=1, bound=(0, float('inf')), non_negative=True):
  67. if isinstance(factor, numbers.Number):
  68. if factor < 0:
  69. raise ValueError('The input value of {} cannot be negative.'.format(name))
  70. factor = [center - factor, center + factor]
  71. if non_negative:
  72. factor[0] = max(0, factor[0])
  73. elif isinstance(factor, (tuple, list)) and len(factor) == 2:
  74. if not bound[0] <= factor[0] <= factor[1] <= bound[1]:
  75. raise ValueError(
  76. "Please check your value range of {} is valid and "
  77. "within the bound {}.".format(name, bound)
  78. )
  79. else:
  80. raise TypeError("Input of {} should be either a single value, or a list/tuple of " "length 2.".format(name))
  81. factor = np.random.uniform(factor[0], factor[1])
  82. return factor
  83. def central_crop(image, size=None, central_fraction=None):
  84. if size is None and central_fraction is None:
  85. raise ValueError('central_fraction and size can not be both None')
  86. if not (_is_pil_image(image) or _is_numpy_image(image)):
  87. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  88. if _is_pil_image(image):
  89. return F_pil.center_crop(image, size, central_fraction)
  90. else:
  91. return F_cv2.center_crop(image, size, central_fraction)
  92. def crop(image, offset_height, offset_width, target_height, target_width):
  93. if not (_is_pil_image(image) or _is_numpy_image(image)):
  94. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  95. if _is_pil_image(image):
  96. return F_pil.crop(image, offset_height, offset_width, target_height, target_width)
  97. else:
  98. return F_cv2.crop(image, offset_height, offset_width, target_height, target_width)
  99. def pad(image, padding, padding_value, mode):
  100. if not (_is_pil_image(image) or _is_numpy_image(image)):
  101. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  102. if _is_pil_image(image):
  103. return F_pil.pad(image, padding, padding_value, mode)
  104. else:
  105. return F_cv2.pad(image, padding, padding_value, mode)
  106. def resize(image, size, method):
  107. if not (_is_pil_image(image) or _is_numpy_image(image)):
  108. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  109. if _is_pil_image(image):
  110. return F_pil.resize(image, size, method)
  111. else:
  112. return F_cv2.resize(image, size, method)
  113. def transpose(image, order):
  114. if not (_is_pil_image(image) or _is_numpy_image(image)):
  115. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  116. if _is_pil_image(image):
  117. return F_pil.transpose(image, order)
  118. else:
  119. return F_cv2.transpose(image, order)
  120. def hwc_to_chw(image):
  121. if not (_is_pil_image(image) or _is_numpy_image(image)):
  122. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  123. if _is_pil_image(image):
  124. return F_pil.hwc_to_chw(image)
  125. else:
  126. return F_cv2.hwc_to_chw(image)
  127. def chw_to_hwc(image):
  128. if not (_is_pil_image(image) or _is_numpy_image(image)):
  129. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  130. if _is_pil_image(image):
  131. return F_pil.chw_to_hwc(image)
  132. else:
  133. return F_cv2.chw_to_hwc(image)
  134. def rgb_to_hsv(image):
  135. if not (_is_pil_image(image) or isinstance(image, np.ndarray) and (image.ndim == 3)):
  136. raise TypeError('image should be PIL Image or ndarray with dim=3. Got {}'.format(type(image)))
  137. if _is_pil_image(image):
  138. return F_pil.rgb_to_hsv(image)
  139. else:
  140. return F_cv2.rgb_to_hsv(image)
  141. def hsv_to_rgb(image):
  142. if not (_is_pil_image(image) or isinstance(image, np.ndarray) and (image.ndim == 3)):
  143. raise TypeError('image should be PIL Image or ndarray with dim=3. Got {}'.format(type(image)))
  144. if _is_pil_image(image):
  145. return F_pil.hsv_to_rgb(image)
  146. else:
  147. return F_cv2.hsv_to_rgb(image)
  148. def rgb_to_gray(image, num_output_channels):
  149. if not (_is_pil_image(image) or isinstance(image, np.ndarray) and (image.ndim == 3)):
  150. raise TypeError('image should be PIL Image or ndarray with dim=3. Got {}'.format(type(image)))
  151. if _is_pil_image(image):
  152. return F_pil.rgb_to_gray(image, num_output_channels)
  153. else:
  154. return F_cv2.rgb_to_gray(image, num_output_channels)
  155. def adjust_brightness(image, brightness_factor):
  156. if not (_is_pil_image(image) or _is_numpy_image(image)):
  157. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  158. if _is_pil_image(image):
  159. return F_pil.adjust_brightness(image, brightness_factor)
  160. else:
  161. return F_cv2.adjust_brightness(image, brightness_factor)
  162. def adjust_contrast(image, contrast_factor):
  163. if not (_is_pil_image(image) or _is_numpy_image(image)):
  164. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  165. if _is_pil_image(image):
  166. return F_pil.adjust_contrast(image, contrast_factor)
  167. else:
  168. return F_cv2.adjust_contrast(image, contrast_factor)
  169. def adjust_hue(image, hue_factor):
  170. if not (_is_pil_image(image) or _is_numpy_image(image)):
  171. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  172. if _is_pil_image(image):
  173. return F_pil.adjust_hue(image, hue_factor)
  174. else:
  175. return F_cv2.adjust_hue(image, hue_factor)
  176. def adjust_saturation(image, saturation_factor):
  177. if not (_is_pil_image(image) or _is_numpy_image(image)):
  178. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  179. if _is_pil_image(image):
  180. return F_pil.adjust_saturation(image, saturation_factor)
  181. else:
  182. return F_cv2.adjust_saturation(image, saturation_factor)
  183. def hflip(image):
  184. if not (_is_pil_image(image) or _is_numpy_image(image)):
  185. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  186. if _is_pil_image(image):
  187. return F_pil.hflip(image)
  188. else:
  189. return F_cv2.hflip(image)
  190. def vflip(image):
  191. if not (_is_pil_image(image) or _is_numpy_image(image)):
  192. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  193. if _is_pil_image(image):
  194. return F_pil.vflip(image)
  195. else:
  196. return F_cv2.vflip(image)
  197. def padtoboundingbox(image, offset_height, offset_width, target_height, target_width, padding_value):
  198. if not (_is_pil_image(image) or _is_numpy_image(image)):
  199. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  200. if _is_pil_image(image):
  201. return F_pil.padtoboundingbox(image, offset_height, offset_width, target_height, target_width, padding_value)
  202. else:
  203. return F_cv2.padtoboundingbox(image, offset_height, offset_width, target_height, target_width, padding_value)
  204. def normalize(image, mean, std, data_format):
  205. if not _is_tensor_image(image):
  206. if _is_pil_image(image):
  207. image = np.asarray(image)
  208. image = paddle.to_tensor(image)
  209. image = image.astype('float32')
  210. if data_format == 'CHW':
  211. num_channels = image.shape[0]
  212. elif data_format == 'HWC':
  213. num_channels = image.shape[2]
  214. if isinstance(mean, numbers.Number):
  215. mean = (mean, ) * num_channels
  216. elif isinstance(mean, (list, tuple)):
  217. if len(mean) != num_channels:
  218. raise ValueError("Length of mean must be 1 or equal to the number of channels({0}).".format(num_channels))
  219. if isinstance(std, numbers.Number):
  220. std = (std, ) * num_channels
  221. elif isinstance(std, (list, tuple)):
  222. if len(std) != num_channels:
  223. raise ValueError("Length of std must be 1 or equal to the number of channels({0}).".format(num_channels))
  224. if data_format == 'CHW':
  225. std = np.array(std).reshape((-1, 1, 1))
  226. mean = np.array(mean).reshape((-1, 1, 1))
  227. elif data_format == 'HWC':
  228. mean = np.array(mean)
  229. std = np.array(std)
  230. mean = paddle.to_tensor(mean).astype('float32')
  231. std = paddle.to_tensor(std).astype('float32')
  232. return (image - mean) / std
  233. def standardize(image):
  234. '''
  235. Reference to tf.image.per_image_standardization().
  236. Linearly scales each image in image to have mean 0 and variance 1.
  237. '''
  238. if not _is_tensor_image(image):
  239. if _is_pil_image(image):
  240. image = np.asarray(image)
  241. image = paddle.to_tensor(image)
  242. image = image.astype('float32')
  243. num_pixels = paddle.to_tensor(image.size, dtype='float32')
  244. image_mean = paddle.mean(image)
  245. stddev = paddle.std(image)
  246. min_stddev = 1.0 / paddle.sqrt(num_pixels)
  247. adjusted_stddev = paddle.maximum(stddev, min_stddev)
  248. return (image - image_mean) / adjusted_stddev
  249. def random_brightness(image, brightness_factor):
  250. '''
  251. Perform a random brightness on the input image.
  252. Parameters
  253. ----------
  254. image:
  255. Input images to adjust random brightness
  256. brightness_factor:
  257. Brightness adjustment factor (default=(1, 1)). Cannot be negative.
  258. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness].
  259. If it is a sequence, it should be [min, max] for the range.
  260. Returns:
  261. Adjusted image.
  262. -------
  263. '''
  264. if not (_is_pil_image(image) or _is_numpy_image(image)):
  265. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  266. brightness_factor = random_factor(brightness_factor, name='brightness')
  267. if _is_pil_image(image):
  268. return F_pil.adjust_brightness(image, brightness_factor)
  269. else:
  270. return F_cv2.adjust_brightness(image, brightness_factor)
  271. def random_contrast(image, contrast_factor):
  272. if not (_is_pil_image(image) or _is_numpy_image(image)):
  273. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  274. contrast_factor = random_factor(contrast_factor, name='contrast')
  275. if _is_pil_image(image):
  276. return F_pil.adjust_contrast(image, contrast_factor)
  277. else:
  278. return F_cv2.adjust_contrast(image, contrast_factor)
  279. def random_saturation(image, saturation_factor):
  280. if not (_is_pil_image(image) or _is_numpy_image(image)):
  281. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  282. saturation_factor = random_factor(saturation_factor, name='saturation')
  283. if _is_pil_image(image):
  284. return F_pil.adjust_saturation(image, saturation_factor)
  285. else:
  286. return F_cv2.adjust_saturation(image, saturation_factor)
  287. def random_hue(image, hue_factor):
  288. if not (_is_pil_image(image) or _is_numpy_image(image)):
  289. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  290. hue_factor = random_factor(hue_factor, name='hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  291. if _is_pil_image(image):
  292. return F_pil.adjust_hue(image, hue_factor)
  293. else:
  294. return F_cv2.adjust_hue(image, hue_factor)
  295. def random_crop(image, size, padding, pad_if_needed, fill, padding_mode):
  296. if not (_is_pil_image(image) or _is_numpy_image(image)):
  297. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  298. if isinstance(size, int):
  299. size = (size, size)
  300. elif isinstance(size, (tuple, list)) and len(size) == 2:
  301. size = size
  302. else:
  303. raise ValueError('Size should be a int or a list/tuple with length of 2. ' 'But got {}'.format(size))
  304. if padding is not None:
  305. image = pad(image, padding, fill, padding_mode)
  306. h, w = _get_image_size(image)
  307. # pad the width if needed
  308. if pad_if_needed and w < size[1]:
  309. image = pad(image, (size[1] - w, 0), fill, padding_mode)
  310. # pad the height if needed
  311. if pad_if_needed and h < size[0]:
  312. image = pad(image, (0, size[0] - h), fill, padding_mode)
  313. h, w = _get_image_size(image)
  314. target_height, target_width = size
  315. if h < target_height or w < target_width:
  316. raise ValueError(
  317. 'Crop size {} should be smaller than input image size {}. '.format((target_height, target_width), (h, w))
  318. )
  319. offset_height = random.randint(0, h - target_height)
  320. offset_width = random.randint(0, w - target_width)
  321. return crop(image, offset_height, offset_width, target_height, target_width)
  322. def random_resized_crop(image, size, scale, ratio, interpolation):
  323. if not (_is_pil_image(image) or _is_numpy_image(image)):
  324. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  325. if isinstance(size, int):
  326. size = (size, size)
  327. elif isinstance(size, (list, tuple)) and len(size) == 2:
  328. size = size
  329. else:
  330. raise TypeError('Size should be a int or a list/tuple with length of 2.' 'But got {}.'.format(size))
  331. if not (isinstance(scale, (list, tuple)) and len(scale) == 2):
  332. raise TypeError('Scale should be a list/tuple with length of 2.' 'But got {}.'.format(scale))
  333. if not (isinstance(ratio, (list, tuple)) and len(ratio) == 2):
  334. raise TypeError('Scale should be a list/tuple with length of 2.' 'But got {}.'.format(ratio))
  335. if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
  336. raise ValueError("Scale and ratio should be of kind (min, max)")
  337. def _get_param(image, scale, ratio):
  338. height, width = _get_image_size(image)
  339. area = height * width
  340. log_ratio = tuple(math.log(x) for x in ratio)
  341. for _ in range(10):
  342. target_area = np.random.uniform(*scale) * area
  343. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  344. w = int(round(math.sqrt(target_area * aspect_ratio)))
  345. h = int(round(math.sqrt(target_area / aspect_ratio)))
  346. if 0 < w <= width and 0 < h <= height:
  347. i = random.randint(0, height - h)
  348. j = random.randint(0, width - w)
  349. return i, j, h, w
  350. # Fallback to central crop
  351. in_ratio = float(width) / float(height)
  352. if in_ratio < min(ratio):
  353. w = width
  354. h = int(round(w / min(ratio)))
  355. elif in_ratio > max(ratio):
  356. h = height
  357. w = int(round(h * max(ratio)))
  358. else:
  359. # return whole image
  360. w = width
  361. h = height
  362. i = (height - h) // 2
  363. j = (width - w) // 2
  364. return i, j, h, w
  365. offset_height, offset_width, target_height, target_width = _get_param(image, scale, ratio)
  366. image = crop(image, offset_height, offset_width, target_height, target_width)
  367. image = resize(image, size, interpolation)
  368. return image
  369. def random_vflip(image, prob):
  370. if random.random() < prob:
  371. return vflip(image)
  372. return image
  373. def random_hflip(image, prob):
  374. if random.random() < prob:
  375. return hflip(image)
  376. return image
  377. def random_rotation(image, degrees, interpolation, expand, center, fill):
  378. if not (_is_pil_image(image) or _is_numpy_image(image)):
  379. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  380. if isinstance(degrees, numbers.Number):
  381. if degrees < 0:
  382. raise ValueError('If degrees is a single number, it must be positive.' 'But got {}'.format(degrees))
  383. degrees = (-degrees, degrees)
  384. elif not (isinstance(degrees, (list, tuple)) and len(degrees) == 2):
  385. raise ValueError('If degrees is a list/tuple, it must be length of 2.' 'But got {}'.format(degrees))
  386. else:
  387. if degrees[0] > degrees[1]:
  388. raise ValueError('if degrees is a list/tuple, it should be (min, max).')
  389. angle = np.random.uniform(degrees[0], degrees[1])
  390. if _is_pil_image(image):
  391. return F_pil.rotate(image, angle, interpolation, expand, center, fill)
  392. else:
  393. return F_cv2.rotate(image, angle, interpolation, expand, center, fill)
  394. def random_shear(image, degrees, interpolation, fill):
  395. if not (_is_pil_image(image) or _is_numpy_image(image)):
  396. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  397. if isinstance(degrees, numbers.Number):
  398. degrees = (-degrees, degrees, 0, 0)
  399. elif isinstance(degrees, (list, tuple)) and (len(degrees) == 2 or len(degrees) == 4):
  400. if len(degrees) == 2:
  401. degrees = (degrees[0], degrees[1], 0, 0)
  402. else:
  403. raise ValueError(
  404. 'degrees should be a single number or a list/tuple with length in (2 ,4).'
  405. 'But got {}'.format(degrees)
  406. )
  407. if _is_pil_image(image):
  408. return F_pil.random_shear(image, degrees, interpolation, fill)
  409. else:
  410. return F_cv2.random_shear(image, degrees, interpolation, fill)
  411. def random_shift(image, shift, interpolation, fill):
  412. if not (_is_pil_image(image) or _is_numpy_image(image)):
  413. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  414. if not (isinstance(shift, (tuple, list)) and len(shift) == 2):
  415. raise ValueError('Shift should be a list/tuple with length of 2.' 'But got {}'.format(shift))
  416. if _is_pil_image(image):
  417. return F_pil.random_shift(image, shift, interpolation, fill)
  418. else:
  419. return F_cv2.random_shift(image, shift, interpolation, fill)
  420. def random_zoom(image, zoom, interpolation, fill):
  421. if not (_is_pil_image(image) or _is_numpy_image(image)):
  422. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  423. if not (isinstance(zoom, (tuple, list)) and len(zoom) == 2):
  424. raise ValueError('Zoom should be a list/tuple with length of 2.' 'But got {}'.format(zoom))
  425. if not (0 <= zoom[0] <= zoom[1]):
  426. raise ValueError('Zoom values should be positive, and zoom[1] should be greater than zoom[0].')
  427. if _is_pil_image(image):
  428. return F_pil.random_zoom(image, zoom, interpolation, fill)
  429. else:
  430. return F_cv2.random_zoom(image, zoom, interpolation, fill)
  431. def random_affine(image, degrees, shift, zoom, shear, interpolation, fill):
  432. if not (_is_pil_image(image) or _is_numpy_image(image)):
  433. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  434. if _is_pil_image(image):
  435. return F_pil.random_affine(image, degrees, shift, zoom, shear, interpolation, fill)
  436. else:
  437. return F_cv2.random_affine(image, degrees, shift, zoom, shear, interpolation, fill)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.