|
|
@@ -298,5 +298,45 @@ namespace TensorFlowNET.UnitTest.NativeAPI |
|
|
|
|
|
|
|
tf.compat.v1.disable_eager_execution(); |
|
|
|
} |
|
|
|
|
|
|
|
/// <summary> |
|
|
|
/// Assign tensor to slice of other tensor. |
|
|
|
/// </summary> |
|
|
|
[TestMethod] |
|
|
|
public void TestAssignOfficial() |
|
|
|
{ |
|
|
|
// example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__ |
|
|
|
|
|
|
|
// python |
|
|
|
// import tensorflow as tf |
|
|
|
// A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32) |
|
|
|
// with tf.compat.v1.Session() as sess: |
|
|
|
// sess.run(tf.compat.v1.global_variables_initializer()) |
|
|
|
// print(sess.run(A[:2, :2])) # => [[1,2], [4,5]] |
|
|
|
|
|
|
|
// op = A[:2,:2].assign(22. * tf.ones((2, 2))) |
|
|
|
// print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]] |
|
|
|
|
|
|
|
// C# |
|
|
|
// [[1,2,3], [4,5,6], [7,8,9]] |
|
|
|
double[][] initial = new double[][] |
|
|
|
{ |
|
|
|
new double[] { 1, 2, 3 }, |
|
|
|
new double[] { 4, 5, 6 }, |
|
|
|
new double[] { 7, 8, 9 } |
|
|
|
}; |
|
|
|
Tensor A = tf.Variable(initial, dtype: tf.float32); |
|
|
|
// Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]] |
|
|
|
Tensor result1 = A[":2", ":2"]; |
|
|
|
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray<double>())); |
|
|
|
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray<double>())); |
|
|
|
|
|
|
|
// An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible' |
|
|
|
Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2))); |
|
|
|
// Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]] |
|
|
|
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray<double>())); |
|
|
|
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray<double>())); |
|
|
|
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray<double>())); |
|
|
|
} |
|
|
|
} |
|
|
|
} |