From 6f8beab362d744f385c0647ef6f634f6bdd5a56b Mon Sep 17 00:00:00 2001 From: Banyc <36535895+Banyc@users.noreply.github.com> Date: Fri, 4 Dec 2020 23:43:58 +0800 Subject: [PATCH] Add test case for Tensor.assign #653 The test case is currently not passed yet. --- .../Basics/TensorTest.cs | 40 +++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs index a81074b3..2811b850 100644 --- a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs +++ b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs @@ -298,5 +298,45 @@ namespace TensorFlowNET.UnitTest.NativeAPI tf.compat.v1.disable_eager_execution(); } + + /// + /// Assign tensor to slice of other tensor. + /// + [TestMethod] + public void TestAssignOfficial() + { + // example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ + + // python + // import tensorflow as tf + // A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32) + // with tf.compat.v1.Session() as sess: + // sess.run(tf.compat.v1.global_variables_initializer()) + // print(sess.run(A[:2, :2])) # => [[1,2], [4,5]] + + // op = A[:2,:2].assign(22. * tf.ones((2, 2))) + // print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]] + + // C# + // [[1,2,3], [4,5,6], [7,8,9]] + double[][] initial = new double[][] + { + new double[] { 1, 2, 3 }, + new double[] { 4, 5, 6 }, + new double[] { 7, 8, 9 } + }; + Tensor A = tf.Variable(initial, dtype: tf.float32); + // Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]] + Tensor result1 = A[":2", ":2"]; + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray())); + + // An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible' + Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2))); + // Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]] + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray())); + Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray())); + } } } \ No newline at end of file