You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

shape.py 6.3 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. from tensorlayer import logging
  4. from tensorlayer.layers.core import Module
  5. import tensorlayer as tl
  6. __all__ = [
  7. 'Flatten',
  8. 'Reshape',
  9. 'Transpose',
  10. 'Shuffle',
  11. ]
  12. class Flatten(Module):
  13. """A layer that reshapes high-dimension input into a vector.
  14. Then we often apply Dense, RNN, Concat and etc on the top of a flatten layer.
  15. [batch_size, mask_row, mask_col, n_mask] ---> [batch_size, mask_row * mask_col * n_mask]
  16. Parameters
  17. ----------
  18. name : None or str
  19. A unique layer name.
  20. Examples
  21. --------
  22. >>> x = tl.layers.Input([8, 4, 3], name='input')
  23. >>> y = tl.layers.Flatten(name='flatten')(x)
  24. [8, 12]
  25. """
  26. def __init__(self, name=None): #'flatten'):
  27. super(Flatten, self).__init__(name)
  28. self.build()
  29. self._built = True
  30. logging.info("Flatten %s:" % (self.name))
  31. def __repr__(self):
  32. s = '{classname}('
  33. s += 'name=\'{name}\''
  34. s += ')'
  35. return s.format(classname=self.__class__.__name__, **self.__dict__)
  36. def build(self, inputs_shape=None):
  37. self.flatten_reshape = tl.ops.FlattenReshape()
  38. # @tf.function
  39. def forward(self, inputs):
  40. outputs = self.flatten_reshape(inputs)
  41. return outputs
  42. class Reshape(Module):
  43. """A layer that reshapes a given tensor.
  44. Parameters
  45. ----------
  46. shape : tuple of int
  47. The output shape, see ``tf.reshape``.
  48. name : str
  49. A unique layer name.
  50. Examples
  51. --------
  52. >>> x = tl.layers.Input([8, 4, 3], name='input')
  53. >>> y = tl.layers.Reshape(shape=[-1, 12], name='reshape')(x)
  54. (8, 12)
  55. """
  56. def __init__(self, shape, name=None): #'reshape'):
  57. super(Reshape, self).__init__(name)
  58. self.shape = shape
  59. logging.info("Reshape %s" % (self.name))
  60. self.build()
  61. self._built = True
  62. def __repr__(self):
  63. s = '{classname}('
  64. s += 'shape={shape},'
  65. s += 'name=\'{name}\''
  66. s += ')'
  67. return s.format(classname=self.__class__.__name__, **self.__dict__)
  68. def build(self, inputs_shape=None):
  69. self.reshape = tl.ops.Reshape(self.shape)
  70. def forward(self, inputs):
  71. outputs = self.reshape(inputs)
  72. return outputs
  73. class Transpose(Module):
  74. """A layer that transposes the dimension of a tensor.
  75. See `tf.transpose() <https://www.tensorflow.org/api_docs/python/tf/transpose>`__ .
  76. Parameters
  77. ----------
  78. perm: list of int
  79. The permutation of the dimensions, similar with ``numpy.transpose``.
  80. If None, it is set to (n-1...0), where n is the rank of the input tensor.
  81. conjugate: bool
  82. By default False. If True, returns the complex conjugate of complex numbers (and transposed)
  83. For example [[1+1j, 2+2j]] --> [[1-1j], [2-2j]]
  84. name : str
  85. A unique layer name.
  86. Examples
  87. ----------
  88. >>> x = tl.layers.Input([8, 4, 3], name='input')
  89. >>> y = tl.layers.Transpose(perm=[0, 2, 1], conjugate=False, name='trans')(x)
  90. (8, 3, 4)
  91. """
  92. def __init__(self, perm=None, conjugate=False, name=None): #'transpose'):
  93. super(Transpose, self).__init__(name)
  94. self.perm = perm
  95. self.conjugate = conjugate
  96. logging.info("Transpose %s: perm: %s, conjugate: %s" % (self.name, self.perm, self.conjugate))
  97. self.build()
  98. self._built = True
  99. def __repr__(self):
  100. s = '{classname}('
  101. s += 'perm={perm},'
  102. s += 'conjugate={conjugate},'
  103. s += 'name=\'{name}\''
  104. s += ')'
  105. return s.format(classname=self.__class__.__name__, **self.__dict__)
  106. def build(self, inputs_shape=None):
  107. self.transpose = tl.ops.Transpose(perm=self.perm, conjugate=self.conjugate)
  108. # @tf.function
  109. def forward(self, inputs):
  110. outputs = self.transpose(a=inputs)
  111. return outputs
  112. class Shuffle(Module):
  113. """A layer that shuffle a 2D image [batch, height, width, channel], see `here <https://arxiv.org/abs/1707.01083>`__.
  114. Parameters
  115. ----------
  116. group: int
  117. The number of groups.
  118. name : str
  119. A unique layer name.
  120. Examples
  121. --------
  122. >>> x = tl.layers.Input([1, 16, 16, 8], name='input')
  123. >>> y = tl.layers.Shuffle(group=2, name='shuffle')(x)
  124. (1, 16, 16, 8)
  125. """
  126. def __init__(self, group, in_channels=None, name=None): #'reshape'):
  127. super(Shuffle, self).__init__(name)
  128. self.group = group
  129. self.inchannels = in_channels
  130. logging.info("Shuffle %s" % (self.name))
  131. self.build()
  132. self._built = True
  133. def __repr__(self):
  134. s = '{classname}('
  135. s += 'group={group},'
  136. s += 'name=\'{name}\''
  137. s += ')'
  138. return s.format(classname=self.__class__.__name__, **self.__dict__)
  139. def build(self, inputs_shape=None):
  140. self.transpose = tl.ops.Transpose([0, 1, 2, 4, 3])
  141. inputs_shape = self.inchannels
  142. if tl.BACKEND == 'mindspore' and inputs_shape == None:
  143. raise ValueError("Do you forget to pass the keyword argument 'in_channels")
  144. if tl.BACKEND == 'mindspore':
  145. h, w, in_channel = inputs_shape[1:]
  146. if in_channel % self.group != 0:
  147. raise ValueError(
  148. "The in_channel must be a multiple of the number of groups. The in_channel got %d and the number of groups is %d."
  149. % (in_channel, self.group)
  150. )
  151. self.reshape1 = tl.ops.Reshape([-1, h, w, in_channel // self.group, self.group])
  152. self.reshape2 = tl.ops.Reshape([-1, h, w, in_channel])
  153. def forward(self, inputs):
  154. if tl.BACKEND == 'tensorflow':
  155. in_shape = tl.get_tensor_shape(inputs)
  156. h, w, in_channel = in_shape[1:]
  157. reshape1 = tl.ops.Reshape([-1, h, w, in_channel // self.group, self.group])
  158. temp = reshape1(inputs)
  159. temp = self.transpose(temp)
  160. reshape2 = tl.ops.Reshape([-1, h, w, in_channel])
  161. outputs = reshape2(temp)
  162. else:
  163. temp = self.reshape1(inputs)
  164. temp = self.transpose(temp)
  165. outputs = self.reshape2(temp)
  166. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.