You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

mobilenetv2_combined.py 4.4 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """mobile net v2"""
  16. from mindspore import nn
  17. from mindspore.ops import operations as P
  18. def make_divisible(input_x, div_by=8):
  19. return int((input_x + div_by) // div_by)
  20. def _conv_bn(in_channel,
  21. out_channel,
  22. ksize,
  23. stride=1):
  24. """Get a conv2d batchnorm and relu layer."""
  25. return nn.SequentialCell(
  26. [nn.Conv2dBnAct(in_channel,
  27. out_channel,
  28. kernel_size=ksize,
  29. stride=stride,
  30. batchnorm=True)])
  31. class InvertedResidual(nn.Cell):
  32. def __init__(self, inp, oup, stride, expend_ratio):
  33. super(InvertedResidual, self).__init__()
  34. self.stride = stride
  35. assert stride in [1, 2]
  36. hidden_dim = int(inp * expend_ratio)
  37. self.use_res_connect = self.stride == 1 and inp == oup
  38. if expend_ratio == 1:
  39. self.conv = nn.SequentialCell([
  40. nn.Conv2dBnAct(hidden_dim,
  41. hidden_dim,
  42. 3,
  43. stride,
  44. group=hidden_dim,
  45. batchnorm=True,
  46. activation='relu6'),
  47. nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
  48. batchnorm=True)
  49. ])
  50. else:
  51. self.conv = nn.SequentialCell([
  52. nn.Conv2dBnAct(inp, hidden_dim, 1, 1,
  53. batchnorm=True,
  54. activation='relu6'),
  55. nn.Conv2dBnAct(hidden_dim,
  56. hidden_dim,
  57. 3,
  58. stride,
  59. group=hidden_dim,
  60. batchnorm=True,
  61. activation='relu6'),
  62. nn.Conv2dBnAct(hidden_dim, oup, 1, 1,
  63. batchnorm=True)
  64. ])
  65. self.add = P.TensorAdd()
  66. def construct(self, input_x):
  67. out = self.conv(input_x)
  68. if self.use_res_connect:
  69. out = self.add(input_x, out)
  70. return out
  71. class MobileNetV2(nn.Cell):
  72. def __init__(self, num_class=1000, input_size=224, width_mul=1.):
  73. super(MobileNetV2, self).__init__()
  74. _ = input_size
  75. block = InvertedResidual
  76. input_channel = 32
  77. last_channel = 1280
  78. inverted_residual_setting = [
  79. [1, 16, 1, 1],
  80. [6, 24, 2, 2],
  81. [6, 32, 3, 2],
  82. [6, 64, 4, 2],
  83. [6, 96, 3, 1],
  84. [6, 160, 3, 2],
  85. [6, 230, 1, 1],
  86. ]
  87. if width_mul > 1.0:
  88. last_channel = make_divisible(last_channel * width_mul)
  89. self.last_channel = last_channel
  90. features = [_conv_bn(3, input_channel, 3, 2)]
  91. for t, c, n, s in inverted_residual_setting:
  92. out_channel = make_divisible(c * width_mul) if t > 1 else c
  93. for i in range(n):
  94. if i == 0:
  95. features.append(block(input_channel, out_channel, s, t))
  96. else:
  97. features.append(block(input_channel, out_channel, 1, t))
  98. input_channel = out_channel
  99. features.append(_conv_bn(input_channel, self.last_channel, 1))
  100. self.features = nn.SequentialCell(features)
  101. self.mean = P.ReduceMean(keep_dims=False)
  102. self.classifier = nn.DenseBnAct(self.last_channel, num_class)
  103. def construct(self, input_x):
  104. out = input_x
  105. out = self.features(out)
  106. out = self.mean(out, (2, 3))
  107. out = self.classifier(out)
  108. return out