You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

TensorOperate.cs 5.7 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System.Linq;
  4. using static Tensorflow.Binding;
  5. namespace TensorFlowNET.UnitTest.ManagedAPI
  6. {
  7. [TestClass]
  8. public class TensorOperate
  9. {
  10. [TestMethod]
  11. public void TransposeTest()
  12. {
  13. // https://www.tensorflow.org/api_docs/python/tf/transpose#for_example_2
  14. var x = tf.constant(new int[,]
  15. {
  16. { 1, 2, 3 },
  17. { 4, 5, 6 }
  18. });
  19. var transpose_x = tf.transpose(x);
  20. Assert.AreEqual(new[] { 1, 4 }, transpose_x[0].numpy());
  21. Assert.AreEqual(new[] { 2, 5 }, transpose_x[1].numpy());
  22. Assert.AreEqual(new[] { 3, 6 }, transpose_x[2].numpy());
  23. #region constant a
  24. var a = tf.constant(np.array(new[, , ,]
  25. {
  26. {
  27. {
  28. { 1, 11, 2, 22 }
  29. },
  30. {
  31. { 3, 33, 4, 44 }
  32. }
  33. },
  34. {
  35. {
  36. { 5, 55, 6, 66 }
  37. },
  38. {
  39. { 7, 77, 8, 88 }
  40. }
  41. }
  42. }));
  43. #endregion
  44. var actual_transposed_a = tf.transpose(a, new[] { 3, 1, 2, 0 });
  45. #region constant transpose_a
  46. var expected_transposed_a = tf.constant(np.array(new[, , ,]
  47. {
  48. {
  49. { { 1, 5 } }, { { 3, 7 } }
  50. },
  51. {
  52. { { 11, 55 } }, { { 33, 77 } }
  53. },
  54. {
  55. {
  56. { 2, 6 }
  57. },
  58. {
  59. { 4, 8 }
  60. }
  61. },
  62. {
  63. {
  64. { 22, 66 }
  65. },
  66. {
  67. { 44, 88 }
  68. }
  69. }
  70. }));
  71. #endregion
  72. Assert.AreEqual((4, 2, 1, 2), actual_transposed_a.TensorShape);
  73. Assert.AreEqual(expected_transposed_a.numpy(), actual_transposed_a.numpy());
  74. }
  75. [TestMethod]
  76. public void InitTensorTest()
  77. {
  78. var a = tf.constant(np.array(new[, ,]
  79. {
  80. { { 1 }, { 2 }, { 3 } },
  81. { { 4 }, { 5 }, { 6 } }
  82. }));
  83. Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, a.shape.dims));
  84. var b = tf.constant(new[, ,]
  85. {
  86. { { 1 }, { 2 }, { 3 } },
  87. { { 4 }, { 5 }, { 6 } }
  88. });
  89. Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3, 1 }, b.shape.dims));
  90. }
  91. [TestMethod]
  92. public void ConcatTest()
  93. {
  94. var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } });
  95. var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } });
  96. var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } });
  97. var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
  98. Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims));
  99. }
  100. [TestMethod]
  101. public void ConcatDoubleTest()
  102. {
  103. var a = tf.constant(new[,] { { 1.0, 2.0 }, { 3.0, 4.0 } });
  104. var b = tf.constant(new[,] { { 5.0, 6.0 }, { 7.0, 8.0 } });
  105. var c = tf.constant(new[,] { { 9.0, 10.0 }, { 11.0, 12.0 } });
  106. var concatValue = tf.concat(new[] { a, b, c }, axis: 0);
  107. Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 6, 2 }, concatValue.shape.dims));
  108. }
  109. [TestMethod]
  110. public void ConcatAndSplitTest()
  111. {
  112. var a = tf.constant(new[,] { { 1, 2 }, { 3, 4 } });
  113. var b = tf.constant(new[,] { { 5, 6 }, { 7, 8 } });
  114. var c = tf.constant(new[,] { { 9, 10 }, { 11, 12 } });
  115. var value = tf.concat(new[] { a, b, c }, axis: 0);
  116. var splitValue = tf.split(value, 3, axis: 0);
  117. Assert.AreEqual(3, splitValue.Length);
  118. Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 2 }, splitValue[0].shape.dims));
  119. }
  120. #region ones/zeros like
  121. [TestMethod]
  122. public void TestOnesLike()
  123. {
  124. #region 2-dimension
  125. var ones2D = tf.ones_like(new int[,]
  126. {
  127. { 1, 2, 3 },
  128. { 4, 5, 6 }
  129. });
  130. Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[0].numpy());
  131. Assert.AreEqual(new[] { 1, 1, 1 }, ones2D[1].numpy());
  132. #endregion
  133. #region 1-dimension
  134. var ones1D = tf.ones_like(new int[,]
  135. {
  136. { 1, 2, 3 }
  137. });
  138. Assert.AreEqual(new[] { 1, 1, 1 }, ones1D[0].numpy());
  139. #endregion
  140. }
  141. [TestMethod]
  142. public void TestZerosLike()
  143. {
  144. #region 2-dimension
  145. var zeros2D = tf.zeros_like(new int[,]
  146. {
  147. { 1, 2, 3 },
  148. { 4, 5, 6 }
  149. });
  150. Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[0].numpy());
  151. Assert.AreEqual(new[] { 0, 0, 0 }, zeros2D[1].numpy());
  152. #endregion
  153. #region 1-dimension
  154. var zeros1D = tf.zeros_like(new int[,]
  155. {
  156. { 1, 2, 3 }
  157. });
  158. Assert.AreEqual(new[] { 0, 0, 0 }, zeros1D[0].numpy());
  159. #endregion
  160. }
  161. #endregion
  162. }
  163. }