You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialize.py 12 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ut for model serialize(save/load)"""
  16. import os
  17. import stat
  18. import time
  19. import pytest
  20. import numpy as np
  21. import mindspore.common.dtype as mstype
  22. import mindspore.nn as nn
  23. from mindspore import context
  24. from mindspore.common.parameter import Parameter
  25. from mindspore.common.tensor import Tensor
  26. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  27. from mindspore.nn import WithLossCell, TrainOneStepCell
  28. from mindspore.nn.optim.momentum import Momentum
  29. from mindspore.ops import operations as P
  30. from mindspore.train.callback import _CheckpointManager
  31. from mindspore.train.serialization import save_checkpoint, load_checkpoint, load_param_into_net, \
  32. _exec_save_checkpoint, export, _save_graph
  33. from ..ut_filter import non_graph_engine
  34. context.set_context(mode=context.GRAPH_MODE)
  35. class Net(nn.Cell):
  36. """Net definition."""
  37. def __init__(self, num_classes=10):
  38. super(Net, self).__init__()
  39. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
  40. self.bn1 = nn.BatchNorm2d(64)
  41. self.relu = nn.ReLU()
  42. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  43. self.flatten = nn.Flatten()
  44. self.fc = nn.Dense(int(224 * 224 * 64 / 16), num_classes)
  45. def construct(self, x):
  46. x = self.conv1(x)
  47. x = self.bn1(x)
  48. x = self.relu(x)
  49. x = self.maxpool(x)
  50. x = self.flatten(x)
  51. x = self.fc(x)
  52. return x
  53. _input_x = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  54. _cur_dir = os.path.dirname(os.path.realpath(__file__))
  55. def setup_module():
  56. import shutil
  57. if os.path.exists('./test_files'):
  58. shutil.rmtree('./test_files')
  59. def test_save_graph():
  60. """ test_exec_save_graph """
  61. class Net1(nn.Cell):
  62. def __init__(self):
  63. super(Net1, self).__init__()
  64. self.add = P.TensorAdd()
  65. def construct(self, x, y):
  66. z = self.add(x, y)
  67. return z
  68. net = Net1()
  69. net.set_train()
  70. out_me_list = []
  71. x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
  72. y = Tensor(np.array([1.2]).astype(np.float32))
  73. out_put = net(x, y)
  74. _save_graph(network=net, file_name="net-graph.meta")
  75. out_me_list.append(out_put)
  76. def test_save_checkpoint():
  77. """ test_save_checkpoint """
  78. parameter_list = []
  79. one_param = {}
  80. param1 = {}
  81. param2 = {}
  82. one_param['name'] = "param_test"
  83. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  84. param1['name'] = "param"
  85. param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
  86. param2['name'] = "new_param"
  87. param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
  88. parameter_list.append(one_param)
  89. parameter_list.append(param1)
  90. parameter_list.append(param2)
  91. if os.path.exists('./parameters.ckpt'):
  92. os.chmod('./parameters.ckpt', stat.S_IWRITE)
  93. os.remove('./parameters.ckpt')
  94. ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  95. save_checkpoint(parameter_list, ckpoint_file_name)
  96. def test_load_checkpoint_error_filename():
  97. ckpoint_file_name = 1
  98. with pytest.raises(ValueError):
  99. load_checkpoint(ckpoint_file_name)
  100. def test_load_checkpoint():
  101. ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  102. par_dict = load_checkpoint(ckpoint_file_name)
  103. assert len(par_dict) == 3
  104. assert par_dict['param_test'].name == 'param_test'
  105. assert par_dict['param_test'].data.dtype == mstype.float32
  106. assert par_dict['param_test'].data.shape == (1, 3, 224, 224)
  107. assert isinstance(par_dict, dict)
  108. def test_checkpoint_manager():
  109. """ test_checkpoint_manager """
  110. ckp_mgr = _CheckpointManager()
  111. ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt')
  112. with open(ckpoint_file_name, 'w'):
  113. os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR)
  114. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  115. assert ckp_mgr.ckpoint_num == 1
  116. ckp_mgr.remove_ckpoint_file(ckpoint_file_name)
  117. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  118. assert ckp_mgr.ckpoint_num == 0
  119. assert not os.path.exists(ckpoint_file_name)
  120. another_file_name = os.path.join(_cur_dir, './test2.ckpt')
  121. another_file_name = os.path.realpath(another_file_name)
  122. with open(another_file_name, 'w'):
  123. os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
  124. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  125. assert ckp_mgr.ckpoint_num == 1
  126. ckp_mgr.remove_oldest_ckpoint_file()
  127. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  128. assert ckp_mgr.ckpoint_num == 0
  129. assert not os.path.exists(another_file_name)
  130. # test keep_one_ckpoint_per_minutes
  131. file1 = os.path.realpath(os.path.join(_cur_dir, './time_file1.ckpt'))
  132. file2 = os.path.realpath(os.path.join(_cur_dir, './time_file2.ckpt'))
  133. file3 = os.path.realpath(os.path.join(_cur_dir, './time_file3.ckpt'))
  134. with open(file1, 'w'):
  135. os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
  136. with open(file2, 'w'):
  137. os.chmod(file2, stat.S_IWUSR | stat.S_IRUSR)
  138. with open(file3, 'w'):
  139. os.chmod(file3, stat.S_IWUSR | stat.S_IRUSR)
  140. time1 = time.time()
  141. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  142. assert ckp_mgr.ckpoint_num == 3
  143. ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
  144. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  145. assert ckp_mgr.ckpoint_num == 1
  146. if os.path.exists(_cur_dir + '/time_file1.ckpt'):
  147. os.chmod(_cur_dir + '/time_file1.ckpt', stat.S_IWRITE)
  148. os.remove(_cur_dir + '/time_file1.ckpt')
  149. def test_load_param_into_net_error_net():
  150. parameter_dict = {}
  151. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  152. name="conv1.weight")
  153. parameter_dict["conv1.weight"] = one_param
  154. with pytest.raises(TypeError):
  155. load_param_into_net('', parameter_dict)
  156. def test_load_param_into_net_error_dict():
  157. net = Net(10)
  158. with pytest.raises(TypeError):
  159. load_param_into_net(net, '')
  160. def test_load_param_into_net_erro_dict_param():
  161. net = Net(10)
  162. net.init_parameters_data()
  163. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  164. parameter_dict = {}
  165. one_param = ''
  166. parameter_dict["conv1.weight"] = one_param
  167. with pytest.raises(TypeError):
  168. load_param_into_net(net, parameter_dict)
  169. def test_load_param_into_net_has_more_param():
  170. """ test_load_param_into_net_has_more_param """
  171. net = Net(10)
  172. net.init_parameters_data()
  173. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  174. parameter_dict = {}
  175. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  176. name="conv1.weight")
  177. parameter_dict["conv1.weight"] = one_param
  178. two_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  179. name="conv1.weight")
  180. parameter_dict["conv1.w"] = two_param
  181. load_param_into_net(net, parameter_dict)
  182. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
  183. def test_load_param_into_net_param_type_and_shape_error():
  184. net = Net(10)
  185. net.init_parameters_data()
  186. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  187. parameter_dict = {}
  188. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7))), name="conv1.weight")
  189. parameter_dict["conv1.weight"] = one_param
  190. with pytest.raises(RuntimeError):
  191. load_param_into_net(net, parameter_dict)
  192. def test_load_param_into_net_param_type_error():
  193. net = Net(10)
  194. net.init_parameters_data()
  195. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  196. parameter_dict = {}
  197. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32),
  198. name="conv1.weight")
  199. parameter_dict["conv1.weight"] = one_param
  200. with pytest.raises(RuntimeError):
  201. load_param_into_net(net, parameter_dict)
  202. def test_load_param_into_net_param_shape_error():
  203. net = Net(10)
  204. net.init_parameters_data()
  205. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  206. parameter_dict = {}
  207. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7,)), dtype=mstype.int32),
  208. name="conv1.weight")
  209. parameter_dict["conv1.weight"] = one_param
  210. with pytest.raises(RuntimeError):
  211. load_param_into_net(net, parameter_dict)
  212. def test_load_param_into_net():
  213. net = Net(10)
  214. net.init_parameters_data()
  215. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  216. parameter_dict = {}
  217. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  218. name="conv1.weight")
  219. parameter_dict["conv1.weight"] = one_param
  220. load_param_into_net(net, parameter_dict)
  221. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
  222. def test_exec_save_checkpoint():
  223. net = Net()
  224. loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  225. opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
  226. loss_net = WithLossCell(net, loss)
  227. train_network = TrainOneStepCell(loss_net, opt)
  228. _exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")
  229. load_checkpoint("new_ckpt.ckpt")
  230. def test_load_checkpoint_empty_file():
  231. os.mknod("empty.ckpt")
  232. with pytest.raises(ValueError):
  233. load_checkpoint("empty.ckpt")
  234. class MYNET(nn.Cell):
  235. """ NET definition """
  236. def __init__(self):
  237. super(MYNET, self).__init__()
  238. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  239. self.bn = nn.BatchNorm2d(64)
  240. self.relu = nn.ReLU()
  241. self.flatten = nn.Flatten()
  242. self.fc = nn.Dense(64 * 222 * 222, 3) # padding=0
  243. def construct(self, x):
  244. x = self.conv(x)
  245. x = self.bn(x)
  246. x = self.relu(x)
  247. x = self.flatten(x)
  248. out = self.fc(x)
  249. return out
  250. @non_graph_engine
  251. def test_export():
  252. net = MYNET()
  253. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  254. export(net, input_data, file_name="./me_export.pb", file_format="GEIR")
  255. @non_graph_engine
  256. def test_binary_export():
  257. net = MYNET()
  258. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  259. export(net, input_data, file_name="./me_binary_export.pb", file_format="BINARY")
  260. def teardown_module():
  261. files = ['parameters.ckpt', 'new_ckpt.ckpt', 'empty.ckpt']
  262. for item in files:
  263. file_name = './' + item
  264. if not os.path.exists(file_name):
  265. continue
  266. os.chmod(file_name, stat.S_IWRITE)
  267. os.remove(file_name)