You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 15 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Runtime.InteropServices;
  5. using System.Text;
  6. using Tensorflow;
  7. using Buffer = Tensorflow.Buffer;
  8. namespace TensorFlowNET.UnitTest
  9. {
  10. [TestClass]
  11. public class GraphTest : CApiTest
  12. {
  13. /// <summary>
  14. /// Port from c_api_test.cc
  15. /// `TEST(CAPI, Graph)`
  16. /// </summary>
  17. [TestMethod]
  18. public void c_api_Graph()
  19. {
  20. var s = new Status();
  21. var graph = new Graph();
  22. // Make a placeholder operation.
  23. var feed = c_test_util.Placeholder(graph, s);
  24. EXPECT_EQ("feed", feed.Name);
  25. EXPECT_EQ("Placeholder", feed.OpType);
  26. EXPECT_EQ("", feed.Device);
  27. EXPECT_EQ(1, feed.NumOutputs);
  28. EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0));
  29. EXPECT_EQ(1, feed.OutputListLength("output"));
  30. EXPECT_EQ(0, feed.NumInputs);
  31. EXPECT_EQ(0, feed.OutputNumConsumers(0));
  32. EXPECT_EQ(0, feed.NumControlInputs);
  33. EXPECT_EQ(0, feed.NumControlOutputs);
  34. AttrValue attr_value = null;
  35. ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
  36. EXPECT_EQ(attr_value.Type, DataType.DtInt32);
  37. // Test not found errors in TF_Operation*() query functions.
  38. EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s));
  39. EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code);
  40. Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
  41. EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message);
  42. // Make a constant oper with the scalar "3".
  43. var three = c_test_util.ScalarConst(3, graph, s);
  44. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  45. // Add oper.
  46. var add = c_test_util.Add(feed, three, graph, s);
  47. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  48. // Test TF_Operation*() query functions.
  49. EXPECT_EQ("add", add.Name);
  50. EXPECT_EQ("AddN", add.OpType);
  51. EXPECT_EQ("", add.Device);
  52. EXPECT_EQ(1, add.NumOutputs);
  53. EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0));
  54. EXPECT_EQ(1, add.OutputListLength("sum"));
  55. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  56. EXPECT_EQ(2, add.InputListLength("inputs"));
  57. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  58. EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0));
  59. EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1));
  60. var add_in_0 = add.Input(0);
  61. EXPECT_EQ(feed, add_in_0.oper);
  62. EXPECT_EQ(0, add_in_0.index);
  63. var add_in_1 = add.Input(1);
  64. EXPECT_EQ(three, add_in_1.oper);
  65. EXPECT_EQ(0, add_in_1.index);
  66. EXPECT_EQ(0, add.OutputNumConsumers(0));
  67. EXPECT_EQ(0, add.NumControlInputs);
  68. EXPECT_EQ(0, add.NumControlOutputs);
  69. ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
  70. EXPECT_EQ(DataType.DtInt32, attr_value.Type);
  71. ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
  72. EXPECT_EQ(2, attr_value.I);
  73. // Placeholder oper now has a consumer.
  74. EXPECT_EQ(1, feed.OutputNumConsumers(0));
  75. TF_Input[] feed_port = feed.OutputConsumers(0, 1);
  76. EXPECT_EQ(1, feed_port.Length);
  77. EXPECT_EQ(add, feed_port[0].oper);
  78. EXPECT_EQ(0, feed_port[0].index);
  79. // The scalar const oper also has a consumer.
  80. EXPECT_EQ(1, three.OutputNumConsumers(0));
  81. TF_Input[] three_port = three.OutputConsumers(0, 1);
  82. EXPECT_EQ(add, three_port[0].oper);
  83. EXPECT_EQ(1, three_port[0].index);
  84. // Serialize to GraphDef.
  85. var graph_def = c_test_util.GetGraphDef(graph);
  86. // Validate GraphDef is what we expect.
  87. bool found_placeholder = false;
  88. bool found_scalar_const = false;
  89. bool found_add = false;
  90. foreach (var n in graph_def.Node)
  91. {
  92. if (c_test_util.IsPlaceholder(n))
  93. {
  94. Assert.IsFalse(found_placeholder);
  95. found_placeholder = true;
  96. }
  97. else if (c_test_util.IsScalarConst(n, 3))
  98. {
  99. Assert.IsFalse(found_scalar_const);
  100. found_scalar_const = true;
  101. }
  102. else if (c_test_util.IsAddN(n, 2))
  103. {
  104. Assert.IsFalse(found_add);
  105. found_add = true;
  106. }
  107. else
  108. {
  109. Assert.Fail($"Unexpected NodeDef: {n}");
  110. }
  111. }
  112. ASSERT_TRUE(found_placeholder);
  113. ASSERT_TRUE(found_scalar_const);
  114. ASSERT_TRUE(found_add);
  115. // Add another oper to the graph.
  116. var neg = c_test_util.Neg(add, graph, s);
  117. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  118. // Serialize to NodeDef.
  119. var node_def = c_test_util.GetNodeDef(neg);
  120. // Validate NodeDef is what we expect.
  121. ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));
  122. // Serialize to GraphDef.
  123. var graph_def2 = c_test_util.GetGraphDef(graph);
  124. // Compare with first GraphDef + added NodeDef.
  125. graph_def.Node.Add(node_def);
  126. EXPECT_EQ(graph_def.ToString(), graph_def2.ToString());
  127. // Look up some nodes by name.
  128. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
  129. EXPECT_EQ(neg, neg2);
  130. var node_def2 = c_test_util.GetNodeDef(neg2);
  131. EXPECT_EQ(node_def.ToString(), node_def2.ToString());
  132. Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
  133. EXPECT_EQ(feed, feed2);
  134. node_def = c_test_util.GetNodeDef(feed);
  135. node_def2 = c_test_util.GetNodeDef(feed2);
  136. EXPECT_EQ(node_def.ToString(), node_def2.ToString());
  137. // Test iterating through the nodes of a graph.
  138. found_placeholder = false;
  139. found_scalar_const = false;
  140. found_add = false;
  141. bool found_neg = false;
  142. uint pos = 0;
  143. Operation oper;
  144. while((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
  145. {
  146. if (oper.Equals(feed))
  147. {
  148. Assert.IsFalse(found_placeholder);
  149. found_placeholder = true;
  150. }
  151. else if (oper.Equals(three))
  152. {
  153. Assert.IsFalse(found_scalar_const);
  154. found_scalar_const = true;
  155. }
  156. else if (oper.Equals(add))
  157. {
  158. Assert.IsFalse(found_add);
  159. found_add = true;
  160. }
  161. else if (oper.Equals(neg))
  162. {
  163. Assert.IsFalse(found_neg);
  164. found_neg = true;
  165. }
  166. else
  167. {
  168. node_def = c_test_util.GetNodeDef(oper);
  169. Assert.Fail($"Unexpected Node: {node_def.ToString()}");
  170. }
  171. }
  172. ASSERT_TRUE(found_placeholder);
  173. ASSERT_TRUE(found_scalar_const);
  174. ASSERT_TRUE(found_add);
  175. ASSERT_TRUE(found_neg);
  176. graph.Dispose();
  177. s.Dispose();
  178. }
  179. /// <summary>
  180. /// Port from c_api_test.cc
  181. /// `TEST(CAPI, ImportGraphDef)`
  182. /// </summary>
  183. [TestMethod]
  184. public void c_api_ImportGraphDef()
  185. {
  186. var s = new Status();
  187. var graph = new Graph();
  188. // Create a simple graph.
  189. c_test_util.Placeholder(graph, s);
  190. var oper = c_test_util.ScalarConst(3, graph, s);
  191. c_test_util.Neg(oper, graph, s);
  192. // Export to a GraphDef.
  193. var graph_def = new Buffer();
  194. c_api.TF_GraphToGraphDef(graph, graph_def, s);
  195. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  196. // Import it, with a prefix, in a fresh graph.
  197. graph.Dispose();
  198. graph = new Graph();
  199. var opts = c_api.TF_NewImportGraphDefOptions();
  200. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
  201. c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
  202. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  203. Operation scalar = graph.OperationByName("imported/scalar");
  204. Operation feed = graph.OperationByName("imported/feed");
  205. Operation neg = graph.OperationByName("imported/neg");
  206. // Test basic structure of the imported graph.
  207. EXPECT_EQ(0, scalar.NumInputs);
  208. EXPECT_EQ(0, feed.NumInputs);
  209. EXPECT_EQ(1, neg.NumInputs);
  210. var neg_input = neg.Input(0);
  211. EXPECT_EQ(scalar, neg_input.oper);
  212. EXPECT_EQ(0, neg_input.index);
  213. // Test that we can't see control edges involving the source and sink nodes.
  214. EXPECT_EQ(0, scalar.NumControlInputs);
  215. EXPECT_EQ(0, scalar.GetControlInputs().Length);
  216. EXPECT_EQ(0, scalar.NumControlOutputs);
  217. EXPECT_EQ(0, scalar.GetControlOutputs().Length);
  218. EXPECT_EQ(0, feed.NumControlInputs);
  219. EXPECT_EQ(0, feed.GetControlInputs().Length);
  220. EXPECT_EQ(0, feed.NumControlOutputs);
  221. EXPECT_EQ(0, feed.GetControlOutputs().Length);
  222. EXPECT_EQ(0, neg.NumControlInputs);
  223. EXPECT_EQ(0, neg.GetControlInputs().Length);
  224. EXPECT_EQ(0, neg.NumControlOutputs);
  225. EXPECT_EQ(0, neg.GetControlOutputs().Length);
  226. // Import it again, with an input mapping, return outputs, and a return
  227. // operation, into the same graph.
  228. c_api.TF_DeleteImportGraphDefOptions(opts);
  229. opts = c_api.TF_NewImportGraphDefOptions();
  230. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
  231. c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
  232. c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
  233. c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
  234. EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
  235. c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
  236. EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
  237. var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s);
  238. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  239. Operation scalar2 = graph.OperationByName("imported2/scalar");
  240. Operation feed2 = graph.OperationByName("imported2/feed");
  241. Operation neg2 = graph.OperationByName("imported2/neg");
  242. // Check input mapping
  243. neg_input = neg.Input(0);
  244. EXPECT_EQ(scalar, neg_input.oper);
  245. EXPECT_EQ(0, neg_input.index);
  246. // Check return outputs
  247. var return_outputs = graph.ReturnOutputs(results);
  248. ASSERT_EQ(2, return_outputs.Length);
  249. EXPECT_EQ(feed2, return_outputs[0].oper);
  250. EXPECT_EQ(0, return_outputs[0].index);
  251. EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
  252. EXPECT_EQ(0, return_outputs[1].index);
  253. // Check return operation
  254. var return_opers = graph.ReturnOperations(results);
  255. ASSERT_EQ(1, return_opers.Length);
  256. EXPECT_EQ(scalar2, return_opers[0]); // not remapped
  257. c_api.TF_DeleteImportGraphDefResults(results);
  258. // Import again, with control dependencies, into the same graph.
  259. c_api.TF_DeleteImportGraphDefOptions(opts);
  260. opts = c_api.TF_NewImportGraphDefOptions();
  261. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
  262. c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
  263. c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
  264. c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
  265. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  266. var scalar3 = graph.OperationByName("imported3/scalar");
  267. var feed3 = graph.OperationByName("imported3/feed");
  268. var neg3 = graph.OperationByName("imported3/neg");
  269. ASSERT_TRUE(scalar3 != IntPtr.Zero);
  270. ASSERT_TRUE(feed3 != IntPtr.Zero);
  271. ASSERT_TRUE(neg3 != IntPtr.Zero);
  272. // Check that newly-imported scalar and feed have control deps (neg3 will
  273. // inherit them from input)
  274. var control_inputs = scalar3.GetControlInputs();
  275. ASSERT_EQ(2, scalar3.NumControlInputs);
  276. EXPECT_EQ(feed, control_inputs[0]);
  277. EXPECT_EQ(feed2, control_inputs[1]);
  278. control_inputs = feed3.GetControlInputs();
  279. ASSERT_EQ(2, feed3.NumControlInputs);
  280. EXPECT_EQ(feed, control_inputs[0]);
  281. EXPECT_EQ(feed2, control_inputs[1]);
  282. // Export to a graph def so we can import a graph with control dependencies
  283. graph_def.Dispose();
  284. graph_def = new Buffer();
  285. c_api.TF_GraphToGraphDef(graph, graph_def, s);
  286. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  287. // Import again, with remapped control dependency, into the same graph
  288. c_api.TF_DeleteImportGraphDefOptions(opts);
  289. opts = c_api.TF_NewImportGraphDefOptions();
  290. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
  291. c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
  292. c_api.TF_GraphImportGraphDef(graph, graph_def, opts, s);
  293. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  294. var scalar4 = graph.OperationByName("imported4/imported3/scalar");
  295. var feed4 = graph.OperationByName("imported4/imported2/feed");
  296. // Check that imported `imported3/scalar` has remapped control dep from
  297. // original graph and imported control dep
  298. control_inputs = scalar4.GetControlInputs();
  299. ASSERT_EQ(2, scalar4.NumControlInputs);
  300. EXPECT_EQ(feed, control_inputs[0]);
  301. EXPECT_EQ(feed4, control_inputs[1]);
  302. c_api.TF_DeleteImportGraphDefOptions(opts);
  303. c_api.TF_DeleteBuffer(graph_def);
  304. // Can add nodes to the imported graph without trouble.
  305. c_test_util.Add(feed, scalar, graph, s);
  306. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  307. //graph.Dispose();
  308. s.Dispose();
  309. }
  310. }
  311. }

tensorflow框架的.NET版本,提供了丰富的特性和API,可以借此很方便地在.NET平台下搭建深度学习训练与推理流程。

Contributors (1)