You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet.py 8.5 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. #! /usr/bin/python
  2. # -*- coding: utf-8 -*-
  3. """ResNet for ImageNet.
  4. # Reference:
  5. - [Deep Residual Learning for Image Recognition](
  6. https://arxiv.org/abs/1512.03385) (CVPR 2016 Best Paper Award)
  7. """
  8. import os
  9. import tensorlayer as tl
  10. from tensorlayer import logging
  11. from tensorlayer.files import (assign_weights, maybe_download_and_extract)
  12. from tensorlayer.layers import (BatchNorm, Conv2d, Dense, Elementwise, GlobalMeanPool2d, Input, MaxPool2d)
  13. from tensorlayer.layers import Module, SequentialLayer
  14. __all__ = [
  15. 'ResNet50',
  16. ]
  17. block_names = ['2a', '2b', '2c', '3a', '3b', '3c', '3d', '4a', '4b', '4c', '4d', '4e', '4f', '5a', '5b', '5c'
  18. ] + ['avg_pool', 'fc1000']
  19. block_filters = [[64, 64, 256], [128, 128, 512], [256, 256, 1024], [512, 512, 2048]]
  20. in_channels_conv = [64, 256, 512, 1024]
  21. in_channels_identity = [256, 512, 1024, 2048]
  22. henorm = tl.initializers.he_normal()
  23. class identity_block(Module):
  24. """The identity block where there is no conv layer at shortcut.
  25. Parameters
  26. ----------
  27. input : tf tensor
  28. Input tensor from above layer.
  29. kernel_size : int
  30. The kernel size of middle conv layer at main path.
  31. n_filters : list of integers
  32. The numbers of filters for 3 conv layer at main path.
  33. stage : int
  34. Current stage label.
  35. block : str
  36. Current block label.
  37. Returns
  38. -------
  39. Output tensor of this block.
  40. """
  41. def __init__(self, kernel_size, n_filters, stage, block):
  42. super(identity_block, self).__init__()
  43. filters1, filters2, filters3 = n_filters
  44. _in_channels = in_channels_identity[stage - 2]
  45. conv_name_base = 'res' + str(stage) + block + '_branch'
  46. bn_name_base = 'bn' + str(stage) + block + '_branch'
  47. self.conv1 = Conv2d(filters1, (1, 1), W_init=henorm, name=conv_name_base + '2a', in_channels=_in_channels)
  48. self.bn1 = BatchNorm(name=bn_name_base + '2a', act='relu', num_features=filters1)
  49. ks = (kernel_size, kernel_size)
  50. self.conv2 = Conv2d(
  51. filters2, ks, padding='SAME', W_init=henorm, name=conv_name_base + '2b', in_channels=filters1
  52. )
  53. self.bn2 = BatchNorm(name=bn_name_base + '2b', act='relu', num_features=filters2)
  54. self.conv3 = Conv2d(filters3, (1, 1), W_init=henorm, name=conv_name_base + '2c', in_channels=filters2)
  55. self.bn3 = BatchNorm(name=bn_name_base + '2c', num_features=filters3)
  56. self.add = Elementwise(tl.add, act='relu')
  57. def forward(self, inputs):
  58. output = self.conv1(inputs)
  59. output = self.bn1(output)
  60. output = self.conv2(output)
  61. output = self.bn2(output)
  62. output = self.conv3(output)
  63. output = self.bn3(output)
  64. result = self.add([output, inputs])
  65. return result
  66. class conv_block(Module):
  67. def __init__(self, kernel_size, n_filters, stage, block, strides=(2, 2)):
  68. super(conv_block, self).__init__()
  69. filters1, filters2, filters3 = n_filters
  70. _in_channels = in_channels_conv[stage - 2]
  71. conv_name_base = 'res' + str(stage) + block + '_branch'
  72. bn_name_base = 'bn' + str(stage) + block + '_branch'
  73. self.conv1 = Conv2d(
  74. filters1, (1, 1), strides=strides, W_init=henorm, name=conv_name_base + '2a', in_channels=_in_channels
  75. )
  76. self.bn1 = BatchNorm(name=bn_name_base + '2a', act='relu', num_features=filters1)
  77. ks = (kernel_size, kernel_size)
  78. self.conv2 = Conv2d(
  79. filters2, ks, padding='SAME', W_init=henorm, name=conv_name_base + '2b', in_channels=filters1
  80. )
  81. self.bn2 = BatchNorm(name=bn_name_base + '2b', act='relu', num_features=filters2)
  82. self.conv3 = Conv2d(filters3, (1, 1), W_init=henorm, name=conv_name_base + '2c', in_channels=filters2)
  83. self.bn3 = BatchNorm(name=bn_name_base + '2c', num_features=filters3)
  84. self.shortcut_conv = Conv2d(
  85. filters3, (1, 1), strides=strides, W_init=henorm, name=conv_name_base + '1', in_channels=_in_channels
  86. )
  87. self.shortcut_bn = BatchNorm(name=bn_name_base + '1', num_features=filters3)
  88. self.add = Elementwise(tl.add, act='relu')
  89. def forward(self, inputs):
  90. output = self.conv1(inputs)
  91. output = self.bn1(output)
  92. output = self.conv2(output)
  93. output = self.bn2(output)
  94. output = self.conv3(output)
  95. output = self.bn3(output)
  96. shortcut = self.shortcut_conv(inputs)
  97. shortcut = self.shortcut_bn(shortcut)
  98. result = self.add([output, shortcut])
  99. return result
  100. class ResNet50_model(Module):
  101. def __init__(self, end_with='fc1000', n_classes=1000):
  102. super(ResNet50_model, self).__init__()
  103. self.end_with = end_with
  104. self.n_classes = n_classes
  105. self.conv1 = Conv2d(64, (7, 7), in_channels=3, strides=(2, 2), padding='SAME', W_init=henorm, name='conv1')
  106. self.bn_conv1 = BatchNorm(name='bn_conv1', act="relu", num_features=64)
  107. self.max_pool1 = MaxPool2d((3, 3), strides=(2, 2), name='max_pool1')
  108. self.res_layer = self.make_layer()
  109. def forward(self, inputs):
  110. z = self.conv1(inputs)
  111. z = self.bn_conv1(z)
  112. z = self.max_pool1(z)
  113. z = self.res_layer(z)
  114. return z
  115. def make_layer(self):
  116. layer_list = []
  117. for i, block_name in enumerate(block_names):
  118. if len(block_name) == 2:
  119. stage = int(block_name[0])
  120. block = block_name[1]
  121. if block == 'a':
  122. strides = (1, 1) if stage == 2 else (2, 2)
  123. layer_list.append(
  124. conv_block(3, block_filters[stage - 2], stage=stage, block=block, strides=strides)
  125. )
  126. else:
  127. layer_list.append(identity_block(3, block_filters[stage - 2], stage=stage, block=block))
  128. elif block_name == 'avg_pool':
  129. layer_list.append(GlobalMeanPool2d(name='avg_pool'))
  130. elif block_name == 'fc1000':
  131. layer_list.append(Dense(self.n_classes, name='fc1000', in_channels=2048))
  132. if block_name == self.end_with:
  133. break
  134. return SequentialLayer(layer_list)
  135. def ResNet50(pretrained=False, end_with='fc1000', n_classes=1000):
  136. """Pre-trained ResNet50 model. Input shape [?, 224, 224, 3].
  137. To use pretrained model, input should be in BGR format and subtracted from ImageNet mean [103.939, 116.779, 123.68].
  138. Parameters
  139. ----------
  140. pretrained : boolean
  141. Whether to load pretrained weights. Default False.
  142. end_with : str
  143. The end point of the model [conv, depth1, depth2 ... depth13, globalmeanpool, out].
  144. Default ``out`` i.e. the whole model.
  145. n_classes : int
  146. Number of classes in final prediction.
  147. name : None or str
  148. Name for this model.
  149. Examples
  150. ---------
  151. Classify ImageNet classes, see `tutorial_models_resnet50.py`
  152. TODO Modify the usage example according to the model storage location
  153. >>> # get the whole model with pretrained weights
  154. >>> resnet = ResNet50(pretrained=True)
  155. >>> # use for inferencing
  156. >>> output = resnet(img1)
  157. >>> prob = tl.ops.softmax(output)[0].numpy()
  158. Extract the features before fc layer
  159. >>> resnet = ResNet50(pretrained=True, end_with='5c')
  160. >>> output = resnet(img1)
  161. Returns
  162. -------
  163. ResNet50 model.
  164. """
  165. network = ResNet50_model(end_with=end_with, n_classes=n_classes)
  166. if pretrained:
  167. restore_params(network)
  168. return network
  169. def restore_params(network, path='models'):
  170. logging.info("Restore pre-trained parameters")
  171. maybe_download_and_extract(
  172. 'resnet50_weights_tf_dim_ordering_tf_kernels.h5',
  173. path,
  174. 'https://github.com/fchollet/deep-learning-models/releases/download/v0.2/',
  175. ) # ls -al
  176. try:
  177. import h5py
  178. except Exception:
  179. raise ImportError('h5py not imported')
  180. f = h5py.File(os.path.join(path, 'resnet50_weights_tf_dim_ordering_tf_kernels.h5'), 'r')
  181. # TODO Update parameter loading
  182. # for layer in network.all_layers:
  183. # if len(layer.all_weights) == 0:
  184. # continue
  185. # w_names = list(f[layer.name])
  186. # params = [f[layer.name][n][:] for n in w_names]
  187. # # if 'bn' in layer.name:
  188. # # params = [x.reshape(1, 1, 1, -1) for x in params]
  189. # assign_weights(params, layer)
  190. # del params
  191. f.close()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.