You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_serialize.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # Copyright 2020 Huawei Technologies Co., Ltd
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ============================================================================
  15. """ut for model serialize(save/load)"""
  16. import os
  17. import stat
  18. import time
  19. import numpy as np
  20. import pytest
  21. import mindspore.nn as nn
  22. import mindspore.common.dtype as mstype
  23. from mindspore.common.tensor import Tensor
  24. from mindspore.common.parameter import Parameter
  25. from mindspore.ops import operations as P
  26. from mindspore.nn import SoftmaxCrossEntropyWithLogits
  27. from mindspore.nn.optim.momentum import Momentum
  28. from mindspore.nn import WithLossCell, TrainOneStepCell
  29. from mindspore.train.callback import _CheckpointManager
  30. from mindspore.train.serialization import save_checkpoint, load_checkpoint,load_param_into_net, \
  31. _exec_save_checkpoint, export, _save_graph
  32. from ..ut_filter import run_on_onnxruntime, non_graph_engine
  33. from mindspore import context
  34. context.set_context(mode=context.GRAPH_MODE)
  35. class Net(nn.Cell):
  36. """Net definition."""
  37. def __init__(self, num_classes=10):
  38. super(Net, self).__init__()
  39. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
  40. self.bn1 = nn.BatchNorm2d(64)
  41. self.relu = nn.ReLU()
  42. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2)
  43. self.flatten = nn.Flatten()
  44. self.fc = nn.Dense(int(224*224*64/16), num_classes)
  45. def construct(self, x):
  46. x = self.conv1(x)
  47. x = self.bn1(x)
  48. x = self.relu(x)
  49. x = self.maxpool(x)
  50. x = self.flatten(x)
  51. x = self.fc(x)
  52. return x
  53. _input_x = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  54. _cur_dir = os.path.dirname(os.path.realpath(__file__))
  55. def setup_module():
  56. import shutil
  57. if os.path.exists('./test_files'):
  58. shutil.rmtree('./test_files')
  59. def test_save_graph():
  60. """ test_exec_save_graph """
  61. class Net(nn.Cell):
  62. def __init__(self):
  63. super(Net, self).__init__()
  64. self.add = P.TensorAdd()
  65. def construct(self, x, y):
  66. z = self.add(x, y)
  67. return z
  68. net = Net()
  69. net.set_train()
  70. out_me_list = []
  71. x = Tensor(np.random.rand(2, 1, 2, 3).astype(np.float32))
  72. y = Tensor(np.array([1.2]).astype(np.float32))
  73. out_put = net(x, y)
  74. _save_graph(network=net, file_name="net-graph.meta")
  75. out_me_list.append(out_put)
  76. def test_save_checkpoint():
  77. """ test_save_checkpoint """
  78. parameter_list = []
  79. one_param = {}
  80. param1 = {}
  81. param2 = {}
  82. one_param['name'] = "param_test"
  83. one_param['data'] = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]), dtype=mstype.float32)
  84. param1['name'] = "param"
  85. param1['data'] = Tensor(np.random.randint(0, 255, [12, 1024]), dtype=mstype.float32)
  86. param2['name'] = "new_param"
  87. param2['data'] = Tensor(np.random.randint(0, 255, [12, 1024, 1]), dtype=mstype.float32)
  88. parameter_list.append(one_param)
  89. parameter_list.append(param1)
  90. parameter_list.append(param2)
  91. if os.path.exists('./parameters.ckpt'):
  92. os.chmod('./parameters.ckpt', stat.S_IWRITE)
  93. os.remove('./parameters.ckpt')
  94. ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  95. save_checkpoint(parameter_list, ckpoint_file_name)
  96. def test_load_checkpoint_error_filename():
  97. ckpoint_file_name = 1
  98. with pytest.raises(ValueError):
  99. load_checkpoint(ckpoint_file_name)
  100. def test_load_checkpoint():
  101. ckpoint_file_name = os.path.join(_cur_dir, './parameters.ckpt')
  102. par_dict = load_checkpoint(ckpoint_file_name)
  103. assert len(par_dict) == 3
  104. assert par_dict['param_test'].name == 'param_test'
  105. assert par_dict['param_test'].data.dtype() == mstype.float32
  106. assert par_dict['param_test'].data.shape() == (1, 3, 224, 224)
  107. assert isinstance(par_dict, dict)
  108. def test_checkpoint_manager():
  109. """ test_checkpoint_manager """
  110. ckp_mgr = _CheckpointManager()
  111. ckpoint_file_name = os.path.join(_cur_dir, './test1.ckpt')
  112. with open(ckpoint_file_name, 'w'):
  113. os.chmod(ckpoint_file_name, stat.S_IWUSR | stat.S_IRUSR)
  114. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  115. assert ckp_mgr.ckpoint_num == 1
  116. ckp_mgr.remove_ckpoint_file(ckpoint_file_name)
  117. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  118. assert ckp_mgr.ckpoint_num == 0
  119. assert not os.path.exists(ckpoint_file_name)
  120. another_file_name = os.path.join(_cur_dir, './test2.ckpt')
  121. another_file_name = os.path.realpath(another_file_name)
  122. with open(another_file_name, 'w'):
  123. os.chmod(another_file_name, stat.S_IWUSR | stat.S_IRUSR)
  124. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  125. assert ckp_mgr.ckpoint_num == 1
  126. ckp_mgr.remove_oldest_ckpoint_file()
  127. ckp_mgr.update_ckpoint_filelist(_cur_dir, "test")
  128. assert ckp_mgr.ckpoint_num == 0
  129. assert not os.path.exists(another_file_name)
  130. # test keep_one_ckpoint_per_minutes
  131. file1 = os.path.realpath(os.path.join(_cur_dir, './time_file1.ckpt'))
  132. file2 = os.path.realpath(os.path.join(_cur_dir, './time_file2.ckpt'))
  133. file3 = os.path.realpath(os.path.join(_cur_dir, './time_file3.ckpt'))
  134. with open(file1, 'w'):
  135. os.chmod(file1, stat.S_IWUSR | stat.S_IRUSR)
  136. with open(file2, 'w'):
  137. os.chmod(file2, stat.S_IWUSR | stat.S_IRUSR)
  138. with open(file3, 'w'):
  139. os.chmod(file3, stat.S_IWUSR | stat.S_IRUSR)
  140. time1 = time.time()
  141. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  142. assert ckp_mgr.ckpoint_num == 3
  143. ckp_mgr.keep_one_ckpoint_per_minutes(1, time1)
  144. ckp_mgr.update_ckpoint_filelist(_cur_dir, "time_file")
  145. assert ckp_mgr.ckpoint_num == 1
  146. if os.path.exists(_cur_dir + '/time_file1.ckpt'):
  147. os.chmod(_cur_dir + '/time_file1.ckpt', stat.S_IWRITE)
  148. os.remove(_cur_dir + '/time_file1.ckpt')
  149. def test_load_param_into_net_error_net():
  150. parameter_dict = {}
  151. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  152. name="conv1.weight")
  153. parameter_dict["conv1.weight"] = one_param
  154. with pytest.raises(TypeError):
  155. load_param_into_net('', parameter_dict)
  156. def test_load_param_into_net_error_dict():
  157. net = Net(10)
  158. with pytest.raises(TypeError):
  159. load_param_into_net(net, '')
  160. def test_load_param_into_net_erro_dict_param():
  161. net = Net(10)
  162. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  163. parameter_dict = {}
  164. one_param = ''
  165. parameter_dict["conv1.weight"] = one_param
  166. with pytest.raises(TypeError):
  167. load_param_into_net(net, parameter_dict)
  168. def test_load_param_into_net_has_more_param():
  169. """ test_load_param_into_net_has_more_param """
  170. net = Net(10)
  171. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  172. parameter_dict = {}
  173. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  174. name="conv1.weight")
  175. parameter_dict["conv1.weight"] = one_param
  176. two_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  177. name="conv1.weight")
  178. parameter_dict["conv1.w"] = two_param
  179. load_param_into_net(net, parameter_dict)
  180. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
  181. def test_load_param_into_net_param_type_and_shape_error():
  182. net = Net(10)
  183. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  184. parameter_dict = {}
  185. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7))), name="conv1.weight")
  186. parameter_dict["conv1.weight"] = one_param
  187. with pytest.raises(RuntimeError):
  188. load_param_into_net(net, parameter_dict)
  189. def test_load_param_into_net_param_type_error():
  190. net = Net(10)
  191. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  192. parameter_dict = {}
  193. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.int32),
  194. name="conv1.weight")
  195. parameter_dict["conv1.weight"] = one_param
  196. with pytest.raises(RuntimeError):
  197. load_param_into_net(net, parameter_dict)
  198. def test_load_param_into_net_param_shape_error():
  199. net = Net(10)
  200. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  201. parameter_dict = {}
  202. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7,)), dtype=mstype.int32),
  203. name="conv1.weight")
  204. parameter_dict["conv1.weight"] = one_param
  205. with pytest.raises(RuntimeError):
  206. load_param_into_net(net, parameter_dict)
  207. def test_load_param_into_net():
  208. net = Net(10)
  209. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 0
  210. parameter_dict = {}
  211. one_param = Parameter(Tensor(np.ones(shape=(64, 3, 7, 7)), dtype=mstype.float32),
  212. name="conv1.weight")
  213. parameter_dict["conv1.weight"] = one_param
  214. load_param_into_net(net, parameter_dict)
  215. assert net.conv1.weight.default_input.asnumpy()[0][0][0][0] == 1
  216. def test_exec_save_checkpoint():
  217. net = Net()
  218. loss = SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
  219. opt = Momentum(net.trainable_params(), 0.0, 0.9, 0.0001, 1024)
  220. loss_net = WithLossCell(net, loss)
  221. train_network = TrainOneStepCell(loss_net, opt)
  222. _exec_save_checkpoint(train_network, ckpoint_file_name="./new_ckpt.ckpt")
  223. load_checkpoint("new_ckpt.ckpt")
  224. def test_load_checkpoint_empty_file():
  225. os.mknod("empty.ckpt")
  226. with pytest.raises(ValueError):
  227. load_checkpoint("empty.ckpt")
  228. class MYNET(nn.Cell):
  229. """ NET definition """
  230. def __init__(self):
  231. super(MYNET, self).__init__()
  232. self.conv = nn.Conv2d(3, 64, 3, has_bias=False, weight_init='normal', pad_mode='valid')
  233. self.bn = nn.BatchNorm2d(64)
  234. self.relu = nn.ReLU()
  235. self.flatten = nn.Flatten()
  236. self.fc = nn.Dense(64*222*222, 3) # padding=0
  237. def construct(self, x):
  238. x = self.conv(x)
  239. x = self.bn(x)
  240. x = self.relu(x)
  241. x = self.flatten(x)
  242. out = self.fc(x)
  243. return out
  244. @non_graph_engine
  245. def test_export():
  246. net = MYNET()
  247. input_data = Tensor(np.random.randint(0, 255, [1, 3, 224, 224]).astype(np.float32))
  248. export(net, input_data, file_name="./me_export.pb", file_format="GEIR")
  249. class BatchNormTester(nn.Cell):
  250. "used to test exporting network in training mode in onnx format"
  251. def __init__(self, num_features):
  252. super(BatchNormTester, self).__init__()
  253. self.bn = nn.BatchNorm2d(num_features)
  254. def construct(self, x):
  255. return self.bn(x)
  256. class DepthwiseConv2dAndReLU6(nn.Cell):
  257. "Net for testing DepthwiseConv2d and ReLU6"
  258. def __init__(self, input_channel, kernel_size):
  259. super(DepthwiseConv2dAndReLU6, self).__init__()
  260. weight_shape = [1, input_channel, kernel_size, kernel_size]
  261. from mindspore.common.initializer import initializer
  262. self.weight = Parameter(initializer('ones', weight_shape), name='weight')
  263. self.depthwise_conv = P.DepthwiseConv2dNative(channel_multiplier=1, kernel_size=(kernel_size, kernel_size))
  264. self.relu6 = nn.ReLU6()
  265. def construct(self, x):
  266. x = self.depthwise_conv(x, self.weight)
  267. x = self.relu6(x)
  268. return x
  269. def test_batchnorm_train_onnx_export():
  270. input = Tensor(np.ones([1, 3, 32, 32]).astype(np.float32) * 0.01)
  271. net = BatchNormTester(3)
  272. net.set_train()
  273. if not net.training:
  274. raise ValueError('netowrk is not in training mode')
  275. export(net, input, file_name='batch_norm.onnx', file_format='ONNX')
  276. if not net.training:
  277. raise ValueError('netowrk is not in training mode')
  278. class LeNet5(nn.Cell):
  279. """LeNet5 definition"""
  280. def __init__(self):
  281. super(LeNet5, self).__init__()
  282. self.conv1 = nn.Conv2d(1, 6, 5, pad_mode='valid')
  283. self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
  284. self.fc1 = nn.Dense(16 * 5 * 5, 120)
  285. self.fc2 = nn.Dense(120, 84)
  286. self.fc3 = nn.Dense(84, 10)
  287. self.relu = nn.ReLU()
  288. self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
  289. self.flatten = P.Flatten()
  290. def construct(self, x):
  291. x = self.max_pool2d(self.relu(self.conv1(x)))
  292. x = self.max_pool2d(self.relu(self.conv2(x)))
  293. x = self.flatten(x)
  294. x = self.relu(self.fc1(x))
  295. x = self.relu(self.fc2(x))
  296. x = self.fc3(x)
  297. return x
  298. def test_lenet5_onnx_export():
  299. input = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  300. net = LeNet5()
  301. export(net, input, file_name='lenet5.onnx', file_format='ONNX')
  302. class DefinedNet(nn.Cell):
  303. """simple Net definition with maxpoolwithargmax."""
  304. def __init__(self, num_classes=10):
  305. super(DefinedNet, self).__init__()
  306. self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=0, weight_init="zeros")
  307. self.bn1 = nn.BatchNorm2d(64)
  308. self.relu = nn.ReLU()
  309. self.maxpool = P.MaxPoolWithArgmax(padding="same", ksize=2, strides=2)
  310. self.flatten = nn.Flatten()
  311. self.fc = nn.Dense(int(56*56*64), num_classes)
  312. def construct(self, x):
  313. x = self.conv1(x)
  314. x = self.bn1(x)
  315. x = self.relu(x)
  316. x, argmax = self.maxpool(x)
  317. x = self.flatten(x)
  318. x = self.fc(x)
  319. return x
  320. def test_net_onnx_maxpoolwithargmax_export():
  321. input = Tensor(np.ones([1, 3, 224, 224]).astype(np.float32) * 0.01)
  322. net = DefinedNet()
  323. export(net, input, file_name='definedNet.onnx', file_format='ONNX')
  324. @run_on_onnxruntime
  325. def test_lenet5_onnx_load_run():
  326. onnx_file = 'lenet5.onnx'
  327. input = Tensor(np.ones([1, 1, 32, 32]).astype(np.float32) * 0.01)
  328. net = LeNet5()
  329. export(net, input, file_name=onnx_file, file_format='ONNX')
  330. import onnx
  331. import onnxruntime as ort
  332. print('--------------------- onnx load ---------------------')
  333. # Load the ONNX model
  334. model = onnx.load(onnx_file)
  335. # Check that the IR is well formed
  336. onnx.checker.check_model(model)
  337. # Print a human readable representation of the graph
  338. g = onnx.helper.printable_graph(model.graph)
  339. print(g)
  340. print('------------------ onnxruntime run ------------------')
  341. ort_session = ort.InferenceSession(onnx_file)
  342. input_map = {'x' : input.asnumpy()}
  343. # provide only input x to run model
  344. outputs = ort_session.run(None, input_map)
  345. print(outputs[0])
  346. # overwrite default weight to run model
  347. for item in net.trainable_params():
  348. input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32)
  349. outputs = ort_session.run(None, input_map)
  350. print(outputs[0])
  351. @run_on_onnxruntime
  352. def test_depthwiseconv_relu6_onnx_load_run():
  353. onnx_file = 'depthwiseconv_relu6.onnx'
  354. input_channel = 3
  355. input = Tensor(np.ones([1, input_channel, 32, 32]).astype(np.float32) * 0.01)
  356. net = DepthwiseConv2dAndReLU6(input_channel, kernel_size=3)
  357. export(net, input, file_name=onnx_file, file_format='ONNX')
  358. import onnx
  359. import onnxruntime as ort
  360. print('--------------------- onnx load ---------------------')
  361. # Load the ONNX model
  362. model = onnx.load(onnx_file)
  363. # Check that the IR is well formed
  364. onnx.checker.check_model(model)
  365. # Print a human readable representation of the graph
  366. g = onnx.helper.printable_graph(model.graph)
  367. print(g)
  368. print('------------------ onnxruntime run ------------------')
  369. ort_session = ort.InferenceSession(onnx_file)
  370. input_map = {'x' : input.asnumpy()}
  371. # provide only input x to run model
  372. outputs = ort_session.run(None, input_map)
  373. print(outputs[0])
  374. # overwrite default weight to run model
  375. for item in net.trainable_params():
  376. input_map[item.name] = np.ones(item.default_input.asnumpy().shape, dtype=np.float32)
  377. outputs = ort_session.run(None, input_map)
  378. print(outputs[0])
  379. def teardown_module():
  380. files = ['parameters.ckpt', 'new_ckpt.ckpt', 'lenet5.onnx', 'batch_norm.onnx', 'empty.ckpt']
  381. for item in files:
  382. file_name = './' + item
  383. if not os.path.exists(file_name):
  384. continue
  385. os.chmod(file_name, stat.S_IWRITE)
  386. os.remove(file_name)