You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

functional_cv2.py 23 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import numpy as np
  4. from numpy import sin, cos, tan
  5. import math
  6. import numbers
  7. import importlib
  8. def try_import(module_name):
  9. """Try importing a module, with an informative error message on failure."""
  10. install_name = module_name
  11. if module_name.find('.') > -1:
  12. install_name = module_name.split('.')[0]
  13. if module_name == 'cv2':
  14. install_name = 'opencv-python'
  15. try:
  16. mod = importlib.import_module(module_name)
  17. return mod
  18. except ImportError:
  19. err_msg = (
  20. "Failed importing {}. This likely means that some paddle modules "
  21. "require additional dependencies that have to be "
  22. "manually installed (usually with `pip install {}`). "
  23. ).format(module_name, install_name)
  24. raise ImportError(err_msg)
  25. def crop(image, offset_height, offset_width, target_height, target_width):
  26. image_height, image_width = image.shape[0:2]
  27. if offset_width < 0:
  28. raise ValueError('offset_width must be >0.')
  29. if offset_height < 0:
  30. raise ValueError('offset_height must be >0.')
  31. if target_height < 0:
  32. raise ValueError('target_height must be >0.')
  33. if target_width < 0:
  34. raise ValueError('target_width must be >0.')
  35. if offset_width + target_width > image_width:
  36. raise ValueError('offset_width + target_width must be <= image width.')
  37. if offset_height + target_height > image_height:
  38. raise ValueError('offset_height + target_height must be <= image height.')
  39. return image[offset_height:offset_height + target_height, offset_width:offset_width + target_width]
  40. def center_crop(image, size, central_fraction):
  41. image_height, image_width = image.shape[0:2]
  42. if size is not None:
  43. if not isinstance(size, (int, list, tuple)) or (isinstance(size, (list, tuple)) and len(size) != 2):
  44. raise TypeError(
  45. "Size should be a single integer or a list/tuple (h, w) of length 2.But"
  46. "got {}.".format(size)
  47. )
  48. if isinstance(size, int):
  49. target_height = size
  50. target_width = size
  51. else:
  52. target_height = size[0]
  53. target_width = size[1]
  54. elif central_fraction is not None:
  55. if central_fraction <= 0.0 or central_fraction > 1.0:
  56. raise ValueError('central_fraction must be within (0, 1]')
  57. target_height = int(central_fraction * image_height)
  58. target_width = int(central_fraction * image_width)
  59. crop_top = int(round((image_height - target_height) / 2.))
  60. crop_left = int(round((image_width - target_width) / 2.))
  61. return crop(image, crop_top, crop_left, target_height, target_width)
  62. def pad(image, padding, padding_value, mode):
  63. if isinstance(padding, int):
  64. top = bottom = left = right = padding
  65. elif isinstance(padding, (tuple, list)):
  66. if len(padding) == 2:
  67. left = right = padding[0]
  68. top = bottom = padding[1]
  69. elif len(padding) == 4:
  70. left = padding[0]
  71. top = padding[1]
  72. right = padding[2]
  73. bottom = padding[3]
  74. else:
  75. raise TypeError("The size of the padding list or tuple should be 2 or 4." "But got {}".format(padding))
  76. else:
  77. raise TypeError("Padding can be any of: a number, a tuple or list of size 2 or 4." "But got {}".format(padding))
  78. if mode not in ['constant', 'edge', 'reflect', 'symmetric']:
  79. raise ValueError("Padding mode should be 'constant', 'edge', 'reflect', or 'symmetric'.")
  80. cv2 = try_import('cv2')
  81. _cv2_pad_from_str = {
  82. 'constant': cv2.BORDER_CONSTANT,
  83. 'edge': cv2.BORDER_REPLICATE,
  84. 'reflect': cv2.BORDER_REFLECT_101,
  85. 'symmetric': cv2.BORDER_REFLECT
  86. }
  87. if len(image.shape) == 3 and image.shape[2] == 1:
  88. return cv2.copyMakeBorder(
  89. image, top=top, bottom=bottom, left=left, right=right, borderType=_cv2_pad_from_str[mode],
  90. value=padding_value
  91. )[:, :, np.newaxis]
  92. else:
  93. return cv2.copyMakeBorder(
  94. image, top=top, bottom=bottom, left=left, right=right, borderType=_cv2_pad_from_str[mode],
  95. value=padding_value
  96. )
  97. def resize(image, size, method):
  98. if not (isinstance(size, int) or (isinstance(size, (list, tuple)) and len(size) == 2)):
  99. raise TypeError('Size should be a single number or a list/tuple (h, w) of length 2.' 'Got {}.'.format(size))
  100. if method not in ('nearest', 'bilinear', 'area', 'bicubic' 'lanczos'):
  101. raise ValueError(
  102. "Unknown resize method! resize method must be in "
  103. "(\'nearest\',\'bilinear\',\'bicubic\',\'area\',\'lanczos\')"
  104. )
  105. cv2 = try_import('cv2')
  106. _cv2_interp_from_str = {
  107. 'nearest': cv2.INTER_NEAREST,
  108. 'bilinear': cv2.INTER_LINEAR,
  109. 'area': cv2.INTER_AREA,
  110. 'bicubic': cv2.INTER_CUBIC,
  111. 'lanczos': cv2.INTER_LANCZOS4
  112. }
  113. h, w = image.shape[:2]
  114. if isinstance(size, int):
  115. if (w <= h and w == size) or (h <= w and h == size):
  116. return image
  117. if w < h:
  118. target_w = size
  119. target_h = int(size * h / w)
  120. else:
  121. target_h = size
  122. target_w = int(size * w / h)
  123. size = (target_h, target_w)
  124. output = cv2.resize(image, dsize=(size[1], size[0]), interpolation=_cv2_interp_from_str[method])
  125. if len(image.shape) == 3 and image.shape[2] == 1:
  126. return output[:, :, np.newaxis]
  127. else:
  128. return output
  129. def transpose(image, order):
  130. if not (isinstance(order, (list, tuple)) and len(order) == 3):
  131. raise TypeError("Order must be a list/tuple of length 3." "But got {}.".format(order))
  132. image_shape = image.shape
  133. if len(image_shape) == 2:
  134. image = image[..., np.newaxis]
  135. return image.transpose(order)
  136. def hwc_to_chw(image):
  137. image_shape = image.shape
  138. if len(image_shape) == 2:
  139. image = image[..., np.newaxis]
  140. return image.transpose((2, 0, 1))
  141. def chw_to_hwc(image):
  142. image_shape = image.shape
  143. if len(image_shape) == 2:
  144. image = image[..., np.newaxis]
  145. return image.transpose((1, 2, 0))
  146. def rgb_to_hsv(image):
  147. cv2 = try_import('cv2')
  148. image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
  149. return image
  150. def hsv_to_rgb(image):
  151. cv2 = try_import('cv2')
  152. image = cv2.cvtColor(image, cv2.COLOR_HSV2RGB)
  153. return image
  154. def rgb_to_gray(image, num_output_channels):
  155. cv2 = try_import('cv2')
  156. if num_output_channels == 1:
  157. image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis]
  158. elif num_output_channels == 3:
  159. image = np.broadcast_to(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)[:, :, np.newaxis], image.shape)
  160. else:
  161. raise ValueError('num_output_channels should be either 1 or 3')
  162. return image
  163. def adjust_brightness(image, brightness_factor):
  164. if brightness_factor < 0:
  165. raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))
  166. cv2 = try_import('cv2')
  167. table = np.array([i * brightness_factor for i in range(0, 256)]).clip(0, 255).astype('uint8')
  168. if len(image.shape) == 3 and image.shape[2] == 1:
  169. return cv2.LUT(image, table)[:, :, np.newaxis]
  170. else:
  171. return cv2.LUT(image, table)
  172. def adjust_contrast(image, contrast_factor):
  173. """Adjusts contrast of an image.
  174. Args:
  175. img (np.array): Image to be adjusted.
  176. contrast_factor (float): How much to adjust the contrast. Can be any
  177. non negative number. 0 gives a solid gray image, 1 gives the
  178. original image while 2 increases the contrast by a factor of 2.
  179. Returns:
  180. np.array: Contrast adjusted image.
  181. """
  182. if contrast_factor < 0:
  183. raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))
  184. cv2 = try_import('cv2')
  185. table = np.array([(i - 127) * contrast_factor + 127 for i in range(0, 256)]).clip(0, 255).astype('uint8')
  186. if len(image.shape) == 3 and image.shape[2] == 1:
  187. return cv2.LUT(image, table)[:, :, np.newaxis]
  188. else:
  189. return cv2.LUT(image, table)
  190. def adjust_hue(image, hue_factor):
  191. """Adjusts hue of an image.
  192. The image hue is adjusted by converting the image to HSV and
  193. cyclically shifting the intensities in the hue channel (H).
  194. The image is then converted back to original image mode.
  195. `hue_factor` is the amount of shift in H channel and must be in the
  196. interval `[-0.5, 0.5]`.
  197. Args:
  198. image (PIL.Image): PIL Image to be adjusted.
  199. hue_factor (float): How much to shift the hue channel. Should be in
  200. [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
  201. HSV space in positive and negative direction respectively.
  202. 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
  203. with complementary colors while 0 gives the original image.
  204. Returns:
  205. PIL.Image: Hue adjusted image.
  206. """
  207. if not (-0.5 <= hue_factor <= 0.5):
  208. raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))
  209. cv2 = try_import('cv2')
  210. dtype = image.dtype
  211. image = image.astype(np.uint8)
  212. hsv_img = cv2.cvtColor(image, cv2.COLOR_RGB2HSV_FULL)
  213. h, s, v = cv2.split(hsv_img)
  214. alpha = np.random.uniform(hue_factor, hue_factor)
  215. h = h.astype(np.uint8)
  216. # uint8 addition take cares of rotation across boundaries
  217. with np.errstate(over="ignore"):
  218. h += np.uint8(alpha * 255)
  219. hsv_img = cv2.merge([h, s, v])
  220. return cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB_FULL).astype(dtype)
  221. def adjust_saturation(image, saturation_factor):
  222. """Adjusts color saturation of an image.
  223. Args:
  224. image (np.array): Image to be adjusted.
  225. saturation_factor (float): How much to adjust the saturation. 0 will
  226. give a black and white image, 1 will give the original image while
  227. 2 will enhance the saturation by a factor of 2.
  228. Returns:
  229. np.array: Saturation adjusted image.
  230. """
  231. if saturation_factor < 0:
  232. raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))
  233. cv2 = try_import('cv2')
  234. dtype = image.dtype
  235. image = image.astype(np.float32)
  236. alpha = np.random.uniform(saturation_factor, saturation_factor)
  237. gray_img = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  238. gray_img = gray_img[..., np.newaxis]
  239. img = image * alpha + gray_img * (1 - alpha)
  240. return img.clip(0, 255).astype(dtype)
  241. def hflip(image):
  242. """Horizontally flips the given image.
  243. Args:
  244. image (np.array): Image to be flipped.
  245. Returns:
  246. np.array: Horizontall flipped image.
  247. """
  248. cv2 = try_import('cv2')
  249. return cv2.flip(image, 1)
  250. def vflip(image):
  251. """Vertically flips the given np.array.
  252. Args:
  253. image (np.array): Image to be flipped.
  254. Returns:
  255. np.array: Vertically flipped image.
  256. """
  257. cv2 = try_import('cv2')
  258. if len(image.shape) == 3 and image.shape[2] == 1:
  259. return cv2.flip(image, 0)[:, :, np.newaxis]
  260. else:
  261. return cv2.flip(image, 0)
  262. def padtoboundingbox(image, offset_height, offset_width, target_height, target_width, padding_value):
  263. '''
  264. Parameters
  265. ----------
  266. image:
  267. A np.array image to be padded size of (target_width, target_height)
  268. offset_height:
  269. Number of rows of padding_values to add on top.
  270. offset_width:
  271. Number of columns of padding_values to add on the left.
  272. target_height:
  273. Height of output image.
  274. target_width:
  275. Width of output image.
  276. padding_value:
  277. value to pad
  278. Returns:
  279. np.array image: padded image
  280. -------
  281. '''
  282. if offset_height < 0:
  283. raise ValueError('offset_height must be >= 0')
  284. if offset_width < 0:
  285. raise ValueError('offset_width must be >= 0')
  286. height, width = image.shape[:2]
  287. after_padding_width = target_width - offset_width - width
  288. after_padding_height = target_height - offset_height - height
  289. if after_padding_height < 0:
  290. raise ValueError('image height must be <= target - offset')
  291. if after_padding_width < 0:
  292. raise ValueError('image width must be <= target - offset')
  293. return pad(
  294. image, padding=(offset_width, offset_height, after_padding_width, after_padding_height),
  295. padding_value=padding_value, mode='constant'
  296. )
  297. def rotate(img, angle, interpolation, expand, center, fill):
  298. """Rotates the image by angle.
  299. Args:
  300. img (np.array): Image to be rotated.
  301. angle (float or int): In degrees degrees counter clockwise order.
  302. interpolation (int|str, optional): Interpolation method. If omitted, or if the
  303. image has only one channel, it is set to cv2.INTER_NEAREST.
  304. when use cv2 backend, support method are as following:
  305. - "nearest": cv2.INTER_NEAREST,
  306. - "bilinear": cv2.INTER_LINEAR,
  307. - "bicubic": cv2.INTER_CUBIC
  308. expand (bool, optional): Optional expansion flag.
  309. If true, expands the output image to make it large enough to hold the entire rotated image.
  310. If false or omitted, make the output image the same size as the input image.
  311. Note that the expand flag assumes rotation around the center and no translation.
  312. center (2-tuple, optional): Optional center of rotation.
  313. Origin is the upper left corner.
  314. Default is the center of the image.
  315. fill (3-tuple or int): RGB pixel fill value for area outside the rotated image.
  316. If int, it is used for all channels respectively.
  317. Returns:
  318. np.array: Rotated image.
  319. """
  320. cv2 = try_import('cv2')
  321. _cv2_interp_from_str = {
  322. 'nearest': cv2.INTER_NEAREST,
  323. 'bilinear': cv2.INTER_LINEAR,
  324. 'area': cv2.INTER_AREA,
  325. 'bicubic': cv2.INTER_CUBIC,
  326. 'lanczos': cv2.INTER_LANCZOS4
  327. }
  328. h, w, c = img.shape
  329. if isinstance(fill, numbers.Number):
  330. fill = (fill, ) * c
  331. elif not (isinstance(fill, (list, tuple)) and len(fill) == c):
  332. raise ValueError(
  333. 'If fill should be a single number or a list/tuple with length of image channels.'
  334. 'But got {}'.format(fill)
  335. )
  336. if center is None:
  337. center = (w / 2.0, h / 2.0)
  338. M = cv2.getRotationMatrix2D(center, angle, 1)
  339. if expand:
  340. def transform(x, y, matrix):
  341. (a, b, c, d, e, f) = matrix
  342. return a * x + b * y + c, d * x + e * y + f
  343. # calculate output size
  344. xx = []
  345. yy = []
  346. angle = -math.radians(angle)
  347. expand_matrix = [
  348. round(math.cos(angle), 15),
  349. round(math.sin(angle), 15),
  350. 0.0,
  351. round(-math.sin(angle), 15),
  352. round(math.cos(angle), 15),
  353. 0.0,
  354. ]
  355. post_trans = (0, 0)
  356. expand_matrix[2], expand_matrix[5] = transform(
  357. -center[0] - post_trans[0], -center[1] - post_trans[1], expand_matrix
  358. )
  359. expand_matrix[2] += center[0]
  360. expand_matrix[5] += center[1]
  361. for x, y in ((0, 0), (w, 0), (w, h), (0, h)):
  362. x, y = transform(x, y, expand_matrix)
  363. xx.append(x)
  364. yy.append(y)
  365. nw = math.ceil(max(xx)) - math.floor(min(xx))
  366. nh = math.ceil(max(yy)) - math.floor(min(yy))
  367. M[0, 2] += (nw - w) * 0.5
  368. M[1, 2] += (nh - h) * 0.5
  369. w, h = int(nw), int(nh)
  370. if len(img.shape) == 3 and img.shape[2] == 1:
  371. return cv2.warpAffine(img, M, (w, h), flags=_cv2_interp_from_str[interpolation], borderValue=fill)[:, :,
  372. np.newaxis]
  373. else:
  374. return cv2.warpAffine(img, M, (w, h), flags=_cv2_interp_from_str[interpolation], borderValue=fill)
  375. def get_affine_matrix(center, angle, translate, scale, shear):
  376. rot = math.radians(angle)
  377. sx, sy = [math.radians(s) for s in shear]
  378. cx, cy = center
  379. tx, ty = translate
  380. # RSS without scaling
  381. a = math.cos(rot - sy) / math.cos(sy)
  382. b = -math.cos(rot - sy) * math.tan(sx) / math.cos(sy) - math.sin(rot)
  383. c = math.sin(rot - sy) / math.cos(sy)
  384. d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
  385. # Inverted rotation matrix with scale and shear
  386. # det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
  387. matrix = [d, -b, 0.0, -c, a, 0.0]
  388. matrix = [x / scale for x in matrix]
  389. # Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
  390. matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
  391. matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
  392. # Apply center translation: C * RSS^-1 * C^-1 * T^-1
  393. matrix[2] += cx
  394. matrix[5] += cy
  395. return matrix
  396. def random_shear(image, degrees, interpolation, fill):
  397. cv2 = try_import('cv2')
  398. _cv2_interp_from_str = {
  399. 'nearest': cv2.INTER_NEAREST,
  400. 'bilinear': cv2.INTER_LINEAR,
  401. 'area': cv2.INTER_AREA,
  402. 'bicubic': cv2.INTER_CUBIC,
  403. 'lanczos': cv2.INTER_LANCZOS4
  404. }
  405. h, w, c = image.shape
  406. if isinstance(fill, numbers.Number):
  407. fill = (fill, ) * c
  408. elif not (isinstance(fill, (list, tuple)) and len(fill) == c):
  409. raise ValueError(
  410. 'If fill should be a single number or a list/tuple with length of image channels.'
  411. 'But got {}'.format(fill)
  412. )
  413. center = (w / 2.0, h / 2.0)
  414. shear = [-np.random.uniform(degrees[0], degrees[1]), -np.random.uniform(degrees[2], degrees[3])]
  415. matrix = get_affine_matrix(center=center, angle=0, translate=(0, 0), scale=1.0, shear=shear)
  416. matrix = np.asarray(matrix).reshape((2, 3))
  417. if len(image.shape) == 3 and image.shape[2] == 1:
  418. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation],
  419. borderValue=fill)[:, :, np.newaxis]
  420. else:
  421. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation], borderValue=fill)
  422. def random_shift(image, shift, interpolation, fill):
  423. cv2 = try_import('cv2')
  424. _cv2_interp_from_str = {
  425. 'nearest': cv2.INTER_NEAREST,
  426. 'bilinear': cv2.INTER_LINEAR,
  427. 'area': cv2.INTER_AREA,
  428. 'bicubic': cv2.INTER_CUBIC,
  429. 'lanczos': cv2.INTER_LANCZOS4
  430. }
  431. h, w, c = image.shape
  432. if isinstance(fill, numbers.Number):
  433. fill = (fill, ) * c
  434. elif not (isinstance(fill, (list, tuple)) and len(fill) == c):
  435. raise ValueError(
  436. 'If fill should be a single number or a list/tuple with length of image channels.'
  437. 'But got {}'.format(fill)
  438. )
  439. hrg = shift[0]
  440. wrg = shift[1]
  441. tx = -np.random.uniform(-hrg, hrg) * w
  442. ty = -np.random.uniform(-wrg, wrg) * h
  443. center = (w / 2.0, h / 2.0)
  444. matrix = get_affine_matrix(center=center, angle=0, translate=(tx, ty), scale=1.0, shear=(0, 0))
  445. matrix = np.asarray(matrix).reshape((2, 3))
  446. if len(image.shape) == 3 and image.shape[2] == 1:
  447. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation],
  448. borderValue=fill)[:, :, np.newaxis]
  449. else:
  450. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation], borderValue=fill)
  451. def random_zoom(image, zoom, interpolation, fill):
  452. cv2 = try_import('cv2')
  453. _cv2_interp_from_str = {
  454. 'nearest': cv2.INTER_NEAREST,
  455. 'bilinear': cv2.INTER_LINEAR,
  456. 'area': cv2.INTER_AREA,
  457. 'bicubic': cv2.INTER_CUBIC,
  458. 'lanczos': cv2.INTER_LANCZOS4
  459. }
  460. h, w, c = image.shape
  461. if isinstance(fill, numbers.Number):
  462. fill = (fill, ) * c
  463. elif not (isinstance(fill, (list, tuple)) and len(fill) == c):
  464. raise ValueError(
  465. 'If fill should be a single number or a list/tuple with length of image channels.'
  466. 'But got {}'.format(fill)
  467. )
  468. scale = 1 / np.random.uniform(zoom[0], zoom[1])
  469. center = (w / 2.0, h / 2.0)
  470. matrix = get_affine_matrix(center=center, angle=0, translate=(0, 0), scale=scale, shear=(0, 0))
  471. matrix = np.asarray(matrix).reshape((2, 3))
  472. if len(image.shape) == 3 and image.shape[2] == 1:
  473. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation],
  474. borderValue=fill)[:, :, np.newaxis]
  475. else:
  476. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation], borderValue=fill)
  477. def random_affine(image, degrees, shift, zoom, shear, interpolation, fill):
  478. cv2 = try_import('cv2')
  479. _cv2_interp_from_str = {
  480. 'nearest': cv2.INTER_NEAREST,
  481. 'bilinear': cv2.INTER_LINEAR,
  482. 'area': cv2.INTER_AREA,
  483. 'bicubic': cv2.INTER_CUBIC,
  484. 'lanczos': cv2.INTER_LANCZOS4
  485. }
  486. h, w, c = image.shape
  487. if isinstance(fill, numbers.Number):
  488. fill = (fill, ) * c
  489. elif not (isinstance(fill, (list, tuple)) and len(fill) == c):
  490. raise ValueError(
  491. 'If fill should be a single number or a list/tuple with length of image channels.'
  492. 'But got {}'.format(fill)
  493. )
  494. center = (w / 2.0, h / 2.0)
  495. angle = -float(np.random.uniform(degrees[0], degrees[1]))
  496. if shift is not None:
  497. max_dx = float(shift[0] * h)
  498. max_dy = float(shift[1] * w)
  499. tx = -int(round(np.random.uniform(-max_dx, max_dx)))
  500. ty = -int(round(np.random.uniform(-max_dy, max_dy)))
  501. shift = [tx, ty]
  502. else:
  503. shift = [0, 0]
  504. if zoom is not None:
  505. scale = 1 / np.random.uniform(zoom[0], zoom[1])
  506. else:
  507. scale = 1.0
  508. shear_x = shear_y = 0.0
  509. if shear is not None:
  510. shear_x = float(np.random.uniform(shear[0], shear[1]))
  511. if len(shear) == 4:
  512. shear_y = float(np.random.uniform(shear[2], shear[3]))
  513. shear = (-shear_x, -shear_y)
  514. matrix = get_affine_matrix(center=center, angle=angle, translate=shift, scale=scale, shear=shear)
  515. matrix = np.asarray(matrix).reshape((2, 3))
  516. if len(image.shape) == 3 and image.shape[2] == 1:
  517. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation],
  518. borderValue=fill)[:, :, np.newaxis]
  519. else:
  520. return cv2.warpAffine(image, matrix, (w, h), flags=_cv2_interp_from_str[interpolation], borderValue=fill)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.