You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 6.4 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.Basics
  9. {
  10. [TestClass]
  11. public class ConstantTest
  12. {
  13. Status status = new Status();
  14. [TestMethod]
  15. public void ScalarConst()
  16. {
  17. var tensor1 = tf.constant(8); // int
  18. Assert.AreEqual(tensor1.dtype, TF_DataType.TF_INT32);
  19. var tensor2 = tf.constant(6.0f); // float
  20. Assert.AreEqual(tensor2.dtype, TF_DataType.TF_FLOAT);
  21. var tensor3 = tf.constant(6.0); // double
  22. Assert.AreEqual(tensor3.dtype, TF_DataType.TF_DOUBLE);
  23. }
  24. /*[DataTestMethod]
  25. [DataRow(int.MinValue)]
  26. [DataRow(-1)]
  27. [DataRow(0)]
  28. [DataRow(1)]
  29. [DataRow(int.MaxValue)]
  30. public void ScalarConstTypecast_int(int value)
  31. {
  32. var tensor = (Tensor)value;
  33. with(tf.Session(), sess =>
  34. {
  35. var result = sess.run(tensor);
  36. Assert.AreEqual(result.Data<int>()[0], value);
  37. });
  38. }
  39. [DataTestMethod]
  40. [DataRow(double.NegativeInfinity)]
  41. [DataRow(double.MinValue)]
  42. [DataRow(-1d)]
  43. [DataRow(0d)]
  44. [DataRow(double.Epsilon)]
  45. [DataRow(1d)]
  46. [DataRow(double.MaxValue)]
  47. [DataRow(double.PositiveInfinity)]
  48. [DataRow(double.NaN)]
  49. public void ScalarConstTypecast_double(double value)
  50. {
  51. var tensor = (Tensor)value;
  52. with(tf.Session(), sess =>
  53. {
  54. var result = sess.run(tensor);
  55. Assert.AreEqual(result.Data<double>()[0], value);
  56. });
  57. }
  58. [DataTestMethod]
  59. [DataRow(float.NegativeInfinity)]
  60. [DataRow(float.MinValue)]
  61. [DataRow(-1f)]
  62. [DataRow(0f)]
  63. [DataRow(float.Epsilon)]
  64. [DataRow(1f)]
  65. [DataRow(float.MaxValue)]
  66. [DataRow(float.PositiveInfinity)]
  67. [DataRow(float.NaN)]
  68. public void ScalarConstTypecast_float(float value)
  69. {
  70. var tensor = (Tensor)value;
  71. with(tf.Session(), sess =>
  72. {
  73. var result = sess.run(tensor);
  74. Assert.AreEqual(result.Data<double>()[0], value);
  75. });
  76. }
  77. [TestMethod]
  78. public void StringConst()
  79. {
  80. string str = "Hello, TensorFlow.NET!";
  81. var tensor = tf.constant(str);
  82. with(tf.Session(), sess =>
  83. {
  84. var result = sess.run(tensor);
  85. Assert.IsTrue(result.Data<string>()[0] == str);
  86. });
  87. }*/
  88. [TestMethod]
  89. public void ZerosConst()
  90. {
  91. // small size
  92. var tensor = tf.zeros((3, 2), tf.int32, "small");
  93. Assert.AreEqual(tensor.shape[0], 3);
  94. Assert.AreEqual(tensor.shape[1], 2);
  95. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>()));
  96. // big size
  97. tensor = tf.zeros((200, 100), tf.int32, "big");
  98. Assert.AreEqual(tensor.shape[0], 200);
  99. Assert.AreEqual(tensor.shape[1], 100);
  100. var data = tensor.numpy().ToArray<int>();
  101. Assert.AreEqual(0, data[0]);
  102. Assert.AreEqual(0, data[500]);
  103. Assert.AreEqual(0, data[data.Length - 1]);
  104. }
  105. [TestMethod]
  106. public void OnesConst()
  107. {
  108. var ones = tf.ones(new Shape(3, 2), tf.float32, "ones");
  109. Assert.AreEqual(ones.dtype, tf.float32);
  110. Assert.AreEqual(ones.shape[0], 3);
  111. Assert.AreEqual(ones.shape[1], 2);
  112. Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
  113. }
  114. [TestMethod]
  115. public void OnesToHalves()
  116. {
  117. var ones = tf.ones(new Shape(3, 2), tf.float64, "ones");
  118. var halfes = ones * 0.5;
  119. Assert.AreEqual(halfes.shape[0], 3);
  120. Assert.AreEqual(halfes.shape[1], 2);
  121. Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(halfes.numpy().ToArray<double>()));
  122. }
  123. [TestMethod]
  124. public void NDimConst()
  125. {
  126. var nd = np.array(new int[][]
  127. {
  128. new int[]{ 3, 1, 1 },
  129. new int[]{ 2, 1, 3 }
  130. });
  131. var tensor = tf.constant(nd);
  132. var data = tensor.numpy().ToArray<int>();
  133. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape));
  134. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
  135. }
  136. [TestMethod]
  137. public void Multiply()
  138. {
  139. var a = tf.constant(3.0);
  140. var b = tf.constant(2.0);
  141. var c = a * b;
  142. Assert.AreEqual(6.0, (double)c);
  143. }
  144. [TestMethod]
  145. public void StringEncode()
  146. {
  147. string str = "Hello, TensorFlow.NET!";
  148. var handle = Marshal.StringToHGlobalAnsi(str);
  149. var dst_len = c_api.TF_StringEncodedSize((ulong)str.Length);
  150. Assert.AreEqual(dst_len, (ulong)23);
  151. IntPtr dst = Marshal.AllocHGlobal((int)dst_len);
  152. var encoded_len = c_api.TF_StringEncode(handle, (ulong)str.Length, dst, dst_len, status.Handle);
  153. Assert.AreEqual((ulong)23, encoded_len);
  154. Assert.AreEqual(status.Code, TF_Code.TF_OK);
  155. string encoded_str = Marshal.PtrToStringUTF8(dst + sizeof(byte));
  156. Assert.AreEqual(encoded_str, str);
  157. Assert.AreEqual(str.Length, Marshal.ReadByte(dst));
  158. // c_api.TF_StringDecode(dst, (ulong)str.Length, IntPtr.Zero, ref dst_len, status.Handle);
  159. }
  160. [TestMethod]
  161. public void Reshape()
  162. {
  163. var ones = tf.ones((3, 2), tf.float32, "ones");
  164. var reshaped = tf.reshape(ones, (2, 3));
  165. Assert.AreEqual(reshaped.dtype, tf.float32);
  166. Assert.AreEqual(reshaped.shape[0], 2);
  167. Assert.AreEqual(reshaped.shape[1], 3);
  168. Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
  169. }
  170. }
  171. }