You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionApiTest.cs 3.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using Tensorflow.Graphs;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest.ManagedAPI
  10. {
  11. [TestClass]
  12. public class FunctionApiTest : TFNetApiTest
  13. {
  14. Tensor Min(Tensor a, Tensor b)
  15. {
  16. return tf.cond(a < b, () => a, () => b);
  17. }
  18. [TestMethod]
  19. public void MulInAutoGraph()
  20. {
  21. var a = tf.constant(1);
  22. var b = tf.constant(2);
  23. // For first time running, tf.net will record the operations in graph mode.
  24. // And register to tensorflow op library.
  25. var output = Mul(a, b);
  26. Assert.AreEqual(2, (int)output);
  27. var c = tf.constant(3);
  28. // for the following invoke, Mul will be intercepted and run it in eager mode.
  29. output = Mul(b, c);
  30. Assert.AreEqual(6, (int)output);
  31. }
  32. /// <summary>
  33. /// Method with AutoGraph attribute will be converted to FuncGraph
  34. /// when it's invoked for the first time.
  35. /// </summary>
  36. /// <param name="a"></param>
  37. /// <param name="b"></param>
  38. /// <returns></returns>
  39. [AutoGraph]
  40. Tensor Mul(Tensor a, Tensor b)
  41. {
  42. return a * b;
  43. }
  44. [TestMethod]
  45. public void TwoInputs_OneOutput()
  46. {
  47. var func = tf.autograph.to_graph(Add);
  48. var a = tf.constant(1);
  49. var b = tf.constant(2);
  50. var output = func(a, b);
  51. Assert.AreEqual(3, (int)output);
  52. }
  53. Tensor Add(Tensor a, Tensor b)
  54. {
  55. return a + b;
  56. }
  57. [TestMethod]
  58. public void TwoInputs_OneOutput_Condition()
  59. {
  60. var func = tf.autograph.to_graph(Condition);
  61. var a = tf.constant(3);
  62. var b = tf.constant(2);
  63. var output = func(a, b);
  64. Assert.AreEqual(2, (int)output);
  65. }
  66. Tensor Condition(Tensor a, Tensor b)
  67. {
  68. return tf.cond(a < b, a, b);
  69. }
  70. [TestMethod]
  71. public void TwoInputs_OneOutput_Lambda()
  72. {
  73. var func = tf.autograph.to_graph((x, y) => x * y);
  74. var output = func(tf.constant(3), tf.constant(2));
  75. Assert.AreEqual(6, (int)output);
  76. }
  77. [TestMethod]
  78. public void TwoInputs_OneOutput_WhileLoop()
  79. {
  80. var func = tf.autograph.to_graph((x, y) => x * y);
  81. var output = func(tf.constant(3), tf.constant(2));
  82. Assert.AreEqual(6, (int)output);
  83. }
  84. Tensor WhileLoop()
  85. {
  86. var i = tf.constant(0);
  87. Func<Tensor, Tensor> c = i => tf.less(i, 10);
  88. Func<Tensor, Tensor> b = i => tf.add(i, 1);
  89. //var r = tf.(c, b, [i])
  90. throw new NotImplementedException("");
  91. }
  92. }
  93. }