You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

resnet_example.py 7.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. resnet50 example
  17. """
  18. import numpy as np
  19. import mindspore.nn as nn
  20. from mindspore import Tensor
  21. from mindspore.ops import operations as P
  22. from ..ut_filter import non_graph_engine
  23. def conv3x3(in_channels, out_channels, stride=1, padding=1):
  24. """3x3 convolution """
  25. weight = Tensor(np.ones([out_channels, in_channels, 3, 3]).astype(np.float32) * 0.01)
  26. return nn.Conv2d(in_channels, out_channels,
  27. kernel_size=3, stride=stride, padding=padding, weight_init=weight)
  28. def conv1x1(in_channels, out_channels, stride=1, padding=0):
  29. """1x1 convolution"""
  30. weight = Tensor(np.ones([out_channels, in_channels, 1, 1]).astype(np.float32) * 0.01)
  31. return nn.Conv2d(in_channels, out_channels,
  32. kernel_size=1, stride=stride, padding=padding, weight_init=weight)
  33. def bn_with_initialize(out_channels):
  34. shape = (out_channels)
  35. mean = Tensor(np.ones(shape).astype(np.float32) * 0.01)
  36. var = Tensor(np.ones(shape).astype(np.float32) * 0.01)
  37. beta = Tensor(np.ones(shape).astype(np.float32) * 0.01)
  38. gamma = Tensor(np.ones(shape).astype(np.float32) * 0.01)
  39. return nn.BatchNorm2d(num_features=out_channels,
  40. beta_init=beta,
  41. gamma_init=gamma,
  42. moving_mean_init=mean,
  43. moving_var_init=var)
  44. class ResidualBlock(nn.Cell):
  45. """
  46. residual Block
  47. """
  48. expansion = 4
  49. def __init__(self,
  50. in_channels,
  51. out_channels,
  52. stride=1,
  53. down_sample=False):
  54. super(ResidualBlock, self).__init__()
  55. out_chls = out_channels // self.expansion
  56. self.conv1 = conv1x1(in_channels, out_chls, stride=stride, padding=0)
  57. self.bn1 = bn_with_initialize(out_chls)
  58. self.conv2 = conv3x3(out_chls, out_chls, stride=1, padding=1)
  59. self.bn2 = bn_with_initialize(out_chls)
  60. self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
  61. self.bn3 = bn_with_initialize(out_channels)
  62. self.relu = nn.ReLU()
  63. self.downsample = down_sample
  64. self.conv_down_sample = conv1x1(in_channels, out_channels,
  65. stride=stride, padding=0)
  66. self.bn_down_sample = bn_with_initialize(out_channels)
  67. self.add = P.TensorAdd()
  68. def construct(self, x):
  69. """
  70. :param x:
  71. :return:
  72. """
  73. identity = x
  74. out = self.conv1(x)
  75. out = self.bn1(out)
  76. out = self.relu(out)
  77. out = self.conv2(out)
  78. out = self.bn2(out)
  79. out = self.relu(out)
  80. out = self.conv3(out)
  81. out = self.bn3(out)
  82. if self.downsample:
  83. identity = self.conv_down_sample(identity)
  84. identity = self.bn_down_sample(identity)
  85. out = self.add(out, identity)
  86. out = self.relu(out)
  87. return out
  88. class MakeLayer3(nn.Cell):
  89. """
  90. make resnet50 3 layers
  91. """
  92. def __init__(self, block, in_channels, out_channels, stride):
  93. super(MakeLayer3, self).__init__()
  94. self.block_down_sample = block(in_channels, out_channels,
  95. stride=stride, down_sample=True)
  96. self.block1 = block(out_channels, out_channels, stride=1)
  97. self.block2 = block(out_channels, out_channels, stride=1)
  98. def construct(self, x):
  99. x = self.block_down_sample(x)
  100. x = self.block1(x)
  101. x = self.block2(x)
  102. return x
  103. class MakeLayer4(nn.Cell):
  104. """
  105. make resnet50 4 layers
  106. """
  107. def __init__(self, block, in_channels, out_channels, stride):
  108. super(MakeLayer4, self).__init__()
  109. self.block_down_sample = block(in_channels, out_channels,
  110. stride=stride, down_sample=True)
  111. self.block1 = block(out_channels, out_channels, stride=1)
  112. self.block2 = block(out_channels, out_channels, stride=1)
  113. self.block3 = block(out_channels, out_channels, stride=1)
  114. def construct(self, x):
  115. x = self.block_down_sample(x)
  116. x = self.block1(x)
  117. x = self.block2(x)
  118. x = self.block3(x)
  119. return x
  120. class MakeLayer6(nn.Cell):
  121. """
  122. make resnet50 6 layers
  123. """
  124. def __init__(self, block, in_channels, out_channels, stride):
  125. super(MakeLayer6, self).__init__()
  126. self.block_down_sample = block(in_channels, out_channels,
  127. stride=stride, down_sample=True)
  128. self.block1 = block(out_channels, out_channels, stride=1)
  129. self.block2 = block(out_channels, out_channels, stride=1)
  130. self.block3 = block(out_channels, out_channels, stride=1)
  131. self.block4 = block(out_channels, out_channels, stride=1)
  132. self.block5 = block(out_channels, out_channels, stride=1)
  133. def construct(self, x):
  134. x = self.block_down_sample(x)
  135. x = self.block1(x)
  136. x = self.block2(x)
  137. x = self.block3(x)
  138. x = self.block4(x)
  139. x = self.block5(x)
  140. return x
  141. class ResNet50(nn.Cell):
  142. """
  143. resnet nn.Cell
  144. """
  145. def __init__(self, block, num_classes=100):
  146. super(ResNet50, self).__init__()
  147. weight_conv = Tensor(np.ones([64, 3, 7, 7]).astype(np.float32) * 0.01)
  148. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, weight_init=weight_conv)
  149. self.bn1 = bn_with_initialize(64)
  150. self.relu = nn.ReLU()
  151. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  152. self.layer1 = MakeLayer3(
  153. block, in_channels=64, out_channels=256, stride=1)
  154. self.layer2 = MakeLayer4(
  155. block, in_channels=256, out_channels=512, stride=2)
  156. self.layer3 = MakeLayer6(
  157. block, in_channels=512, out_channels=1024, stride=2)
  158. self.layer4 = MakeLayer3(
  159. block, in_channels=1024, out_channels=2048, stride=2)
  160. self.avgpool = nn.AvgPool2d(7, 1)
  161. self.flatten = nn.Flatten()
  162. weight_fc = Tensor(np.ones([num_classes, 512 * block.expansion]).astype(np.float32) * 0.01)
  163. bias_fc = Tensor(np.ones([num_classes]).astype(np.float32) * 0.01)
  164. self.fc = nn.Dense(512 * block.expansion, num_classes, weight_init=weight_fc, bias_init=bias_fc)
  165. def construct(self, x):
  166. """
  167. :param x:
  168. :return:
  169. """
  170. x = self.conv1(x)
  171. x = self.bn1(x)
  172. x = self.relu(x)
  173. x = self.maxpool(x)
  174. x = self.layer1(x)
  175. x = self.layer2(x)
  176. x = self.layer3(x)
  177. x = self.layer4(x)
  178. x = self.avgpool(x)
  179. x = self.flatten(x)
  180. x = self.fc(x)
  181. return x
  182. def resnet50():
  183. return ResNet50(ResidualBlock, 10)
  184. @non_graph_engine
  185. def test_compile():
  186. net = resnet50()
  187. input_data = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01)
  188. output = net(input_data)
  189. print(output.asnumpy())