You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils_predict.py 1.5 kB

4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import os
  4. import unittest
  5. os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  6. import numpy as np
  7. import tensorflow as tf
  8. import tensorlayer as tl
  9. from tests.utils import CustomTestCase
  10. class Util_Predict_Test(CustomTestCase):
  11. @classmethod
  12. def setUpClass(cls):
  13. cls.x1 = tf.placeholder(tf.float32, [None, 5, 5, 3])
  14. cls.x2 = tf.placeholder(tf.float32, [8, 5, 5, 3])
  15. cls.X1 = np.ones([127, 5, 5, 3])
  16. cls.X2 = np.ones([7, 5, 5, 3])
  17. cls.batch_size = 8
  18. @classmethod
  19. def tearDownClass(cls):
  20. tf.reset_default_graph()
  21. def test_case1(self):
  22. with self.assertNotRaises(Exception):
  23. with tf.Session() as sess:
  24. n = tl.layers.InputLayer(self.x1)
  25. y = n.outputs
  26. y_op = tf.nn.softmax(y)
  27. tl.utils.predict(sess, n, self.X1, self.x1, y_op, batch_size=self.batch_size)
  28. sess.close()
  29. def test_case2(self):
  30. with self.assertRaises(Exception):
  31. with tf.Session() as sess:
  32. n = tl.layers.InputLayer(self.x2)
  33. y = n.outputs
  34. y_op = tf.nn.softmax(y)
  35. tl.utils.predict(sess, n, self.X2, self.x2, y_op, batch_size=self.batch_size)
  36. sess.close()
  37. if __name__ == '__main__':
  38. tf.logging.set_verbosity(tf.logging.DEBUG)
  39. tl.logging.set_verbosity(tl.logging.DEBUG)
  40. unittest.main()

TensorLayer3.0 是一款兼容多种深度学习框架为计算后端的深度学习库。计划兼容TensorFlow, Pytorch, MindSpore, Paddle.