You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layers_extend.py 795 B

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  6. import tensorlayer as tl
  7. from tests.utils import CustomTestCase
  8. class Layer_Extend_Test(CustomTestCase):
  9. @classmethod
  10. def setUpClass(cls):
  11. pass
  12. @classmethod
  13. def tearDownClass(cls):
  14. pass
  15. def test_expand_dims(self):
  16. x = tl.layers.Input([8, 3])
  17. expandlayer = tl.layers.ExpandDims(axis=-1)
  18. y = expandlayer(x)
  19. self.assertEqual(tl.get_tensor_shape(y), [8, 3, 1])
  20. def test_tile(self):
  21. x = tl.layers.Input([8, 3])
  22. tilelayer = tl.layers.Tile(multiples=[2, 3])
  23. y = tilelayer(x)
  24. self.assertEqual(tl.get_tensor_shape(y), [16, 9])
  25. if __name__ == '__main__':
  26. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.