You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionApiTest.cs 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using Tensorflow.Graphs;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.ManagedAPI
  7. {
  8. [TestClass]
  9. public class FunctionApiTest : TFNetApiTest
  10. {
  11. Tensor Min(Tensor a, Tensor b)
  12. {
  13. return tf.cond(a < b, () => a, () => b);
  14. }
  15. [TestMethod]
  16. public void MulInAutoGraph()
  17. {
  18. var a = tf.constant(1);
  19. var b = tf.constant(2);
  20. // For first time running, tf.net will record the operations in graph mode.
  21. // And register to tensorflow op library.
  22. var output = Mul(a, b);
  23. Assert.AreEqual(2, (int)output);
  24. var c = tf.constant(3);
  25. // for the following invoke, Mul will be intercepted and run it in eager mode.
  26. output = Mul(b, c);
  27. Assert.AreEqual(6, (int)output);
  28. }
  29. /// <summary>
  30. /// Method with AutoGraph attribute will be converted to FuncGraph
  31. /// when it's invoked for the first time.
  32. /// </summary>
  33. /// <param name="a"></param>
  34. /// <param name="b"></param>
  35. /// <returns></returns>
  36. [AutoGraph]
  37. Tensor Mul(Tensor a, Tensor b)
  38. {
  39. return a * b;
  40. }
  41. [TestMethod]
  42. public void TwoInputs_OneOutput()
  43. {
  44. var func = tf.autograph.to_graph(Add);
  45. var a = tf.constant(1);
  46. var b = tf.constant(2);
  47. var output = func(a, b);
  48. Assert.AreEqual(3, (int)output);
  49. }
  50. Tensor Add(Tensor a, Tensor b)
  51. {
  52. return a + b;
  53. }
  54. [TestMethod]
  55. public void TwoInputs_OneOutput_Condition()
  56. {
  57. var func = tf.autograph.to_graph(Condition);
  58. var a = tf.constant(3);
  59. var b = tf.constant(2);
  60. var output = func(a, b);
  61. Assert.AreEqual(2, (int)output);
  62. }
  63. Tensor Condition(Tensor a, Tensor b)
  64. {
  65. return tf.cond(a < b, a, b);
  66. }
  67. [TestMethod]
  68. public void TwoInputs_OneOutput_Lambda()
  69. {
  70. var func = tf.autograph.to_graph((x, y) => x * y);
  71. var output = func(tf.constant(3), tf.constant(2));
  72. Assert.AreEqual(6, (int)output);
  73. }
  74. [TestMethod]
  75. public void TwoInputs_OneOutput_WhileLoop()
  76. {
  77. var func = tf.autograph.to_graph((x, y) => x * y);
  78. var output = func(tf.constant(3), tf.constant(2));
  79. Assert.AreEqual(6, (int)output);
  80. }
  81. Tensor WhileLoop()
  82. {
  83. var i = tf.constant(0);
  84. Func<Tensor, Tensor> c = i => tf.less(i, 10);
  85. Func<Tensor, Tensor> b = i => tf.add(i, 1);
  86. //var r = tf.(c, b, [i])
  87. throw new NotImplementedException("");
  88. }
  89. }
  90. }