You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ConstantTest.cs 5.5 kB

6 years ago
6 years ago
6 years ago
6 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Linq;
  5. using Tensorflow;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.Basics
  8. {
  9. [TestClass]
  10. public class ConstantTest : EagerModeTestBase
  11. {
  12. Status status = new Status();
  13. [TestMethod]
  14. public void ScalarConst()
  15. {
  16. var tensor1 = tf.constant(8); // int
  17. Assert.AreEqual(tensor1.dtype, TF_DataType.TF_INT32);
  18. var tensor2 = tf.constant(6.0f); // float
  19. Assert.AreEqual(tensor2.dtype, TF_DataType.TF_FLOAT);
  20. var tensor3 = tf.constant(6.0); // double
  21. Assert.AreEqual(tensor3.dtype, TF_DataType.TF_DOUBLE);
  22. }
  23. /*[DataTestMethod]
  24. [DataRow(int.MinValue)]
  25. [DataRow(-1)]
  26. [DataRow(0)]
  27. [DataRow(1)]
  28. [DataRow(int.MaxValue)]
  29. public void ScalarConstTypecast_int(int value)
  30. {
  31. var tensor = (Tensor)value;
  32. with(tf.Session(), sess =>
  33. {
  34. var result = sess.run(tensor);
  35. Assert.AreEqual(result.Data<int>()[0], value);
  36. });
  37. }
  38. [DataTestMethod]
  39. [DataRow(double.NegativeInfinity)]
  40. [DataRow(double.MinValue)]
  41. [DataRow(-1d)]
  42. [DataRow(0d)]
  43. [DataRow(double.Epsilon)]
  44. [DataRow(1d)]
  45. [DataRow(double.MaxValue)]
  46. [DataRow(double.PositiveInfinity)]
  47. [DataRow(double.NaN)]
  48. public void ScalarConstTypecast_double(double value)
  49. {
  50. var tensor = (Tensor)value;
  51. with(tf.Session(), sess =>
  52. {
  53. var result = sess.run(tensor);
  54. Assert.AreEqual(result.Data<double>()[0], value);
  55. });
  56. }
  57. [DataTestMethod]
  58. [DataRow(float.NegativeInfinity)]
  59. [DataRow(float.MinValue)]
  60. [DataRow(-1f)]
  61. [DataRow(0f)]
  62. [DataRow(float.Epsilon)]
  63. [DataRow(1f)]
  64. [DataRow(float.MaxValue)]
  65. [DataRow(float.PositiveInfinity)]
  66. [DataRow(float.NaN)]
  67. public void ScalarConstTypecast_float(float value)
  68. {
  69. var tensor = (Tensor)value;
  70. with(tf.Session(), sess =>
  71. {
  72. var result = sess.run(tensor);
  73. Assert.AreEqual(result.Data<double>()[0], value);
  74. });
  75. }
  76. [TestMethod]
  77. public void StringConst()
  78. {
  79. string str = "Hello, TensorFlow.NET!";
  80. var tensor = tf.constant(str);
  81. with(tf.Session(), sess =>
  82. {
  83. var result = sess.run(tensor);
  84. Assert.IsTrue(result.Data<string>()[0] == str);
  85. });
  86. }*/
  87. [TestMethod]
  88. public void ZerosConst()
  89. {
  90. // small size
  91. var tensor = tf.zeros((3, 2), tf.int32, "small");
  92. Assert.AreEqual(tensor.shape[0], 3);
  93. Assert.AreEqual(tensor.shape[1], 2);
  94. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 0, 0, 0, 0, 0, 0 }, tensor.numpy().ToArray<int>()));
  95. // big size
  96. tensor = tf.zeros((200, 100), tf.int32, "big");
  97. Assert.AreEqual(tensor.shape[0], 200);
  98. Assert.AreEqual(tensor.shape[1], 100);
  99. var data = tensor.numpy().ToArray<int>();
  100. Assert.AreEqual(0, data[0]);
  101. Assert.AreEqual(0, data[500]);
  102. Assert.AreEqual(0, data[data.Length - 1]);
  103. }
  104. [TestMethod]
  105. public void OnesConst()
  106. {
  107. var ones = tf.ones(new Shape(3, 2), tf.float32, "ones");
  108. Assert.AreEqual(ones.dtype, tf.float32);
  109. Assert.AreEqual(ones.shape[0], 3);
  110. Assert.AreEqual(ones.shape[1], 2);
  111. Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
  112. }
  113. [TestMethod]
  114. public void OnesToHalves()
  115. {
  116. var ones = tf.ones(new Shape(3, 2), tf.float64, "ones");
  117. var halfes = ones * 0.5;
  118. Assert.AreEqual(halfes.shape[0], 3);
  119. Assert.AreEqual(halfes.shape[1], 2);
  120. Assert.IsTrue(new[] { .5, .5, .5, .5, .5, .5 }.SequenceEqual(halfes.numpy().ToArray<double>()));
  121. }
  122. [TestMethod]
  123. public void NDimConst()
  124. {
  125. var nd = np.array(new int[,]
  126. {
  127. { 3, 1, 1 },
  128. { 2, 1, 3 }
  129. });
  130. var tensor = tf.constant(nd);
  131. var data = tensor.numpy().ToArray<int>();
  132. Assert.IsTrue(Enumerable.SequenceEqual(new long[] { 2, 3 }, tensor.shape.dims));
  133. Assert.IsTrue(Enumerable.SequenceEqual(new int[] { 3, 1, 1, 2, 1, 3 }, data));
  134. }
  135. [TestMethod]
  136. public void Multiply()
  137. {
  138. var a = tf.constant(3.0);
  139. var b = tf.constant(2.0);
  140. var c = a * b;
  141. Assert.AreEqual(6.0, (double)c);
  142. }
  143. [TestMethod]
  144. public void Reshape()
  145. {
  146. var ones = tf.ones((3, 2), tf.float32, "ones");
  147. var reshaped = tf.reshape(ones, (2, 3));
  148. Assert.AreEqual(reshaped.dtype, tf.float32);
  149. Assert.AreEqual(reshaped.shape[0], 2);
  150. Assert.AreEqual(reshaped.shape[1], 3);
  151. Assert.IsTrue(new float[] { 1, 1, 1, 1, 1, 1 }.SequenceEqual(ones.numpy().ToArray<float>()));
  152. }
  153. }
  154. }