You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils_saveload.py 4.7 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  6. import numpy as np
  7. import tensorflow as tf
  8. import tensorlayer as tl
  9. from tensorlayer.layers import *
  10. from tensorlayer.models import *
  11. from tests.utils import CustomTestCase
  12. def basic_static_model():
  13. ni = Input((None, 24, 24, 3))
  14. nn = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, name="conv1")(ni)
  15. nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')(nn)
  16. nn = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, name="conv2")(nn)
  17. nn = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')(nn)
  18. nn = Flatten(name='flatten')(nn)
  19. nn = Dense(100, act=None, name="dense1")(nn)
  20. nn = Dense(10, act=None, name="dense2")(nn)
  21. M = Model(inputs=ni, outputs=nn, name='basic_static')
  22. return M
  23. class basic_dynamic_model(Model):
  24. def __init__(self):
  25. super(basic_dynamic_model, self).__init__(name="basic_dynamic")
  26. self.conv1 = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, in_channels=3, name="conv1")
  27. self.pool1 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool1')
  28. self.conv2 = Conv2d(16, (5, 5), (1, 1), padding='SAME', act=tf.nn.relu, in_channels=16, name="conv2")
  29. self.pool2 = MaxPool2d((3, 3), (2, 2), padding='SAME', name='pool2')
  30. self.flatten = Flatten(name='flatten')
  31. self.dense1 = Dense(100, act=None, in_channels=576, name="dense1")
  32. self.dense2 = Dense(10, act=None, in_channels=100, name="dense2")
  33. def forward(self, x):
  34. x = self.conv1(x)
  35. x = self.pool1(x)
  36. x = self.conv2(x)
  37. x = self.pool2(x)
  38. x = self.flatten(x)
  39. x = self.dense1(x)
  40. x = self.dense2(x)
  41. return x
  42. class Model_Core_Test(CustomTestCase):
  43. @classmethod
  44. def setUpClass(cls):
  45. cls.static_model = basic_static_model()
  46. cls.dynamic_model = basic_dynamic_model()
  47. @classmethod
  48. def tearDownClass(cls):
  49. pass
  50. def test_hdf5(self):
  51. modify_val = np.zeros_like(self.static_model.all_weights[-2].numpy())
  52. ori_val = self.static_model.all_weights[-2].numpy()
  53. tl.files.save_weights_to_hdf5("./model_basic.h5", self.static_model)
  54. self.static_model.all_weights[-2].assign(modify_val)
  55. tl.files.load_hdf5_to_weights_in_order("./model_basic.h5", self.static_model)
  56. self.assertLess(np.max(np.abs(ori_val - self.static_model.all_weights[-2].numpy())), 1e-7)
  57. self.static_model.all_weights[-2].assign(modify_val)
  58. tl.files.load_hdf5_to_weights("./model_basic.h5", self.static_model)
  59. self.assertLess(np.max(np.abs(ori_val - self.static_model.all_weights[-2].numpy())), 1e-7)
  60. ori_weights = self.static_model._all_weights
  61. self.static_model._all_weights = self.static_model._all_weights[1:]
  62. self.static_model.all_weights[-2].assign(modify_val)
  63. tl.files.load_hdf5_to_weights("./model_basic.h5", self.static_model, skip=True)
  64. self.assertLess(np.max(np.abs(ori_val - self.static_model.all_weights[-2].numpy())), 1e-7)
  65. self.static_model._all_weights = ori_weights
  66. def test_npz(self):
  67. modify_val = np.zeros_like(self.dynamic_model.all_weights[-2].numpy())
  68. ori_val = self.dynamic_model.all_weights[-2].numpy()
  69. tl.files.save_npz(self.dynamic_model.all_weights, "./model_basic.npz")
  70. self.dynamic_model.all_weights[-2].assign(modify_val)
  71. tl.files.load_and_assign_npz("./model_basic.npz", self.dynamic_model)
  72. self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.all_weights[-2].numpy())), 1e-7)
  73. def test_npz_dict(self):
  74. modify_val = np.zeros_like(self.dynamic_model.all_weights[-2].numpy())
  75. ori_val = self.dynamic_model.all_weights[-2].numpy()
  76. tl.files.save_npz_dict(self.dynamic_model.all_weights, "./model_basic.npz")
  77. self.dynamic_model.all_weights[-2].assign(modify_val)
  78. tl.files.load_and_assign_npz_dict("./model_basic.npz", self.dynamic_model)
  79. self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.all_weights[-2].numpy())), 1e-7)
  80. ori_weights = self.dynamic_model._all_weights
  81. self.dynamic_model._all_weights = self.static_model._all_weights[1:]
  82. self.dynamic_model.all_weights[-2].assign(modify_val)
  83. tl.files.load_and_assign_npz_dict("./model_basic.npz", self.dynamic_model, skip=True)
  84. self.assertLess(np.max(np.abs(ori_val - self.dynamic_model.all_weights[-2].numpy())), 1e-7)
  85. self.dynamic_model._all_weights = ori_weights
  86. if __name__ == '__main__':
  87. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.