You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ScanTestCase.cs 1.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. using System;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using NumSharp;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.functional_ops_test
  7. {
  8. /// <summary>
  9. /// https://www.tensorflow.org/api_docs/python/tf/scan
  10. /// </summary>
  11. [Ignore]
  12. [TestClass]
  13. public class ScanTestCase
  14. {
  15. [Ignore("TODO")]
  16. [TestMethod]
  17. public void ScanForward()
  18. {
  19. var fn = new Func<Tensor, Tensor, Tensor>((a, x) => tf.add(a, x));
  20. var sess = tf.Session().as_default();
  21. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
  22. var scan = functional_ops.scan(fn, input);
  23. sess.run(scan, (input, np.array(1,2,3,4,5,6))).Should().Be(np.array(1,3,6,10,15,21));
  24. }
  25. [Ignore("TODO")]
  26. [TestMethod]
  27. public void ScanReverse()
  28. {
  29. var fn = new Func<Tensor, Tensor, Tensor>((a, x) => tf.add(a, x));
  30. var sess = tf.Session().as_default();
  31. var input = tf.placeholder(TF_DataType.TF_INT32, new TensorShape(6));
  32. var scan = functional_ops.scan(fn, input, reverse:true);
  33. sess.run(scan, (input, np.array(1,2,3,4,5,6))).Should().Be(np.array(21,20,18,15,11,6));
  34. }
  35. }
  36. }