You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ControlFlowApiTest.cs 1.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq.Expressions;
  5. using System.Runtime.CompilerServices;
  6. using System.Security.Cryptography.X509Certificates;
  7. using System.Text;
  8. using Tensorflow;
  9. using static Tensorflow.Binding;
  10. namespace TensorFlowNET.UnitTest.ManagedAPI
  11. {
  12. [TestClass]
  13. public class ControlFlowApiTest
  14. {
  15. [TestMethod]
  16. public void WhileLoopOneInputEagerMode()
  17. {
  18. tf.enable_eager_execution();
  19. var i = tf.constant(2);
  20. Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
  21. Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
  22. var r = tf.while_loop(c, b, i);
  23. Assert.AreEqual(10, (int)r);
  24. }
  25. [TestMethod]
  26. public void WhileLoopTwoInputsEagerMode()
  27. {
  28. tf.enable_eager_execution();
  29. var i = tf.constant(2);
  30. var j = tf.constant(3);
  31. Func<Tensor[], Tensor> c = (x) => tf.less(x[0] + x[1], 10);
  32. Func<Tensor[], Tensor[]> b = (x) => new[] { tf.add(x[0], 1), tf.add(x[1], 1) };
  33. var r = tf.while_loop(c, b, new[] { i, j });
  34. Assert.AreEqual(5, (int)r[0]);
  35. Assert.AreEqual(6, (int)r[1]);
  36. }
  37. [TestMethod, Ignore]
  38. public void WhileLoopGraphMode()
  39. {
  40. tf.compat.v1.disable_eager_execution();
  41. var i = tf.constant(2);
  42. Func<Tensor, Tensor> c = (x) => tf.less(x, 10);
  43. Func<Tensor, Tensor> b = (x) => tf.add(x, 1);
  44. var r = tf.while_loop(c, b, i);
  45. Assert.AreEqual(10, (int)r);
  46. }
  47. }
  48. }