You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

Array.Indexing.Test.cs 3.6 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Text;
  6. using Tensorflow;
  7. using Tensorflow.NumPy;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest.NumPy
  10. {
  11. /// <summary>
  12. /// https://numpy.org/doc/stable/user/basics.indexing.html
  13. /// </summary>
  14. [TestClass]
  15. public class ArrayIndexingTest : EagerModeTestBase
  16. {
  17. [TestMethod]
  18. public void int_params()
  19. {
  20. var x = np.arange(24).reshape((2, 3, 4));
  21. x[1, 2, 3] = 1;
  22. var y = x[1, 2, 3];
  23. Assert.AreEqual(y.shape, Shape.Scalar);
  24. Assert.AreEqual(y, 1);
  25. x[0, 0] = new[] { 3, 1, 1, 2 };
  26. y = x[0, 0];
  27. Assert.AreEqual(y.shape, 4);
  28. Assert.AreEqual(y, new[] { 3, 1, 1, 2 });
  29. y = x[0];
  30. Assert.AreEqual(y.shape, (3, 4));
  31. var z = np.arange(12).reshape((3, 4));
  32. x[1] = z;
  33. Assert.AreEqual(x[1], z);
  34. }
  35. [TestMethod]
  36. public void slice_params()
  37. {
  38. var x = np.arange(12).reshape((3, 4));
  39. var y = x[new Slice(0, 1), new Slice(2)];
  40. Assert.AreEqual(y.shape, (1, 2));
  41. Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2)));
  42. }
  43. [TestMethod]
  44. public void slice_string_params()
  45. {
  46. var x = np.arange(12).reshape((3, 4));
  47. var y = x[Slice.ParseSlices("0:1,2:")];
  48. Assert.AreEqual(y.shape, (1, 2));
  49. Assert.AreEqual(y, np.array(new[] { 2, 3 }).reshape((1, 2)));
  50. }
  51. [TestMethod]
  52. public void slice_out_bound()
  53. {
  54. var input_shape = tf.constant(new int[] { 1, 1 });
  55. var input_shape_val = input_shape.numpy();
  56. input_shape_val[(int)input_shape.size - 1] = 1;
  57. input_shape.Dispose();
  58. }
  59. [TestMethod]
  60. public void shape_helper_get_shape_3dim()
  61. {
  62. var x = np.arange(24).reshape((4, 3, 2));
  63. var shape1 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true));
  64. Assert.AreEqual(shape1, (3, 2));
  65. var shape2 = ShapeHelper.GetShape(x.shape, new Slice(1));
  66. Assert.AreEqual(shape2, (3, 3, 2));
  67. var shape3 = ShapeHelper.GetShape(x.shape, new Slice(2), Slice.All);
  68. Assert.AreEqual(shape3, (2, 3, 2));
  69. var shape4 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(2));
  70. Assert.AreEqual(shape4, (1, 2));
  71. var shape5 = ShapeHelper.GetShape(x.shape, new Slice(1, isIndex: true), new Slice(1));
  72. Assert.AreEqual(shape5, (2, 2));
  73. var shape6 = ShapeHelper.GetShape(x.shape, new Slice(1), new Slice(1, isIndex: true), new Slice(1));
  74. Assert.AreEqual(shape6, (3, 1));
  75. }
  76. [TestMethod]
  77. public void shape_helper_get_shape_4dim()
  78. {
  79. var x = np.arange(120).reshape((4, 3, 2, 5));
  80. var slices = new[] { new Slice(1, isIndex: true), new Slice(1), new Slice(0, isIndex: true), new Slice(1) };
  81. var shape1 = ShapeHelper.GetShape(x.shape, slices);
  82. Assert.AreEqual(shape1, (2, 4));
  83. var shape2 = ShapeHelper.GetShape(x.shape, Slice.All);
  84. Assert.AreEqual(shape2, (4, 3, 2, 5));
  85. var shape3 = ShapeHelper.GetShape(x.shape, Slice.All, new Slice(0, isIndex: true));
  86. Assert.AreEqual(shape3, (4, 3, 2));
  87. }
  88. }
  89. }