You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reuse_mlp.py 1.7 kB

4 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  6. import tensorflow as tf
  7. import tensorlayer as tl
  8. from tests.utils import CustomTestCase
  9. # define the network
  10. def mlp(x, is_train=True, reuse=False):
  11. with tf.variable_scope("MLP", reuse=reuse):
  12. tl.layers.set_name_reuse(reuse) # print warning
  13. network = tl.layers.InputLayer(x, name='input')
  14. network = tl.layers.DropoutLayer(network, keep=0.8, is_fix=True, is_train=is_train, name='drop1')
  15. network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu1')
  16. network = tl.layers.DropoutLayer(network, keep=0.5, is_fix=True, is_train=is_train, name='drop2')
  17. network = tl.layers.DenseLayer(network, n_units=800, act=tf.nn.relu, name='relu2')
  18. network = tl.layers.DropoutLayer(network, keep=0.5, is_fix=True, is_train=is_train, name='drop3')
  19. network = tl.layers.DenseLayer(network, n_units=10, name='output')
  20. return network
  21. class MLP_Reuse_Test(CustomTestCase):
  22. @classmethod
  23. def setUpClass(cls):
  24. # define placeholder
  25. cls.x = tf.placeholder(tf.float32, shape=[None, 784], name='x')
  26. # define inferences
  27. mlp(cls.x, is_train=True, reuse=False)
  28. mlp(cls.x, is_train=False, reuse=True)
  29. @classmethod
  30. def tearDownClass(cls):
  31. tf.reset_default_graph()
  32. def test_reuse(self):
  33. with self.assertRaises(Exception):
  34. mlp(self.x, is_train=False, reuse=False) # Already defined model with the same var_scope
  35. if __name__ == '__main__':
  36. tf.logging.set_verbosity(tf.logging.DEBUG)
  37. tl.logging.set_verbosity(tl.logging.DEBUG)
  38. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.