You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionTest.cs 4.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using Tensorflow.NumPy;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Binding;
  8. namespace TensorFlowNET.UnitTest.Basics
  9. {
  10. [TestClass]
  11. public class SessionTest : GraphModeTestBase
  12. {
  13. [TestMethod]
  14. public void EvalTensor()
  15. {
  16. lock (this)
  17. {
  18. var a = constant_op.constant(np.array(3.0).reshape((1, 1)));
  19. var b = constant_op.constant(np.array(2.0).reshape((1, 1)));
  20. var c = math_ops.matmul(a, b, name: "matmul");
  21. var sess = tf.Session();
  22. var result = c.eval(sess);
  23. Assert.AreEqual(result[0], 6.0);
  24. }
  25. }
  26. [TestMethod]
  27. public void Eval_SmallString_Scalar()
  28. {
  29. var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING);
  30. var c = tf.strings.substr(a, 4, 8);
  31. var sess = tf.Session();
  32. var result = c.eval(sess).StringData();
  33. Assert.AreEqual(result[0], "heythere");
  34. }
  35. [TestMethod]
  36. public void Eval_LargeString_Scalar()
  37. {
  38. lock (this)
  39. {
  40. const int size = 30_000;
  41. var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING);
  42. var c = tf.strings.substr(a, 0, size - 5000);
  43. var sess = tf.Session();
  44. var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
  45. Console.WriteLine(result);
  46. }
  47. }
  48. [TestMethod]
  49. public void Autocast_Case0()
  50. {
  51. var sess = tf.Session().as_default();
  52. ITensorOrOperation operation = tf.global_variables_initializer();
  53. // the cast to ITensorOrOperation is essential for the test of this method signature
  54. var ret = sess.run(operation);
  55. }
  56. [TestMethod]
  57. public void Autocast_Case1()
  58. {
  59. var sess = tf.Session().as_default();
  60. var input = tf.placeholder(tf.int32, shape: new Shape(6));
  61. var op = tf.reshape(input, new int[] { 2, 3 });
  62. sess.run(tf.global_variables_initializer());
  63. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6)));
  64. Assert.AreEqual(ret.shape, (2, 3));
  65. assertAllEqual(ret.ToArray<int>(), new[] { 1, 2, 3, 4, 5, 6 });
  66. print(ret.dtype);
  67. print(ret);
  68. }
  69. [TestMethod]
  70. public void Autocast_Case2()
  71. {
  72. var sess = tf.Session().as_default();
  73. var input = tf.placeholder(tf.float32, shape: new Shape(6));
  74. var op = tf.reshape(input, new int[] { 2, 3 });
  75. sess.run(tf.global_variables_initializer());
  76. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  77. }
  78. [TestMethod, Ignore]
  79. public void Autocast_Case3()
  80. {
  81. var sess = tf.Session().as_default();
  82. var input = tf.placeholder(tf.float32, shape: new Shape(6));
  83. var op = tf.reshape(input, new int[] { 2, 3 });
  84. sess.run(tf.global_variables_initializer());
  85. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  86. Assert.AreEqual(ret.shape, (2, 3));
  87. Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 });
  88. print(ret.dtype);
  89. print(ret);
  90. }
  91. [TestMethod, Ignore]
  92. public void Autocast_Case4()
  93. {
  94. var sess = tf.Session().as_default();
  95. var input = tf.placeholder(tf.byte8, shape: new Shape(6));
  96. var op = tf.reshape(input, new int[] { 2, 3 });
  97. sess.run(tf.global_variables_initializer());
  98. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  99. Assert.AreEqual(ret.shape, (2, 3));
  100. Assert.AreEqual(ret, new[] { 1, 2, 3, 4, 5, 6 });
  101. print(ret.dtype);
  102. print(ret);
  103. }
  104. }
  105. }