You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_cell.py 7.8 kB


  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ test cell """
  16. import numpy as np
  17. import pytest
  18. import mindspore.nn as nn
  19. from mindspore import Tensor, Parameter
  20. from mindspore.common.api import _executor
  21. class ModA(nn.Cell):
  22. def __init__(self, tensor):
  23. super(ModA, self).__init__()
  24. self.weight = Parameter(tensor, name="weight")
  25. def construct(self, *inputs):
  26. pass
  27. class ModB(nn.Cell):
  28. def __init__(self, tensor):
  29. super(ModB, self).__init__()
  30. self.weight = Parameter(tensor, name="weight")
  31. def construct(self, *inputs):
  32. pass
  33. class ModC(nn.Cell):
  34. def __init__(self, ta, tb):
  35. super(ModC, self).__init__()
  36. self.mod1 = ModA(ta)
  37. self.mod2 = ModB(tb)
  38. def construct(self, *inputs):
  39. pass
  40. class Net(nn.Cell):
  41. """ Net definition """
  42. name_len = 4
  43. cells_num = 3
  44. def __init__(self, ta, tb):
  45. super(Net, self).__init__()
  46. self.mod1 = ModA(ta)
  47. self.mod2 = ModB(tb)
  48. self.mod3 = ModC(ta, tb)
  49. def construct(self, *inputs):
  50. pass
  51. class Net2(nn.Cell):
  52. def __init__(self, ta, tb):
  53. super(Net2, self).__init__(auto_prefix=False)
  54. self.mod1 = ModA(ta)
  55. self.mod2 = ModB(tb)
  56. self.mod3 = ModC(ta, tb)
  57. def construct(self, *inputs):
  58. pass
  59. class ConvNet(nn.Cell):
  60. """ ConvNet definition """
  61. image_h = 224
  62. image_w = 224
  63. output_ch = 64
  64. def __init__(self, num_classes=10):
  65. super(ConvNet, self).__init__()
  66. self.conv1 = nn.Conv2d(3, ConvNet.output_ch, kernel_size=7, stride=2, pad_mode="pad", padding=3)
  67. self.bn1 = nn.BatchNorm2d(ConvNet.output_ch)
  68. self.relu = nn.ReLU()
  69. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode="same")
  70. self.flatten = nn.Flatten()
  71. self.fc = nn.Dense(
  72. int(ConvNet.image_h * ConvNet.image_w * ConvNet.output_ch / (4 * 4)),
  73. num_classes)
  74. def construct(self, x):
  75. x = self.conv1(x)
  76. x = self.bn1(x)
  77. x = self.relu(x)
  78. x = self.maxpool(x)
  79. x = self.flatten(x)
  80. x = self.fc(x)
  81. return x
  82. def test_basic():
  83. ta = Tensor(np.ones([2, 3]))
  84. tb = Tensor(np.ones([1, 4]))
  85. n = Net(ta, tb)
  86. names = list(n.parameters_dict().keys())
  87. assert len(names) == n.name_len
  88. assert names[0] == "mod1.weight"
  89. assert names[1] == "mod2.weight"
  90. assert names[2] == "mod3.mod1.weight"
  91. assert names[3] == "mod3.mod2.weight"
  92. def test_parameter_name():
  93. """ test_parameter_name """
  94. ta = Tensor(np.ones([2, 3]))
  95. tb = Tensor(np.ones([1, 4]))
  96. n = Net(ta, tb)
  97. names = []
  98. for m in n.parameters_and_names():
  99. if m[0]:
  100. names.append(m[0])
  101. assert names[0] == "mod1.weight"
  102. assert names[1] == "mod2.weight"
  103. assert names[2] == "mod3.mod1.weight"
  104. assert names[3] == "mod3.mod2.weight"
  105. def test_cell_name():
  106. """ test_cell_name """
  107. ta = Tensor(np.ones([2, 3]))
  108. tb = Tensor(np.ones([1, 4]))
  109. n = Net(ta, tb)
  110. n.insert_child_to_cell('modNone', None)
  111. names = []
  112. for m in n.cells_and_names():
  113. if m[0]:
  114. names.append(m[0])
  115. assert names[0] == "mod1"
  116. assert names[1] == "mod2"
  117. assert names[2] == "mod3"
  118. assert names[3] == "mod3.mod1"
  119. assert names[4] == "mod3.mod2"
  120. def test_cells():
  121. ta = Tensor(np.ones([2, 3]))
  122. tb = Tensor(np.ones([1, 4]))
  123. n = Net(ta, tb)
  124. ch = list(n.cells())
  125. assert len(ch) == n.cells_num
  126. def test_exceptions():
  127. """ test_exceptions """
  128. t = Tensor(np.ones([2, 3]))
  129. class ModError(nn.Cell):
  130. def __init__(self, tensor):
  131. self.weight = Parameter(tensor, name="weight")
  132. super(ModError, self).__init__()
  133. def construct(self, *inputs):
  134. pass
  135. with pytest.raises(AttributeError):
  136. ModError(t)
  137. class ModError1(nn.Cell):
  138. def __init__(self, tensor):
  139. super().__init__()
  140. self.weight = Parameter(tensor, name="weight")
  141. self.weight = None
  142. self.weight = ModA(tensor)
  143. def construct(self, *inputs):
  144. pass
  145. with pytest.raises(TypeError):
  146. ModError1(t)
  147. class ModError2(nn.Cell):
  148. def __init__(self, tensor):
  149. super().__init__()
  150. self.mod = ModA(tensor)
  151. self.mod = None
  152. self.mod = tensor
  153. def construct(self, *inputs):
  154. pass
  155. with pytest.raises(TypeError):
  156. ModError2(t)
  157. m = nn.Cell()
  158. with pytest.raises(NotImplementedError):
  159. m.construct()
  160. def test_del():
  161. """ test_del """
  162. ta = Tensor(np.ones([2, 3]))
  163. tb = Tensor(np.ones([1, 4]))
  164. n = Net(ta, tb)
  165. names = list(n.parameters_dict().keys())
  166. assert len(names) == n.name_len
  167. del n.mod1
  168. names = list(n.parameters_dict().keys())
  169. assert len(names) == n.name_len - 1
  170. with pytest.raises(AttributeError):
  171. del n.mod1.weight
  172. del n.mod2.weight
  173. names = list(n.parameters_dict().keys())
  174. assert len(names) == n.name_len - 2
  175. with pytest.raises(AttributeError):
  176. del n.mod
  177. def test_add_attr():
  178. """ test_add_attr """
  179. ta = Tensor(np.ones([2, 3]))
  180. tb = Tensor(np.ones([1, 4]))
  181. p = Parameter(ta, name="weight")
  182. m = nn.Cell()
  183. m.insert_param_to_cell('weight', p)
  184. with pytest.raises(TypeError):
  185. m.insert_child_to_cell("network", p)
  186. with pytest.raises(KeyError):
  187. m.insert_param_to_cell('', p)
  188. with pytest.raises(KeyError):
  189. m.insert_param_to_cell('a.b', p)
  190. m.insert_param_to_cell('weight', p)
  191. with pytest.raises(KeyError):
  192. m.insert_child_to_cell('', ModA(ta))
  193. with pytest.raises(KeyError):
  194. m.insert_child_to_cell('a.b', ModB(tb))
  195. with pytest.raises(TypeError):
  196. m.insert_child_to_cell('buffer', tb)
  197. with pytest.raises(TypeError):
  198. m.insert_param_to_cell('w', ta)
  199. with pytest.raises(TypeError):
  200. m.insert_child_to_cell('m', p)
  201. class ModAddCellError(nn.Cell):
  202. def __init__(self, tensor):
  203. self.mod = ModA(tensor)
  204. super().__init__()
  205. def construct(self, *inputs):
  206. pass
  207. with pytest.raises(AttributeError):
  208. ModAddCellError(ta)
  209. def test_train_eval():
  210. m = nn.Cell()
  211. assert not m.training
  212. m.set_train()
  213. assert m.training
  214. m.set_train(False)
  215. assert not m.training
  216. def test_stop_update_name():
  217. ta = Tensor(np.ones([2, 3]))
  218. tb = Tensor(np.ones([1, 4]))
  219. n = Net2(ta, tb)
  220. names = list(n.parameters_dict().keys())
  221. assert names[0] == "weight"
  222. assert names[1] == "mod1.weight"
  223. assert names[2] == "mod2.weight"
  224. class ModelName(nn.Cell):
  225. def __init__(self, tensor):
  226. super(ModelName, self).__init__()
  227. self.w2 = Parameter(tensor, name="weight")
  228. self.w1 = Parameter(tensor, name="weight")
  229. self.w3 = Parameter(tensor, name=None)
  230. self.w4 = Parameter(tensor, name=None)
  231. def construct(self, *inputs):
  232. pass
  233. def test_cell_names():
  234. ta = Tensor(np.ones([2, 3]))
  235. mn = ModelName(ta)
  236. with pytest.raises(ValueError):
  237. _executor.compile(mn)