You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionApiTest.cs 2.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.ManagedAPI
  6. {
  7. [TestClass]
  8. public class FunctionApiTest : TFNetApiTest
  9. {
  10. Tensor Min(Tensor a, Tensor b)
  11. {
  12. return tf.cond(a < b, () => a, () => b);
  13. }
  14. [TestMethod]
  15. public void MulInAutoGraph()
  16. {
  17. var a = tf.constant(1);
  18. var b = tf.constant(2);
  19. // For first time running, tf.net will record the operations in graph mode.
  20. // And register to tensorflow op library.
  21. var output = Mul(a, b);
  22. Assert.AreEqual(2, (int)output);
  23. var c = tf.constant(3);
  24. // for the following invoke, Mul will be intercepted and run it in eager mode.
  25. output = Mul(b, c);
  26. Assert.AreEqual(6, (int)output);
  27. }
  28. /// <summary>
  29. /// Method with AutoGraph attribute will be converted to FuncGraph
  30. /// when it's invoked for the first time.
  31. /// </summary>
  32. /// <param name="a"></param>
  33. /// <param name="b"></param>
  34. /// <returns></returns>
  35. // [AutoGraph]
  36. Tensor Mul(Tensor a, Tensor b)
  37. {
  38. return a * b;
  39. }
  40. [TestMethod]
  41. public void TwoInputs_OneOutput()
  42. {
  43. var func = tf.autograph.to_graph(Add);
  44. var a = tf.constant(1);
  45. var b = tf.constant(2);
  46. var output = func(a, b);
  47. Assert.AreEqual(3, (int)output);
  48. }
  49. Tensor Add(Tensor a, Tensor b)
  50. {
  51. return a + b;
  52. }
  53. [TestMethod]
  54. public void TwoInputs_OneOutput_Condition()
  55. {
  56. var func = tf.autograph.to_graph(Condition);
  57. var a = tf.constant(3);
  58. var b = tf.constant(2);
  59. var output = func(a, b);
  60. Assert.AreEqual(2, (int)output);
  61. }
  62. Tensor Condition(Tensor a, Tensor b)
  63. {
  64. return tf.cond(a < b, a, b);
  65. }
  66. [TestMethod]
  67. public void TwoInputs_OneOutput_Lambda()
  68. {
  69. var func = tf.autograph.to_graph((x, y) => x * y);
  70. var output = func(tf.constant(3), tf.constant(2));
  71. Assert.AreEqual(6, (int)output);
  72. }
  73. [TestMethod]
  74. public void TwoInputs_OneOutput_WhileLoop()
  75. {
  76. var func = tf.autograph.to_graph((x, y) => x * y);
  77. var output = func(tf.constant(3), tf.constant(2));
  78. Assert.AreEqual(6, (int)output);
  79. }
  80. Tensor WhileLoop()
  81. {
  82. var i = tf.constant(0);
  83. Func<Tensor, Tensor> c = i => tf.less(i, 10);
  84. Func<Tensor, Tensor> b = i => tf.add(i, 1);
  85. //var r = tf.(c, b, [i])
  86. throw new NotImplementedException("");
  87. }
  88. }
  89. }