You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiFunctionTest.cs 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Runtime.InteropServices;
  6. using Tensorflow;
  7. using Tensorflow.Functions;
  8. using static TensorFlowNET.UnitTest.c_test_util;
  9. namespace TensorFlowNET.UnitTest.NativeAPI
  10. {
  11. /// <summary>
  12. /// tensorflow\c\c_api_function_test.cc
  13. /// `class CApiColocationTest`
  14. /// </summary>
  15. [TestClass]
  16. public class CApiFunctionTest : CApiTest, IDisposable
  17. {
  18. Graph func_graph_;
  19. Graph host_graph_;
  20. string func_name_ = "MyFunc";
  21. string func_node_name_ = "MyFunc_0";
  22. Status s_;
  23. IntPtr func_;
  24. [TestInitialize]
  25. public void Initialize()
  26. {
  27. func_graph_ = new Graph();
  28. host_graph_ = new Graph();
  29. s_ = new Status();
  30. }
  31. [TestMethod]
  32. public void OneOp_ZeroInputs_OneOutput()
  33. {
  34. var c = ScalarConst(10, func_graph_, s_, "scalar10");
  35. // Define
  36. Define(-1, new Operation[0], new Operation[0], new[] { c }, new string[0]);
  37. // Use, run, and verify
  38. var func_op = Use(new Operation[0]);
  39. Run(new KeyValuePair<Operation, Tensor>[0], func_op, 10);
  40. VerifyFDef(new[] { "scalar10_0" });
  41. }
  42. void Define(int num_opers, Operation[] opers,
  43. Operation[] inputs, Operation[] outputs,
  44. string[] output_names, bool expect_failure = false)
  45. => DefineT(num_opers, opers,
  46. inputs.Select(x => new TF_Output(x, 0)).ToArray(),
  47. outputs.Select(x => new TF_Output(x, 0)).ToArray(),
  48. output_names, expect_failure);
  49. void DefineT(int num_opers, Operation[] opers,
  50. TF_Output[] inputs, TF_Output[] outputs,
  51. string[] output_names, bool expect_failure = false)
  52. {
  53. IntPtr output_names_ptr = IntPtr.Zero;
  54. func_ = c_api.TF_GraphToFunction(func_graph_, func_name_, false,
  55. num_opers, num_opers == -1 ? new IntPtr[0] : opers.Select(x => (IntPtr)x).ToArray(),
  56. inputs.Length, inputs.ToArray(),
  57. outputs.Length, outputs.ToArray(),
  58. output_names_ptr, IntPtr.Zero, null, s_.Handle);
  59. // delete output_names_ptr
  60. if (expect_failure)
  61. {
  62. ASSERT_EQ(IntPtr.Zero, func_);
  63. return;
  64. }
  65. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  66. ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_)));
  67. c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle);
  68. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  69. }
  70. Operation Use(Operation[] inputs)
  71. => UseT(inputs.Select(x => new TF_Output(x, 0)).ToArray());
  72. Operation UseT(TF_Output[] inputs)
  73. => UseHelper(inputs);
  74. Operation UseHelper(TF_Output[] inputs)
  75. {
  76. var desc = TF_NewOperation(host_graph_, func_name_, func_node_name_);
  77. foreach (var input in inputs)
  78. TF_AddInput(desc, input);
  79. c_api.TF_SetDevice(desc, "/cpu:0");
  80. var op = TF_FinishOperation(desc, s_);
  81. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  82. ASSERT_NE(op, IntPtr.Zero);
  83. return op;
  84. }
  85. void Run(KeyValuePair<Operation, Tensor>[] inputs, Operation output, int expected_result)
  86. => Run(inputs, new[] { new TF_Output(output, 0) }, new[] { expected_result });
  87. unsafe void Run(KeyValuePair<Operation, Tensor>[] inputs, TF_Output[] outputs, int[] expected_results)
  88. {
  89. var csession = new CSession(host_graph_, s_);
  90. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  91. csession.SetInputs(inputs);
  92. csession.SetOutputs(outputs);
  93. csession.Run(s_);
  94. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  95. for (int i = 0; i < expected_results.Length; ++i)
  96. {
  97. var output = csession.output_tensor(i);
  98. ASSERT_NE(output, IntPtr.Zero);
  99. EXPECT_EQ(TF_DataType.TF_INT32, c_api.TF_TensorType(output));
  100. EXPECT_EQ(0, c_api.TF_NumDims(output));
  101. ASSERT_EQ(sizeof(int), (int)c_api.TF_TensorByteSize(output));
  102. var output_contents = c_api.TF_TensorData(output);
  103. EXPECT_EQ(expected_results[i], *(int*)output_contents.ToPointer());
  104. }
  105. }
  106. void VerifyFDef(string[] nodes)
  107. {
  108. var fdef = GetFunctionDef(func_);
  109. EXPECT_NE(fdef, IntPtr.Zero);
  110. VerifyFDefNodes(fdef, nodes);
  111. }
  112. void VerifyFDefNodes(FunctionDef fdef, string[] nodes)
  113. {
  114. ASSERT_EQ(nodes.Length, fdef.NodeDef.Count);
  115. }
  116. public void Dispose()
  117. {
  118. }
  119. }
  120. }