You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 5.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using System.Text;
  7. using Tensorflow;
  8. using static Tensorflow.Binding;
  9. namespace TensorFlowNET.UnitTest.Basics
  10. {
  11. [TestClass]
  12. public class ConstantTest
  13. {
  14. Status status = new Status();
  15. [TestMethod]
  16. public void ScalarConst()
  17. {
  18. var tensor1 = tf.constant(8); // int
  19. Assert.AreEqual(tensor1.dtype, TF_DataType.TF_INT32);
  20. var tensor2 = tf.constant(6.0f); // float
  21. Assert.AreEqual(tensor2.dtype, TF_DataType.TF_FLOAT);
  22. var tensor3 = tf.constant(6.0); // double
  23. Assert.AreEqual(tensor3.dtype, TF_DataType.TF_DOUBLE);
  24. }
  25. /*[DataTestMethod]
  26. [DataRow(int.MinValue)]
  27. [DataRow(-1)]
  28. [DataRow(0)]
  29. [DataRow(1)]
  30. [DataRow(int.MaxValue)]
  31. public void ScalarConstTypecast_int(int value)
  32. {
  33. var tensor = (Tensor)value;
  34. with(tf.Session(), sess =>
  35. {
  36. var result = sess.run(tensor);
  37. Assert.AreEqual(result.Data<int>()[0], value);
  38. });
  39. }
  40. [DataTestMethod]
  41. [DataRow(double.NegativeInfinity)]
  42. [DataRow(double.MinValue)]
  43. [DataRow(-1d)]
  44. [DataRow(0d)]
  45. [DataRow(double.Epsilon)]
  46. [DataRow(1d)]
  47. [DataRow(double.MaxValue)]
  48. [DataRow(double.PositiveInfinity)]
  49. [DataRow(double.NaN)]
  50. public void ScalarConstTypecast_double(double value)
  51. {
  52. var tensor = (Tensor)value;
  53. with(tf.Session(), sess =>
  54. {
  55. var result = sess.run(tensor);
  56. Assert.AreEqual(result.Data<double>()[0], value);
  57. });
  58. }
  59. [DataTestMethod]
  60. [DataRow(float.NegativeInfinity)]
  61. [DataRow(float.MinValue)]
  62. [DataRow(-1f)]
  63. [DataRow(0f)]
  64. [DataRow(float.Epsilon)]
  65. [DataRow(1f)]
  66. [DataRow(float.MaxValue)]
  67. [DataRow(float.PositiveInfinity)]
  68. [DataRow(float.NaN)]
  69. public void ScalarConstTypecast_float(float value)
  70. {
  71. var tensor = (Tensor)value;
  72. with(tf.Session(), sess =>
  73. {
  74. var result = sess.run(tensor);
  75. Assert.AreEqual(result.Data<double>()[0], value);
  76. });
  77. }
  78. [TestMethod]
  79. public void StringConst()
  80. {
  81. string str = "Hello, TensorFlow.NET!";
  82. var tensor = tf.constant(str);
  83. with(tf.Session(), sess =>
  84. {
  85. var result = sess.run(tensor);
  86. Assert.IsTrue(result.Data<string>()[0] == str);
  87. });
  88. }*/
  89. [TestMethod]
  90. public void ZerosConst()
  91. {
  92. // small size
  93. var tensor = tf.zeros((3, 2), tf.int32, "small");
  94. Assert.AreEqual(tensor.shape[0], 3);
  95. Assert.AreEqual(tensor.shape[1], 2);
  96. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>()));
  97. // big size
  98. tensor = tf.zeros((200, 100), tf.int32, "big");
  99. Assert.AreEqual(tensor.shape[0], 200);
  100. Assert.AreEqual(tensor.shape[1], 100);
  101. var data = tensor.numpy().ToArray<int>();
  102. Assert.AreEqual(0, data[0]);
  103. Assert.AreEqual(0, data[500]);
  104. Assert.AreEqual(0, data[data.Length - 1]);
  105. }
  106. [TestMethod]
  107. public void OnesConst()
  108. {
  109. var ones = tf.ones(new Shape(3, 2), tf.float32, "ones");
  110. Assert.AreEqual(ones.dtype, tf.float32);
  111. Assert.AreEqual(ones.shape[0], 3);
  112. Assert.AreEqual(ones.shape[1], 2);
  113. Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
  114. }
  115. [TestMethod]
  116. public void OnesToHalves()
  117. {
  118. var ones = tf.ones(new Shape(3, 2), tf.float64, "ones");
  119. var halfes = ones * 0.5;
  120. Assert.AreEqual(halfes.shape[0], 3);
  121. Assert.AreEqual(halfes.shape[1], 2);
  122. Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(halfes.numpy().ToArray<double>()));
  123. }
  124. [TestMethod]
  125. public void NDimConst()
  126. {
  127. var nd = np.array(new int[][]
  128. {
  129. new int[]{ 3, 1, 1 },
  130. new int[]{ 2, 1, 3 }
  131. });
  132. var tensor = tf.constant(nd);
  133. var data = tensor.numpy().ToArray<int>();
  134. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 2, 3 }, tensor.shape));
  135. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
  136. }
  137. [TestMethod]
  138. public void Multiply()
  139. {
  140. var a = tf.constant(3.0);
  141. var b = tf.constant(2.0);
  142. var c = a * b;
  143. Assert.AreEqual(6.0, (double)c);
  144. }
  145. [TestMethod]
  146. public void Reshape()
  147. {
  148. var ones = tf.ones((3, 2), tf.float32, "ones");
  149. var reshaped = tf.reshape(ones, (2, 3));
  150. Assert.AreEqual(reshaped.dtype, tf.float32);
  151. Assert.AreEqual(reshaped.shape[0], 2);
  152. Assert.AreEqual(reshaped.shape[1], 3);
  153. Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
  154. }
  155. }
  156. }