diff --git a/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs b/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs index bad1f926..c1754393 100644 --- a/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs +++ b/test/TensorFlowNET.UnitTest/ManagedAPI/ControlFlowApiTest.cs @@ -1,4 +1,5 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; +using NumSharp; using System; using Tensorflow; using static Tensorflow.Binding; @@ -45,5 +46,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI var r = tf.while_loop(c, b, i); Assert.AreEqual(10, (int)r); } + + [TestMethod, Ignore] + public void ScanFunctionGraphMode() + { + tf.compat.v1.disable_eager_execution(); + Func fn = (prev, current) => tf.add(prev, current); + var input = tf.placeholder(TF_DataType.TF_FLOAT, new TensorShape(6)); + var scan = tf.scan(fn, input); + + using (var sess = tf.Session()) + { + sess.run(tf.global_variables_initializer()); + var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); + Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray()); + } + } } }