You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

spatial_transformer.py 11 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import numpy as np
  4. from six.moves import xrange
  5. import tensorlayer as tl
  6. from tensorlayer import logging
  7. from tensorlayer.layers.core import Module
  8. __all__ = [
  9. 'transformer',
  10. 'batch_transformer',
  11. 'SpatialTransformer2dAffine',
  12. ]
  13. def transformer(U, theta, out_size, name='SpatialTransformer2dAffine'):
  14. """Spatial Transformer Layer for `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`__
  15. , see :class:`SpatialTransformer2dAffine` class.
  16. Parameters
  17. ----------
  18. U : list of float
  19. The output of a convolutional net should have the
  20. shape [num_batch, height, width, num_channels].
  21. theta: float
  22. The output of the localisation network should be [num_batch, 6], value range should be [0, 1] (via tanh).
  23. out_size: tuple of int
  24. The size of the output of the network (height, width)
  25. name: str
  26. Optional function name
  27. Returns
  28. -------
  29. Tensor
  30. The transformed tensor.
  31. References
  32. ----------
  33. - `Spatial Transformer Networks <https://arxiv.org/abs/1506.02025>`__
  34. - `TensorFlow/Models <https://github.com/tensorflow/models/tree/master/transformer>`__
  35. Notes
  36. -----
  37. To initialize the network to the identity transform init.
  38. >>> import tensorflow as tf
  39. >>> # ``theta`` to
  40. >>> identity = np.array([[1., 0., 0.], [0., 1., 0.]])
  41. >>> identity = identity.flatten()
  42. >>> theta = tf.Variable(initial_value=identity)
  43. """
  44. def _repeat(x, n_repeats):
  45. rep = tl.transpose(a=tl.expand_dims(tl.ones(shape=tl.stack([
  46. n_repeats,
  47. ])), axis=1), perm=[1, 0])
  48. rep = tl.cast(rep, 'int32')
  49. x = tl.matmul(tl.reshape(x, (-1, 1)), rep)
  50. return tl.reshape(x, [-1])
  51. def _interpolate(im, x, y, out_size):
  52. # constants
  53. num_batch, height, width, channels = tl.get_tensor_shape(im)
  54. x = tl.cast(x, 'float32')
  55. y = tl.cast(y, 'float32')
  56. height_f = tl.cast(height, 'float32')
  57. width_f = tl.cast(width, 'float32')
  58. out_height = out_size[0]
  59. out_width = out_size[1]
  60. zero = tl.zeros([], dtype='int32')
  61. max_y = tl.cast(height - 1, 'int32')
  62. max_x = tl.cast(width - 1, 'int32')
  63. # scale indices from [-1, 1] to [0, width/height]
  64. x = (x + 1.0) * (width_f) / 2.0
  65. y = (y + 1.0) * (height_f) / 2.0
  66. # do sampling
  67. x0 = tl.cast(tl.floor(x), 'int32')
  68. x1 = x0 + 1
  69. y0 = tl.cast(tl.floor(y), 'int32')
  70. y1 = y0 + 1
  71. x0 = tl.clip_by_value(x0, zero, max_x)
  72. x1 = tl.clip_by_value(x1, zero, max_x)
  73. y0 = tl.clip_by_value(y0, zero, max_y)
  74. y1 = tl.clip_by_value(y1, zero, max_y)
  75. dim2 = width
  76. dim1 = width * height
  77. base = _repeat(tl.range(num_batch) * dim1, out_height * out_width)
  78. base_y0 = base + y0 * dim2
  79. base_y1 = base + y1 * dim2
  80. idx_a = base_y0 + x0
  81. idx_b = base_y1 + x0
  82. idx_c = base_y0 + x1
  83. idx_d = base_y1 + x1
  84. # use indices to lookup pixels in the flat image and restore
  85. # channels dim
  86. im_flat = tl.reshape(im, tl.stack([-1, channels]))
  87. im_flat = tl.cast(im_flat, 'float32')
  88. Ia = tl.gather(im_flat, idx_a)
  89. Ib = tl.gather(im_flat, idx_b)
  90. Ic = tl.gather(im_flat, idx_c)
  91. Id = tl.gather(im_flat, idx_d)
  92. # and finally calculate interpolated values
  93. x0_f = tl.cast(x0, 'float32')
  94. x1_f = tl.cast(x1, 'float32')
  95. y0_f = tl.cast(y0, 'float32')
  96. y1_f = tl.cast(y1, 'float32')
  97. wa = tl.expand_dims(((x1_f - x) * (y1_f - y)), 1)
  98. wb = tl.expand_dims(((x1_f - x) * (y - y0_f)), 1)
  99. wc = tl.expand_dims(((x - x0_f) * (y1_f - y)), 1)
  100. wd = tl.expand_dims(((x - x0_f) * (y - y0_f)), 1)
  101. output = tl.add_n([wa * Ia, wb * Ib, wc * Ic, wd * Id])
  102. return output
  103. def _meshgrid(height, width):
  104. # This should be equivalent to:
  105. # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
  106. # np.linspace(-1, 1, height))
  107. # ones = np.ones(np.prod(x_t.shape))
  108. # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
  109. x_t = tl.matmul(
  110. tl.ones(shape=tl.stack([height, 1])),
  111. tl.transpose(a=tl.expand_dims(tl.linspace(-1.0, 1.0, width), 1), perm=[1, 0])
  112. )
  113. y_t = tl.matmul(tl.expand_dims(tl.linspace(-1.0, 1.0, height), 1), tl.ones(shape=tl.stack([1, width])))
  114. x_t_flat = tl.reshape(x_t, (1, -1))
  115. y_t_flat = tl.reshape(y_t, (1, -1))
  116. ones = tl.ones(shape=tl.get_tensor_shape(x_t_flat))
  117. grid = tl.concat(axis=0, values=[x_t_flat, y_t_flat, ones])
  118. return grid
  119. def _transform(theta, input_dim, out_size):
  120. num_batch, _, _, num_channels = tl.get_tensor_shape(input_dim)
  121. theta = tl.reshape(theta, (-1, 2, 3))
  122. theta = tl.cast(theta, 'float32')
  123. # grid of (x_t, y_t, 1), eq (1) in ref [1]
  124. out_height = out_size[0]
  125. out_width = out_size[1]
  126. grid = _meshgrid(out_height, out_width)
  127. grid = tl.expand_dims(grid, 0)
  128. grid = tl.reshape(grid, [-1])
  129. grid = tl.tile(grid, tl.stack([num_batch]))
  130. grid = tl.reshape(grid, tl.stack([num_batch, 3, -1]))
  131. # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
  132. T_g = tl.matmul(theta, grid)
  133. x_s = tl.slice(T_g, [0, 0, 0], [-1, 1, -1])
  134. y_s = tl.slice(T_g, [0, 1, 0], [-1, 1, -1])
  135. x_s_flat = tl.reshape(x_s, [-1])
  136. y_s_flat = tl.reshape(y_s, [-1])
  137. input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat, out_size)
  138. output = tl.reshape(input_transformed, tl.stack([num_batch, out_height, out_width, num_channels]))
  139. return output
  140. output = _transform(theta, U, out_size)
  141. return output
  142. def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer2dAffine'):
  143. """Batch Spatial Transformer function for `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`__.
  144. Parameters
  145. ----------
  146. U : list of float
  147. tensor of inputs [batch, height, width, num_channels]
  148. thetas : list of float
  149. a set of transformations for each input [batch, num_transforms, 6]
  150. out_size : list of int
  151. the size of the output [out_height, out_width]
  152. name : str
  153. optional function name
  154. Returns
  155. ------
  156. float
  157. Tensor of size [batch * num_transforms, out_height, out_width, num_channels]
  158. """
  159. # with tf.compat.v1.variable_scope(name):
  160. num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
  161. indices = [[i] * num_transforms for i in xrange(num_batch)]
  162. input_repeated = tl.gather(U, tl.reshape(indices, [-1]))
  163. return transformer(input_repeated, thetas, out_size)
  164. class SpatialTransformer2dAffine(Module):
  165. """The :class:`SpatialTransformer2dAffine` class is a 2D `Spatial Transformer Layer <https://arxiv.org/abs/1506.02025>`__ for
  166. `2D Affine Transformation <https://en.wikipedia.org/wiki/Affine_transformation>`__.
  167. Parameters
  168. -----------
  169. out_size : tuple of int or None
  170. - The size of the output of the network (height, width), the feature maps will be resized by this.
  171. in_channels : int
  172. The number of in channels.
  173. data_format : str
  174. "channel_last" (NHWC, default) or "channels_first" (NCHW).
  175. name : str
  176. - A unique layer name.
  177. References
  178. -----------
  179. - `Spatial Transformer Networks <https://arxiv.org/abs/1506.02025>`__
  180. - `TensorFlow/Models <https://github.com/tensorflow/models/tree/master/transformer>`__
  181. """
  182. def __init__(
  183. self,
  184. out_size=(40, 40),
  185. in_channels=None,
  186. data_format='channel_last',
  187. name=None,
  188. ):
  189. super(SpatialTransformer2dAffine, self).__init__(name)
  190. self.in_channels = in_channels
  191. self.out_size = out_size
  192. self.data_format = data_format
  193. if self.in_channels is not None:
  194. self.build(self.in_channels)
  195. self._built = True
  196. logging.info("SpatialTransformer2dAffine %s" % self.name)
  197. def __repr__(self):
  198. s = '{classname}(out_size={out_size}, '
  199. if self.in_channels is not None:
  200. s += 'in_channels=\'{in_channels}\''
  201. if self.name is not None:
  202. s += ', name=\'{name}\''
  203. s += ')'
  204. return s.format(classname=self.__class__.__name__, **self.__dict__)
  205. def build(self, inputs_shape):
  206. if self.in_channels is None and len(inputs_shape) != 2:
  207. raise AssertionError("The dimension of theta layer input must be rank 2, please reshape or flatten it")
  208. if self.in_channels:
  209. shape = [self.in_channels, 6]
  210. else:
  211. # self.in_channels = inputs_shape[1] # BUG
  212. # shape = [inputs_shape[1], 6]
  213. self.in_channels = inputs_shape[0][-1] # zsdonghao
  214. shape = [self.in_channels, 6]
  215. self.W = self._get_weights("weights", shape=tuple(shape), init=tl.initializers.Zeros())
  216. identity = np.reshape(np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), newshape=(6, ))
  217. self.b = self._get_weights("biases", shape=(6, ), init=tl.initializers.Constant(identity))
  218. def forward(self, inputs):
  219. """
  220. :param inputs: a tuple (theta_input, U).
  221. - theta_input is of size [batch, in_channels]. We will use a :class:`Dense` to
  222. make the theta size to [batch, 6], value range to [0, 1] (via tanh).
  223. - U is the previous layer, which the affine transformation is applied to.
  224. :return: tensor of size [batch, out_size[0], out_size[1], n_channels] after affine transformation,
  225. n_channels is identical to that of U.
  226. """
  227. theta_input, U = inputs
  228. theta = tl.tanh(tl.matmul(theta_input, self.W) + self.b)
  229. outputs = transformer(U, theta, out_size=self.out_size)
  230. # automatically set batch_size and channels
  231. # e.g. [?, 40, 40, ?] --> [64, 40, 40, 1] or [64, 20, 20, 4]
  232. batch_size = theta_input.shape[0]
  233. n_channels = U.shape[-1]
  234. if self.data_format == 'channel_last':
  235. outputs = tl.reshape(outputs, shape=[batch_size, self.out_size[0], self.out_size[1], n_channels])
  236. else:
  237. raise Exception("unimplement data_format {}".format(self.data_format))
  238. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.