You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

extend.py 2.7 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. __all__ = [
  7. 'ExpandDims',
  8. 'Tile',
  9. ]
  10. class ExpandDims(Module):
  11. """
  12. The :class:`ExpandDims` class inserts a dimension of 1 into a tensor's shape,
  13. see `tf.expand_dims() <https://www.tensorflow.org/api_docs/python/tf/expand_dims>`__ .
  14. Parameters
  15. ----------
  16. axis : int
  17. The dimension index at which to expand the shape of input.
  18. name : str
  19. A unique layer name. If None, a unique name will be automatically assigned.
  20. Examples
  21. --------
  22. >>> x = tl.layers.Input([10, 3], name='in')
  23. >>> y = tl.layers.ExpandDims(axis=-1)(x)
  24. [10, 3, 1]
  25. """
  26. def __init__(
  27. self,
  28. axis,
  29. name=None # 'expand_dims',
  30. ):
  31. super(ExpandDims, self).__init__(name)
  32. self.axis = axis
  33. self.build((None, ))
  34. self._built = True
  35. logging.info("ExpandDims %s: axis: %d" % (self.name, self.axis))
  36. def __repr__(self):
  37. s = '{classname}('
  38. s += 'axis={axis},'
  39. s += 'name={name}'
  40. s += ")"
  41. return s.format(classname=self.__class__.__name__, **self.__dict__)
  42. def build(self, inputs_shape):
  43. self.expand_dims = tl.ops.ExpandDims(axis=self.axis)
  44. # @tf.function
  45. def forward(self, inputs):
  46. outputs = self.expand_dims(inputs)
  47. return outputs
  48. class Tile(Module):
  49. """
  50. The :class:`Tile` class constructs a tensor by tiling a given tensor,
  51. see `tf.tile() <https://www.tensorflow.org/api_docs/python/tf/tile>`__ .
  52. Parameters
  53. ----------
  54. multiples: tensor
  55. Must be one of the following types: int32, int64.
  56. 1-D Length must be the same as the number of dimensions in input.
  57. name : None or str
  58. A unique layer name.
  59. Examples
  60. --------
  61. >>> x = tl.layers.Input([10, 3], name='in')
  62. >>> y = tl.layers.Tile(multiples=[2, 3])(x)
  63. """
  64. def __init__(self, multiples=None, name=None): #'tile'):
  65. super(Tile, self).__init__(name)
  66. self.multiples = multiples
  67. self.build((None, ))
  68. self._built = True
  69. logging.info("Tile %s: multiples: %s" % (self.name, self.multiples))
  70. def __repr__(self):
  71. s = '{classname}('
  72. s += 'multiples={multiples},'
  73. s += 'name={name}'
  74. s += ")"
  75. return s.format(classname=self.__class__.__name__, **self.__dict__)
  76. def build(self, inputs_shape):
  77. self.tile = tl.ops.Tile()
  78. # @tf.function
  79. def forward(self, inputs):
  80. outputs = self.tile(inputs, multiples=self.multiples)
  81. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.