You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet18.py 7.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2021 The KubeEdge Authors.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import tensorflow as tf
  19. WEIGHT_DECAY_KEY = 'WEIGHT_DECAY'
  20. def _residual_block_first(x, is_training, out_channel, strides, name="unit"):
  21. in_channel = x.get_shape().as_list()[-1]
  22. with tf.variable_scope(name) as scope:
  23. print('\tBuilding residual unit: %s' % scope.name)
  24. # Shortcut connection
  25. if in_channel == out_channel:
  26. print('in_channel == out_channel')
  27. if strides == 1:
  28. shortcut = tf.identity(x)
  29. else:
  30. shortcut = tf.nn.max_pool(x, [1, strides, strides, 1],
  31. [1, strides, strides, 1], 'VALID')
  32. else:
  33. shortcut = _conv(x, 1, out_channel, strides, name='shortcut')
  34. # Residual
  35. x = _conv(x, 3, out_channel, strides, name='conv_1')
  36. x = _bn(x, is_training, name='bn_1')
  37. x = _relu(x, name='relu_1')
  38. print(x)
  39. x = _conv(x, 3, out_channel, 1, name='conv_2')
  40. x = _bn(x, is_training, name='bn_2')
  41. print(x)
  42. # Merge
  43. x = x + shortcut
  44. x = _relu(x, name='relu_2')
  45. print(x)
  46. return x
  47. def _residual_block(x, is_training, name="unit"):
  48. num_channel = x.get_shape().as_list()[-1]
  49. with tf.variable_scope(name) as scope:
  50. print('\tBuilding residual unit: %s' % scope.name)
  51. # Shortcut connection
  52. shortcut = x
  53. # Residual
  54. x = _conv(x, 3, num_channel, 1, name='conv_1')
  55. x = _bn(x, is_training, name='bn_1')
  56. x = _relu(x, name='relu_1')
  57. print(x)
  58. x = _conv(x, 3, num_channel, 1, name='conv_2')
  59. x = _bn(x, is_training, name='bn_2')
  60. print(x)
  61. x = x + shortcut
  62. x = _relu(x, name='relu_2')
  63. print(x)
  64. return x
  65. def _conv(x, filter_size, out_channel, strides, name="conv"):
  66. """
  67. Helper functions(counts FLOPs and number of weights)
  68. """
  69. in_shape = x.get_shape()
  70. with tf.variable_scope(name):
  71. # Main operation: conv2d
  72. kernel = tf.get_variable('kernel',
  73. [filter_size, filter_size, in_shape[3],
  74. out_channel], tf.float32,
  75. initializer=tf.random_normal_initializer(
  76. stddev=np.sqrt(
  77. 2.0 / filter_size /
  78. filter_size / out_channel)))
  79. if kernel not in tf.get_collection(WEIGHT_DECAY_KEY):
  80. tf.add_to_collection(WEIGHT_DECAY_KEY, kernel)
  81. if strides == 1:
  82. conv = tf.nn.conv2d(x, kernel, [1, strides, strides, 1],
  83. padding='SAME')
  84. else:
  85. kernel_size_effective = filter_size
  86. pad_total = kernel_size_effective - 1
  87. pad_beg = pad_total // 2
  88. pad_end = pad_total - pad_beg
  89. x = tf.pad(x, [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end],
  90. [0, 0]])
  91. conv = tf.nn.conv2d(x, kernel, [1, strides, strides, 1],
  92. padding='VALID')
  93. return conv
  94. def _fc(x, out_dim, name="fc"):
  95. with tf.variable_scope(name):
  96. # Main operation: fc
  97. with tf.device('/CPU:0'):
  98. w = tf.get_variable('weights', [x.get_shape()[1], out_dim],
  99. tf.float32,
  100. initializer=tf.random_normal_initializer(
  101. stddev=np.sqrt(1.0 / out_dim)))
  102. b = tf.get_variable('biases', [out_dim], tf.float32,
  103. initializer=tf.constant_initializer(0.0))
  104. if w not in tf.get_collection(WEIGHT_DECAY_KEY):
  105. tf.add_to_collection(WEIGHT_DECAY_KEY, w)
  106. fc = tf.nn.bias_add(tf.matmul(x, w), b)
  107. return fc
  108. def _bn(x, is_training, name="bn"):
  109. bn = tf.layers.batch_normalization(inputs=x, momentum=0.99, epsilon=1e-5,
  110. center=True, scale=True,
  111. training=is_training, name=name,
  112. fused=True)
  113. return bn
  114. def _relu(x, name="relu"):
  115. return tf.nn.relu(x, name=name)
  116. class ResNet18(object):
  117. def __init__(self, images, is_training):
  118. self._build_network(images, is_training)
  119. def _build_network(self, images, is_training, num_classes=None):
  120. _counted_scope = []
  121. self.end_points = {}
  122. print('Building resnet18 model')
  123. # filters = [128, 128, 256, 512, 1024]
  124. filters = [64, 64, 128, 256, 512]
  125. kernels = [7, 3, 3, 3, 3]
  126. strides = [2, 0, 2, 2, 2]
  127. # conv1
  128. print('\tBuilding unit: conv1')
  129. with tf.variable_scope('conv1'):
  130. x = _conv(images, kernels[0], filters[0], strides[0])
  131. x = _bn(x, is_training)
  132. x = _relu(x)
  133. print(x)
  134. x = tf.nn.max_pool(x, [1, 3, 3, 1], [1, 2, 2, 1], 'SAME')
  135. print(x)
  136. self.end_points['conv1_output'] = x
  137. # conv2_x
  138. x = _residual_block(x, is_training, name='conv2_1')
  139. x = _residual_block(x, is_training, name='conv2_2')
  140. self.end_points['conv2_output'] = x
  141. # conv3_x
  142. x = _residual_block_first(x, is_training, filters[2], strides[2],
  143. name='conv3_1')
  144. x = _residual_block(x, is_training, name='conv3_2')
  145. self.end_points['conv3_output'] = x
  146. # conv4_x
  147. x = _residual_block_first(x, is_training, filters[3], strides[3],
  148. name='conv4_1')
  149. x = _residual_block(x, is_training, name='conv4_2')
  150. self.end_points['conv4_output'] = x
  151. # conv5_x
  152. x = _residual_block_first(x, is_training, filters[4], strides[4],
  153. name='conv5_1')
  154. x = _residual_block(x, is_training, name='conv5_2')
  155. self.end_points['conv5_output'] = x
  156. # Logit
  157. if num_classes is not None:
  158. with tf.variable_scope('logits') as scope:
  159. print('\tBuilding unit: %s' % scope.name)
  160. x = tf.reduce_mean(x, [1, 2], name="logits_bottleneck")
  161. # x = _fc(x, num_classes) # original resnet18 code used only
  162. # 8 output classes
  163. self.end_points['logits'] = x
  164. # print (self.end_points)
  165. self.model = x
  166. def output(self):
  167. return self.model