You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

quan_conv.py 6.2 kB

4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. import tensorlayer as tl
  4. from tensorlayer import logging
  5. from tensorlayer.layers.core import Module
  6. from tensorlayer.layers.utils import (quantize_active_overflow, quantize_weight_overflow)
  7. __all__ = ['QuanConv2d']
  8. class QuanConv2d(Module):
  9. """The :class:`QuanConv2d` class is a quantized convolutional layer without BN, which weights are 'bitW' bits and the output of the previous layer
  10. are 'bitA' bits while inferencing.
  11. Note that, the bias vector would not be binarized.
  12. Parameters
  13. ----------
  14. With TensorLayer
  15. n_filter : int
  16. The number of filters.
  17. filter_size : tuple of int
  18. The filter size (height, width).
  19. strides : tuple of int
  20. The sliding window strides of corresponding input dimensions.
  21. It must be in the same order as the ``shape`` parameter.
  22. act : activation function
  23. The activation function of this layer.
  24. padding : str
  25. The padding algorithm type: "SAME" or "VALID".
  26. bitW : int
  27. The bits of this layer's parameter
  28. bitA : int
  29. The bits of the output of previous layer
  30. use_gemm : boolean
  31. If True, use gemm instead of ``tf.matmul`` for inference.
  32. TODO: support gemm
  33. data_format : str
  34. "channels_last" (NHWC, default) or "channels_first" (NCHW).
  35. dilation_rate : tuple of int
  36. Specifying the dilation rate to use for dilated convolution.
  37. W_init : initializer
  38. The initializer for the the weight matrix.
  39. b_init : initializer or None
  40. The initializer for the the bias vector. If None, skip biases.
  41. in_channels : int
  42. The number of in channels.
  43. name : None or str
  44. A unique layer name.
  45. Examples
  46. ---------
  47. With TensorLayer
  48. >>> net = tl.layers.Input([8, 12, 12, 64], name='input')
  49. >>> quanconv2d = tl.layers.QuanConv2d(
  50. ... n_filter=32, filter_size=(5, 5), strides=(1, 1), act=tl.ReLU, padding='SAME', name='quancnn2d'
  51. ... )(net)
  52. >>> print(quanconv2d)
  53. >>> output shape : (8, 12, 12, 32)
  54. """
  55. def __init__(
  56. self,
  57. bitW=8,
  58. bitA=8,
  59. n_filter=32,
  60. filter_size=(3, 3),
  61. strides=(1, 1),
  62. act=None,
  63. padding='SAME',
  64. use_gemm=False,
  65. data_format="channels_last",
  66. dilation_rate=(1, 1),
  67. W_init=tl.initializers.truncated_normal(stddev=0.02),
  68. b_init=tl.initializers.constant(value=0.0),
  69. in_channels=None,
  70. name=None # 'quan_cnn2d',
  71. ):
  72. super().__init__(name, act=act)
  73. self.bitW = bitW
  74. self.bitA = bitA
  75. self.n_filter = n_filter
  76. self.filter_size = filter_size
  77. self.strides = self._strides = strides
  78. self.padding = padding
  79. self.use_gemm = use_gemm
  80. self.data_format = data_format
  81. self.dilation_rate = self._dilation_rate = dilation_rate
  82. self.W_init = W_init
  83. self.b_init = b_init
  84. self.in_channels = in_channels
  85. if self.in_channels:
  86. self.build(None)
  87. self._built = True
  88. logging.info(
  89. "QuanConv2d %s: n_filter: %d filter_size: %s strides: %s pad: %s act: %s" % (
  90. self.name, n_filter, str(filter_size), str(strides), padding,
  91. self.act.__class__.__name__ if self.act is not None else 'No Activation'
  92. )
  93. )
  94. if self.use_gemm:
  95. raise Exception("TODO. The current version use tf.matmul for inferencing.")
  96. if len(self.strides) != 2:
  97. raise ValueError("len(strides) should be 2.")
  98. def __repr__(self):
  99. actstr = self.act.__name__ if self.act is not None else 'No Activation'
  100. s = (
  101. '{classname}(in_channels={in_channels}, out_channels={n_filter}, kernel_size={filter_size}'
  102. ', strides={strides}, padding={padding}'
  103. )
  104. if self.dilation_rate != (1, ) * len(self.dilation_rate):
  105. s += ', dilation={dilation_rate}'
  106. if self.b_init is None:
  107. s += ', bias=False'
  108. s += (', ' + actstr)
  109. if self.name is not None:
  110. s += ', name=\'{name}\''
  111. s += ')'
  112. return s.format(classname=self.__class__.__name__, **self.__dict__)
  113. def build(self, inputs_shape):
  114. if self.data_format == 'channels_last':
  115. self.data_format = 'NHWC'
  116. if self.in_channels is None:
  117. self.in_channels = inputs_shape[-1]
  118. self._strides = [1, self._strides[0], self._strides[1], 1]
  119. self._dilation_rate = [1, self._dilation_rate[0], self._dilation_rate[1], 1]
  120. elif self.data_format == 'channels_first':
  121. self.data_format = 'NCHW'
  122. if self.in_channels is None:
  123. self.in_channels = inputs_shape[1]
  124. self._strides = [1, 1, self._strides[0], self._strides[1]]
  125. self._dilation_rate = [1, 1, self._dilation_rate[0], self._dilation_rate[1]]
  126. else:
  127. raise Exception("data_format should be either channels_last or channels_first")
  128. self.filter_shape = (self.filter_size[0], self.filter_size[1], self.in_channels, self.n_filter)
  129. self.W = self._get_weights("filters", shape=self.filter_shape, init=self.W_init)
  130. if self.b_init:
  131. self.b = self._get_weights("biases", shape=(self.n_filter, ), init=self.b_init)
  132. self.bias_add = tl.ops.BiasAdd(data_format=self.data_format)
  133. self.conv2d = tl.ops.Conv2D(
  134. strides=self.strides, padding=self.padding, data_format=self.data_format, dilations=self._dilation_rate
  135. )
  136. def forward(self, inputs):
  137. if self._forward_state == False:
  138. if self._built == False:
  139. self.build(tl.get_tensor_shape(inputs))
  140. self._built = True
  141. self._forward_state = True
  142. inputs = quantize_active_overflow(inputs, self.bitA)
  143. W_ = quantize_weight_overflow(self.W, self.bitW)
  144. outputs = self.conv2d(inputs, W_)
  145. if self.b_init:
  146. outputs = self.bias_add(outputs, self.b)
  147. if self.act:
  148. outputs = self.act(outputs)
  149. return outputs

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.