|
|
@@ -1,4 +1,5 @@ |
|
|
|
using Microsoft.VisualStudio.TestTools.UnitTesting; |
|
|
|
using NumSharp; |
|
|
|
using System; |
|
|
|
using Tensorflow; |
|
|
|
using static Tensorflow.Binding; |
|
|
@@ -45,5 +46,21 @@ namespace TensorFlowNET.UnitTest.ManagedAPI |
|
|
|
var r = tf.while_loop(c, b, i); |
|
|
|
Assert.AreEqual(10, (int)r); |
|
|
|
} |
|
|
|
|
|
|
|
[TestMethod, Ignore] |
|
|
|
public void ScanFunctionGraphMode() |
|
|
|
{ |
|
|
|
tf.compat.v1.disable_eager_execution(); |
|
|
|
Func<Tensor, Tensor, Tensor> fn = (prev, current) => tf.add(prev, current); |
|
|
|
var input = tf.placeholder(TF_DataType.TF_FLOAT, new TensorShape(6)); |
|
|
|
var scan = tf.scan(fn, input); |
|
|
|
|
|
|
|
using (var sess = tf.Session()) |
|
|
|
{ |
|
|
|
sess.run(tf.global_variables_initializer()); |
|
|
|
var result = sess.run(scan, new FeedItem(input, np.array(1, 2, 3, 4, 5, 6))); |
|
|
|
Assert.AreEqual(new float[] { 1, 3, 6, 10, 15, 21 }, result.ToArray<float>()); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |