You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionTest.cs 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Reflection;
  6. using System.Runtime.CompilerServices;
  7. using System.Text;
  8. using FluentAssertions;
  9. using Google.Protobuf;
  10. using Tensorflow;
  11. using Tensorflow.Util;
  12. using static Tensorflow.Binding;
  13. namespace TensorFlowNET.UnitTest
  14. {
  15. [TestClass]
  16. public class SessionTest : CApiTest
  17. {
  18. /// <summary>
  19. /// tensorflow\c\c_api_test.cc
  20. /// `TEST(CAPI, Session)`
  21. /// </summary>
  22. [TestMethod, Ignore]
  23. public void Session()
  24. {
  25. lock (Locks.ProcessWide)
  26. {
  27. var s = new Status();
  28. var graph = new Graph().as_default();
  29. // Make a placeholder operation.
  30. var feed = c_test_util.Placeholder(graph, s);
  31. // Make a constant operation with the scalar "2".
  32. var two = c_test_util.ScalarConst(2, graph, s);
  33. // Add operation.
  34. var add = c_test_util.Add(feed, two, graph, s);
  35. var csession = new CSession(graph, s);
  36. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  37. // Run the graph.
  38. var inputs = new Dictionary<Operation, Tensor>();
  39. inputs.Add(feed, new Tensor(3));
  40. csession.SetInputs(inputs);
  41. var outputs = new TF_Output[] {new TF_Output(add, 0)};
  42. csession.SetOutputs(outputs);
  43. csession.Run(s);
  44. Tensor outTensor = csession.output_tensor(0);
  45. EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
  46. EXPECT_EQ(0, outTensor.NDims);
  47. ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize);
  48. var output_contents = outTensor.ToArray<int>();
  49. EXPECT_EQ(3 + 2, output_contents[0]);
  50. // Add another operation to the graph.
  51. var neg = c_test_util.Neg(add, graph, s);
  52. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  53. // Run up to the new operation.
  54. inputs = new Dictionary<Operation, Tensor>();
  55. inputs.Add(feed, new Tensor(7));
  56. csession.SetInputs(inputs);
  57. outputs = new TF_Output[] {new TF_Output(neg, 0)};
  58. csession.SetOutputs(outputs);
  59. csession.Run(s);
  60. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  61. outTensor = csession.output_tensor(0);
  62. ASSERT_TRUE(outTensor != IntPtr.Zero);
  63. EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
  64. EXPECT_EQ(0, outTensor.NDims); // scalar
  65. ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize);
  66. output_contents = outTensor.ToArray<int>();
  67. EXPECT_EQ(-(7 + 2), output_contents[0]);
  68. // Clean up
  69. csession.CloseAndDelete(s);
  70. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  71. }
  72. }
  73. [TestMethod]
  74. public void EvalTensor()
  75. {
  76. lock (this)
  77. {
  78. var a = constant_op.constant(np.array(3.0).reshape(1, 1));
  79. var b = constant_op.constant(np.array(2.0).reshape(1, 1));
  80. var c = math_ops.matmul(a, b, name: "matmul");
  81. using (var sess = tf.Session())
  82. {
  83. var result = c.eval(sess);
  84. Assert.AreEqual(6, result.GetAtIndex<double>(0));
  85. }
  86. }
  87. }
  88. [TestMethod]
  89. public void Eval_SmallString_Scalar()
  90. {
  91. lock (this)
  92. {
  93. var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING);
  94. var c = tf.strings.substr(a, 4, 8);
  95. using (var sess = tf.Session())
  96. {
  97. var result = (string) c.eval(sess);
  98. Console.WriteLine(result);
  99. result.Should().Be("heythere");
  100. }
  101. }
  102. }
  103. [TestMethod]
  104. public void Eval_LargeString_Scalar()
  105. {
  106. lock (this)
  107. {
  108. const int size = 30_000;
  109. var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING);
  110. var c = tf.strings.substr(a, 0, size - 5000);
  111. using (var sess = tf.Session())
  112. {
  113. var result = (string) c.eval(sess);
  114. Console.WriteLine((string) result);
  115. result.Should().HaveLength(size - 5000).And.ContainAll("a");
  116. }
  117. }
  118. }
  119. }
  120. }