You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lambda_layers.py 10 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorflow as tf
  4. from tensorlayer import logging
  5. from tensorlayer.files import utils
  6. from tensorlayer.layers.core import Module
  7. __all__ = [
  8. 'Lambda',
  9. 'ElementwiseLambda',
  10. ]
  11. class Lambda(Module):
  12. """A layer that takes a user-defined function using Lambda.
  13. If the function has trainable weights, the weights should be provided.
  14. Remember to make sure the weights provided when the layer is constructed are SAME as
  15. the weights used when the layer is forwarded.
  16. For multiple inputs see :class:`ElementwiseLambda`.
  17. Parameters
  18. ----------
  19. fn : function
  20. The function that applies to the inputs (e.g. tensor from the previous layer).
  21. fn_weights : list
  22. The trainable weights for the function if any. Optional.
  23. fn_args : dict
  24. The arguments for the function if any. Optional.
  25. name : str or None
  26. A unique layer name.
  27. Examples
  28. ---------
  29. Non-parametric and non-args case:
  30. This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
  31. >>> x = tl.layers.Input([8, 3], name='input')
  32. >>> y = tl.layers.Lambda(lambda x: 2*x, name='lambda')(x)
  33. Non-parametric and with args case:
  34. This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
  35. >>> def customize_func(x, foo=42): # x is the inputs, foo is an argument
  36. >>> return foo * x
  37. >>> x = tl.layers.Input([8, 3], name='input')
  38. >>> lambdalayer = tl.layers.Lambda(customize_func, fn_args={'foo': 2}, name='lambda')(x)
  39. Any function with outside variables:
  40. This case has not been supported in Model.save() / Model.load() yet.
  41. Please avoid using Model.save() / Model.load() to save / load models that contain such Lambda layer. Instead, you may use Model.save_weights() / Model.load_weights() to save / load model weights.
  42. Note: In this case, fn_weights should be a list, and then the trainable weights in this Lambda layer can be added into the weights of the whole model.
  43. >>> a = tl.ops.Variable(1.0)
  44. >>> def func(x):
  45. >>> return x + a
  46. >>> x = tl.layers.Input([8, 3], name='input')
  47. >>> y = tl.layers.Lambda(func, fn_weights=[a], name='lambda')(x)
  48. Parametric case, merge other wrappers into TensorLayer:
  49. This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
  50. >>> layers = [
  51. >>> tl.layers.Dense(10, act=tl.Relu),
  52. >>> tl.layers.Dense(5, act=tl.Relu),
  53. >>> tl.layers.Dense(1, activation=tf.identity)
  54. >>> ]
  55. >>> perceptron = tl.layers.SequentialLayer(layers)
  56. >>> # in order to compile keras model and get trainable_variables of the keras model
  57. >>> _ = perceptron(np.random.random([100, 5]).astype(np.float32))
  58. >>>
  59. >>> class CustomizeModel(tl.layers.Module):
  60. >>> def __init__(self):
  61. >>> super(CustomizeModel, self).__init__()
  62. >>> self.dense = tl.layers.Dense(in_channels=1, n_units=5)
  63. >>> self.lambdalayer = tl.layers.Lambda(perceptron, perceptron.trainable_variables)
  64. >>>
  65. >>> def forward(self, x):
  66. >>> z = self.dense(x)
  67. >>> z = self.lambdalayer(z)
  68. >>> return z
  69. >>>
  70. >>> optimizer = tl.optimizers.Adam(learning_rate=0.1)
  71. >>> model = CustomizeModel()
  72. >>> model.set_train()
  73. >>>
  74. >>> for epoch in range(50):
  75. >>> with tf.GradientTape() as tape:
  76. >>> pred_y = model(data_x)
  77. >>> loss = tl.cost.mean_squared_error(pred_y, data_y)
  78. >>>
  79. >>> gradients = tape.gradient(loss, model.trainable_weights)
  80. >>> optimizer.apply_gradients(zip(gradients, model.trainable_weights))
  81. """
  82. def __init__(
  83. self,
  84. fn,
  85. fn_weights=None,
  86. fn_args=None,
  87. name=None,
  88. ):
  89. super(Lambda, self).__init__(name=name)
  90. self.fn = fn
  91. self._trainable_weights = fn_weights if fn_weights is not None else []
  92. self.fn_args = fn_args if fn_args is not None else {}
  93. try:
  94. fn_name = repr(self.fn)
  95. except:
  96. fn_name = 'name not available'
  97. logging.info("Lambda %s: func: %s, len_weights: %s" % (self.name, fn_name, len(self._trainable_weights)))
  98. self.build()
  99. self._built = True
  100. def __repr__(self):
  101. s = '{classname}('
  102. s += 'fn={fn_name},'
  103. s += 'len_weights={len_weights},'
  104. s += 'name=\'{name}\''
  105. s += ')'
  106. try:
  107. fn_name = repr(self.fn)
  108. except:
  109. fn_name = 'name not available'
  110. return s.format(
  111. classname=self.__class__.__name__, fn_name=fn_name, len_weights=len(self._trainable_weights),
  112. **self.__dict__
  113. )
  114. def build(self, inputs_shape=None):
  115. pass
  116. def forward(self, inputs, **kwargs):
  117. if len(kwargs) == 0:
  118. outputs = self.fn(inputs, **self.fn_args)
  119. else:
  120. outputs = self.fn(inputs, **kwargs)
  121. return outputs
  122. def get_args(self):
  123. init_args = {}
  124. if isinstance(self.fn, tf.keras.layers.Layer) or isinstance(self.fn, tf.keras.Model):
  125. init_args.update({"layer_type": "keraslayer"})
  126. init_args["fn"] = utils.save_keras_model(self.fn)
  127. init_args["fn_weights"] = None
  128. if len(self._nodes) == 0:
  129. init_args["keras_input_shape"] = []
  130. else:
  131. init_args["keras_input_shape"] = self._nodes[0].in_tensors[0].get_shape().as_list()
  132. else:
  133. init_args = {"layer_type": "normal"}
  134. return init_args
  135. class ElementwiseLambda(Module):
  136. """A layer that use a custom function to combine multiple :class:`Layer` inputs.
  137. If the function has trainable weights, the weights should be provided.
  138. Remember to make sure the weights provided when the layer is constructed are SAME as
  139. the weights used when the layer is forwarded.
  140. Parameters
  141. ----------
  142. fn : function
  143. The function that applies to the inputs (e.g. tensor from the previous layer).
  144. fn_weights : list
  145. The trainable weights for the function if any. Optional.
  146. fn_args : dict
  147. The arguments for the function if any. Optional.
  148. name : str or None
  149. A unique layer name.
  150. Examples
  151. --------
  152. Non-parametric and with args case
  153. This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
  154. >>> def func(noise, mean, std, foo=42):
  155. >>> return mean + noise * tf.exp(std * 0.5) + foo
  156. >>> noise = tl.layers.Input([100, 1])
  157. >>> mean = tl.layers.Input([100, 1])
  158. >>> std = tl.layers.Input([100, 1])
  159. >>> out = tl.layers.ElementwiseLambda(fn=func, fn_args={'foo': 84}, name='elementwiselambda')([noise, mean, std])
  160. Non-parametric and non-args case
  161. This case is supported in the Model.save() / Model.load() to save / load the whole model architecture and weights(optional).
  162. >>> noise = tl.layers.Input([100, 1])
  163. >>> mean = tl.layers.Input([100, 1])
  164. >>> std = tl.layers.Input([100, 1])
  165. >>> out = tl.layers.ElementwiseLambda(fn=lambda x, y, z: x + y * tf.exp(z * 0.5), name='elementwiselambda')([noise, mean, std])
  166. Any function with outside variables
  167. This case has not been supported in Model.save() / Model.load() yet.
  168. Please avoid using Model.save() / Model.load() to save / load models that contain such ElementwiseLambda layer. Instead, you may use Model.save_weights() / Model.load_weights() to save / load model weights.
  169. Note: In this case, fn_weights should be a list, and then the trainable weights in this ElementwiseLambda layer can be added into the weights of the whole model.
  170. >>> vara = [tf.Variable(1.0)]
  171. >>> def func(noise, mean, std):
  172. >>> return mean + noise * tf.exp(std * 0.5) + vara
  173. >>> noise = tl.layers.Input([100, 1])
  174. >>> mean = tl.layers.Input([100, 1])
  175. >>> std = tl.layers.Input([100, 1])
  176. >>> out = tl.layers.ElementwiseLambda(fn=func, fn_weights=vara, name='elementwiselambda')([noise, mean, std])
  177. """
  178. def __init__(
  179. self,
  180. fn,
  181. fn_weights=None,
  182. fn_args=None,
  183. name=None, #'elementwiselambda',
  184. ):
  185. super(ElementwiseLambda, self).__init__(name=name)
  186. self.fn = fn
  187. self._trainable_weights = fn_weights if fn_weights is not None else []
  188. self.fn_args = fn_args if fn_args is not None else {}
  189. try:
  190. fn_name = repr(self.fn)
  191. except:
  192. fn_name = 'name not available'
  193. logging.info(
  194. "ElementwiseLambda %s: func: %s, len_weights: %s" % (self.name, fn_name, len(self._trainable_weights))
  195. )
  196. self.build()
  197. self._built = True
  198. def __repr__(self):
  199. s = '{classname}('
  200. s += 'fn={fn_name},'
  201. s += 'len_weights={len_weights},'
  202. s += 'name=\'{name}\''
  203. s += ')'
  204. try:
  205. fn_name = repr(self.fn)
  206. except:
  207. fn_name = 'name not available'
  208. return s.format(
  209. classname=self.__class__.__name__, fn_name=fn_name, len_weights=len(self._trainable_weights),
  210. **self.__dict__
  211. )
  212. def build(self, inputs_shape=None):
  213. # do nothing
  214. # the weights of the function are provided when the Lambda layer is constructed
  215. pass
  216. # @tf.function
  217. def forward(self, inputs, **kwargs):
  218. if not isinstance(inputs, list):
  219. raise TypeError(
  220. "The inputs should be a list of values which corresponds with the customised lambda function."
  221. )
  222. if len(kwargs) == 0:
  223. outputs = self.fn(*inputs, **self.fn_args)
  224. else:
  225. outputs = self.fn(*inputs, **kwargs)
  226. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.