You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiFunctionTest.cs 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using Tensorflow;
  6. using static TensorFlowNET.UnitTest.c_test_util;
  7. namespace TensorFlowNET.UnitTest.NativeAPI
  8. {
  9. /// <summary>
  10. /// tensorflow\c\c_api_function_test.cc
  11. /// `class CApiColocationTest`
  12. /// </summary>
  13. [TestClass]
  14. public class CApiFunctionTest : CApiTest, IDisposable
  15. {
  16. Graph func_graph_;
  17. Graph host_graph_;
  18. string func_name_ = "MyFunc";
  19. string func_node_name_ = "MyFunc_0";
  20. Status s_;
  21. IntPtr func_;
  22. [TestInitialize]
  23. public void Initialize()
  24. {
  25. func_graph_ = new Graph();
  26. host_graph_ = new Graph();
  27. s_ = new Status();
  28. }
  29. [TestMethod]
  30. public void OneOp_ZeroInputs_OneOutput()
  31. {
  32. var c = ScalarConst(10, func_graph_, s_, "scalar10");
  33. // Define
  34. Define(-1, new Operation[0], new Operation[0], new[] { c }, new string[0]);
  35. // Use, run, and verify
  36. var func_op = Use(new Operation[0]);
  37. Run(new KeyValuePair<Operation, Tensor>[0], func_op, 10);
  38. VerifyFDef(new[] { "scalar10_0" },
  39. new List<IOSpec>(),
  40. new List<IOSpec> { new IOSpec("scalar10", DataType.DtInt32) },
  41. new List<EdgeSpec> { new EdgeSpec("scalar10_0:output:0", "scalar10") },
  42. new List<EdgeSpec>());
  43. }
  44. [TestMethod]
  45. public void OneOp_OneInput_OneOutput()
  46. {
  47. // Define
  48. var feed = Placeholder(func_graph_, s_);
  49. var neg = Neg(feed, func_graph_, s_);
  50. Define(-1, new Operation[0], new[] { feed }, new[] { neg }, new string[0]);
  51. // Use, run, and verify
  52. var func_feed = Placeholder(host_graph_, s_);
  53. var func_op = Use(new[] { func_feed });
  54. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, -3);
  55. VerifyFDef(new string[] { "neg_0" },
  56. new List<IOSpec> { new IOSpec("feed", DataType.DtInt32) },
  57. new List<IOSpec> { new IOSpec("neg", DataType.DtInt32) },
  58. new List<EdgeSpec> { new EdgeSpec("feed", "neg_0:0"), new EdgeSpec("neg_0:y:0", "neg") },
  59. new List<EdgeSpec>());
  60. }
  61. [TestMethod]
  62. public void OneOutput_OutputNames()
  63. {
  64. // Define
  65. var feed = Placeholder(func_graph_, s_);
  66. var neg = Neg(feed, func_graph_, s_);
  67. Define(-1,
  68. new Operation[0],
  69. new[] { feed },
  70. new[] { neg },
  71. new[] { "negated_num" });
  72. // Use, run, and verify
  73. var func_feed = Placeholder(host_graph_, s_);
  74. var func_op = Use(new[] { func_feed });
  75. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, -3);
  76. VerifyFDef(new string[] { "neg" },
  77. new List<IOSpec> { new IOSpec("feed", DataType.DtInt32) },
  78. new List<IOSpec> { new IOSpec("negated_num", DataType.DtInt32) },
  79. new List<EdgeSpec> { new EdgeSpec("feed", "neg:0"), new EdgeSpec("neg:y:0", "negated_num") },
  80. new List<EdgeSpec>());
  81. }
  82. [TestMethod]
  83. public void OutputNames_SameNameAsInput()
  84. {
  85. // Define
  86. var feed = Placeholder(func_graph_, s_, "negation");
  87. var neg = Neg(feed, func_graph_, s_, "neg");
  88. Define(-1,
  89. new Operation[0],
  90. new[] { feed },
  91. new[] { neg },
  92. new[] { "negation" });
  93. // Use, run, and verify
  94. var func_feed = Placeholder(host_graph_, s_);
  95. var func_op = Use(new[] { func_feed });
  96. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, -3);
  97. VerifyFDef(new string[] { "neg" },
  98. new List<IOSpec> { new IOSpec("negation_0", DataType.DtInt32) },
  99. new List<IOSpec> { new IOSpec("negation", DataType.DtInt32) },
  100. new List<EdgeSpec> { new EdgeSpec("negation_0", "neg:0"), new EdgeSpec("neg:y:0", "negation") },
  101. new List<EdgeSpec>());
  102. }
  103. [TestMethod]
  104. public void ZeroOps_Identity()
  105. {
  106. // Define
  107. var feed = Placeholder(func_graph_, s_);
  108. Define(-1,
  109. new Operation[0],
  110. new[] { feed },
  111. new[] { feed },
  112. new string[0]);
  113. // Use, run, and verify
  114. var func_feed = Placeholder(host_graph_, s_);
  115. var func_op = Use(new[] { func_feed });
  116. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, 3);
  117. VerifyFDef(new string[0],
  118. new List<IOSpec> { new IOSpec("feed_0", DataType.DtInt32) },
  119. new List<IOSpec> { new IOSpec("feed", DataType.DtInt32) },
  120. new List<EdgeSpec> { new EdgeSpec("feed_0", "feed") },
  121. new List<EdgeSpec>());
  122. }
  123. [TestMethod]
  124. public void ZeroOps_Permutation()
  125. {
  126. // Define
  127. var feed1 = Placeholder(func_graph_, s_, "feed1");
  128. var feed2 = Placeholder(func_graph_, s_, "feed2");
  129. Define(-1,
  130. null,
  131. new[] { feed1, feed2 },
  132. new[] { feed2, feed1 },
  133. null);
  134. // Use, run, and verify
  135. var two = ScalarConst(2, host_graph_, s_);
  136. var func_feed = Placeholder(host_graph_, s_);
  137. var func_op = Use(new[] { two, func_feed });
  138. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  139. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  140. new[] { 3, 2 });
  141. VerifyFDef(new string[0],
  142. new List<IOSpec> { new IOSpec("feed1_0"), new IOSpec("feed2_0") },
  143. new List<IOSpec> { new IOSpec("feed2"), new IOSpec("feed1") },
  144. new List<EdgeSpec> { new EdgeSpec("feed1_0", "feed1"), new EdgeSpec("feed2_0", "feed2") },
  145. new List<EdgeSpec>());
  146. }
  147. [TestMethod]
  148. public void ZeroOps_Permutation_OutputNames()
  149. {
  150. // Define
  151. var feed1 = Placeholder(func_graph_, s_, "feed1");
  152. var feed2 = Placeholder(func_graph_, s_, "feed2");
  153. Define(-1,
  154. null,
  155. new[] { feed1, feed2 },
  156. new[] { feed2, feed1 },
  157. new[] { "first", "second" });
  158. // Use, run, and verify
  159. var two = ScalarConst(2, host_graph_, s_);
  160. var func_feed = Placeholder(host_graph_, s_);
  161. var func_op = Use(new[] { two, func_feed });
  162. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  163. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  164. new[] { 3, 2 });
  165. VerifyFDef(new string[0],
  166. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  167. new List<IOSpec> { new IOSpec("first"), new IOSpec("second") },
  168. new List<EdgeSpec> { new EdgeSpec("feed1", "second"), new EdgeSpec("feed2", "first") },
  169. new List<EdgeSpec>());
  170. }
  171. [TestMethod]
  172. public void OneOp_TwoInputs_OneOutput()
  173. {
  174. // Define
  175. var feed1 = Placeholder(func_graph_, s_, "feed1");
  176. var feed2 = Placeholder(func_graph_, s_, "feed2");
  177. var add = Add(feed1, feed2, func_graph_, s_);
  178. Define(-1,
  179. null,
  180. new[] { feed1, feed2 },
  181. new[] { add },
  182. null);
  183. // Use, run, and verify
  184. var two = ScalarConst(2, host_graph_, s_);
  185. var func_feed = Placeholder(host_graph_, s_);
  186. var func_op = Use(new[] { two, func_feed });
  187. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  188. func_op,
  189. 2 + 3);
  190. VerifyFDef(new string[] { "add_0" },
  191. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  192. new List<IOSpec> { new IOSpec("add") },
  193. new List<EdgeSpec>
  194. {
  195. new EdgeSpec("feed1", "add_0:0"),
  196. new EdgeSpec("feed2", "add_0:1"),
  197. new EdgeSpec("add_0:sum:0", "add")
  198. },
  199. new List<EdgeSpec>());
  200. }
  201. [TestMethod]
  202. public void OneOp_TwoInputs_ZeroOutputs()
  203. {
  204. // Define
  205. var feed1 = Placeholder(func_graph_, s_, "feed1");
  206. var feed2 = Placeholder(func_graph_, s_, "feed2");
  207. var add = Add(feed1, feed2, func_graph_, s_);
  208. Define(-1,
  209. null,
  210. new[] { feed1, feed2 },
  211. new Operation[0],
  212. null);
  213. // Use, run, and verify
  214. var two = ScalarConst(2, host_graph_, s_);
  215. var func_feed = Placeholder(host_graph_, s_);
  216. var func_op = Use(new[] { two, func_feed });
  217. VerifyFDef(new string[] { "add" },
  218. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  219. new List<IOSpec>(),
  220. new List<EdgeSpec>
  221. {
  222. new EdgeSpec("feed1", "add:0"),
  223. new EdgeSpec("feed2", "add:1")
  224. },
  225. new List<EdgeSpec>());
  226. }
  227. [TestMethod]
  228. public void TwoOps_ThreeInputs_OneOutput()
  229. {
  230. // Define
  231. var feed1 = Placeholder(func_graph_, s_, "feed1");
  232. var feed2 = Placeholder(func_graph_, s_, "feed2");
  233. var feed3 = Placeholder(func_graph_, s_, "feed3");
  234. var add1 = Add(feed1, feed2, func_graph_, s_, "add1");
  235. var add2 = Add(add1, feed3, func_graph_, s_, "add2");
  236. Define(-1,
  237. null,
  238. new[] { feed1, feed2, feed3 },
  239. new[] { add2 },
  240. null);
  241. // Use, run, and verify
  242. var two = ScalarConst(2, host_graph_, s_, "two");
  243. var ten = ScalarConst(10, host_graph_, s_, "ten");
  244. var func_feed = Placeholder(host_graph_, s_);
  245. var func_op = Use(new[] { two, ten, func_feed });
  246. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  247. func_op,
  248. 2 + 10 + 3);
  249. VerifyFDef(new string[] { "add1", "add2_0" },
  250. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2"), new IOSpec("feed3") },
  251. new List<IOSpec> { new IOSpec("add2") },
  252. new List<EdgeSpec>
  253. {
  254. new EdgeSpec("feed1", "add1:0"),
  255. new EdgeSpec("feed2", "add1:1"),
  256. new EdgeSpec("add1:sum:0", "add2_0:0"),
  257. new EdgeSpec("feed3", "add2_0:1"),
  258. new EdgeSpec("add2_0:sum:0", "add2"),
  259. },
  260. new List<EdgeSpec>());
  261. }
  262. [TestMethod]
  263. public void OneOp_TwoInputs_TwoDuplicateOutputs()
  264. {
  265. // Define
  266. var feed1 = Placeholder(func_graph_, s_, "feed1");
  267. var feed2 = Placeholder(func_graph_, s_, "feed2");
  268. var add = Add(feed1, feed2, func_graph_, s_);
  269. Define(-1,
  270. null,
  271. new[] { feed1, feed2 },
  272. new[] { add, add },
  273. null);
  274. // Use, run, and verify
  275. var two = ScalarConst(2, host_graph_, s_);
  276. var func_feed = Placeholder(host_graph_, s_);
  277. var func_op = Use(new[] { two, func_feed });
  278. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  279. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  280. new[] { 5, 5 });
  281. VerifyFDef(new string[] { "add_1" },
  282. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  283. new List<IOSpec> { new IOSpec("add"), new IOSpec("add_0") },
  284. new List<EdgeSpec>
  285. {
  286. new EdgeSpec("feed1", "add_1:0"),
  287. new EdgeSpec("feed2", "add_1:1"),
  288. new EdgeSpec("add_1:sum:0", "add"),
  289. new EdgeSpec("add_1:sum:0", "add_0")
  290. },
  291. new List<EdgeSpec>());
  292. }
  293. [TestMethod]
  294. public void TwoDuplicateOutputs_OutputNames()
  295. {
  296. // Define
  297. var feed1 = Placeholder(func_graph_, s_, "feed1");
  298. var feed2 = Placeholder(func_graph_, s_, "feed2");
  299. var add = Add(feed1, feed2, func_graph_, s_);
  300. Define(-1,
  301. null,
  302. new[] { feed1, feed2 },
  303. new[] { add, add },
  304. new[] { "out1", "out2" });
  305. // Use, run, and verify
  306. var two = ScalarConst(2, host_graph_, s_);
  307. var func_feed = Placeholder(host_graph_, s_);
  308. var func_op = Use(new[] { two, func_feed });
  309. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  310. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  311. new[] { 5, 5 });
  312. VerifyFDef(new string[] { "add" },
  313. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  314. new List<IOSpec> { new IOSpec("out1"), new IOSpec("out2") },
  315. new List<EdgeSpec>
  316. {
  317. new EdgeSpec("feed1", "add:0"),
  318. new EdgeSpec("feed2", "add:1"),
  319. new EdgeSpec("add:sum:0", "out1"),
  320. new EdgeSpec("add:sum:0", "out2")
  321. },
  322. new List<EdgeSpec>());
  323. }
  324. [TestMethod]
  325. public void TwoOps_ThreeInputs_TwoOutputs()
  326. {
  327. // Define
  328. var feed1 = Placeholder(func_graph_, s_, "feed1");
  329. var feed2 = Placeholder(func_graph_, s_, "feed2");
  330. var feed3 = Placeholder(func_graph_, s_, "feed3");
  331. var add1 = Add(feed1, feed2, func_graph_, s_, "add1");
  332. var add2 = Add(add1, feed3, func_graph_, s_, "add2");
  333. Define(-1,
  334. null,
  335. new[] { feed1, feed2, feed3 },
  336. new[] { add1, add2 },
  337. null);
  338. // Use, run, and verify
  339. var two = ScalarConst(2, host_graph_, s_, "two");
  340. var ten = ScalarConst(10, host_graph_, s_, "ten");
  341. var func_feed = Placeholder(host_graph_, s_);
  342. var func_op = Use(new[] { two, ten, func_feed });
  343. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  344. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  345. new[] { 12, 15 });
  346. VerifyFDef(new string[] { "add1_0", "add2_0" },
  347. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2"), new IOSpec("feed3") },
  348. new List<IOSpec> { new IOSpec("add1"), new IOSpec("add2") },
  349. new List<EdgeSpec>
  350. {
  351. new EdgeSpec("feed1", "add1_0:0"),
  352. new EdgeSpec("feed2", "add1_0:1"),
  353. new EdgeSpec("add1_0:sum:0", "add2_0:0"),
  354. new EdgeSpec("feed3", "add2_0:1"),
  355. new EdgeSpec("add1_0:sum:0", "add1"),
  356. new EdgeSpec("add2_0:sum:0", "add2")
  357. },
  358. new List<EdgeSpec>());
  359. }
  360. void Define(int num_opers, Operation[] opers,
  361. Operation[] inputs, Operation[] outputs,
  362. string[] output_names, bool expect_failure = false)
  363. => DefineT(num_opers, opers,
  364. inputs.Select(x => new TF_Output(x, 0)).ToArray(),
  365. outputs.Select(x => new TF_Output(x, 0)).ToArray(),
  366. output_names, expect_failure);
  367. void DefineT(int num_opers, Operation[] opers,
  368. TF_Output[] inputs, TF_Output[] outputs,
  369. string[] output_names, bool expect_failure = false)
  370. {
  371. func_ = c_api.TF_GraphToFunction(func_graph_, func_name_, false,
  372. num_opers, num_opers == -1 ? null : opers.Select(x => (IntPtr)x).ToArray(),
  373. inputs.Length, inputs.ToArray(),
  374. outputs.Length, outputs.ToArray(),
  375. output_names == null || output_names.Length == 0 ? null : output_names,
  376. IntPtr.Zero, null, s_.Handle);
  377. if (expect_failure)
  378. {
  379. ASSERT_EQ(IntPtr.Zero, func_);
  380. return;
  381. }
  382. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  383. ASSERT_NE(func_, IntPtr.Zero);
  384. ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_)));
  385. c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle);
  386. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  387. }
  388. Operation Use(Operation[] inputs)
  389. => UseT(inputs.Select(x => new TF_Output(x, 0)).ToArray());
  390. Operation UseT(TF_Output[] inputs)
  391. => UseHelper(inputs);
  392. Operation UseHelper(TF_Output[] inputs)
  393. {
  394. var desc = TF_NewOperation(host_graph_, func_name_, func_node_name_);
  395. foreach (var input in inputs)
  396. TF_AddInput(desc, input);
  397. c_api.TF_SetDevice(desc, "/cpu:0");
  398. var op = TF_FinishOperation(desc, s_);
  399. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  400. ASSERT_NE(op, IntPtr.Zero);
  401. return op;
  402. }
  403. void Run(KeyValuePair<Operation, Tensor>[] inputs, Operation output, int expected_result)
  404. => Run(inputs, new[] { new TF_Output(output, 0) }, new[] { expected_result });
  405. unsafe void Run(KeyValuePair<Operation, Tensor>[] inputs, TF_Output[] outputs, int[] expected_results)
  406. {
  407. var csession = new CSession(host_graph_, s_);
  408. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  409. csession.SetInputs(inputs);
  410. csession.SetOutputs(outputs);
  411. csession.Run(s_);
  412. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  413. for (int i = 0; i < expected_results.Length; ++i)
  414. {
  415. var output = csession.output_tensor(i);
  416. ASSERT_TRUE(output != IntPtr.Zero);
  417. EXPECT_EQ(TF_DataType.TF_INT32, c_api.TF_TensorType(output));
  418. EXPECT_EQ(0, c_api.TF_NumDims(output));
  419. ASSERT_EQ(sizeof(int), (int)c_api.TF_TensorByteSize(output));
  420. var output_contents = c_api.TF_TensorData(output);
  421. EXPECT_EQ(expected_results[i], *(int*)output_contents.ToPointer());
  422. }
  423. }
  424. void VerifyFDef(string[] nodes, List<IOSpec> inputs, List<IOSpec> outputs,
  425. List<EdgeSpec> e_edges, List<EdgeSpec> c_edges,
  426. bool is_exact_edges = true)
  427. {
  428. var fdef = GetFunctionDef(func_);
  429. EXPECT_NE(fdef, IntPtr.Zero);
  430. VerifyFDefNodes(fdef, nodes);
  431. VerifyFDefInputs(fdef, inputs);
  432. VerifyFDefOutputs(fdef, outputs);
  433. VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
  434. }
  435. void VerifyFDefNodes(FunctionDef fdef, string[] nodes)
  436. {
  437. ASSERT_EQ(nodes.Length, fdef.NodeDef.Count);
  438. foreach (var node in fdef.NodeDef)
  439. {
  440. ASSERT_TRUE(nodes.Contains(node.Name), $"Got unexpected node: {node.Name} in fdef: {fdef}");
  441. }
  442. }
  443. void VerifyFDefInputs(FunctionDef fdef, List<IOSpec> inputs)
  444. {
  445. var signature = fdef.Signature;
  446. ASSERT_EQ(inputs.Count, signature.InputArg.Count);
  447. for (int i = 0; i < inputs.Count; ++i)
  448. {
  449. var arg = signature.InputArg[i];
  450. var input = inputs[i];
  451. if (input.Value != DataType.DtInvalid)
  452. ASSERT_EQ(arg.Type, input.Value, $"");
  453. ASSERT_EQ(arg.Name, input.Key, $"Got unexpected name for input {i}. fdef: {fdef}");
  454. }
  455. }
  456. void VerifyFDefOutputs(FunctionDef fdef, List<IOSpec> outputs)
  457. {
  458. var signature = fdef.Signature;
  459. ASSERT_EQ(outputs.Count, signature.OutputArg.Count);
  460. for (int i = 0; i < outputs.Count; ++i)
  461. {
  462. var arg = signature.OutputArg[i];
  463. var output = outputs[i];
  464. if (output.Value != DataType.DtInvalid)
  465. ASSERT_EQ(arg.Type, output.Value, $"");
  466. ASSERT_EQ(arg.Name, output.Key, $"Got unexpected name for input {i}. fdef: {fdef}");
  467. }
  468. }
  469. void VerifyFDefEdges(FunctionDef fdef, List<EdgeSpec> e_edges, List<EdgeSpec> c_edges, bool is_exact_edges = true)
  470. {
  471. // Build a set of edges from fdef
  472. var a_edges = new List<EdgeSpec>(); // actual edges
  473. // Get edges from inputs to body nodes and between body nodes
  474. foreach (var node in fdef.NodeDef)
  475. {
  476. for (int i = 0; i < node.Input.Count; ++i)
  477. {
  478. var input = node.Input[i];
  479. a_edges.Add(new EdgeSpec(input, $"{node.Name}:{i}"));
  480. }
  481. }
  482. // Get edges from body nodes to outputs and from inputs to outputs
  483. foreach (var arg in fdef.Signature.OutputArg)
  484. {
  485. var iter = fdef.Ret.FirstOrDefault(x => x.Key == arg.Name);
  486. if (iter.Key != null)
  487. {
  488. a_edges.Add(new EdgeSpec(iter.Value, arg.Name));
  489. }
  490. else
  491. {
  492. a_edges.Add(new EdgeSpec(arg.Name, arg.Name));
  493. }
  494. }
  495. // Verify edges
  496. foreach (var edge in e_edges)
  497. {
  498. ASSERT_TRUE(a_edges.Contains(edge));
  499. }
  500. foreach (var edge in c_edges)
  501. {
  502. ASSERT_TRUE(a_edges.Contains(edge));
  503. }
  504. // If caller specified all edges, check that we have seen all
  505. if (is_exact_edges)
  506. {
  507. ASSERT_EQ(e_edges.Count + c_edges.Count, a_edges.Count,
  508. $"Expected edges: {e_edges}, Expected Control edges: {c_edges}, Actual edges: {a_edges}");
  509. }
  510. }
  511. public void Dispose()
  512. {
  513. }
  514. public struct IOSpec
  515. {
  516. KeyValuePair<string, DataType> pair;
  517. public string Key => pair.Key;
  518. public DataType Value => pair.Value;
  519. public IOSpec(string key, DataType value = DataType.DtInvalid)
  520. {
  521. pair = new KeyValuePair<string, DataType>(key, value);
  522. }
  523. }
  524. public struct EdgeSpec
  525. {
  526. KeyValuePair<string, string> pair;
  527. public string Key => pair.Key;
  528. public string Value => pair.Value;
  529. public EdgeSpec(string key, string value)
  530. {
  531. pair = new KeyValuePair<string, string>(key, value);
  532. }
  533. }
  534. }
  535. }