You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layers_shape.py 4.2 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. import numpy as np
  6. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  7. import tensorlayer as tl
  8. from tests.utils import CustomTestCase
  9. class Layer_Shape_Test(CustomTestCase):
  10. @classmethod
  11. def setUpClass(cls):
  12. cls.data = tl.layers.Input(shape=(8, 4, 3), init=tl.initializers.random_normal())
  13. cls.imgdata = tl.layers.Input(shape=(2, 16, 16, 8), init=tl.initializers.random_normal())
  14. @classmethod
  15. def tearDownClass(cls):
  16. pass
  17. def test_flatten(self):
  18. class CustomizeModel(tl.layers.Module):
  19. def __init__(self):
  20. super(CustomizeModel, self).__init__()
  21. self.flatten = tl.layers.Flatten()
  22. def forward(self, x):
  23. return self.flatten(x)
  24. model = CustomizeModel()
  25. print(model.flatten)
  26. model.set_train()
  27. out = model(self.data)
  28. self.assertEqual(out.get_shape().as_list(), [8, 12])
  29. def test_reshape(self):
  30. class CustomizeModel(tl.layers.Module):
  31. def __init__(self):
  32. super(CustomizeModel, self).__init__()
  33. self.reshape1 = tl.layers.Reshape(shape=(8, 12))
  34. self.reshape2 = tl.layers.Reshape(shape=(-1, 12))
  35. self.reshape3 = tl.layers.Reshape(shape=())
  36. def forward(self, x):
  37. return self.reshape1(x), self.reshape2(x), self.reshape3(x[0][0][0])
  38. model = CustomizeModel()
  39. print(model.reshape1)
  40. print(model.reshape2)
  41. print(model.reshape3)
  42. model.set_train()
  43. out1, out2, out3 = model(self.data)
  44. self.assertEqual(out1.get_shape().as_list(), [8, 12])
  45. self.assertEqual(out2.get_shape().as_list(), [8, 12])
  46. self.assertEqual(out3.get_shape().as_list(), [])
  47. def test_transpose(self):
  48. class CustomizeModel(tl.layers.Module):
  49. def __init__(self):
  50. super(CustomizeModel, self).__init__()
  51. self.transpose1 = tl.layers.Transpose()
  52. self.transpose2 = tl.layers.Transpose([2, 1, 0])
  53. self.transpose3 = tl.layers.Transpose([0, 2, 1])
  54. self.transpose4 = tl.layers.Transpose(conjugate=True)
  55. def forward(self, x):
  56. return self.transpose1(x), self.transpose2(x), self.transpose3(x), self.transpose4(x)
  57. real = tl.layers.Input(shape=(8, 4, 3), init=tl.initializers.random_normal())
  58. comp = tl.layers.Input(shape=(8, 4, 3), init=tl.initializers.random_normal())
  59. import tensorflow as tf
  60. complex_data = tf.dtypes.complex(real, comp)
  61. model = CustomizeModel()
  62. print(model.transpose1)
  63. print(model.transpose2)
  64. print(model.transpose3)
  65. print(model.transpose4)
  66. model.set_train()
  67. out1, out2, out3, out4 = model(self.data)
  68. self.assertEqual(out1.get_shape().as_list(), [3, 4, 8])
  69. self.assertEqual(out2.get_shape().as_list(), [3, 4, 8])
  70. self.assertEqual(out3.get_shape().as_list(), [8, 3, 4])
  71. self.assertEqual(out4.get_shape().as_list(), [3, 4, 8])
  72. self.assertTrue(np.array_equal(out1.numpy(), out4.numpy()))
  73. out1, out2, out3, out4 = model(complex_data)
  74. self.assertEqual(out1.get_shape().as_list(), [3, 4, 8])
  75. self.assertEqual(out2.get_shape().as_list(), [3, 4, 8])
  76. self.assertEqual(out3.get_shape().as_list(), [8, 3, 4])
  77. self.assertEqual(out4.get_shape().as_list(), [3, 4, 8])
  78. self.assertTrue(np.array_equal(np.conj(out1.numpy()), out4.numpy()))
  79. def test_shuffle(self):
  80. class CustomizeModel(tl.layers.Module):
  81. def __init__(self, x):
  82. super(CustomizeModel, self).__init__()
  83. self.shuffle = tl.layers.Shuffle(x)
  84. def forward(self, x):
  85. return self.shuffle(x)
  86. model = CustomizeModel(2)
  87. print(model.shuffle)
  88. model.set_train()
  89. out = model(self.imgdata)
  90. self.assertEqual(out.get_shape().as_list(), [2, 16, 16, 8])
  91. if __name__ == '__main__':
  92. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.