You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ArrayOpsTest.cs 10 kB

2 years ago
2 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using Tensorflow;
  4. using static Tensorflow.Binding;
  5. using System.Linq;
  6. namespace TensorFlowNET.UnitTest.ManagedAPI
  7. {
  8. [TestClass]
  9. public class ArrayOpsTest : EagerModeTestBase
  10. {
  11. /// <summary>
  12. /// https://www.tensorflow.org/api_docs/python/tf/slice
  13. /// </summary>
  14. [TestMethod]
  15. public void Slice()
  16. {
  17. // Tests based on example code in TF documentation
  18. var input_array = tf.constant(np.array(new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }).reshape((3, 2, 3)));
  19. var indices = tf.constant(np.array(new int[] { 0, 2 }));
  20. var r1 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 1, 3 }));
  21. Assert.AreEqual(new Shape(1, 1, 3), r1.shape);
  22. var r1np = r1.numpy();
  23. Assert.AreEqual(r1np[0, 0, 0], 3);
  24. Assert.AreEqual(r1np[0, 0, 1], 3);
  25. Assert.AreEqual(r1np[0, 0, 2], 3);
  26. var r2 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 2, 3 }));
  27. Assert.AreEqual(new Shape(1, 2, 3), r2.shape);
  28. var r2np = r2.numpy();
  29. Assert.AreEqual(r2np[0, 0, 0], 3);
  30. Assert.AreEqual(r2np[0, 0, 1], 3);
  31. Assert.AreEqual(r2np[0, 0, 2], 3);
  32. Assert.AreEqual(r2np[0, 1, 0], 4);
  33. Assert.AreEqual(r2np[0, 1, 1], 4);
  34. Assert.AreEqual(r2np[0, 1, 2], 4);
  35. var r3 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 2, 1, 3 }));
  36. Assert.AreEqual(new Shape(2, 1, 3), r3.shape);
  37. var r3np = r3.numpy();
  38. Assert.AreEqual(r3np[0, 0, 0], 3);
  39. Assert.AreEqual(r3np[0, 0, 1], 3);
  40. Assert.AreEqual(r3np[0, 0, 2], 3);
  41. Assert.AreEqual(r3np[1, 0, 0], 5);
  42. Assert.AreEqual(r3np[1, 0, 1], 5);
  43. Assert.AreEqual(r3np[1, 0, 2], 5);
  44. }
  45. /// <summary>
  46. /// https://www.tensorflow.org/api_docs/python/tf/gather
  47. /// </summary>
  48. [TestMethod]
  49. public void Gather()
  50. {
  51. var input_array = tf.constant(np.arange(12).reshape((3, 4)).astype(np.float32));
  52. var indices = tf.constant(np.array(new int[] { 0, 2 }));
  53. var result = array_ops.gather(input_array, indices);
  54. Assert.AreEqual(new Shape(2, 4), result.shape);
  55. Assert.AreEqual(result.numpy()[0, 0], 0.0f);
  56. Assert.AreEqual(result.numpy()[0, 1], 1.0f);
  57. Assert.AreEqual(result.numpy()[1, 3], 11.0f);
  58. // Tests based on example code in Python doc string for tf.gather()
  59. var p1 = tf.random.normal(new Shape(5, 6, 7, 8));
  60. var i1 = tf.random_uniform(new Shape(10, 11), maxval: 7, dtype: tf.int32);
  61. var r1 = tf.gather(p1, i1, axis: 2);
  62. Assert.AreEqual(new Shape(5, 6, 10, 11, 8), r1.shape);
  63. var p2 = tf.random.normal(new Shape(4, 3));
  64. var i2 = tf.constant(new int[,] { { 0, 2 } });
  65. var r2 = tf.gather(p2, i2, axis: 0);
  66. Assert.AreEqual(new Shape(1, 2, 3), r2.shape);
  67. var r3 = tf.gather(p2, i2, axis: 1);
  68. Assert.AreEqual(new Shape(4, 1, 2), r3.shape);
  69. }
  70. /// <summary>
  71. /// https://www.tensorflow.org/api_docs/python/tf/TensorArray
  72. /// </summary>
  73. [TestMethod]
  74. public void TensorArray()
  75. {
  76. var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false);
  77. ta.write(0, 10);
  78. ta.write(1, 20);
  79. ta.write(2, 30);
  80. Assert.AreEqual(ta.read(0).numpy(), 10f);
  81. Assert.AreEqual(ta.read(1).numpy(), 20f);
  82. Assert.AreEqual(ta.read(2).numpy(), 30f);
  83. }
  84. /// <summary>
  85. ///
  86. /// </summary>
  87. [TestMethod]
  88. public void Reverse()
  89. {
  90. /*
  91. * python run get test data code:
  92. import tensorflow as tf
  93. data=[[1, 2, 3], [4, 5, 6], [7,8,9]]
  94. data2 = tf.constant(data)
  95. print('test data shaper:', data2.shape)
  96. print('test data:', data2)
  97. axis = [-2,-1,0,1]
  98. for i in axis:
  99. print('')
  100. print('axis:', i)
  101. ax = tf.constant([i])
  102. datar = tf.reverse(data2, ax)
  103. datar2 = array_ops.reverse(data2, ax)
  104. print(datar)
  105. print(datar2)
  106. * */
  107. var inputData = np.array(new int[,] { { 1, 2, 3 }, { 4, 5, 6 }, { 7, 8, 9 } });
  108. var expectedOutput = new[] {
  109. // np.array(new int[,] { { 7, 8, 9 }, { 4, 5, 6 }, { 1, 2, 3 } }),
  110. np.array(new int[,] { { 3, 2, 1 }, { 6, 5, 4 }, { 9, 8, 7 } }),
  111. np.array(new int[,] { { 7, 8, 9 }, { 4, 5, 6 }, { 1, 2, 3 } }),
  112. np.array(new int[,] { { 3, 2, 1 }, { 6, 5, 4 }, { 9, 8, 7 } })
  113. };
  114. var axes = new int [] {
  115. -1,
  116. 0,
  117. 1 };
  118. for (var i = 0; i < axes.Length; i++)
  119. {
  120. var axis = axes[i];
  121. var expected = tf.constant(expectedOutput[i]).numpy();
  122. var inputTensor = tf.constant(inputData);
  123. var axisTrensor = tf.constant(new[] { axis });
  124. var outputTensor = tf.reverse_v2(inputTensor, axisTrensor);
  125. var npout = outputTensor.numpy();
  126. Assert.IsTrue(Enumerable.SequenceEqual(npout, expected), $"axis:{axis}");
  127. var outputTensor2 = tf.reverse_v2(inputTensor, new[] { axis } );
  128. var npout2 = outputTensor2.numpy();
  129. Assert.IsTrue(Enumerable.SequenceEqual(npout2, expected), $"axis:{axis}");
  130. }
  131. }
  132. }
  133. }
  134. using Microsoft.VisualStudio.TestTools.UnitTesting;
  135. using Tensorflow.NumPy;
  136. using Tensorflow;
  137. using static Tensorflow.Binding;
  138. using System.Linq;
  139. namespace TensorFlowNET.UnitTest.ManagedAPI
  140. {
  141. [TestClass]
  142. public class ArrayOpsTest : EagerModeTestBase
  143. {
  144. /// <summary>
  145. /// https://www.tensorflow.org/api_docs/python/tf/slice
  146. /// </summary>
  147. [TestMethod]
  148. public void Slice()
  149. {
  150. // Tests based on example code in TF documentation
  151. var input_array = tf.constant(np.array(new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }).reshape((3,2,3)));
  152. var indices = tf.constant(np.array(new int[] { 0, 2 }));
  153. var r1 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 1, 3 }));
  154. Assert.AreEqual(new Shape(1,1,3), r1.shape);
  155. var r1np = r1.numpy();
  156. Assert.AreEqual(r1np[0, 0, 0], 3);
  157. Assert.AreEqual(r1np[0, 0, 1], 3);
  158. Assert.AreEqual(r1np[0, 0, 2], 3);
  159. var r2 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 1, 2, 3 }));
  160. Assert.AreEqual(new Shape(1, 2, 3), r2.shape);
  161. var r2np = r2.numpy();
  162. Assert.AreEqual(r2np[0, 0, 0], 3);
  163. Assert.AreEqual(r2np[0, 0, 1], 3);
  164. Assert.AreEqual(r2np[0, 0, 2], 3);
  165. Assert.AreEqual(r2np[0, 1, 0], 4);
  166. Assert.AreEqual(r2np[0, 1, 1], 4);
  167. Assert.AreEqual(r2np[0, 1, 2], 4);
  168. var r3 = array_ops.slice(input_array, ops.convert_n_to_tensor(new object[] { 1, 0, 0 }), ops.convert_n_to_tensor(new object[] { 2, 1, 3 }));
  169. Assert.AreEqual(new Shape(2, 1, 3), r3.shape);
  170. var r3np = r3.numpy();
  171. Assert.AreEqual(r3np[0, 0, 0], 3);
  172. Assert.AreEqual(r3np[0, 0, 1], 3);
  173. Assert.AreEqual(r3np[0, 0, 2], 3);
  174. Assert.AreEqual(r3np[1, 0, 0], 5);
  175. Assert.AreEqual(r3np[1, 0, 1], 5);
  176. Assert.AreEqual(r3np[1, 0, 2], 5);
  177. }
  178. /// <summary>
  179. /// https://www.tensorflow.org/api_docs/python/tf/gather
  180. /// </summary>
  181. [TestMethod]
  182. public void Gather()
  183. {
  184. var input_array = tf.constant(np.arange(12).reshape((3, 4)).astype(np.float32));
  185. var indices = tf.constant(np.array(new int[] { 0, 2 }));
  186. var result = array_ops.gather(input_array, indices);
  187. Assert.AreEqual(new Shape(2, 4), result.shape);
  188. Assert.AreEqual(result.numpy()[0, 0], 0.0f);
  189. Assert.AreEqual(result.numpy()[0, 1], 1.0f);
  190. Assert.AreEqual(result.numpy()[1, 3], 11.0f);
  191. // Tests based on example code in Python doc string for tf.gather()
  192. var p1 = tf.random.normal(new Shape(5, 6, 7, 8));
  193. var i1 = tf.random_uniform(new Shape(10, 11), maxval: 7, dtype: tf.int32);
  194. var r1 = tf.gather(p1, i1, axis:2);
  195. Assert.AreEqual(new Shape(5, 6, 10, 11, 8), r1.shape);
  196. var p2 = tf.random.normal(new Shape(4,3));
  197. var i2 = tf.constant(new int[,] { { 0, 2} });
  198. var r2 = tf.gather(p2, i2, axis: 0);
  199. Assert.AreEqual(new Shape(1, 2, 3), r2.shape);
  200. var r3 = tf.gather(p2, i2, axis: 1);
  201. Assert.AreEqual(new Shape(4,1,2), r3.shape);
  202. }
  203. /// <summary>
  204. /// https://www.tensorflow.org/api_docs/python/tf/TensorArray
  205. /// </summary>
  206. [TestMethod]
  207. public void TensorArray()
  208. {
  209. var ta = tf.TensorArray(tf.float32, size: 0, dynamic_size: true, clear_after_read: false);
  210. ta.write(0, 10);
  211. ta.write(1, 20);
  212. ta.write(2, 30);
  213. Assert.AreEqual(ta.read(0).numpy(), 10f);
  214. Assert.AreEqual(ta.read(1).numpy(), 20f);
  215. Assert.AreEqual(ta.read(2).numpy(), 30f);
  216. }
  217. /// <summary>
  218. /// https://www.tensorflow.org/api_docs/python/tf/reverse
  219. /// </summary>
  220. [TestMethod]
  221. public void ReverseArray()
  222. {
  223. var a = tf.random.normal((2, 3));
  224. var b = tf.reverse(a, -1);
  225. Assert.IsTrue(Equal(a[0].ToArray<float>().Reverse().ToArray(), b[0].ToArray<float>()));
  226. Assert.IsTrue(Equal(a[1].ToArray<float>().Reverse().ToArray(), b[1].ToArray<float>()));
  227. }
  228. }
  229. }