Browse Source

Add test case for Tensor.assign #653

The test case is currently not passed yet.
tags/v0.30
Banyc Haiping 4 years ago
parent
commit
6f8beab362
1 changed files with 40 additions and 0 deletions
  1. +40
    -0
      test/TensorFlowNET.UnitTest/Basics/TensorTest.cs

+ 40
- 0
test/TensorFlowNET.UnitTest/Basics/TensorTest.cs View File

@@ -298,5 +298,45 @@ namespace TensorFlowNET.UnitTest.NativeAPI

tf.compat.v1.disable_eager_execution();
}

/// <summary>
/// Assign tensor to slice of other tensor.
/// </summary>
[TestMethod]
public void TestAssignOfficial()
{
// example from https://www.tensorflow.org/api_docs/python/tf/Variable#__getitem__

// python
// import tensorflow as tf
// A = tf.Variable([[1,2,3], [4,5,6], [7,8,9]], dtype=tf.float32)
// with tf.compat.v1.Session() as sess:
// sess.run(tf.compat.v1.global_variables_initializer())
// print(sess.run(A[:2, :2])) # => [[1,2], [4,5]]

// op = A[:2,:2].assign(22. * tf.ones((2, 2)))
// print(sess.run(op)) # => [[22, 22, 3], [22, 22, 6], [7,8,9]]

// C#
// [[1,2,3], [4,5,6], [7,8,9]]
double[][] initial = new double[][]
{
new double[] { 1, 2, 3 },
new double[] { 4, 5, 6 },
new double[] { 7, 8, 9 }
};
Tensor A = tf.Variable(initial, dtype: tf.float32);
// Console.WriteLine(A[":2", ":2"]); // => [[1,2], [4,5]]
Tensor result1 = A[":2", ":2"];
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 1, 2 }, result1[0].ToArray<double>()));
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 4, 5 }, result1[1].ToArray<double>()));

// An unhandled exception of type 'System.ArgumentException' occurred in TensorFlow.NET.dll: 'Dimensions {2, 2, and {2, 2, are not compatible'
Tensor op = A[":2", ":2"].assign(22.0 * tf.ones((2, 2)));
// Console.WriteLine(op); // => [[22, 22, 3], [22, 22, 6], [7,8,9]]
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 3 }, op[0].ToArray<double>()));
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 22, 22, 6 }, op[1].ToArray<double>()));
Assert.IsTrue(Enumerable.SequenceEqual(new double[] { 7, 8, 9 }, op[2].ToArray<double>()));
}
}
}

Loading…
Cancel
Save