You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 6.7 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using Tensorflow;
  7. using static Tensorflow.Python;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class ConstantTest
  12. {
  13. Status status = new Status();
  14. [TestMethod]
  15. public void ScalarConst()
  16. {
  17. var tensor1 = tf.constant(8); // int
  18. var tensor2 = tf.constant(6.0f); // float
  19. var tensor3 = tf.constant(6.0); // double
  20. }
  21. /*[DataTestMethod]
  22. [DataRow(int.MinValue)]
  23. [DataRow(-1)]
  24. [DataRow(0)]
  25. [DataRow(1)]
  26. [DataRow(int.MaxValue)]
  27. public void ScalarConstTypecast_int(int value)
  28. {
  29. var tensor = (Tensor)value;
  30. with(tf.Session(), sess =>
  31. {
  32. var result = sess.run(tensor);
  33. Assert.AreEqual(result.Data<int>()[0], value);
  34. });
  35. }
  36. [DataTestMethod]
  37. [DataRow(double.NegativeInfinity)]
  38. [DataRow(double.MinValue)]
  39. [DataRow(-1d)]
  40. [DataRow(0d)]
  41. [DataRow(double.Epsilon)]
  42. [DataRow(1d)]
  43. [DataRow(double.MaxValue)]
  44. [DataRow(double.PositiveInfinity)]
  45. [DataRow(double.NaN)]
  46. public void ScalarConstTypecast_double(double value)
  47. {
  48. var tensor = (Tensor)value;
  49. with(tf.Session(), sess =>
  50. {
  51. var result = sess.run(tensor);
  52. Assert.AreEqual(result.Data<double>()[0], value);
  53. });
  54. }
  55. [DataTestMethod]
  56. [DataRow(float.NegativeInfinity)]
  57. [DataRow(float.MinValue)]
  58. [DataRow(-1f)]
  59. [DataRow(0f)]
  60. [DataRow(float.Epsilon)]
  61. [DataRow(1f)]
  62. [DataRow(float.MaxValue)]
  63. [DataRow(float.PositiveInfinity)]
  64. [DataRow(float.NaN)]
  65. public void ScalarConstTypecast_float(float value)
  66. {
  67. var tensor = (Tensor)value;
  68. with(tf.Session(), sess =>
  69. {
  70. var result = sess.run(tensor);
  71. Assert.AreEqual(result.Data<double>()[0], value);
  72. });
  73. }
  74. [TestMethod]
  75. public void StringConst()
  76. {
  77. string str = "Hello, TensorFlow.NET!";
  78. var tensor = tf.constant(str);
  79. with(tf.Session(), sess =>
  80. {
  81. var result = sess.run(tensor);
  82. Assert.IsTrue(result.Data<string>()[0] == str);
  83. });
  84. }*/
  85. [TestMethod]
  86. public void ZerosConst()
  87. {
  88. // small size
  89. var tensor = tf.zeros(new Shape(3, 2), TF_DataType.TF_INT32, "small");
  90. using (var sess = tf.Session())
  91. {
  92. var result = sess.run(tensor);
  93. Assert.AreEqual(result.shape[0], 3);
  94. Assert.AreEqual(result.shape[1], 2);
  95. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, result.Data<int>()));
  96. }
  97. // big size
  98. tensor = tf.zeros(new Shape(200, 100), TF_DataType.TF_INT32, "big");
  99. using (var sess = tf.Session())
  100. {
  101. var result = sess.run(tensor);
  102. Assert.AreEqual(result.shape[0], 200);
  103. Assert.AreEqual(result.shape[1], 100);
  104. var data = result.Data<int>();
  105. Assert.AreEqual(0, data[0]);
  106. Assert.AreEqual(0, data[500]);
  107. Assert.AreEqual(0, data[result.size - 1]);
  108. }
  109. }
  110. [TestMethod]
  111. public void OnesConst()
  112. {
  113. var ones = tf.ones(new Shape(3, 2), TF_DataType.TF_DOUBLE, "ones");
  114. using (var sess = tf.Session())
  115. {
  116. var result = sess.run(ones);
  117. Assert.AreEqual(result.shape[0], 3);
  118. Assert.AreEqual(result.shape[1], 2);
  119. Assert.IsTrue(new[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(result.Data<int>()));
  120. }
  121. }
  122. [TestMethod]
  123. public void OnesToHalves()
  124. {
  125. var ones = tf.ones(new Shape(3, 2), TF_DataType.TF_DOUBLE, "ones");
  126. var halfes = ones * 0.5;
  127. using (var sess = tf.Session())
  128. {
  129. var result = sess.run(halfes);
  130. Assert.AreEqual(result.shape[0], 3);
  131. Assert.AreEqual(result.shape[1], 2);
  132. Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(result.Data<double>()));
  133. }
  134. }
  135. [TestMethod]
  136. public void NDimConst()
  137. {
  138. var nd = np.array(new int[][]
  139. {
  140. new int[]{ 3, 1, 1 },
  141. new int[]{ 2, 1, 3 }
  142. });
  143. var tensor = tf.constant(nd);
  144. using (var sess = tf.Session())
  145. {
  146. var result = sess.run(tensor);
  147. var data = result.Data<int>();
  148. Assert.AreEqual(result.shape[0], 2);
  149. Assert.AreEqual(result.shape[1], 3);
  150. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
  151. }
  152. }
  153. [TestMethod]
  154. public void Multiply()
  155. {
  156. var a = tf.constant(3.0);
  157. var b = tf.constant(2.0);
  158. var c = a * b;
  159. var sess = tf.Session();
  160. double result = sess.run(c);
  161. sess.close();
  162. Assert.AreEqual(6.0, result);
  163. }
  164. [TestMethod]
  165. public void StringEncode()
  166. {
  167. string str = "Hello, TensorFlow.NET!";
  168. var handle = Marshal.StringToHGlobalAnsi(str);
  169. ulong dst_len = (ulong)c_api.TF_StringEncodedSize((UIntPtr)str.Length);
  170. Assert.AreEqual(dst_len, (ulong)23);
  171. IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
  172. ulong encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status);
  173. Assert.AreEqual((ulong)23, encoded_len);
  174. Assert.AreEqual(status.Code, TF_Code.TF_OK);
  175. string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
  176. Assert.AreEqual(encoded_str, str);
  177. Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
  178. //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
  179. }
  180. /// <summary>
  181. /// tensorflow\c\c_api_test.cc
  182. /// TestEncodeDecode
  183. /// </summary>
  184. [TestMethod]
  185. public void EncodeDecode()
  186. {
  187. }
  188. }
  189. }