You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlFlowApiTest.cs 1.5 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.ManagedAPI
  6. {
  7. [TestClass]
  8. public class ControlFlowApiTest
  9. {
  10. [TestMethod]
  11. public void WhileLoopOneInputEagerMode()
  12. {
  13. tf.enable_eager_execution();
  14. var i = tf.constant(2);
  15. Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
  16. Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
  17. var r = tf.while_loop(c, b, i);
  18. Assert.AreEqual(10, (int)r);
  19. }
  20. [TestMethod]
  21. public void WhileLoopTwoInputsEagerMode()
  22. {
  23. tf.enable_eager_execution();
  24. var i = tf.constant(2);
  25. var j = tf.constant(3);
  26. Func<Tensor[], Tensor> c = (x) => tf.less(x[0] + x[1], 10);
  27. Func<Tensor[], Tensor[]> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
  28. var r = tf.while_loop(c, b, new[] { i, j });
  29. Assert.AreEqual(5, (int)r[0]);
  30. Assert.AreEqual(6, (int)r[1]);
  31. }
  32. [TestMethod, Ignore]
  33. public void WhileLoopGraphMode()
  34. {
  35. tf.compat.v1.disable_eager_execution();
  36. var i = tf.constant(2);
  37. Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
  38. Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
  39. var r = tf.while_loop(c, b, i);
  40. Assert.AreEqual(10, (int)r);
  41. }
  42. }
  43. }