You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stack.py 3.1 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. __all__ = [
  7. 'Stack',
  8. 'UnStack',
  9. ]
  10. class Stack(Module):
  11. """
  12. The :class:`Stack` class is a layer for stacking a list of rank-R tensors into one rank-(R+1) tensor, see `tf.stack() <https://www.tensorflow.org/api_docs/python/tf/stack>`__.
  13. Parameters
  14. ----------
  15. axis : int
  16. New dimension along which to stack.
  17. name : str
  18. A unique layer name.
  19. Examples
  20. ---------
  21. >>> import tensorlayer as tl
  22. >>> ni = tl.layers.Input([10, 784], name='input')
  23. >>> net1 = tl.layers.Dense(10, name='dense1')(ni)
  24. >>> net2 = tl.layers.Dense(10, name='dense2')(ni)
  25. >>> net3 = tl.layers.Dense(10, name='dense3')(ni)
  26. >>> net = tl.layers.Stack(axis=1, name='stack')([net1, net2, net3])
  27. (10, 3, 10)
  28. """
  29. def __init__(
  30. self,
  31. axis=1,
  32. name=None, #'stack',
  33. ):
  34. super().__init__(name)
  35. self.axis = axis
  36. self.build(None)
  37. self._built = True
  38. logging.info("Stack %s: axis: %d" % (self.name, self.axis))
  39. def __repr__(self):
  40. s = '{classname}(axis={axis}'
  41. if self.name is not None:
  42. s += ', name=\'{name}\''
  43. s += ')'
  44. return s.format(classname=self.__class__.__name__, **self.__dict__)
  45. def build(self, inputs_shape):
  46. self.stack = tl.ops.Stack(axis=self.axis)
  47. def forward(self, inputs):
  48. outputs = self.stack(inputs)
  49. return outputs
  50. class UnStack(Module):
  51. """
  52. The :class:`UnStack` class is a layer for unstacking the given dimension of a rank-R tensor into rank-(R-1) tensors., see `tf.unstack() <https://www.tensorflow.org/api_docs/python/tf/unstack>`__.
  53. Parameters
  54. ----------
  55. num : int or None
  56. The length of the dimension axis. Automatically inferred if None (the default).
  57. axis : int
  58. Dimension along which axis to concatenate.
  59. name : str
  60. A unique layer name.
  61. Returns
  62. -------
  63. list of :class:`Layer`
  64. The list of layer objects unstacked from the input.
  65. Examples
  66. --------
  67. >>> ni = tl.layers.Input([4, 10], name='input')
  68. >>> nn = tl.layers.Dense(n_units=5)(ni)
  69. >>> nn = tl.layers.UnStack(axis=1)(nn) # unstack in channel axis
  70. >>> len(nn) # 5
  71. >>> nn[0].shape # (4,)
  72. """
  73. def __init__(self, num=None, axis=0, name=None): #'unstack'):
  74. super().__init__(name)
  75. self.num = num
  76. self.axis = axis
  77. self.build(None)
  78. self._built = True
  79. logging.info("UnStack %s: num: %s axis: %d" % (self.name, self.num, self.axis))
  80. def __repr__(self):
  81. s = '{classname}(num={num}, axis={axis}'
  82. if self.name is not None:
  83. s += ', name=\'{name}\''
  84. s += ')'
  85. return s.format(classname=self.__class__.__name__, **self.__dict__)
  86. def build(self, inputs_shape):
  87. self.unstack = tl.ops.Unstack(num=self.num, axis=self.axis)
  88. def forward(self, inputs):
  89. outputs = self.unstack(inputs)
  90. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.