You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionTest.cs 4.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.Basics
  9. {
  10. [TestClass]
  11. public class SessionTest : GraphModeTestBase
  12. {
  13. [TestMethod]
  14. public void EvalTensor()
  15. {
  16. lock (this)
  17. {
  18. var a = constant_op.constant(np.array(3.0).reshape((1, 1)));
  19. var b = constant_op.constant(np.array(2.0).reshape((1, 1)));
  20. var c = math_ops.matmul(a, b, name: "matmul");
  21. using (var sess = tf.Session())
  22. {
  23. var result = c.eval(sess);
  24. Assert.AreEqual(result[0], 6.0);
  25. }
  26. }
  27. }
  28. [TestMethod]
  29. public void Eval_SmallString_Scalar()
  30. {
  31. var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING);
  32. var c = tf.strings.substr(a, 4, 8);
  33. using (var sess = tf.Session())
  34. {
  35. var result = c.eval(sess).StringData();
  36. Assert.AreEqual(result[0], "heythere");
  37. }
  38. }
  39. [TestMethod]
  40. public void Eval_LargeString_Scalar()
  41. {
  42. lock (this)
  43. {
  44. const int size = 30_000;
  45. var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING);
  46. var c = tf.strings.substr(a, 0, size - 5000);
  47. using (var sess = tf.Session())
  48. {
  49. var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
  50. Console.WriteLine(result);
  51. }
  52. }
  53. }
  54. [TestMethod]
  55. public void Autocast_Case0()
  56. {
  57. var sess = tf.Session().as_default();
  58. ITensorOrOperation operation = tf.global_variables_initializer();
  59. // the cast to ITensorOrOperation is essential for the test of this method signature
  60. var ret = sess.run(operation);
  61. }
  62. [TestMethod]
  63. public void Autocast_Case1()
  64. {
  65. var sess = tf.Session().as_default();
  66. var input = tf.placeholder(tf.int32, shape: new Shape(6));
  67. var op = tf.reshape(input, new int[] { 2, 3 });
  68. sess.run(tf.global_variables_initializer());
  69. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6)));
  70. Assert.AreEqual(ret.shape, (2, 3));
  71. assertAllEqual(ret.ToArray<int>(), new[] { 1, 2, 3, 4, 5, 6 });
  72. print(ret.dtype);
  73. print(ret);
  74. }
  75. [TestMethod]
  76. public void Autocast_Case2()
  77. {
  78. var sess = tf.Session().as_default();
  79. var input = tf.placeholder(tf.float32, shape: new Shape(6));
  80. var op = tf.reshape(input, new int[] { 2, 3 });
  81. sess.run(tf.global_variables_initializer());
  82. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  83. }
  84. [TestMethod, Ignore]
  85. public void Autocast_Case3()
  86. {
  87. var sess = tf.Session().as_default();
  88. var input = tf.placeholder(tf.float32, shape: new Shape(6));
  89. var op = tf.reshape(input, new int[] { 2, 3 });
  90. sess.run(tf.global_variables_initializer());
  91. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  92. Assert.AreEqual(ret.shape, (2, 3));
  93. Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 });
  94. print(ret.dtype);
  95. print(ret);
  96. }
  97. [TestMethod, Ignore]
  98. public void Autocast_Case4()
  99. {
  100. var sess = tf.Session().as_default();
  101. var input = tf.placeholder(tf.byte8, shape: new Shape(6));
  102. var op = tf.reshape(input, new int[] { 2, 3 });
  103. sess.run(tf.global_variables_initializer());
  104. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  105. Assert.AreEqual(ret.shape, (2, 3));
  106. Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 });
  107. print(ret.dtype);
  108. print(ret);
  109. }
  110. }
  111. }