You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionTest.cs 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Reflection;
  6. using System.Runtime.CompilerServices;
  7. using System.Text;
  8. using FluentAssertions;
  9. using Google.Protobuf;
  10. using Tensorflow;
  11. using static Tensorflow.Binding;
  12. namespace TensorFlowNET.UnitTest
  13. {
  14. [TestClass]
  15. public class SessionTest : CApiTest
  16. {
  17. /// <summary>
  18. /// tensorflow\c\c_api_test.cc
  19. /// `TEST(CAPI, Session)`
  20. /// </summary>
  21. [TestMethod]
  22. public void Session()
  23. {
  24. lock (this)
  25. {
  26. var s = new Status();
  27. var graph = new Graph();
  28. // Make a placeholder operation.
  29. var feed = c_test_util.Placeholder(graph, s);
  30. // Make a constant operation with the scalar "2".
  31. var two = c_test_util.ScalarConst(2, graph, s);
  32. // Add operation.
  33. var add = c_test_util.Add(feed, two, graph, s);
  34. var csession = new CSession(graph, s);
  35. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  36. // Run the graph.
  37. var inputs = new Dictionary<Operation, Tensor>();
  38. inputs.Add(feed, new Tensor(3));
  39. csession.SetInputs(inputs);
  40. var outputs = new TF_Output[] {new TF_Output(add, 0)};
  41. csession.SetOutputs(outputs);
  42. csession.Run(s);
  43. Tensor outTensor = csession.output_tensor(0);
  44. EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
  45. EXPECT_EQ(0, outTensor.NDims);
  46. ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize);
  47. var output_contents = outTensor.ToArray<int>();
  48. EXPECT_EQ(3 + 2, output_contents[0]);
  49. // Add another operation to the graph.
  50. var neg = c_test_util.Neg(add, graph, s);
  51. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  52. // Run up to the new operation.
  53. inputs = new Dictionary<Operation, Tensor>();
  54. inputs.Add(feed, new Tensor(7));
  55. csession.SetInputs(inputs);
  56. outputs = new TF_Output[] {new TF_Output(neg, 0)};
  57. csession.SetOutputs(outputs);
  58. csession.Run(s);
  59. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  60. outTensor = csession.output_tensor(0);
  61. ASSERT_TRUE(outTensor != IntPtr.Zero);
  62. EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
  63. EXPECT_EQ(0, outTensor.NDims); // scalar
  64. ASSERT_EQ((ulong) sizeof(uint), outTensor.bytesize);
  65. output_contents = outTensor.ToArray<int>();
  66. EXPECT_EQ(-(7 + 2), output_contents[0]);
  67. // Clean up
  68. csession.CloseAndDelete(s);
  69. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  70. }
  71. }
  72. [TestMethod]
  73. public void EvalTensor()
  74. {
  75. lock (this)
  76. {
  77. var a = constant_op.constant(np.array(3.0).reshape(1, 1));
  78. var b = constant_op.constant(np.array(2.0).reshape(1, 1));
  79. var c = math_ops.matmul(a, b, name: "matmul");
  80. using (var sess = tf.Session())
  81. {
  82. var result = c.eval(sess);
  83. Assert.AreEqual(6, result.Data<double>()[0]);
  84. }
  85. }
  86. }
  87. [TestMethod]
  88. public void Eval_SmallString_Scalar()
  89. {
  90. lock (this)
  91. {
  92. var a = constant_op.constant("123 heythere 123 ", TF_DataType.TF_STRING);
  93. var c = tf.strings.substr(a, 4, 8);
  94. using (var sess = tf.Session())
  95. {
  96. var result = (string) c.eval(sess);
  97. Console.WriteLine(result);
  98. result.Should().Be("heythere");
  99. }
  100. }
  101. }
  102. [TestMethod]
  103. public void Eval_LargeString_Scalar()
  104. {
  105. lock (this)
  106. {
  107. const int size = 30_000;
  108. var a = constant_op.constant(new string('a', size), TF_DataType.TF_STRING);
  109. var c = tf.strings.substr(a, 0, size - 5000);
  110. using (var sess = tf.Session())
  111. {
  112. var result = (string) c.eval(sess);
  113. Console.WriteLine((string) result);
  114. result.Should().HaveLength(size - 5000).And.ContainAll("a");
  115. }
  116. }
  117. }
  118. }
  119. }