You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

res18_example.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """
  16. resnet50 example
  17. """
  18. import numpy as np
  19. import mindspore.nn as nn # pylint: disable=C0414
  20. from mindspore import Tensor
  21. from mindspore.common.api import _executor
  22. from mindspore.ops.operations import TensorAdd
  23. from ...train_step_wrap import train_step_with_loss_warp
  24. def conv3x3(in_channels, out_channels, stride=1, padding=1, pad_mode='pad'):
  25. """3x3 convolution """
  26. return nn.Conv2d(in_channels, out_channels,
  27. kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode)
  28. def conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='pad'):
  29. """1x1 convolution"""
  30. return nn.Conv2d(in_channels, out_channels,
  31. kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode)
  32. class ResidualBlock(nn.Cell):
  33. """
  34. residual Block
  35. """
  36. expansion = 4
  37. def __init__(self,
  38. in_channels,
  39. out_channels,
  40. stride=1,
  41. down_sample=False):
  42. super(ResidualBlock, self).__init__()
  43. out_chls = out_channels // self.expansion
  44. self.conv1 = conv1x1(in_channels, out_chls, stride=1, padding=0)
  45. self.bn1 = nn.BatchNorm2d(out_chls)
  46. self.conv2 = conv3x3(out_chls, out_chls, stride=stride, padding=1)
  47. self.bn2 = nn.BatchNorm2d(out_chls)
  48. self.conv3 = conv1x1(out_chls, out_channels, stride=1, padding=0)
  49. self.bn3 = nn.BatchNorm2d(out_channels)
  50. self.relu = nn.ReLU()
  51. self.downsample = down_sample
  52. self.conv_down_sample = conv1x1(in_channels, out_channels,
  53. stride=stride, padding=0)
  54. self.bn_down_sample = nn.BatchNorm2d(out_channels)
  55. self.add = TensorAdd()
  56. def construct(self, x):
  57. """
  58. :param x:
  59. :return:
  60. """
  61. identity = x
  62. out = self.conv1(x)
  63. out = self.bn1(out)
  64. out = self.relu(out)
  65. out = self.conv2(out)
  66. out = self.bn2(out)
  67. out = self.relu(out)
  68. out = self.conv3(out)
  69. out = self.bn3(out)
  70. if self.downsample:
  71. identity = self.conv_down_sample(identity)
  72. identity = self.bn_down_sample(identity)
  73. out = self.add(out, identity)
  74. out = self.relu(out)
  75. return out
  76. class ResNet18(nn.Cell):
  77. """
  78. resnet nn.Cell
  79. """
  80. def __init__(self, block, num_classes=100):
  81. super(ResNet18, self).__init__()
  82. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, pad_mode='pad')
  83. self.bn1 = nn.BatchNorm2d(64)
  84. self.relu = nn.ReLU()
  85. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same')
  86. self.layer1 = self.MakeLayer(
  87. block, 2, in_channels=64, out_channels=256, stride=1)
  88. self.layer2 = self.MakeLayer(
  89. block, 2, in_channels=256, out_channels=512, stride=2)
  90. self.layer3 = self.MakeLayer(
  91. block, 2, in_channels=512, out_channels=1024, stride=2)
  92. self.layer4 = self.MakeLayer(
  93. block, 2, in_channels=1024, out_channels=2048, stride=2)
  94. self.avgpool = nn.AvgPool2d(7, 1)
  95. self.flatten = nn.Flatten()
  96. self.fc = nn.Dense(512 * block.expansion, num_classes)
  97. def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
  98. """
  99. make block layer
  100. :param block:
  101. :param layer_num:
  102. :param in_channels:
  103. :param out_channels:
  104. :param stride:
  105. :return:
  106. """
  107. layers = []
  108. resblk = block(in_channels, out_channels,
  109. stride=stride, down_sample=True)
  110. layers.append(resblk)
  111. for _ in range(1, layer_num):
  112. resblk = block(out_channels, out_channels, stride=1)
  113. layers.append(resblk)
  114. return nn.SequentialCell(layers)
  115. def construct(self, x):
  116. """
  117. :param x:
  118. :return:
  119. """
  120. x = self.conv1(x)
  121. x = self.bn1(x)
  122. x = self.relu(x)
  123. x = self.maxpool(x)
  124. x = self.layer1(x)
  125. x = self.layer2(x)
  126. x = self.layer3(x)
  127. x = self.layer4(x)
  128. x = self.avgpool(x)
  129. x = self.flatten(x)
  130. x = self.fc(x)
  131. return x
  132. class ResNet9(nn.Cell):
  133. """
  134. resnet nn.Cell
  135. """
  136. def __init__(self, block, num_classes=100):
  137. super(ResNet9, self).__init__()
  138. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
  139. self.bn1 = nn.BatchNorm2d(64)
  140. self.relu = nn.ReLU()
  141. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  142. self.layer1 = self.MakeLayer(
  143. block, 1, in_channels=64, out_channels=256, stride=1)
  144. self.layer2 = self.MakeLayer(
  145. block, 1, in_channels=256, out_channels=512, stride=2)
  146. self.layer3 = self.MakeLayer(
  147. block, 1, in_channels=512, out_channels=1024, stride=2)
  148. self.layer4 = self.MakeLayer(
  149. block, 1, in_channels=1024, out_channels=2048, stride=2)
  150. self.avgpool = nn.AvgPool2d(7, 1)
  151. self.flatten = nn.Flatten()
  152. self.fc = nn.Dense(512 * block.expansion, num_classes)
  153. def MakeLayer(self, block, layer_num, in_channels, out_channels, stride):
  154. """
  155. make block layer
  156. :param block:
  157. :param layer_num:
  158. :param in_channels:
  159. :param out_channels:
  160. :param stride:
  161. :return:
  162. """
  163. layers = []
  164. resblk = block(in_channels, out_channels,
  165. stride=stride, down_sample=True)
  166. layers.append(resblk)
  167. for _ in range(1, layer_num):
  168. resblk = block(out_channels, out_channels, stride=1)
  169. layers.append(resblk)
  170. return nn.SequentialCell(layers)
  171. def construct(self, x):
  172. """
  173. :param x:
  174. :return:
  175. """
  176. x = self.conv1(x)
  177. x = self.bn1(x)
  178. x = self.relu(x)
  179. x = self.maxpool(x)
  180. x = self.layer1(x)
  181. x = self.layer2(x)
  182. x = self.layer3(x)
  183. x = self.layer4(x)
  184. x = self.avgpool(x)
  185. x = self.flatten(x)
  186. x = self.fc(x)
  187. return x
  188. def resnet18():
  189. return ResNet18(ResidualBlock, 10)
  190. def resnet9():
  191. return ResNet9(ResidualBlock, 10)
  192. def test_compile():
  193. net = resnet18()
  194. input_data = Tensor(np.ones([1, 3, 224, 224]))
  195. _executor.compile(net, input_data)
  196. def test_train_step():
  197. net = train_step_with_loss_warp(resnet9())
  198. input_data = Tensor(np.ones([1, 3, 224, 224]))
  199. label = Tensor(np.zeros([1, 10]))
  200. _executor.compile(net, input_data, label)
  201. def test_train_step_training():
  202. net = train_step_with_loss_warp(resnet9())
  203. input_data = Tensor(np.ones([1, 3, 224, 224]))
  204. label = Tensor(np.zeros([1, 10]))
  205. net.set_train()
  206. _executor.compile(net, input_data, label)