You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_dense.py 3.9 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. __all__ = [
  7. 'Dense',
  8. ]
  9. class Dense(Module):
  10. """The :class:`Dense` class is a fully connected layer.
  11. Parameters
  12. ----------
  13. n_units : int
  14. The number of units of this layer.
  15. act : activation function
  16. The activation function of this layer.
  17. W_init : initializer
  18. The initializer for the weight matrix.
  19. b_init : initializer or None
  20. The initializer for the bias vector. If None, skip biases.
  21. in_channels: int
  22. The number of channels of the previous layer.
  23. If None, it will be automatically detected when the layer is forwarded for the first time.
  24. name : None or str
  25. A unique layer name. If None, a unique name will be automatically generated.
  26. Examples
  27. --------
  28. With TensorLayer
  29. >>> net = tl.layers.Input([100, 50], name='input')
  30. >>> dense = tl.layers.Dense(n_units=800, act=tl.ReLU, in_channels=50, name='dense_1')
  31. >>> print(dense)
  32. Dense(n_units=800, relu, in_channels='50', name='dense_1')
  33. >>> tensor = tl.layers.Dense(n_units=800, act=tl.ReLU, name='dense_2')(net)
  34. >>> print(tensor)
  35. tf.Tensor([...], shape=(100, 800), dtype=float32)
  36. Notes
  37. -----
  38. If the layer input has more than two axes, it needs to be flatten by using :class:`Flatten`.
  39. """
  40. def __init__(
  41. self,
  42. n_units,
  43. act=None,
  44. W_init=tl.initializers.truncated_normal(stddev=0.05),
  45. b_init=tl.initializers.constant(value=0.0),
  46. in_channels=None,
  47. name=None, # 'dense',
  48. ):
  49. super(Dense, self).__init__(name, act=act)
  50. self.n_units = n_units
  51. self.W_init = W_init
  52. self.b_init = b_init
  53. self.in_channels = in_channels
  54. if self.in_channels is not None:
  55. self.build(self.in_channels)
  56. self._built = True
  57. logging.info(
  58. "Dense %s: %d %s" %
  59. (self.name, self.n_units, self.act.__class__.__name__ if self.act is not None else 'No Activation')
  60. )
  61. def __repr__(self):
  62. actstr = self.act.__class__.__name__ if self.act is not None else 'No Activation'
  63. s = ('{classname}(n_units={n_units}, ' + actstr)
  64. if self.in_channels is not None:
  65. s += ', in_channels=\'{in_channels}\''
  66. if self.name is not None:
  67. s += ', name=\'{name}\''
  68. s += ')'
  69. return s.format(classname=self.__class__.__name__, **self.__dict__)
  70. def build(self, inputs_shape):
  71. if self.in_channels is None and len(inputs_shape) != 2:
  72. raise AssertionError("The input dimension must be rank 2, please reshape or flatten it")
  73. if self.in_channels:
  74. shape = [self.in_channels, self.n_units]
  75. else:
  76. self.in_channels = inputs_shape[1]
  77. shape = [inputs_shape[1], self.n_units]
  78. self.W = self._get_weights("weights", shape=tuple(shape), init=self.W_init)
  79. self.b_init_flag = False
  80. if self.b_init:
  81. self.b = self._get_weights("biases", shape=(self.n_units, ), init=self.b_init)
  82. self.b_init_flag = True
  83. self.bias_add = tl.ops.BiasAdd()
  84. self.act_init_flag = False
  85. if self.act:
  86. self.act_init_flag = True
  87. self.matmul = tl.ops.MatMul()
  88. def forward(self, inputs):
  89. if self._forward_state == False:
  90. if self._built == False:
  91. self.build(tl.get_tensor_shape(inputs))
  92. self._built = True
  93. self._forward_state = True
  94. z = self.matmul(inputs, self.W)
  95. if self.b_init_flag:
  96. z = self.bias_add(z, self.b)
  97. if self.act_init_flag:
  98. z = self.act(z)
  99. return z

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.