You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 6.6 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System.Linq;
  4. using Tensorflow;
  5. using static Tensorflow.Python;
  6. namespace TensorFlowNET.UnitTest
  7. {
  8. [TestClass]
  9. public class ConstantTest
  10. {
  11. Status status = new Status();
  12. [TestMethod]
  13. public void ScalarConst()
  14. {
  15. var tensor1 = tf.constant(8); // int
  16. var tensor2 = tf.constant(6.0f); // float
  17. var tensor3 = tf.constant(6.0); // double
  18. }
  19. /*[DataTestMethod]
  20. [DataRow(int.MinValue)]
  21. [DataRow(-1)]
  22. [DataRow(0)]
  23. [DataRow(1)]
  24. [DataRow(int.MaxValue)]
  25. public void ScalarConstTypecast_int(int value)
  26. {
  27. var tensor = (Tensor)value;
  28. with(tf.Session(), sess =>
  29. {
  30. var result = sess.run(tensor);
  31. Assert.AreEqual(result.Data<int>()[0], value);
  32. });
  33. }
  34. [DataTestMethod]
  35. [DataRow(double.NegativeInfinity)]
  36. [DataRow(double.MinValue)]
  37. [DataRow(-1d)]
  38. [DataRow(0d)]
  39. [DataRow(double.Epsilon)]
  40. [DataRow(1d)]
  41. [DataRow(double.MaxValue)]
  42. [DataRow(double.PositiveInfinity)]
  43. [DataRow(double.NaN)]
  44. public void ScalarConstTypecast_double(double value)
  45. {
  46. var tensor = (Tensor)value;
  47. with(tf.Session(), sess =>
  48. {
  49. var result = sess.run(tensor);
  50. Assert.AreEqual(result.Data<double>()[0], value);
  51. });
  52. }
  53. [DataTestMethod]
  54. [DataRow(float.NegativeInfinity)]
  55. [DataRow(float.MinValue)]
  56. [DataRow(-1f)]
  57. [DataRow(0f)]
  58. [DataRow(float.Epsilon)]
  59. [DataRow(1f)]
  60. [DataRow(float.MaxValue)]
  61. [DataRow(float.PositiveInfinity)]
  62. [DataRow(float.NaN)]
  63. public void ScalarConstTypecast_float(float value)
  64. {
  65. var tensor = (Tensor)value;
  66. with(tf.Session(), sess =>
  67. {
  68. var result = sess.run(tensor);
  69. Assert.AreEqual(result.Data<double>()[0], value);
  70. });
  71. }
  72. [TestMethod]
  73. public void StringConst()
  74. {
  75. string str = "Hello, TensorFlow.NET!";
  76. var tensor = tf.constant(str);
  77. with(tf.Session(), sess =>
  78. {
  79. var result = sess.run(tensor);
  80. Assert.IsTrue(result.Data<string>()[0] == str);
  81. });
  82. }*/
  83. [TestMethod]
  84. public void ZerosConst()
  85. {
  86. // small size
  87. var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small");
  88. with(tf.Session(), sess =>
  89. {
  90. var result = sess.run(tensor);
  91. Assert.AreEqual(result.shape[0], 3);
  92. Assert.AreEqual(result.shape[1], 2);
  93. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>()));
  94. });
  95. // big size
  96. tensor = tf.zeros(new Shape(200, 100), TF_DataType.TF_INT32, "big");
  97. with(tf.Session(), sess =>
  98. {
  99. var result = sess.run(tensor);
  100. Assert.AreEqual(result.shape[0], 200);
  101. Assert.AreEqual(result.shape[1], 100);
  102. var data = result.Data<int>();
  103. Assert.AreEqual(0, data[0]);
  104. Assert.AreEqual(0, data[500]);
  105. Assert.AreEqual(0, data[result.size - 1]);
  106. });
  107. }
  108. [TestMethod]
  109. public void OnesConst()
  110. {
  111. var ones = tf.ones(new Shape(3, 2), TF_DataType.TF_DOUBLE, "ones");
  112. with(tf.Session(), sess =>
  113. {
  114. var result = sess.run(ones);
  115. Assert.AreEqual(result.shape[0], 3);
  116. Assert.AreEqual(result.shape[1], 2);
  117. Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data<int>()));
  118. });
  119. }
  120. [TestMethod]
  121. public void OnesToHalves()
  122. {
  123. var ones = tf.ones(new Shape(3, 2), TF_DataType.TF_DOUBLE, "ones");
  124. var halfes = ones * 0.5;
  125. with(tf.Session(), sess =>
  126. {
  127. var result = sess.run(halfes);
  128. Assert.AreEqual(result.shape[0], 3);
  129. Assert.AreEqual(result.shape[1], 2);
  130. Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data<double>()));
  131. });
  132. }
  133. [TestMethod]
  134. public void NDimConst()
  135. {
  136. var nd = np.array(new int[][]
  137. {
  138. new int[]{ 3, 1, 1 },
  139. new int[]{ 2, 1, 3 }
  140. });
  141. var tensor = tf.constant(nd);
  142. with(tf.Session(), sess =>
  143. {
  144. var result = sess.run(tensor);
  145. var data = result.Data<int>();
  146. Assert.AreEqual(result.shape[0], 2);
  147. Assert.AreEqual(result.shape[1], 3);
  148. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
  149. });
  150. }
  151. [TestMethod]
  152. public void Multiply()
  153. {
  154. var a = tf.constant(3.0);
  155. var b = tf.constant(2.0);
  156. var c = a * b;
  157. var sess = tf.Session();
  158. double result = sess.run(c);
  159. sess.close();
  160. Assert.AreEqual(6.0, result);
  161. }
  162. [TestMethod]
  163. public void StringEncode()
  164. {
  165. /*string str = "Hello, TensorFlow.NET!";
  166. var handle = Marshal.StringToHGlobalAnsi(str);
  167. ulong dst_len = c_api.TF_StringEncodedSize((UIntPtr)str.Length);
  168. Assert.AreEqual(dst_len, (ulong)23);
  169. IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
  170. ulong encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status);
  171. Assert.AreEqual((ulong)23, encoded_len);
  172. Assert.AreEqual(status.Code, TF_Code.TF_OK);
  173. string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
  174. Assert.AreEqual(encoded_str, str);
  175. Assert.AreEqual(str.Length, Marshal.ReadByte(dst));*/
  176. //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
  177. }
  178. /// <summary>
  179. /// tensorflow\c\c_api_test.cc
  180. /// TestEncodeDecode
  181. /// </summary>
  182. [TestMethod]
  183. public void EncodeDecode()
  184. {
  185. }
  186. }
  187. }