You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

SessionTest.cs 3.1 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using NumSharp;
  3. using System;
  4. using System.Collections.Generic;
  5. using System.Text;
  6. using Tensorflow;
  7. using static Tensorflow.Python;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class SessionTest : CApiTest
  12. {
  13. /// <summary>
  14. /// tensorflow\c\c_api_test.cc
  15. /// `TEST(CAPI, Session)`
  16. /// </summary>
  17. [TestMethod]
  18. public void Session()
  19. {
  20. var s = new Status();
  21. var graph = new Graph();
  22. // Make a placeholder operation.
  23. var feed = c_test_util.Placeholder(graph, s);
  24. // Make a constant operation with the scalar "2".
  25. var two = c_test_util.ScalarConst(2, graph, s);
  26. // Add operation.
  27. var add = c_test_util.Add(feed, two, graph, s);
  28. var csession = new CSession(graph, s);
  29. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  30. // Run the graph.
  31. var inputs = new Dictionary<Operation, Tensor>();
  32. inputs.Add(feed, new Tensor(3));
  33. csession.SetInputs(inputs);
  34. var outputs = new TF_Output[] { new TF_Output(add, 0) };
  35. csession.SetOutputs(outputs);
  36. csession.Run(s);
  37. Tensor outTensor = csession.output_tensor(0);
  38. EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
  39. EXPECT_EQ(0, outTensor.NDims);
  40. ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
  41. var output_contents = outTensor.Data<int>();
  42. EXPECT_EQ(3 + 2, output_contents[0]);
  43. // Add another operation to the graph.
  44. var neg = c_test_util.Neg(add, graph, s);
  45. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  46. // Run up to the new operation.
  47. inputs = new Dictionary<Operation, Tensor>();
  48. inputs.Add(feed, new Tensor(7));
  49. csession.SetInputs(inputs);
  50. outputs = new TF_Output[] { new TF_Output(neg, 0) };
  51. csession.SetOutputs(outputs);
  52. csession.Run(s);
  53. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  54. outTensor = csession.output_tensor(0);
  55. ASSERT_TRUE(outTensor != IntPtr.Zero);
  56. EXPECT_EQ(TF_DataType.TF_INT32, outTensor.dtype);
  57. EXPECT_EQ(0, outTensor.NDims); // scalar
  58. ASSERT_EQ((ulong)sizeof(uint), outTensor.bytesize);
  59. output_contents = outTensor.Data<int>();
  60. EXPECT_EQ(-(7 + 2), output_contents[0]);
  61. // Clean up
  62. csession.CloseAndDelete(s);
  63. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  64. graph.Dispose();
  65. s.Dispose();
  66. }
  67. [TestMethod]
  68. public void EvalTensor()
  69. {
  70. var a = constant_op.constant(np.array(3.0).reshape(1, 1));
  71. var b = constant_op.constant(np.array(2.0).reshape(1, 1));
  72. var c = math_ops.matmul(a, b, name: "matmul");
  73. with(tf.Session(), delegate
  74. {
  75. var result = c.eval();
  76. Assert.AreEqual(6, result.Data<double>()[0]);
  77. });
  78. }
  79. }
  80. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。