You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 5.9 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class ConstantTest
  12. {
  13. Status status = new Status();
  14. [TestMethod]
  15. public void ScalarConst()
  16. {
  17. var tensor1 = tf.constant(8); // int
  18. var tensor2 = tf.constant(6.0f); // float
  19. var tensor3 = tf.constant(6.0); // double
  20. }
  21. /*[DataTestMethod]
  22. [DataRow(int.MinValue)]
  23. [DataRow(-1)]
  24. [DataRow(0)]
  25. [DataRow(1)]
  26. [DataRow(int.MaxValue)]
  27. public void ScalarConstTypecast_int(int value)
  28. {
  29. var tensor = (Tensor)value;
  30. with(tf.Session(), sess =>
  31. {
  32. var result = sess.run(tensor);
  33. Assert.AreEqual(result.Data<int>()[0], value);
  34. });
  35. }
  36. [DataTestMethod]
  37. [DataRow(double.NegativeInfinity)]
  38. [DataRow(double.MinValue)]
  39. [DataRow(-1d)]
  40. [DataRow(0d)]
  41. [DataRow(double.Epsilon)]
  42. [DataRow(1d)]
  43. [DataRow(double.MaxValue)]
  44. [DataRow(double.PositiveInfinity)]
  45. [DataRow(double.NaN)]
  46. public void ScalarConstTypecast_double(double value)
  47. {
  48. var tensor = (Tensor)value;
  49. with(tf.Session(), sess =>
  50. {
  51. var result = sess.run(tensor);
  52. Assert.AreEqual(result.Data<double>()[0], value);
  53. });
  54. }
  55. [DataTestMethod]
  56. [DataRow(float.NegativeInfinity)]
  57. [DataRow(float.MinValue)]
  58. [DataRow(-1f)]
  59. [DataRow(0f)]
  60. [DataRow(float.Epsilon)]
  61. [DataRow(1f)]
  62. [DataRow(float.MaxValue)]
  63. [DataRow(float.PositiveInfinity)]
  64. [DataRow(float.NaN)]
  65. public void ScalarConstTypecast_float(float value)
  66. {
  67. var tensor = (Tensor)value;
  68. with(tf.Session(), sess =>
  69. {
  70. var result = sess.run(tensor);
  71. Assert.AreEqual(result.Data<double>()[0], value);
  72. });
  73. }
  74. [TestMethod]
  75. public void StringConst()
  76. {
  77. string str = "Hello, TensorFlow.NET!";
  78. var tensor = tf.constant(str);
  79. with(tf.Session(), sess =>
  80. {
  81. var result = sess.run(tensor);
  82. Assert.IsTrue(result.Data<string>()[0] == str);
  83. });
  84. }*/
  85. [TestMethod]
  86. public void ZerosConst()
  87. {
  88. // small size
  89. var tensor = tf.zeros(new Shape(3, 2), tf.int32, "small");
  90. Assert.AreEqual(tensor.shape[0], 3);
  91. Assert.AreEqual(tensor.shape[1], 2);
  92. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>()));
  93. // big size
  94. tensor = tf.zeros(new Shape(200, 100), tf.int32, "big");
  95. Assert.AreEqual(tensor.shape[0], 200);
  96. Assert.AreEqual(tensor.shape[1], 100);
  97. var data = tensor.numpy().ToArray<int>();
  98. Assert.AreEqual(0, data[0]);
  99. Assert.AreEqual(0, data[500]);
  100. Assert.AreEqual(0, data[data.Length - 1]);
  101. }
  102. [TestMethod]
  103. public void OnesConst()
  104. {
  105. var ones = tf.ones(new Shape(3, 2), tf.float32, "ones");
  106. Assert.AreEqual(ones.dtype, tf.float32);
  107. Assert.AreEqual(ones.shape[0], 3);
  108. Assert.AreEqual(ones.shape[1], 2);
  109. Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
  110. }
  111. [TestMethod]
  112. public void OnesToHalves()
  113. {
  114. var ones = tf.ones(new Shape(3, 2), tf.float64, "ones");
  115. var halfes = ones * 0.5;
  116. Assert.AreEqual(halfes.shape[0], 3);
  117. Assert.AreEqual(halfes.shape[1], 2);
  118. Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(halfes.numpy().ToArray<double>()));
  119. }
  120. [TestMethod]
  121. public void NDimConst()
  122. {
  123. var nd = np.array(new int[][]
  124. {
  125. new int[]{ 3, 1, 1 },
  126. new int[]{ 2, 1, 3 }
  127. });
  128. var tensor = tf.constant(nd);
  129. var data = tensor.numpy().ToArray<int>();
  130. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape));
  131. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
  132. }
  133. [TestMethod]
  134. public void Multiply()
  135. {
  136. var a = tf.constant(3.0);
  137. var b = tf.constant(2.0);
  138. var c = a * b;
  139. Assert.AreEqual(6.0, (double)c);
  140. }
  141. [TestMethod]
  142. public void StringEncode()
  143. {
  144. string str = "Hello, TensorFlow.NET!";
  145. var handle = Marshal.StringToHGlobalAnsi(str);
  146. ulong dst_len = (ulong)c_api.TF_StringEncodedSize((UIntPtr)str.Length);
  147. Assert.AreEqual(dst_len, (ulong)23);
  148. IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
  149. ulong encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status);
  150. Assert.AreEqual((ulong)23, encoded_len);
  151. Assert.AreEqual(status.Code, TF_Code.TF_OK);
  152. string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
  153. Assert.AreEqual(encoded_str, str);
  154. Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
  155. //c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status);
  156. }
  157. /// <summary>
  158. /// tensorflow\c\c_api_test.cc
  159. /// TestEncodeDecode
  160. /// </summary>
  161. [TestMethod]
  162. public void EncodeDecode()
  163. {
  164. }
  165. }
  166. }