You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_layers_merge.py 2.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  6. import numpy as np
  7. import tensorlayer as tl
  8. from tests.utils import CustomTestCase
  9. class Layer_Merge_Test(CustomTestCase):
  10. @classmethod
  11. def setUpClass(cls):
  12. pass
  13. @classmethod
  14. def tearDownClass(cls):
  15. pass
  16. def test_concat(self):
  17. class CustomModel(tl.layers.Module):
  18. def __init__(self):
  19. super(CustomModel, self).__init__()
  20. self.dense1 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu1_1')
  21. self.dense2 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu2_1')
  22. self.concat = tl.layers.Concat(concat_dim=1, name='concat_layer')
  23. def forward(self, inputs):
  24. d1 = self.dense1(inputs)
  25. d2 = self.dense2(inputs)
  26. outputs = self.concat([d1, d2])
  27. return outputs
  28. model = CustomModel()
  29. model.set_train()
  30. inputs = tl.ops.convert_to_tensor(np.random.random([4, 20]).astype(np.float32))
  31. outputs = model(inputs)
  32. print(model)
  33. self.assertEqual(outputs.get_shape().as_list(), [4, 20])
  34. def test_elementwise(self):
  35. class CustomModel(tl.layers.Module):
  36. def __init__(self):
  37. super(CustomModel, self).__init__()
  38. self.dense1 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu1_1')
  39. self.dense2 = tl.layers.Dense(in_channels=20, n_units=10, act=tl.ReLU, name='relu2_1')
  40. self.element = tl.layers.Elementwise(combine_fn=tl.minimum, name='minimum', act=None)
  41. def forward(self, inputs):
  42. d1 = self.dense1(inputs)
  43. d2 = self.dense2(inputs)
  44. outputs = self.element([d1, d2])
  45. return outputs, d1, d2
  46. model = CustomModel()
  47. model.set_train()
  48. inputs = tl.ops.convert_to_tensor(np.random.random([4, 20]).astype(np.float32))
  49. outputs, d1, d2 = model(inputs)
  50. print(model)
  51. min = tl.ops.minimum(d1, d2)
  52. self.assertEqual(outputs.get_shape().as_list(), [4, 10])
  53. self.assertTrue(np.array_equal(min.numpy(), outputs.numpy()))
  54. if __name__ == '__main__':
  55. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.