You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

merge.py 4.5 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. __all__ = [
  7. 'Concat',
  8. 'Elementwise',
  9. ]
  10. class Concat(Module):
  11. """A layer that concats multiple tensors according to given axis.
  12. Parameters
  13. ----------
  14. concat_dim : int
  15. The dimension to concatenate.
  16. name : None or str
  17. A unique layer name.
  18. Examples
  19. ----------
  20. >>> class CustomModel(Module):
  21. >>> def __init__(self):
  22. >>> super(CustomModel, self).__init__(name="custom")
  23. >>> self.dense1 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu1_1')
  24. >>> self.dense2 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu2_1')
  25. >>> self.concat = tl.layers.Concat(concat_dim=1, name='concat_layer')
  26. >>> def forward(self, inputs):
  27. >>> d1 = self.dense1(inputs)
  28. >>> d2 = self.dense2(inputs)
  29. >>> outputs = self.concat([d1, d2])
  30. >>> return outputs
  31. """
  32. def __init__(
  33. self,
  34. concat_dim=-1,
  35. name=None, #'concat',
  36. ):
  37. super(Concat, self).__init__(name)
  38. self.concat_dim = concat_dim
  39. self.build(None)
  40. self._built = True
  41. logging.info("Concat %s: concat_dim: %d" % (self.name, concat_dim))
  42. def __repr__(self):
  43. s = ('{classname}(concat_dim={concat_dim})')
  44. return s.format(classname=self.__class__.__name__, **self.__dict__)
  45. def build(self, inputs_shape):
  46. self.concat = tl.ops.Concat(self.concat_dim)
  47. # @tf.function
  48. def forward(self, inputs):
  49. """
  50. prev_layer : list of :class:`Layer`
  51. List of layers to concatenate.
  52. """
  53. outputs = self.concat(inputs)
  54. return outputs
  55. class Elementwise(Module):
  56. """A layer that combines multiple :class:`Layer` that have the same output shapes
  57. according to an element-wise operation.
  58. If the element-wise operation is complicated, please consider to use :class:`ElementwiseLambda`.
  59. Parameters
  60. ----------
  61. combine_fn : a TensorFlow element-wise combine function
  62. e.g. AND is ``tl.minimum`` ; OR is ``tl.maximum`` ; ADD is ``tl.add`` ; MUL is ``tl.multiply`` and so on.
  63. See `TensorFlow Math API <https://www.tensorflow.org/versions/master/api_docs/python/math_ops.html#math>`__ .
  64. If the combine function is more complicated, please consider to use :class:`ElementwiseLambda`.
  65. act : activation function
  66. The activation function of this layer.
  67. name : None or str
  68. A unique layer name.
  69. Examples
  70. --------
  71. >>> class CustomModel(tl.models.Model):
  72. >>> def __init__(self):
  73. >>> super(CustomModel, self).__init__(name="custom")
  74. >>> self.dense1 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu1_1')
  75. >>> self.dense2 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu2_1')
  76. >>> self.element = tl.layers.Elementwise(combine_fn=tl.minimum, name='minimum', act=tl.identity)
  77. >>> def forward(self, inputs):
  78. >>> d1 = self.dense1(inputs)
  79. >>> d2 = self.dense2(inputs)
  80. >>> outputs = self.element([d1, d2])
  81. >>> return outputs
  82. """
  83. def __init__(
  84. self,
  85. combine_fn=tl.ops.minimum,
  86. act=None,
  87. name=None, #'elementwise',
  88. ):
  89. super(Elementwise, self).__init__(name, act=act)
  90. self.combine_fn = combine_fn
  91. self.combine_fn_str = str(combine_fn).split(' ')[1]
  92. self.build(None)
  93. self._built = True
  94. logging.info(
  95. "Elementwise %s: fn: %s act: %s" %
  96. (self.name, combine_fn.__name__, ('No Activation' if self.act is None else self.act.__class__.__name__))
  97. )
  98. def __repr__(self):
  99. actstr = self.act.__class__.__name__ if self.act is not None else 'No Activation'
  100. s = ('{classname}(combine_fn={combine_fn_str}, ' + actstr)
  101. if self.name is not None:
  102. s += ', name=\'{name}\''
  103. s += ')'
  104. return s.format(classname=self.__class__.__name__, **self.__dict__)
  105. def build(self, inputs_shape):
  106. pass
  107. # @tf.function
  108. def forward(self, inputs):
  109. outputs = inputs[0]
  110. for input in inputs[1:]:
  111. outputs = self.combine_fn(outputs, input)
  112. if self.act:
  113. outputs = self.act(outputs)
  114. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.