You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mindspore_vision.py 21 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import mindspore as ms
  4. from . import functional_cv2 as F_cv2
  5. from . import functional_pil as F_pil
  6. import mindspore.ops as P
  7. from mindspore.numpy import std
  8. from PIL import Image
  9. import PIL
  10. import numpy as np
  11. import numbers
  12. import random
  13. import math
  14. __all__ = [
  15. 'central_crop',
  16. 'to_tensor',
  17. 'crop',
  18. 'pad',
  19. 'resize',
  20. 'transpose',
  21. 'hwc_to_chw',
  22. 'chw_to_hwc',
  23. 'rgb_to_hsv',
  24. 'hsv_to_rgb',
  25. 'rgb_to_gray',
  26. 'adjust_brightness',
  27. 'adjust_contrast',
  28. 'adjust_hue',
  29. 'adjust_saturation',
  30. 'normalize',
  31. 'hflip',
  32. 'vflip',
  33. 'padtoboundingbox',
  34. 'standardize',
  35. 'random_brightness',
  36. 'random_contrast',
  37. 'random_saturation',
  38. 'random_hue',
  39. 'random_crop',
  40. 'random_resized_crop',
  41. 'random_vflip',
  42. 'random_hflip',
  43. 'random_rotation',
  44. 'random_shear',
  45. 'random_shift',
  46. 'random_zoom',
  47. 'random_affine',
  48. ]
  49. def _is_pil_image(image):
  50. return isinstance(image, Image.Image)
  51. def _is_tensor_image(image):
  52. return isinstance(image, ms.Tensor)
  53. def _is_numpy_image(image):
  54. return isinstance(image, np.ndarray) and (image.ndim in {2, 3})
  55. def _get_image_size(img):
  56. if _is_pil_image(img):
  57. return img.size[::-1]
  58. elif _is_numpy_image(img):
  59. return img.shape[:2]
  60. else:
  61. raise TypeError("Unexpected type {}".format(type(img)))
  62. def random_factor(factor, name, center=1, bound=(0, float('inf')), non_negative=True):
  63. if isinstance(factor, numbers.Number):
  64. if factor < 0:
  65. raise ValueError('The input value of {} cannot be negative.'.format(name))
  66. factor = [center - factor, center + factor]
  67. if non_negative:
  68. factor[0] = max(0, factor[0])
  69. elif isinstance(factor, (tuple, list)) and len(factor) == 2:
  70. if not bound[0] <= factor[0] <= factor[1] <= bound[1]:
  71. raise ValueError(
  72. "Please check your value range of {} is valid and "
  73. "within the bound {}.".format(name, bound)
  74. )
  75. else:
  76. raise TypeError("Input of {} should be either a single value, or a list/tuple of " "length 2.".format(name))
  77. factor = np.random.uniform(factor[0], factor[1])
  78. return factor
  79. def to_tensor(image, data_format='HWC'):
  80. if not (_is_pil_image(image) or _is_numpy_image(image)):
  81. raise TypeError('image should be PIL Image or ndarray. Got {}'.format(type(image)))
  82. image = np.asarray(image).astype('float32')
  83. if image.ndim == 2:
  84. image = image[:, :, None]
  85. if data_format == 'CHW':
  86. image = np.transpose(image, (2, 0, 1))
  87. image = image / 255.
  88. else:
  89. image = image / 255.
  90. return image
  91. def central_crop(image, size=None, central_fraction=None):
  92. if size is None and central_fraction is None:
  93. raise ValueError('central_fraction and size can not be both None')
  94. if not (_is_pil_image(image) or _is_numpy_image(image)):
  95. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  96. if _is_pil_image(image):
  97. return F_pil.center_crop(image, size, central_fraction)
  98. else:
  99. return F_cv2.center_crop(image, size, central_fraction)
  100. def crop(image, offset_height, offset_width, target_height, target_width):
  101. if not (_is_pil_image(image) or _is_numpy_image(image)):
  102. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  103. if _is_pil_image(image):
  104. return F_pil.crop(image, offset_height, offset_width, target_height, target_width)
  105. else:
  106. return F_cv2.crop(image, offset_height, offset_width, target_height, target_width)
  107. def pad(image, padding, padding_value, mode):
  108. if not (_is_pil_image(image) or _is_numpy_image(image)):
  109. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  110. if _is_pil_image(image):
  111. return F_pil.pad(image, padding, padding_value, mode)
  112. else:
  113. return F_cv2.pad(image, padding, padding_value, mode)
  114. def resize(image, size, method):
  115. if not (_is_pil_image(image) or _is_numpy_image(image)):
  116. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  117. if _is_pil_image(image):
  118. return F_pil.resize(image, size, method)
  119. else:
  120. return F_cv2.resize(image, size, method)
  121. def transpose(image, order):
  122. if not (_is_pil_image(image) or _is_numpy_image(image)):
  123. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  124. if _is_pil_image(image):
  125. return F_pil.transpose(image, order)
  126. else:
  127. return F_cv2.transpose(image, order)
  128. def hwc_to_chw(image):
  129. if not (_is_pil_image(image) or _is_numpy_image(image)):
  130. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  131. if _is_pil_image(image):
  132. return F_pil.hwc_to_chw(image)
  133. else:
  134. return F_cv2.hwc_to_chw(image)
  135. def chw_to_hwc(image):
  136. if not (_is_pil_image(image) or _is_numpy_image(image)):
  137. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  138. if _is_pil_image(image):
  139. return F_pil.chw_to_hwc(image)
  140. else:
  141. return F_cv2.chw_to_hwc(image)
  142. def rgb_to_hsv(image):
  143. if not (_is_pil_image(image) or isinstance(image, np.ndarray) and (image.ndim == 3)):
  144. raise TypeError('image should be PIL Image or ndarray with dim=3. Got {}'.format(type(image)))
  145. if _is_pil_image(image):
  146. return F_pil.rgb_to_hsv(image)
  147. else:
  148. return F_cv2.rgb_to_hsv(image)
  149. def hsv_to_rgb(image):
  150. if not (_is_pil_image(image) or isinstance(image, np.ndarray) and (image.ndim == 3)):
  151. raise TypeError('image should be PIL Image or ndarray with dim=3. Got {}'.format(type(image)))
  152. if _is_pil_image(image):
  153. return F_pil.hsv_to_rgb(image)
  154. else:
  155. return F_cv2.hsv_to_rgb(image)
  156. def rgb_to_gray(image, num_output_channels):
  157. if not (_is_pil_image(image) or isinstance(image, np.ndarray) and (image.ndim == 3)):
  158. raise TypeError('image should be PIL Image or ndarray with dim=3. Got {}'.format(type(image)))
  159. if _is_pil_image(image):
  160. return F_pil.rgb_to_gray(image, num_output_channels)
  161. else:
  162. return F_cv2.rgb_to_gray(image, num_output_channels)
  163. def adjust_brightness(image, brightness_factor):
  164. if not (_is_pil_image(image) or _is_numpy_image(image)):
  165. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  166. if _is_pil_image(image):
  167. return F_pil.adjust_brightness(image, brightness_factor)
  168. else:
  169. return F_cv2.adjust_brightness(image, brightness_factor)
  170. def adjust_contrast(image, contrast_factor):
  171. if not (_is_pil_image(image) or _is_numpy_image(image)):
  172. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  173. if _is_pil_image(image):
  174. return F_pil.adjust_contrast(image, contrast_factor)
  175. else:
  176. return F_cv2.adjust_contrast(image, contrast_factor)
  177. def adjust_hue(image, hue_factor):
  178. if not (_is_pil_image(image) or _is_numpy_image(image)):
  179. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  180. if _is_pil_image(image):
  181. return F_pil.adjust_hue(image, hue_factor)
  182. else:
  183. return F_cv2.adjust_hue(image, hue_factor)
  184. def adjust_saturation(image, saturation_factor):
  185. if not (_is_pil_image(image) or _is_numpy_image(image)):
  186. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  187. if _is_pil_image(image):
  188. return F_pil.adjust_saturation(image, saturation_factor)
  189. else:
  190. return F_cv2.adjust_saturation(image, saturation_factor)
  191. def hflip(image):
  192. if not (_is_pil_image(image) or _is_numpy_image(image)):
  193. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  194. if _is_pil_image(image):
  195. return F_pil.hflip(image)
  196. else:
  197. return F_cv2.hflip(image)
  198. def vflip(image):
  199. if not (_is_pil_image(image) or _is_numpy_image(image)):
  200. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  201. if _is_pil_image(image):
  202. return F_pil.vflip(image)
  203. else:
  204. return F_cv2.vflip(image)
  205. def padtoboundingbox(image, offset_height, offset_width, target_height, target_width, padding_value):
  206. if not (_is_pil_image(image) or _is_numpy_image(image)):
  207. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  208. if _is_pil_image(image):
  209. return F_pil.padtoboundingbox(image, offset_height, offset_width, target_height, target_width, padding_value)
  210. else:
  211. return F_cv2.padtoboundingbox(image, offset_height, offset_width, target_height, target_width, padding_value)
  212. def normalize(image, mean, std, data_format):
  213. if _is_pil_image(image):
  214. image = np.asarray(image)
  215. image = image.astype('float32')
  216. if data_format == 'CHW':
  217. num_channels = image.shape[0]
  218. elif data_format == 'HWC':
  219. num_channels = image.shape[2]
  220. if isinstance(mean, numbers.Number):
  221. mean = (mean, ) * num_channels
  222. elif isinstance(mean, (list, tuple)):
  223. if len(mean) != num_channels:
  224. raise ValueError("Length of mean must be 1 or equal to the number of channels({0}).".format(num_channels))
  225. if isinstance(std, numbers.Number):
  226. std = (std, ) * num_channels
  227. elif isinstance(std, (list, tuple)):
  228. if len(std) != num_channels:
  229. raise ValueError("Length of std must be 1 or equal to the number of channels({0}).".format(num_channels))
  230. mean = np.array(mean, dtype=image.dtype)
  231. std = np.array(std, dtype=image.dtype)
  232. if data_format == 'CHW':
  233. image = (image - mean[None, None, :]) / std[None, None, :]
  234. elif data_format == 'HWC':
  235. image = (image - mean[None, None, :]) / std[None, None, :]
  236. return image
  237. def standardize(image):
  238. '''
  239. Reference to tf.image.per_image_standardization().
  240. Linearly scales each image in image to have mean 0 and variance 1.
  241. '''
  242. if _is_pil_image(image):
  243. image = np.asarray(image)
  244. image = image.astype('float32')
  245. num_pixels = image.size
  246. image_mean = np.mean(image, keep_dims=False)
  247. stddev = np.std(image, keep_dims=False)
  248. min_stddev = 1.0 / np.sqrt(num_pixels)
  249. adjusted_stddev = np.maximum(stddev, min_stddev)
  250. return (image - image_mean) / adjusted_stddev
  251. def random_brightness(image, brightness_factor):
  252. '''
  253. Perform a random brightness on the input image.
  254. Parameters
  255. ----------
  256. image:
  257. Input images to adjust random brightness
  258. brightness_factor:
  259. Brightness adjustment factor (default=(1, 1)). Cannot be negative.
  260. If it is a float, the factor is uniformly chosen from the range [max(0, 1-brightness), 1+brightness].
  261. If it is a sequence, it should be [min, max] for the range.
  262. Returns:
  263. Adjusted image.
  264. -------
  265. '''
  266. if not (_is_pil_image(image) or _is_numpy_image(image)):
  267. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  268. brightness_factor = random_factor(brightness_factor, name='brightness')
  269. if _is_pil_image(image):
  270. return F_pil.adjust_brightness(image, brightness_factor)
  271. else:
  272. return F_cv2.adjust_brightness(image, brightness_factor)
  273. def random_contrast(image, contrast_factor):
  274. if not (_is_pil_image(image) or _is_numpy_image(image)):
  275. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  276. contrast_factor = random_factor(contrast_factor, name='contrast')
  277. if _is_pil_image(image):
  278. return F_pil.adjust_contrast(image, contrast_factor)
  279. else:
  280. return F_cv2.adjust_contrast(image, contrast_factor)
  281. def random_saturation(image, saturation_factor):
  282. if not (_is_pil_image(image) or _is_numpy_image(image)):
  283. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  284. saturation_factor = random_factor(saturation_factor, name='saturation')
  285. if _is_pil_image(image):
  286. return F_pil.adjust_saturation(image, saturation_factor)
  287. else:
  288. return F_cv2.adjust_saturation(image, saturation_factor)
  289. def random_hue(image, hue_factor):
  290. if not (_is_pil_image(image) or _is_numpy_image(image)):
  291. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  292. hue_factor = random_factor(hue_factor, name='hue', center=0, bound=(-0.5, 0.5), non_negative=False)
  293. if _is_pil_image(image):
  294. return F_pil.adjust_hue(image, hue_factor)
  295. else:
  296. return F_cv2.adjust_hue(image, hue_factor)
  297. def random_crop(image, size, padding, pad_if_needed, fill, padding_mode):
  298. if not (_is_pil_image(image) or _is_numpy_image(image)):
  299. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  300. if isinstance(size, int):
  301. size = (size, size)
  302. elif isinstance(size, (tuple, list)) and len(size) == 2:
  303. size = size
  304. else:
  305. raise ValueError('Size should be a int or a list/tuple with length of 2. ' 'But got {}'.format(size))
  306. height, width = _get_image_size(image)
  307. if padding is not None:
  308. image = pad(image, padding, fill, padding_mode)
  309. if pad_if_needed and height < size[0]:
  310. image = pad(image, (0, height - size[0]), fill, padding_mode)
  311. if pad_if_needed and width < size[1]:
  312. image = pad(image, (width - size[1], 0), fill, padding_mode)
  313. height, width = _get_image_size(image)
  314. target_height, target_width = size
  315. if height < target_height or width < target_width:
  316. raise ValueError(
  317. 'Crop size {} should be smaller than input image size {}. '.format(
  318. (target_height, target_width), (height, width)
  319. )
  320. )
  321. offset_height = random.randint(0, height - target_height)
  322. offset_width = random.randint(0, width - target_width)
  323. return crop(image, offset_height, offset_width, target_height, target_width)
  324. def random_resized_crop(image, size, scale, ratio, interpolation):
  325. if not (_is_pil_image(image) or _is_numpy_image(image)):
  326. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  327. if isinstance(size, int):
  328. size = (size, size)
  329. elif isinstance(size, (list, tuple)) and len(size) == 2:
  330. size = size
  331. else:
  332. raise TypeError('Size should be a int or a list/tuple with length of 2.' 'But got {}.'.format(size))
  333. if not (isinstance(scale, (list, tuple)) and len(scale) == 2):
  334. raise TypeError('Scale should be a list/tuple with length of 2.' 'But got {}.'.format(scale))
  335. if not (isinstance(ratio, (list, tuple)) and len(ratio) == 2):
  336. raise TypeError('Scale should be a list/tuple with length of 2.' 'But got {}.'.format(ratio))
  337. if (scale[0] > scale[1]) or (ratio[0] > ratio[1]):
  338. raise ValueError("Scale and ratio should be of kind (min, max)")
  339. def _get_param(image, scale, ratio):
  340. height, width = _get_image_size(image)
  341. area = height * width
  342. log_ratio = tuple(math.log(x) for x in ratio)
  343. for _ in range(10):
  344. target_area = np.random.uniform(*scale) * area
  345. aspect_ratio = math.exp(np.random.uniform(*log_ratio))
  346. w = int(round(math.sqrt(target_area * aspect_ratio)))
  347. h = int(round(math.sqrt(target_area / aspect_ratio)))
  348. if 0 < w <= width and 0 < h <= height:
  349. i = random.randint(0, height - h)
  350. j = random.randint(0, width - w)
  351. return i, j, h, w
  352. # Fallback to central crop
  353. in_ratio = float(width) / float(height)
  354. if in_ratio < min(ratio):
  355. w = width
  356. h = int(round(w / min(ratio)))
  357. elif in_ratio > max(ratio):
  358. h = height
  359. w = int(round(h * max(ratio)))
  360. else:
  361. # return whole image
  362. w = width
  363. h = height
  364. i = (height - h) // 2
  365. j = (width - w) // 2
  366. return i, j, h, w
  367. offset_height, offset_width, target_height, target_width = _get_param(image, scale, ratio)
  368. image = crop(image, offset_height, offset_width, target_height, target_width)
  369. image = resize(image, size, interpolation)
  370. return image
  371. def random_vflip(image, prob):
  372. if random.random() < prob:
  373. return vflip(image)
  374. return image
  375. def random_hflip(image, prob):
  376. if random.random() < prob:
  377. return hflip(image)
  378. return image
  379. def random_rotation(image, degrees, interpolation, expand, center, fill):
  380. if not (_is_pil_image(image) or _is_numpy_image(image)):
  381. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  382. if isinstance(degrees, numbers.Number):
  383. if degrees < 0:
  384. raise ValueError('If degrees is a single number, it must be positive.' 'But got {}'.format(degrees))
  385. degrees = (-degrees, degrees)
  386. elif not (isinstance(degrees, (list, tuple)) and len(degrees) == 2):
  387. raise ValueError('If degrees is a list/tuple, it must be length of 2.' 'But got {}'.format(degrees))
  388. else:
  389. if degrees[0] > degrees[1]:
  390. raise ValueError('if degrees is a list/tuple, it should be (min, max).')
  391. angle = np.random.uniform(degrees[0], degrees[1])
  392. if _is_pil_image(image):
  393. return F_pil.rotate(image, angle, interpolation, expand, center, fill)
  394. else:
  395. return F_cv2.rotate(image, angle, interpolation, expand, center, fill)
  396. def random_shear(image, degrees, interpolation, fill):
  397. if not (_is_pil_image(image) or _is_numpy_image(image)):
  398. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  399. if isinstance(degrees, numbers.Number):
  400. degrees = (-degrees, degrees, 0, 0)
  401. elif isinstance(degrees, (list, tuple)) and (len(degrees) == 2 or len(degrees) == 4):
  402. if len(degrees) == 2:
  403. degrees = (degrees[0], degrees[1], 0, 0)
  404. else:
  405. raise ValueError(
  406. 'degrees should be a single number or a list/tuple with length in (2 ,4).'
  407. 'But got {}'.format(degrees)
  408. )
  409. if _is_pil_image(image):
  410. return F_pil.random_shear(image, degrees, interpolation, fill)
  411. else:
  412. return F_cv2.random_shear(image, degrees, interpolation, fill)
  413. def random_shift(image, shift, interpolation, fill):
  414. if not (_is_pil_image(image) or _is_numpy_image(image)):
  415. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  416. if not (isinstance(shift, (tuple, list)) and len(shift) == 2):
  417. raise ValueError('Shift should be a list/tuple with length of 2.' 'But got {}'.format(shift))
  418. if _is_pil_image(image):
  419. return F_pil.random_shift(image, shift, interpolation, fill)
  420. else:
  421. return F_cv2.random_shift(image, shift, interpolation, fill)
  422. def random_zoom(image, zoom, interpolation, fill):
  423. if not (_is_pil_image(image) or _is_numpy_image(image)):
  424. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  425. if not (isinstance(zoom, (tuple, list)) and len(zoom) == 2):
  426. raise ValueError('Zoom should be a list/tuple with length of 2.' 'But got {}'.format(zoom))
  427. if not (0 <= zoom[0] <= zoom[1]):
  428. raise ValueError('Zoom values should be positive, and zoom[1] should be greater than zoom[0].')
  429. if _is_pil_image(image):
  430. return F_pil.random_zoom(image, zoom, interpolation, fill)
  431. else:
  432. return F_cv2.random_zoom(image, zoom, interpolation, fill)
  433. def random_affine(image, degrees, shift, zoom, shear, interpolation, fill):
  434. if not (_is_pil_image(image) or _is_numpy_image(image)):
  435. raise TypeError('image should be PIL Image or ndarray with dim=[2 or 3]. Got {}'.format(type(image)))
  436. if _is_pil_image(image):
  437. return F_pil.random_affine(image, degrees, shift, zoom, shear, interpolation, fill)
  438. else:
  439. return F_cv2.random_affine(image, degrees, shift, zoom, shear, interpolation, fill)

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.