You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layers_basic.py 1.5 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  6. import tensorflow as tf
  7. import tensorlayer as tl
  8. from tests.utils import CustomTestCase
  9. class Layer_Basic_Test(CustomTestCase):
  10. @classmethod
  11. def setUpClass(cls):
  12. x = tf.placeholder(tf.float32, [None, 100])
  13. n = tl.layers.InputLayer(x, name='in')
  14. n = tl.layers.DenseLayer(n, n_units=80, name='d1')
  15. n = tl.layers.DenseLayer(n, n_units=80, name='d2')
  16. n.print_layers()
  17. n.print_params(False)
  18. n2 = n[:, :30]
  19. n2.print_layers()
  20. cls.n_params = n.count_params()
  21. cls.all_layers = n.all_layers
  22. cls.all_params = n.all_params
  23. cls.shape_n = n.outputs.get_shape().as_list()
  24. cls.shape_n2 = n2.outputs.get_shape().as_list()
  25. @classmethod
  26. def tearDownClass(cls):
  27. tf.reset_default_graph()
  28. def test_n_params(self):
  29. self.assertEqual(self.n_params, 14560)
  30. def test_shape_n(self):
  31. self.assertEqual(self.shape_n[-1], 80)
  32. def test_all_layers(self):
  33. self.assertEqual(len(self.all_layers), 3)
  34. def test_all_params(self):
  35. self.assertEqual(len(self.all_params), 4)
  36. def test_shape_n2(self):
  37. self.assertEqual(self.shape_n2[-1], 30)
  38. if __name__ == '__main__':
  39. tf.logging.set_verbosity(tf.logging.DEBUG)
  40. tl.logging.set_verbosity(tl.logging.DEBUG)
  41. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.