You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quan_dense_bn.py 6.7 kB

4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. from tensorflow.python.training import moving_averages
  7. from tensorlayer.layers.utils import (
  8. quantize_active_overflow, quantize_weight_overflow, mean_var_with_update, w_fold, bias_fold
  9. )
  10. __all__ = [
  11. 'QuanDenseWithBN',
  12. ]
  13. class QuanDenseWithBN(Module):
  14. """The :class:`QuanDenseWithBN` class is a quantized fully connected layer with BN, which weights are 'bitW' bits and the output of the previous layer
  15. are 'bitA' bits while inferencing.
  16. # TODO The QuanDenseWithBN only supports TensorFlow backend.
  17. Parameters
  18. ----------
  19. n_units : int
  20. The number of units of this layer.
  21. act : activation function
  22. The activation function of this layer.
  23. decay : float
  24. A decay factor for `ExponentialMovingAverage`.
  25. Suggest to use a large value for large dataset.
  26. epsilon : float
  27. Eplison.
  28. is_train : boolean
  29. Is being used for training or inference.
  30. beta_init : initializer or None
  31. The initializer for initializing beta, if None, skip beta.
  32. Usually you should not skip beta unless you know what happened.
  33. gamma_init : initializer or None
  34. The initializer for initializing gamma, if None, skip gamma.
  35. bitW : int
  36. The bits of this layer's parameter
  37. bitA : int
  38. The bits of the output of previous layer
  39. use_gemm : boolean
  40. If True, use gemm instead of ``tf.matmul`` for inferencing. (TODO).
  41. W_init : initializer
  42. The initializer for the the weight matrix.
  43. W_init_args : dictionary
  44. The arguments for the weight matrix initializer.
  45. in_channels: int
  46. The number of channels of the previous layer.
  47. If None, it will be automatically detected when the layer is forwarded for the first time.
  48. name : a str
  49. A unique layer name.
  50. Examples
  51. ---------
  52. >>> import tensorlayer as tl
  53. >>> net = tl.layers.Input([50, 256])
  54. >>> layer = tl.layers.QuanDenseWithBN(128, act='relu', name='qdbn1')(net)
  55. >>> net = tl.layers.QuanDenseWithBN(256, act='relu', name='qdbn2')(net)
  56. """
  57. def __init__(
  58. self,
  59. n_units=100,
  60. act=None,
  61. decay=0.9,
  62. epsilon=1e-5,
  63. is_train=False,
  64. bitW=8,
  65. bitA=8,
  66. gamma_init=tl.initializers.truncated_normal(stddev=0.05),
  67. beta_init=tl.initializers.truncated_normal(stddev=0.05),
  68. use_gemm=False,
  69. W_init=tl.initializers.truncated_normal(stddev=0.05),
  70. W_init_args=None,
  71. in_channels=None,
  72. name=None, # 'quan_dense_with_bn',
  73. ):
  74. super(QuanDenseWithBN, self).__init__(act=act, W_init_args=W_init_args, name=name)
  75. self.n_units = n_units
  76. self.decay = decay
  77. self.epsilon = epsilon
  78. self.is_train = is_train
  79. self.bitW = bitW
  80. self.bitA = bitA
  81. self.gamma_init = gamma_init
  82. self.beta_init = beta_init
  83. self.use_gemm = use_gemm
  84. self.W_init = W_init
  85. self.in_channels = in_channels
  86. if self.in_channels is not None:
  87. self.build((None, self.in_channels))
  88. self._built = True
  89. logging.info(
  90. "QuanDenseLayerWithBN %s: %d %s" %
  91. (self.name, n_units, self.act.__class__.__name__ if self.act is not None else 'No Activation')
  92. )
  93. def __repr__(self):
  94. actstr = self.act.__class__.__name__ if self.act is not None else 'No Activation'
  95. s = ('{classname}(n_units={n_units}, ' + actstr)
  96. s += ', bitW={bitW}, bitA={bitA}'
  97. if self.in_channels is not None:
  98. s += ', in_channels=\'{in_channels}\''
  99. if self.name is not None:
  100. s += ', name=\'{name}\''
  101. s += ')'
  102. return s.format(classname=self.__class__.__name__, **self.__dict__)
  103. def build(self, inputs_shape):
  104. if self.in_channels is None and len(inputs_shape) != 2:
  105. raise Exception("The input dimension must be rank 2, please reshape or flatten it")
  106. if self.in_channels is None:
  107. self.in_channels = inputs_shape[1]
  108. if self.use_gemm:
  109. raise Exception("TODO. The current version use tf.matmul for inferencing.")
  110. n_in = inputs_shape[-1]
  111. self.W = self._get_weights("weights", shape=(n_in, self.n_units), init=self.W_init)
  112. para_bn_shape = (self.n_units, )
  113. if self.gamma_init:
  114. self.scale_para = self._get_weights("gamm_weights", shape=para_bn_shape, init=self.gamma_init)
  115. else:
  116. self.scale_para = None
  117. if self.beta_init:
  118. self.offset_para = self._get_weights("beta_weights", shape=para_bn_shape, init=self.beta_init)
  119. else:
  120. self.offset_para = None
  121. self.moving_mean = self._get_weights(
  122. "moving_mean", shape=para_bn_shape, init=tl.initializers.constant(1.0), trainable=False
  123. )
  124. self.moving_variance = self._get_weights(
  125. "moving_variacne", shape=para_bn_shape, init=tl.initializers.constant(1.0), trainable=False
  126. )
  127. def forward(self, inputs):
  128. if self._forward_state == False:
  129. if self._built == False:
  130. self.build(tl.get_tensor_shape(inputs))
  131. self._built = True
  132. self._forward_state = True
  133. x = inputs
  134. inputs = quantize_active_overflow(inputs, self.bitA)
  135. mid_out = tl.ops.matmul(x, self.W)
  136. mean, variance = tl.ops.moments(x=mid_out, axes=list(range(len(mid_out.get_shape()) - 1)))
  137. update_moving_mean = moving_averages.assign_moving_average(
  138. self.moving_mean, mean, self.decay, zero_debias=False
  139. ) # if zero_debias=True, has bias
  140. update_moving_variance = moving_averages.assign_moving_average(
  141. self.moving_variance, variance, self.decay, zero_debias=False
  142. ) # if zero_debias=True, has bias
  143. if self.is_train:
  144. mean, var = mean_var_with_update(update_moving_mean, update_moving_variance, mean, variance)
  145. else:
  146. mean, var = self.moving_mean, self.moving_variance
  147. _w_fold = w_fold(self.W, self.scale_para, var, self.epsilon)
  148. W = quantize_weight_overflow(_w_fold, self.bitW)
  149. outputs = tl.ops.matmul(inputs, W)
  150. if self.beta_init:
  151. _bias_fold = bias_fold(self.offset_para, self.scale_para, mean, var, self.epsilon)
  152. outputs = tl.ops.bias_add(outputs, _bias_fold)
  153. else:
  154. outputs = outputs
  155. if self.act:
  156. outputs = self.act(outputs)
  157. else:
  158. outputs = outputs
  159. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.