You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ScanTestCase.cs 1.4 kB

5 years ago
5 years ago
5 years ago
1234567891011121314151617181920212223242526272829303132333435363738394041
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.FunctionalOpsTest
  7. {
  8. /// <summary>
  9. /// https://www.tensorflow.org/api_docs/python/tf/scan
  10. /// </summary>
  11. [TestClass]
  12. public class ScanTestCase : GraphModeTestBase
  13. {
  14. [TestMethod, Ignore("need UpdateEdge API")]
  15. public void ScanForward()
  16. {
  17. var fn = new Func<Tensor, Tensor, Tensor>((a, x) => tf.add(a, x));
  18. var sess = tf.Session().as_default();
  19. var input = tf.placeholder(TF_DataType.TF_INT32, new Shape(6));
  20. var scan = functional_ops.scan(fn, input);
  21. var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6)));
  22. Assert.AreEqual(result, np.array(1, 3, 6, 10, 15, 21));
  23. }
  24. [TestMethod, Ignore("need UpdateEdge API")]
  25. public void ScanReverse()
  26. {
  27. var fn = new Func<Tensor, Tensor, Tensor>((a, x) => tf.add(a, x));
  28. var sess = tf.Session().as_default();
  29. var input = tf.placeholder(TF_DataType.TF_INT32, new Shape(6));
  30. var scan = functional_ops.scan(fn, input, reverse: true);
  31. var result = sess.run(scan, (input, np.array(1, 2, 3, 4, 5, 6)));
  32. Assert.AreEqual(result, np.array(21, 20, 18, 15, 11, 6));
  33. }
  34. }
  35. }