You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

FunctionApiTest.cs 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using Tensorflow.Graphs;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest.ManagedAPI
  10. {
  11. [TestClass]
  12. public class FunctionApiTest : TFNetApiTest
  13. {
  14. [TestMethod]
  15. public void TwoInputs_OneOutput()
  16. {
  17. var func = tf.autograph.to_graph(Add);
  18. var a = tf.constant(1);
  19. var b = tf.constant(2);
  20. var output = func(a, b);
  21. Assert.AreEqual(3, (int)output);
  22. }
  23. Tensor Add(Tensor a, Tensor b)
  24. {
  25. return a + b;
  26. }
  27. [TestMethod]
  28. public void TwoInputs_OneOutput_Condition()
  29. {
  30. var func = tf.autograph.to_graph(Condition);
  31. var a = tf.constant(3);
  32. var b = tf.constant(2);
  33. var output = func(a, b);
  34. Assert.AreEqual(2, (int)output);
  35. }
  36. Tensor Condition(Tensor a, Tensor b)
  37. {
  38. return tf.cond(a < b, a, b);
  39. }
  40. [TestMethod]
  41. public void TwoInputs_OneOutput_Lambda()
  42. {
  43. var func = tf.autograph.to_graph((x, y) => x * y);
  44. var output = func(tf.constant(3), tf.constant(2));
  45. Assert.AreEqual(6, (int)output);
  46. }
  47. [TestMethod]
  48. public void TwoInputs_OneOutput_WhileLoop()
  49. {
  50. var func = tf.autograph.to_graph((x, y) => x * y);
  51. var output = func(tf.constant(3), tf.constant(2));
  52. Assert.AreEqual(6, (int)output);
  53. }
  54. Tensor WhileLoop()
  55. {
  56. var i = tf.constant(0);
  57. Func<Tensor, Tensor> c = i => tf.less(i, 10);
  58. Func<Tensor, Tensor> b = i => tf.add(i, 1);
  59. //var r = tf.(c, b, [i])
  60. throw new NotImplementedException("");
  61. }
  62. }
  63. }