From 6f8beab362d744f385c0647ef6f634f6bdd5a56b Mon Sep 17 00:00:00 2001
From: Banyc <36535895+Banyc@users.noreply.github.com>
Date: Fri, 4 Dec 2020 23:43:58 +0800
Subject: [PATCH] Add test case for Tensor.assign #653
The test case is currently not passed yet.
---
.../Basics/TensorTest.cs | 40 +++++++++++++++++++
1 file changed, 40 insertions(+)
diff --git a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
index a81074b3..2811b850 100644
--- a/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
+++ b/test/TensorFlowNET.UnitTest/Basics/TensorTest.cs
@@ -298,5 +298,45 @@ namespace TensorFlowNET.UnitTest.NativeAPI
tf.compat.v1.disable_eager_execution();
}
+
+ ///
+ /// Assign tensor to slice of other tensor.
+ ///
+ [TestMethod]
+ public void TestAssignOfficial()
+ {
+ // example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__
+
+ // python
+ // import tensorflow as tf
+ // A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
+ // with tf.compat.v1.Session() as sess:
+ // sess.run(tf.compat.v1.global_variables_initializer())
+ // print(sess.run(A[:2, :2])) # => [[1,2], [4,5]]
+
+ // op = A[:2,:2].assign(22. * tf.ones((2, 2)))
+ // print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]
+
+ // C#
+ // [[1,2,3], [4,5,6], [7,8,9]]
+ double[][] initial = new double[][]
+ {
+ new double[] { 1, 2, 3 },
+ new double[] { 4, 5, 6 },
+ new double[] { 7, 8, 9 }
+ };
+ Tensor A = tf.Variable(initial, dtype: tf.float32);
+ // Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]]
+ Tensor result1 = A[":2", ":2"];
+ Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray()));
+ Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray()));
+
+ // An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible'
+ Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2)));
+ // Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]]
+ Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray()));
+ Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray()));
+ Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray()));
+ }
}
}
\ No newline at end of file