You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ArrayOpsTest.cs 3.4 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using NumSharp.Utilities;
  4. using Tensorflow;
  5. using static Tensorflow.Binding;
  6. namespace TensorFlowNET.UnitTest.ManagedAPI
  7. {
  8. [TestClass]
  9. public class ArrayOpsTest : EagerModeTestBase
  10. {
  11. /// <summary>
  12. /// https://www.tensorflow.org/api_docs/python/tf/slice
  13. /// </summary>
  14. [TestMethod]
  15. public void Slice()
  16. {
  17. // Tests based on example code in TF documentation
  18. var input_array = tf.constant(np.array(new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }).reshape(3,2,3));
  19. var indices = tf.constant(np.array(new int[] { 0, 2 }));
  20. var r1 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 1, 1, 3 });
  21. Assert.AreEqual(new TensorShape(1,1,3), r1.shape);
  22. var r1np = r1.numpy();
  23. Assert.AreEqual(r1np[0, 0, 0], 3);
  24. Assert.AreEqual(r1np[0, 0, 1], 3);
  25. Assert.AreEqual(r1np[0, 0, 2], 3);
  26. var r2 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 1, 2, 3 });
  27. Assert.AreEqual(new TensorShape(1, 2, 3), r2.shape);
  28. var r2np = r2.numpy();
  29. Assert.AreEqual(r2np[0, 0, 0], 3);
  30. Assert.AreEqual(r2np[0, 0, 1], 3);
  31. Assert.AreEqual(r2np[0, 0, 2], 3);
  32. Assert.AreEqual(r2np[0, 1, 0], 4);
  33. Assert.AreEqual(r2np[0, 1, 1], 4);
  34. Assert.AreEqual(r2np[0, 1, 2], 4);
  35. var r3 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 2, 1, 3 });
  36. Assert.AreEqual(new TensorShape(2, 1, 3), r3.shape);
  37. var r3np = r3.numpy();
  38. Assert.AreEqual(r3np[0, 0, 0], 3);
  39. Assert.AreEqual(r3np[0, 0, 1], 3);
  40. Assert.AreEqual(r3np[0, 0, 2], 3);
  41. Assert.AreEqual(r3np[1, 0, 0], 5);
  42. Assert.AreEqual(r3np[1, 0, 1], 5);
  43. Assert.AreEqual(r3np[1, 0, 2], 5);
  44. }
  45. /// <summary>
  46. /// https://www.tensorflow.org/api_docs/python/tf/gather
  47. /// </summary>
  48. [TestMethod]
  49. public void Gather()
  50. {
  51. var input_array = tf.constant(np.arange(12).reshape(3, 4).astype(np.float32));
  52. var indices = tf.constant(np.array(new int[] { 0, 2 }));
  53. var result = array_ops.gather(input_array, indices);
  54. Assert.AreEqual(new TensorShape(2, 4), result.shape);
  55. Assert.AreEqual(result.numpy()[0, 0], 0.0f);
  56. Assert.AreEqual(result.numpy()[0, 1], 1.0f);
  57. Assert.AreEqual(result.numpy()[1, 3], 11.0f);
  58. // Tests based on example code in Python doc string for tf.gather()
  59. var p1 = tf.random.normal(new TensorShape(5, 6, 7, 8));
  60. var i1 = tf.random_uniform(new TensorShape(10, 11), maxval: 7, dtype: tf.int32);
  61. var r1 = tf.gather(p1, i1, axis:2);
  62. Assert.AreEqual(new TensorShape(5, 6, 10, 11, 8), r1.shape);
  63. var p2 = tf.random.normal(new TensorShape(4,3));
  64. var i2 = tf.constant(new int[,] { { 0, 2} });
  65. var r2 = tf.gather(p2, i2, axis: 0);
  66. Assert.AreEqual(new TensorShape(1, 2, 3), r2.shape);
  67. var r3 = tf.gather(p2, i2, axis: 1);
  68. Assert.AreEqual(new TensorShape(4,1,2), r3.shape);
  69. }
  70. }
  71. }