You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

GraphTest.cs 18 kB

6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using Tensorflow;
  5. using Buffer = Tensorflow.Buffer;
  6. using static Tensorflow.Binding;
  7. namespace TensorFlowNET.UnitTest.NativeAPI
  8. {
  9. [TestClass]
  10. public class GraphTest : CApiTest
  11. {
  12. /// <summary>
  13. /// Port from c_api_test.cc
  14. /// `TEST(CAPI, Graph)`
  15. /// </summary>
  16. [TestMethod]
  17. public void Graph()
  18. {
  19. var s = new Status();
  20. var graph = new Graph();
  21. // Make a placeholder operation.
  22. var feed = c_test_util.Placeholder(graph, s);
  23. EXPECT_EQ("feed", feed.name);
  24. EXPECT_EQ("Placeholder", feed.OpType);
  25. EXPECT_EQ("", feed.Device);
  26. EXPECT_EQ(1, feed.NumOutputs);
  27. EXPECT_EQ(TF_DataType.TF_INT32, feed.OutputType(0));
  28. EXPECT_EQ(1, feed.OutputListLength("output"));
  29. EXPECT_EQ(0, feed.NumInputs);
  30. EXPECT_EQ(0, feed.OutputNumConsumers(0));
  31. EXPECT_EQ(0, feed.NumControlInputs);
  32. EXPECT_EQ(0, feed.NumControlOutputs);
  33. AttrValue attr_value = null;
  34. ASSERT_TRUE(c_test_util.GetAttrValue(feed, "dtype", ref attr_value, s));
  35. EXPECT_EQ(attr_value.Type, DataType.DtInt32);
  36. // Test not found errors in TF_Operation*() query functions.
  37. EXPECT_EQ(-1, c_api.TF_OperationOutputListLength(feed, "bogus", s.Handle));
  38. EXPECT_EQ(TF_Code.TF_INVALID_ARGUMENT, s.Code);
  39. Assert.IsFalse(c_test_util.GetAttrValue(feed, "missing", ref attr_value, s));
  40. EXPECT_EQ("Operation 'feed' has no attr named 'missing'.", s.Message);
  41. // Make a constant oper with the scalar "3".
  42. var three = c_test_util.ScalarConst(3, graph, s);
  43. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  44. // Add oper.
  45. var add = c_test_util.Add(feed, three, graph, s);
  46. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  47. // Test TF_Operation*() query functions.
  48. EXPECT_EQ("add", add.name);
  49. EXPECT_EQ("AddN", add.OpType);
  50. EXPECT_EQ("", add.Device);
  51. EXPECT_EQ(1, add.NumOutputs);
  52. EXPECT_EQ(TF_DataType.TF_INT32, add.OutputType(0));
  53. EXPECT_EQ(1, add.OutputListLength("sum"));
  54. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  55. EXPECT_EQ(2, add.InputListLength("inputs"));
  56. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  57. EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(0));
  58. EXPECT_EQ(TF_DataType.TF_INT32, add.InputType(1));
  59. var add_in_0 = add.Input(0);
  60. EXPECT_EQ(feed, add_in_0.oper);
  61. EXPECT_EQ(0, add_in_0.index);
  62. var add_in_1 = add.Input(1);
  63. EXPECT_EQ(three, add_in_1.oper);
  64. EXPECT_EQ(0, add_in_1.index);
  65. EXPECT_EQ(0, add.OutputNumConsumers(0));
  66. EXPECT_EQ(0, add.NumControlInputs);
  67. EXPECT_EQ(0, add.NumControlOutputs);
  68. ASSERT_TRUE(c_test_util.GetAttrValue(add, "T", ref attr_value, s));
  69. EXPECT_EQ(DataType.DtInt32, attr_value.Type);
  70. ASSERT_TRUE(c_test_util.GetAttrValue(add, "N", ref attr_value, s));
  71. EXPECT_EQ(2, (int)attr_value.I);
  72. // Placeholder oper now has a consumer.
  73. EXPECT_EQ(1, feed.OutputNumConsumers(0));
  74. TF_Input[] feed_port = feed.OutputConsumers(0, 1);
  75. EXPECT_EQ(1, feed_port.Length);
  76. EXPECT_EQ(add, feed_port[0].oper);
  77. EXPECT_EQ(0, feed_port[0].index);
  78. // The scalar const oper also has a consumer.
  79. EXPECT_EQ(1, three.OutputNumConsumers(0));
  80. TF_Input[] three_port = three.OutputConsumers(0, 1);
  81. EXPECT_EQ(add, three_port[0].oper);
  82. EXPECT_EQ(1, three_port[0].index);
  83. // Serialize to GraphDef.
  84. var graph_def = c_test_util.GetGraphDef(graph);
  85. // Validate GraphDef is what we expect.
  86. bool found_placeholder = false;
  87. bool found_scalar_const = false;
  88. bool found_add = false;
  89. foreach (var n in graph_def.Node)
  90. {
  91. if (c_test_util.IsPlaceholder(n))
  92. {
  93. Assert.IsFalse(found_placeholder);
  94. found_placeholder = true;
  95. }
  96. else if (c_test_util.IsScalarConst(n, 3))
  97. {
  98. Assert.IsFalse(found_scalar_const);
  99. found_scalar_const = true;
  100. }
  101. else if (c_test_util.IsAddN(n, 2))
  102. {
  103. Assert.IsFalse(found_add);
  104. found_add = true;
  105. }
  106. else
  107. {
  108. Assert.Fail($"Unexpected NodeDef: {n}");
  109. }
  110. }
  111. ASSERT_TRUE(found_placeholder);
  112. ASSERT_TRUE(found_scalar_const);
  113. ASSERT_TRUE(found_add);
  114. // Add another oper to the graph.
  115. var neg = c_test_util.Neg(add, graph, s);
  116. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  117. // Serialize to NodeDef.
  118. var node_def = neg.node_def;
  119. // Validate NodeDef is what we expect.
  120. ASSERT_TRUE(c_test_util.IsNeg(node_def, "add"));
  121. // Serialize to GraphDef.
  122. var graph_def2 = c_test_util.GetGraphDef(graph);
  123. // Compare with first GraphDef + added NodeDef.
  124. graph_def.Node.Add(node_def);
  125. EXPECT_EQ(graph_def, graph_def2);
  126. // Look up some nodes by name.
  127. Operation neg2 = c_api.TF_GraphOperationByName(graph, "neg");
  128. EXPECT_EQ(neg, neg2);
  129. var node_def2 = neg2.node_def;
  130. EXPECT_EQ(node_def, node_def2);
  131. Operation feed2 = c_api.TF_GraphOperationByName(graph, "feed");
  132. EXPECT_EQ(feed, feed2);
  133. node_def = feed.node_def;
  134. node_def2 = feed2.node_def;
  135. EXPECT_EQ(node_def, node_def2);
  136. // Test iterating through the nodes of a graph.
  137. found_placeholder = false;
  138. found_scalar_const = false;
  139. found_add = false;
  140. bool found_neg = false;
  141. uint pos = 0;
  142. Operation oper;
  143. while ((oper = c_api.TF_GraphNextOperation(graph, ref pos)) != IntPtr.Zero)
  144. {
  145. if (oper.Equals(feed))
  146. {
  147. Assert.IsFalse(found_placeholder);
  148. found_placeholder = true;
  149. }
  150. else if (oper.Equals(three))
  151. {
  152. Assert.IsFalse(found_scalar_const);
  153. found_scalar_const = true;
  154. }
  155. else if (oper.Equals(add))
  156. {
  157. Assert.IsFalse(found_add);
  158. found_add = true;
  159. }
  160. else if (oper.Equals(neg))
  161. {
  162. Assert.IsFalse(found_neg);
  163. found_neg = true;
  164. }
  165. else
  166. {
  167. node_def = oper.node_def;
  168. Assert.Fail($"Unexpected Node: {node_def.ToString()}");
  169. }
  170. }
  171. ASSERT_TRUE(found_placeholder);
  172. ASSERT_TRUE(found_scalar_const);
  173. ASSERT_TRUE(found_add);
  174. ASSERT_TRUE(found_neg);
  175. graph.Dispose();
  176. s.Dispose();
  177. }
  178. /// <summary>
  179. /// Port from c_api_test.cc
  180. /// `TEST(CAPI, ImportGraphDef)`
  181. /// </summary>
  182. [TestMethod]
  183. public void ImportGraphDef()
  184. {
  185. var s = new Status();
  186. var graph = new Graph().as_default();
  187. // Create a simple graph.
  188. c_test_util.Placeholder(graph, s);
  189. var oper = c_test_util.ScalarConst(3, graph, s);
  190. c_test_util.Neg(oper, graph, s);
  191. // Export to a GraphDef.
  192. var graph_def = new Buffer();
  193. c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle);
  194. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  195. // Import it, with a prefix, in a fresh graph.
  196. graph.Dispose();
  197. graph = new Graph().as_default();
  198. using (var opts = c_api.TF_NewImportGraphDefOptions())
  199. {
  200. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported");
  201. c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
  202. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  203. }
  204. Operation scalar = graph.OperationByName("imported/scalar");
  205. Operation feed = graph.OperationByName("imported/feed");
  206. Operation neg = graph.OperationByName("imported/neg");
  207. // Test basic structure of the imported graph.
  208. EXPECT_EQ(0, scalar.NumInputs);
  209. EXPECT_EQ(0, feed.NumInputs);
  210. EXPECT_EQ(1, neg.NumInputs);
  211. var neg_input = neg.Input(0);
  212. EXPECT_EQ(scalar, neg_input.oper);
  213. EXPECT_EQ(0, neg_input.index);
  214. // Test that we can't see control edges involving the source and sink nodes.
  215. EXPECT_EQ(0, scalar.NumControlInputs);
  216. EXPECT_EQ(0, scalar.GetControlInputs().Length);
  217. EXPECT_EQ(0, scalar.NumControlOutputs);
  218. EXPECT_EQ(0, scalar.GetControlOutputs().Length);
  219. EXPECT_EQ(0, feed.NumControlInputs);
  220. EXPECT_EQ(0, feed.GetControlInputs().Length);
  221. EXPECT_EQ(0, feed.NumControlOutputs);
  222. EXPECT_EQ(0, feed.GetControlOutputs().Length);
  223. EXPECT_EQ(0, neg.NumControlInputs);
  224. EXPECT_EQ(0, neg.GetControlInputs().Length);
  225. EXPECT_EQ(0, neg.NumControlOutputs);
  226. EXPECT_EQ(0, neg.GetControlOutputs().Length);
  227. static SafeImportGraphDefResultsHandle ImportGraph(Status s, Graph graph, Buffer graph_def, Operation scalar)
  228. {
  229. using var opts = c_api.TF_NewImportGraphDefOptions();
  230. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported2");
  231. c_api.TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, new TF_Output(scalar, 0));
  232. c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0);
  233. c_api.TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0);
  234. EXPECT_EQ(2, c_api.TF_ImportGraphDefOptionsNumReturnOutputs(opts));
  235. c_api.TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar");
  236. EXPECT_EQ(1, c_api.TF_ImportGraphDefOptionsNumReturnOperations(opts));
  237. var results = c_api.TF_GraphImportGraphDefWithResults(graph, graph_def.Handle, opts, s.Handle);
  238. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  239. return results;
  240. }
  241. // Import it again, with an input mapping, return outputs, and a return
  242. // operation, into the same graph.
  243. Operation feed2;
  244. using (SafeImportGraphDefResultsHandle results = ImportGraph(s, graph, graph_def, scalar))
  245. {
  246. Operation scalar2 = graph.OperationByName("imported2/scalar");
  247. feed2 = graph.OperationByName("imported2/feed");
  248. Operation neg2 = graph.OperationByName("imported2/neg");
  249. // Check input mapping
  250. neg_input = neg.Input(0);
  251. EXPECT_EQ(scalar, neg_input.oper);
  252. EXPECT_EQ(0, neg_input.index);
  253. // Check return outputs
  254. var return_outputs = graph.ReturnOutputs(results);
  255. ASSERT_EQ(2, return_outputs.Length);
  256. EXPECT_EQ(feed2, return_outputs[0].oper);
  257. EXPECT_EQ(0, return_outputs[0].index);
  258. EXPECT_EQ(scalar, return_outputs[1].oper); // remapped
  259. EXPECT_EQ(0, return_outputs[1].index);
  260. // Check return operation
  261. var return_opers = graph.ReturnOperations(results);
  262. ASSERT_EQ(1, return_opers.Length);
  263. EXPECT_EQ(scalar2, return_opers[0]); // not remapped
  264. }
  265. // Import again, with control dependencies, into the same graph.
  266. using (var opts = c_api.TF_NewImportGraphDefOptions())
  267. {
  268. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported3");
  269. c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed);
  270. c_api.TF_ImportGraphDefOptionsAddControlDependency(opts, feed2);
  271. c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
  272. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  273. }
  274. var scalar3 = graph.OperationByName("imported3/scalar");
  275. var feed3 = graph.OperationByName("imported3/feed");
  276. var neg3 = graph.OperationByName("imported3/neg");
  277. ASSERT_TRUE(scalar3 != IntPtr.Zero);
  278. ASSERT_TRUE(feed3 != IntPtr.Zero);
  279. ASSERT_TRUE(neg3 != IntPtr.Zero);
  280. // Check that newly-imported scalar and feed have control deps (neg3 will
  281. // inherit them from input)
  282. var control_inputs = scalar3.GetControlInputs();
  283. ASSERT_EQ(2, scalar3.NumControlInputs);
  284. EXPECT_EQ(feed, control_inputs[0]);
  285. EXPECT_EQ(feed2, control_inputs[1]);
  286. control_inputs = feed3.GetControlInputs();
  287. ASSERT_EQ(2, feed3.NumControlInputs);
  288. EXPECT_EQ(feed, control_inputs[0]);
  289. EXPECT_EQ(feed2, control_inputs[1]);
  290. // Export to a graph def so we can import a graph with control dependencies
  291. graph_def = new Buffer();
  292. c_api.TF_GraphToGraphDef(graph, graph_def.Handle, s.Handle);
  293. EXPECT_EQ(TF_Code.TF_OK, s.Code);
  294. // Import again, with remapped control dependency, into the same graph
  295. using (var opts = c_api.TF_NewImportGraphDefOptions())
  296. {
  297. c_api.TF_ImportGraphDefOptionsSetPrefix(opts, "imported4");
  298. c_api.TF_ImportGraphDefOptionsRemapControlDependency(opts, "imported/feed", feed);
  299. c_api.TF_GraphImportGraphDef(graph, graph_def.Handle, opts, s.Handle);
  300. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  301. }
  302. var scalar4 = graph.OperationByName("imported4/imported3/scalar");
  303. var feed4 = graph.OperationByName("imported4/imported2/feed");
  304. // Check that imported `imported3/scalar` has remapped control dep from
  305. // original graph and imported control dep
  306. control_inputs = scalar4.GetControlInputs();
  307. ASSERT_EQ(2, scalar4.NumControlInputs);
  308. EXPECT_EQ(feed, control_inputs[0]);
  309. EXPECT_EQ(feed4, control_inputs[1]);
  310. // Can add nodes to the imported graph without trouble.
  311. c_test_util.Add(feed, scalar, graph, s);
  312. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  313. }
  314. /// <summary>
  315. /// Port from c_api_test.cc
  316. /// `TEST(CAPI, ImportGraphDef_WithReturnOutputs)`
  317. /// </summary>
  318. [TestMethod]
  319. public void ImportGraphDef_WithReturnOutputs()
  320. {
  321. var s = new Status();
  322. var graph = new Graph().as_default();
  323. // Create a graph with two nodes: x and 3
  324. c_test_util.Placeholder(graph, s);
  325. ASSERT_TRUE(graph.OperationByName("feed") != null);
  326. var oper = c_test_util.ScalarConst(3, graph, s);
  327. ASSERT_TRUE(graph.OperationByName("scalar") != null);
  328. c_test_util.Neg(oper, graph, s);
  329. ASSERT_TRUE(graph.OperationByName("neg") != null);
  330. // Export to a GraphDef.
  331. var graph_def = graph.ToGraphDef(s);
  332. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  333. // Import it in a fresh graph with return outputs.
  334. graph.Dispose();
  335. graph = new Graph().as_default();
  336. var opts = new ImportGraphDefOptions();
  337. opts.AddReturnOutput("feed", 0);
  338. opts.AddReturnOutput("scalar", 0);
  339. EXPECT_EQ(2, opts.NumReturnOutputs);
  340. var return_outputs = graph.ImportGraphDefWithReturnOutputs(graph_def, opts, s);
  341. ASSERT_EQ(TF_Code.TF_OK, s.Code);
  342. var scalar = graph.OperationByName("scalar");
  343. var feed = graph.OperationByName("feed");
  344. var neg = graph.OperationByName("neg");
  345. ASSERT_TRUE(scalar != IntPtr.Zero);
  346. ASSERT_TRUE(feed != IntPtr.Zero);
  347. ASSERT_TRUE(neg != IntPtr.Zero);
  348. // Check return outputs
  349. EXPECT_EQ(feed, return_outputs[0].oper);
  350. EXPECT_EQ(0, return_outputs[0].index);
  351. EXPECT_EQ(scalar, return_outputs[1].oper);
  352. EXPECT_EQ(0, return_outputs[1].index);
  353. opts.Dispose();
  354. graph_def.Dispose();
  355. graph.Dispose();
  356. s.Dispose();
  357. }
  358. /// <summary>
  359. /// `TEST(CAPI, ImportGraphDef_MissingUnusedInputMappings)`
  360. /// </summary>
  361. [TestMethod]
  362. public void ImportGraphDef_MissingUnusedInputMappings()
  363. {
  364. }
  365. [Ignore]
  366. [TestMethod]
  367. public void ImportGraphMeta()
  368. {
  369. var dir = "my-save-dir/";
  370. using (var sess = tf.Session())
  371. {
  372. var new_saver = tf.train.import_meta_graph(dir + "my-model-10000.meta");
  373. new_saver.restore(sess, dir + "my-model-10000");
  374. var labels = tf.constant(0, dtype: tf.int32, shape: new int[] { 100 }, name: "labels");
  375. var batch_size = tf.size(labels);
  376. var logits = tf.get_collection<ITensorOrOperation>("logits")[0] as Tensor;
  377. var loss = tf.losses.sparse_softmax_cross_entropy(labels: labels,
  378. logits: logits);
  379. }
  380. }
  381. }
  382. }