You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlFlowApiTest.cs 2.2 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.ManagedAPI
  7. {
  8. [TestClass]
  9. public class ControlFlowApiTest
  10. {
  11. [TestMethod]
  12. public void WhileLoopOneInputEagerMode()
  13. {
  14. tf.enable_eager_execution();
  15. var i = tf.constant(2);
  16. Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
  17. Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
  18. var r = tf.while_loop(c, b, i);
  19. Assert.AreEqual(10, (int)r);
  20. }
  21. [TestMethod]
  22. public void WhileLoopTwoInputsEagerMode()
  23. {
  24. tf.enable_eager_execution();
  25. var i = tf.constant(2);
  26. var j = tf.constant(3);
  27. Func<Tensors, Tensor> c = (x) => tf.less(x[0] + x[1], 10);
  28. Func<Tensors, Tensors> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
  29. var r = tf.while_loop(c, b, new[] { i, j });
  30. Assert.AreEqual(5, (int)r[0]);
  31. Assert.AreEqual(6, (int)r[1]);
  32. }
  33. [TestMethod, Ignore]
  34. public void WhileLoopGraphMode()
  35. {
  36. tf.compat.v1.disable_eager_execution();
  37. var i = tf.constant(2);
  38. Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
  39. Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
  40. var r = tf.while_loop(c, b, i);
  41. Assert.AreEqual(10, (int)r);
  42. }
  43. [TestMethod, Ignore]
  44. public void ScanFunctionGraphMode()
  45. {
  46. tf.compat.v1.disable_eager_execution();
  47. Func<Tensor, Tensor, Tensor> fn = (prev, current) => tf.add(prev, current);
  48. var input = tf.placeholder(TF_DataType.TF_FLOAT, new Shape(6));
  49. var scan = tf.scan(fn, input);
  50. var sess = tf.Session();
  51. sess.run(tf.global_variables_initializer());
  52. var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6)));
  53. Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>());
  54. }
  55. }
  56. }