You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionTest.cs 4.7 kB

4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. using FluentAssertions;
  2. using Microsoft.VisualStudio.TestTools.UnitTesting;
  3. using Tensorflow.NumPy;
  4. using System;
  5. using System.Collections.Generic;
  6. using System.Text;
  7. using Tensorflow;
  8. using Tensorflow.Util;
  9. using static Tensorflow.Binding;
  10. namespace TensorFlowNET.UnitTest
  11. {
  12. [TestClass, Ignore]
  13. public class SessionTest
  14. {
  15. [TestMethod]
  16. public void EvalTensor()
  17. {
  18. lock (this)
  19. {
  20. var a = constant_op.constant(np.array(3.0).reshape((1, 1)));
  21. var b = constant_op.constant(np.array(2.0).reshape((1, 1)));
  22. var c = math_ops.matmul(a, b, name: "matmul");
  23. using (var sess = tf.Session())
  24. {
  25. var result = c.eval(sess);
  26. Assert.AreEqual(result[0], 6.0);
  27. }
  28. }
  29. }
  30. [TestMethod]
  31. public void Eval_SmallString_Scalar()
  32. {
  33. lock (this)
  34. {
  35. var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING);
  36. var c = tf.strings.substr(a, 4, 8);
  37. using (var sess = tf.Session())
  38. {
  39. var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
  40. Console.WriteLine(result);
  41. result.Should().Be("heythere");
  42. }
  43. }
  44. }
  45. [TestMethod]
  46. public void Eval_LargeString_Scalar()
  47. {
  48. lock (this)
  49. {
  50. const int size = 30_000;
  51. var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING);
  52. var c = tf.strings.substr(a, 0, size - 5000);
  53. using (var sess = tf.Session())
  54. {
  55. var result = UTF8Encoding.UTF8.GetString(c.eval(sess).ToByteArray());
  56. Console.WriteLine(result);
  57. result.Should().HaveLength(size - 5000).And.ContainAll("a");
  58. }
  59. }
  60. }
  61. [TestMethod]
  62. public void Autocast_Case0()
  63. {
  64. var sess = tf.Session().as_default();
  65. ITensorOrOperation operation = tf.global_variables_initializer();
  66. // the cast to ITensorOrOperation is essential for the test of this method signature
  67. var ret = sess.run(operation);
  68. ret.Should().BeNull();
  69. }
  70. [TestMethod]
  71. public void Autocast_Case1()
  72. {
  73. var sess = tf.Session().as_default();
  74. var input = tf.placeholder(tf.float32, shape: new Shape(6));
  75. var op = tf.reshape(input, new int[] { 2, 3 });
  76. sess.run(tf.global_variables_initializer());
  77. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6)));
  78. ret.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
  79. print(ret.dtype);
  80. print(ret);
  81. }
  82. [TestMethod]
  83. public void Autocast_Case2()
  84. {
  85. var sess = tf.Session().as_default();
  86. var input = tf.placeholder(tf.float64, shape: new Shape(6));
  87. var op = tf.reshape(input, new int[] { 2, 3 });
  88. sess.run(tf.global_variables_initializer());
  89. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  90. ret.Should().BeShaped(2, 3).And.BeOfValuesApproximately(0.001d, 1.1, 2.1, 3.1, 4.1, 5.1, 6.1);
  91. print(ret.dtype);
  92. print(ret);
  93. }
  94. [TestMethod]
  95. public void Autocast_Case3()
  96. {
  97. var sess = tf.Session().as_default();
  98. var input = tf.placeholder(tf.int64, shape: new Shape(6));
  99. var op = tf.reshape(input, new int[] { 2, 3 });
  100. sess.run(tf.global_variables_initializer());
  101. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  102. ret.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
  103. print(ret.dtype);
  104. print(ret);
  105. }
  106. [TestMethod]
  107. public void Autocast_Case4()
  108. {
  109. var sess = tf.Session().as_default();
  110. var input = tf.placeholder(tf.byte8, shape: new Shape(6));
  111. var op = tf.reshape(input, new int[] { 2, 3 });
  112. sess.run(tf.global_variables_initializer());
  113. var ret = sess.run(op, feed_dict: (input, np.array(1, 2, 3, 4, 5, 6).astype(np.float32) + 0.1f));
  114. ret.Should().BeShaped(2, 3).And.BeOfValues(1, 2, 3, 4, 5, 6);
  115. print(ret.dtype);
  116. print(ret);
  117. }
  118. }
  119. }