You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_model_save.py 11 kB

4 years ago

  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  6. import numpy as np
  7. import tensorflow as tf
  8. import tensorlayer as tl
  9. from tensorlayer.layers import *
  10. from tensorlayer.models import *
  11. from tests.utils import CustomTestCase
  12. def basic_static_model(include_top=True):
  13. ni = Input((None, 24, 24, 3))
  14. nn = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, name="conv1")(ni)
  15. nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(nn)
  16. nn = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, name="conv2")(nn)
  17. nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(nn)
  18. nn = Flatten(name='flatten')(nn)
  19. nn = Dense(100, act=None, name="dense1")(nn)
  20. if include_top is True:
  21. nn = Dense(10, act=None, name="dense2")(nn)
  22. M = Model(inputs=ni, outputs=nn)
  23. return M
  24. class basic_dynamic_model(Model):
  25. def __init__(self, include_top=True):
  26. super(basic_dynamic_model, self).__init__()
  27. self.include_top = include_top
  28. self.conv1 = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, in_channels=3, name="conv1")
  29. self.pool1 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')
  30. self.conv2 = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, in_channels=16, name="conv2")
  31. self.pool2 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')
  32. self.flatten = Flatten(name='flatten')
  33. self.dense1 = Dense(100, act=None, in_channels=576, name="dense1")
  34. if include_top is True:
  35. self.dense2 = Dense(10, act=None, in_channels=100, name="dense2")
  36. def forward(self, x):
  37. x = self.conv1(x)
  38. x = self.pool1(x)
  39. x = self.conv2(x)
  40. x = self.pool2(x)
  41. x = self.flatten(x)
  42. x = self.dense1(x)
  43. if self.include_top:
  44. x = self.dense2(x)
  45. return x
  46. class Nested_VGG(Model):
  47. def __init__(self):
  48. super(Nested_VGG, self).__init__()
  49. self.vgg1 = tl.models.vgg16()
  50. self.vgg2 = tl.models.vgg16()
  51. def forward(self, x):
  52. pass
  53. class Model_Save_Test(CustomTestCase):
  54. @classmethod
  55. def setUpClass(cls):
  56. cls.static_basic = basic_static_model()
  57. cls.dynamic_basic = basic_dynamic_model()
  58. cls.static_basic_skip = basic_static_model(include_top=False)
  59. cls.dynamic_basic_skip = basic_dynamic_model(include_top=False)
  60. print([l.name for l in cls.dynamic_basic.all_layers])
  61. print([l.name for l in cls.dynamic_basic_skip.all_layers])
  62. pass
  63. @classmethod
  64. def tearDownClass(cls):
  65. pass
  66. def normal_save(self, model_basic):
  67. # Default save
  68. model_basic.save_weights('./model_basic.none')
  69. # hdf5
  70. print('testing hdf5 saving...')
  71. modify_val = np.zeros_like(model_basic.all_weights[-2].numpy())
  72. ori_val = model_basic.all_weights[-2].numpy()
  73. model_basic.save_weights("./model_basic.h5")
  74. model_basic.all_weights[-2].assign(modify_val)
  75. model_basic.load_weights("./model_basic.h5")
  76. self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
  77. model_basic.all_weights[-2].assign(modify_val)
  78. model_basic.load_weights("./model_basic.h5", format="hdf5")
  79. self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
  80. model_basic.all_weights[-2].assign(modify_val)
  81. model_basic.load_weights("./model_basic.h5", format="hdf5", in_order=False)
  82. self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
  83. # npz
  84. print('testing npz saving...')
  85. model_basic.save_weights("./model_basic.npz", format='npz')
  86. model_basic.all_weights[-2].assign(modify_val)
  87. model_basic.load_weights("./model_basic.npz")
  88. model_basic.all_weights[-2].assign(modify_val)
  89. model_basic.load_weights("./model_basic.npz", format='npz')
  90. model_basic.save_weights("./model_basic.npz")
  91. self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
  92. # npz_dict
  93. print('testing npz_dict saving...')
  94. model_basic.save_weights("./model_basic.npz", format='npz_dict')
  95. model_basic.all_weights[-2].assign(modify_val)
  96. model_basic.load_weights("./model_basic.npz", format='npz_dict')
  97. self.assertLess(np.max(np.abs(ori_val - model_basic.all_weights[-2].numpy())), 1e-7)
  98. # ckpt
  99. try:
  100. model_basic.save_weights('./model_basic.ckpt', format='ckpt')
  101. except Exception as e:
  102. self.assertIsInstance(e, NotImplementedError)
  103. # other cases
  104. try:
  105. model_basic.save_weights('./model_basic.xyz', format='xyz')
  106. except Exception as e:
  107. self.assertIsInstance(e, ValueError)
  108. try:
  109. model_basic.load_weights('./model_basic.xyz', format='xyz')
  110. except Exception as e:
  111. self.assertIsInstance(e, FileNotFoundError)
  112. try:
  113. model_basic.load_weights('./model_basic.h5', format='xyz')
  114. except Exception as e:
  115. self.assertIsInstance(e, ValueError)
  116. def test_normal_save(self):
  117. print('-' * 20, 'test save weights', '-' * 20)
  118. self.normal_save(self.static_basic)
  119. self.normal_save(self.dynamic_basic)
  120. print('testing save dynamic and load static...')
  121. try:
  122. self.dynamic_basic.save_weights("./model_basic.h5")
  123. self.static_basic.load_weights("./model_basic.h5", in_order=False)
  124. except Exception as e:
  125. print(e)
  126. def test_skip(self):
  127. print('-' * 20, 'test skip save/load', '-' * 20)
  128. print("testing dynamic skip load...")
  129. self.dynamic_basic.save_weights("./model_basic.h5")
  130. ori_weights = self.dynamic_basic_skip.all_weights
  131. ori_val = ori_weights[1].numpy()
  132. modify_val = np.zeros_like(ori_val)
  133. self.dynamic_basic_skip.all_weights[1].assign(modify_val)
  134. self.dynamic_basic_skip.load_weights("./model_basic.h5", skip=True)
  135. self.assertLess(np.max(np.abs(ori_val - self.dynamic_basic_skip.all_weights[1].numpy())), 1e-7)
  136. try:
  137. self.dynamic_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False)
  138. except Exception as e:
  139. print(e)
  140. print("testing static skip load...")
  141. self.static_basic.save_weights("./model_basic.h5")
  142. ori_weights = self.static_basic_skip.all_weights
  143. ori_val = ori_weights[1].numpy()
  144. modify_val = np.zeros_like(ori_val)
  145. self.static_basic_skip.all_weights[1].assign(modify_val)
  146. self.static_basic_skip.load_weights("./model_basic.h5", skip=True)
  147. self.assertLess(np.max(np.abs(ori_val - self.static_basic_skip.all_weights[1].numpy())), 1e-7)
  148. try:
  149. self.static_basic_skip.load_weights("./model_basic.h5", in_order=False, skip=False)
  150. except Exception as e:
  151. print(e)
  152. def test_nested_vgg(self):
  153. print('-' * 20, 'test nested vgg', '-' * 20)
  154. nested_vgg = Nested_VGG()
  155. print([l.name for l in nested_vgg.all_layers])
  156. nested_vgg.save_weights("nested_vgg.h5")
  157. # modify vgg1 weight val
  158. tar_weight1 = nested_vgg.vgg1.layers[0].all_weights[0]
  159. print(tar_weight1.name)
  160. ori_val1 = tar_weight1.numpy()
  161. modify_val1 = np.zeros_like(ori_val1)
  162. tar_weight1.assign(modify_val1)
  163. # modify vgg2 weight val
  164. tar_weight2 = nested_vgg.vgg2.layers[1].all_weights[0]
  165. print(tar_weight2.name)
  166. ori_val2 = tar_weight2.numpy()
  167. modify_val2 = np.zeros_like(ori_val2)
  168. tar_weight2.assign(modify_val2)
  169. nested_vgg.load_weights("nested_vgg.h5")
  170. self.assertLess(np.max(np.abs(ori_val1 - tar_weight1.numpy())), 1e-7)
  171. self.assertLess(np.max(np.abs(ori_val2 - tar_weight2.numpy())), 1e-7)
  172. def test_double_nested_vgg(self):
  173. print('-' * 20, 'test_double_nested_vgg', '-' * 20)
  174. class mymodel(Model):
  175. def __init__(self):
  176. super(mymodel, self).__init__()
  177. self.inner = Nested_VGG()
  178. self.list = LayerList(
  179. [
  180. tl.layers.Dense(n_units=4, in_channels=10, name='dense1'),
  181. tl.layers.Dense(n_units=3, in_channels=4, name='dense2')
  182. ]
  183. )
  184. def forward(self, *inputs, **kwargs):
  185. pass
  186. net = mymodel()
  187. net.save_weights("double_nested.h5")
  188. print([x.name for x in net.all_layers])
  189. # modify vgg1 weight val
  190. tar_weight1 = net.inner.vgg1.layers[0].all_weights[0]
  191. ori_val1 = tar_weight1.numpy()
  192. modify_val1 = np.zeros_like(ori_val1)
  193. tar_weight1.assign(modify_val1)
  194. # modify vgg2 weight val
  195. tar_weight2 = net.inner.vgg2.layers[1].all_weights[0]
  196. ori_val2 = tar_weight2.numpy()
  197. modify_val2 = np.zeros_like(ori_val2)
  198. tar_weight2.assign(modify_val2)
  199. net.load_weights("double_nested.h5")
  200. self.assertLess(np.max(np.abs(ori_val1 - tar_weight1.numpy())), 1e-7)
  201. self.assertLess(np.max(np.abs(ori_val2 - tar_weight2.numpy())), 1e-7)
  202. def test_layerlist(self):
  203. print('-' * 20, 'test_layerlist', '-' * 20)
  204. # simple modellayer
  205. ni = tl.layers.Input([10, 4])
  206. nn = tl.layers.Dense(n_units=3, name='dense1')(ni)
  207. modellayer = tl.models.Model(inputs=ni, outputs=nn, name='modellayer').as_layer()
  208. # nested layerlist with modellayer
  209. inputs = tl.layers.Input([10, 5])
  210. layer1 = tl.layers.LayerList([tl.layers.Dense(n_units=4, name='dense1'), modellayer])(inputs)
  211. model = tl.models.Model(inputs=inputs, outputs=layer1, name='layerlistmodel')
  212. model.save_weights("layerlist.h5")
  213. tar_weight = model.get_layer(index=-1)[0].all_weights[0]
  214. print(tar_weight.name)
  215. ori_val = tar_weight.numpy()
  216. modify_val = np.zeros_like(ori_val)
  217. tar_weight.assign(modify_val)
  218. model.load_weights("layerlist.h5")
  219. self.assertLess(np.max(np.abs(ori_val - tar_weight.numpy())), 1e-7)
  220. def test_exceptions(self):
  221. print('-' * 20, 'test_exceptions', '-' * 20)
  222. try:
  223. ni = Input([4, 784])
  224. model = Model(inputs=ni, outputs=ni)
  225. model.save_weights('./empty_model.h5')
  226. except Exception as e:
  227. print(e)
  228. if __name__ == '__main__':
  229. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.