You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

CApiFunctionTest.cs 24 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. using Microsoft.VisualStudio.TestTools.UnitTesting;
  2. using System;
  3. using System.Collections.Generic;
  4. using System.Linq;
  5. using System.Reflection;
  6. using System.Runtime.InteropServices;
  7. using Tensorflow;
  8. using Tensorflow.Functions;
  9. using static TensorFlowNET.UnitTest.c_test_util;
  10. namespace TensorFlowNET.UnitTest.NativeAPI
  11. {
  12. /// <summary>
  13. /// tensorflow\c\c_api_function_test.cc
  14. /// `class CApiColocationTest`
  15. /// </summary>
  16. [TestClass]
  17. public class CApiFunctionTest : CApiTest, IDisposable
  18. {
  19. Graph func_graph_;
  20. Graph host_graph_;
  21. string func_name_ = "MyFunc";
  22. string func_node_name_ = "MyFunc_0";
  23. Status s_;
  24. IntPtr func_;
  25. [TestInitialize]
  26. public void Initialize()
  27. {
  28. func_graph_ = new Graph();
  29. host_graph_ = new Graph();
  30. s_ = new Status();
  31. }
  32. [TestMethod]
  33. public void OneOp_ZeroInputs_OneOutput()
  34. {
  35. var c = ScalarConst(10, func_graph_, s_, "scalar10");
  36. // Define
  37. Define(-1, new Operation[0], new Operation[0], new[] { c }, new string[0]);
  38. // Use, run, and verify
  39. var func_op = Use(new Operation[0]);
  40. Run(new KeyValuePair<Operation, Tensor>[0], func_op, 10);
  41. VerifyFDef(new[] { "scalar10_0" },
  42. new List<IOSpec>(),
  43. new List<IOSpec> { new IOSpec("scalar10", DataType.DtInt32) },
  44. new List<EdgeSpec> { new EdgeSpec("scalar10_0:output:0", "scalar10") },
  45. new List<EdgeSpec>());
  46. }
  47. [TestMethod]
  48. public void OneOp_OneInput_OneOutput()
  49. {
  50. // Define
  51. var feed = Placeholder(func_graph_, s_);
  52. var neg = Neg(feed, func_graph_, s_);
  53. Define(-1, new Operation[0], new[] { feed }, new[] { neg }, new string[0]);
  54. // Use, run, and verify
  55. var func_feed = Placeholder(host_graph_, s_);
  56. var func_op = Use(new[] { func_feed });
  57. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, -3);
  58. VerifyFDef(new string[] { "neg_0" },
  59. new List<IOSpec> { new IOSpec("feed", DataType.DtInt32) },
  60. new List<IOSpec> { new IOSpec("neg", DataType.DtInt32) },
  61. new List<EdgeSpec> { new EdgeSpec("feed", "neg_0:0"), new EdgeSpec("neg_0:y:0", "neg") },
  62. new List<EdgeSpec>());
  63. }
  64. [TestMethod]
  65. public void OneOutput_OutputNames()
  66. {
  67. // Define
  68. var feed = Placeholder(func_graph_, s_);
  69. var neg = Neg(feed, func_graph_, s_);
  70. Define(-1,
  71. new Operation[0],
  72. new[] { feed },
  73. new[] { neg },
  74. new[] { "negated_num" });
  75. // Use, run, and verify
  76. var func_feed = Placeholder(host_graph_, s_);
  77. var func_op = Use(new[] { func_feed });
  78. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, -3);
  79. VerifyFDef(new string[] { "neg" },
  80. new List<IOSpec> { new IOSpec("feed", DataType.DtInt32) },
  81. new List<IOSpec> { new IOSpec("negated_num", DataType.DtInt32) },
  82. new List<EdgeSpec> { new EdgeSpec("feed", "neg:0"), new EdgeSpec("neg:y:0", "negated_num") },
  83. new List<EdgeSpec>());
  84. }
  85. [TestMethod]
  86. public void OutputNames_SameNameAsInput()
  87. {
  88. // Define
  89. var feed = Placeholder(func_graph_, s_, "negation");
  90. var neg = Neg(feed, func_graph_, s_, "neg");
  91. Define(-1,
  92. new Operation[0],
  93. new[] { feed },
  94. new[] { neg },
  95. new[] { "negation" });
  96. // Use, run, and verify
  97. var func_feed = Placeholder(host_graph_, s_);
  98. var func_op = Use(new[] { func_feed });
  99. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, -3);
  100. VerifyFDef(new string[] { "neg" },
  101. new List<IOSpec> { new IOSpec("negation_0", DataType.DtInt32) },
  102. new List<IOSpec> { new IOSpec("negation", DataType.DtInt32) },
  103. new List<EdgeSpec> { new EdgeSpec("negation_0", "neg:0"), new EdgeSpec("neg:y:0", "negation") },
  104. new List<EdgeSpec>());
  105. }
  106. [TestMethod]
  107. public void ZeroOps_Identity()
  108. {
  109. // Define
  110. var feed = Placeholder(func_graph_, s_);
  111. Define(-1,
  112. new Operation[0],
  113. new[] { feed },
  114. new[] { feed },
  115. new string[0]);
  116. // Use, run, and verify
  117. var func_feed = Placeholder(host_graph_, s_);
  118. var func_op = Use(new[] { func_feed });
  119. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) }, func_op, 3);
  120. VerifyFDef(new string[0],
  121. new List<IOSpec> { new IOSpec("feed_0", DataType.DtInt32) },
  122. new List<IOSpec> { new IOSpec("feed", DataType.DtInt32) },
  123. new List<EdgeSpec> { new EdgeSpec("feed_0", "feed") },
  124. new List<EdgeSpec>());
  125. }
  126. [TestMethod]
  127. public void ZeroOps_Permutation()
  128. {
  129. // Define
  130. var feed1 = Placeholder(func_graph_, s_, "feed1");
  131. var feed2 = Placeholder(func_graph_, s_, "feed2");
  132. Define(-1,
  133. null,
  134. new[] { feed1, feed2 },
  135. new[] { feed2, feed1 },
  136. null);
  137. // Use, run, and verify
  138. var two = ScalarConst(2, host_graph_, s_);
  139. var func_feed = Placeholder(host_graph_, s_);
  140. var func_op = Use(new[] { two, func_feed });
  141. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  142. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  143. new[] { 3, 2 });
  144. VerifyFDef(new string[0],
  145. new List<IOSpec> { new IOSpec("feed1_0"), new IOSpec("feed2_0") },
  146. new List<IOSpec> { new IOSpec("feed2"), new IOSpec("feed1") },
  147. new List<EdgeSpec> { new EdgeSpec("feed1_0", "feed1"), new EdgeSpec("feed2_0", "feed2") },
  148. new List<EdgeSpec>());
  149. }
  150. [TestMethod]
  151. public void ZeroOps_Permutation_OutputNames()
  152. {
  153. // Define
  154. var feed1 = Placeholder(func_graph_, s_, "feed1");
  155. var feed2 = Placeholder(func_graph_, s_, "feed2");
  156. Define(-1,
  157. null,
  158. new[] { feed1, feed2 },
  159. new[] { feed2, feed1 },
  160. new[] { "first", "second" });
  161. // Use, run, and verify
  162. var two = ScalarConst(2, host_graph_, s_);
  163. var func_feed = Placeholder(host_graph_, s_);
  164. var func_op = Use(new[] { two, func_feed });
  165. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  166. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  167. new[] { 3, 2 });
  168. VerifyFDef(new string[0],
  169. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  170. new List<IOSpec> { new IOSpec("first"), new IOSpec("second") },
  171. new List<EdgeSpec> { new EdgeSpec("feed1", "second"), new EdgeSpec("feed2", "first") },
  172. new List<EdgeSpec>());
  173. }
  174. [TestMethod]
  175. public void OneOp_TwoInputs_OneOutput()
  176. {
  177. // Define
  178. var feed1 = Placeholder(func_graph_, s_, "feed1");
  179. var feed2 = Placeholder(func_graph_, s_, "feed2");
  180. var add = Add(feed1, feed2, func_graph_, s_);
  181. Define(-1,
  182. null,
  183. new[] { feed1, feed2 },
  184. new[] { add },
  185. null);
  186. // Use, run, and verify
  187. var two = ScalarConst(2, host_graph_, s_);
  188. var func_feed = Placeholder(host_graph_, s_);
  189. var func_op = Use(new[] { two, func_feed });
  190. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  191. func_op,
  192. 2 + 3);
  193. VerifyFDef(new string[] { "add_0" },
  194. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  195. new List<IOSpec> { new IOSpec("add") },
  196. new List<EdgeSpec>
  197. {
  198. new EdgeSpec("feed1", "add_0:0"),
  199. new EdgeSpec("feed2", "add_0:1"),
  200. new EdgeSpec("add_0:sum:0", "add")
  201. },
  202. new List<EdgeSpec>());
  203. }
  204. [TestMethod]
  205. public void OneOp_TwoInputs_ZeroOutputs()
  206. {
  207. // Define
  208. var feed1 = Placeholder(func_graph_, s_, "feed1");
  209. var feed2 = Placeholder(func_graph_, s_, "feed2");
  210. var add = Add(feed1, feed2, func_graph_, s_);
  211. Define(-1,
  212. null,
  213. new[] { feed1, feed2 },
  214. new Operation[0],
  215. null);
  216. // Use, run, and verify
  217. var two = ScalarConst(2, host_graph_, s_);
  218. var func_feed = Placeholder(host_graph_, s_);
  219. var func_op = Use(new[] { two, func_feed });
  220. VerifyFDef(new string[] { "add" },
  221. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  222. new List<IOSpec>(),
  223. new List<EdgeSpec>
  224. {
  225. new EdgeSpec("feed1", "add:0"),
  226. new EdgeSpec("feed2", "add:1")
  227. },
  228. new List<EdgeSpec>());
  229. }
  230. [TestMethod]
  231. public void TwoOps_ThreeInputs_OneOutput()
  232. {
  233. // Define
  234. var feed1 = Placeholder(func_graph_, s_, "feed1");
  235. var feed2 = Placeholder(func_graph_, s_, "feed2");
  236. var feed3 = Placeholder(func_graph_, s_, "feed3");
  237. var add1 = Add(feed1, feed2, func_graph_, s_, "add1");
  238. var add2 = Add(add1, feed3, func_graph_, s_, "add2");
  239. Define(-1,
  240. null,
  241. new[] { feed1, feed2, feed3 },
  242. new[] { add2 },
  243. null);
  244. // Use, run, and verify
  245. var two = ScalarConst(2, host_graph_, s_, "two");
  246. var ten = ScalarConst(10, host_graph_, s_, "ten");
  247. var func_feed = Placeholder(host_graph_, s_);
  248. var func_op = Use(new[] { two, ten, func_feed });
  249. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  250. func_op,
  251. 2 + 10 + 3);
  252. VerifyFDef(new string[] { "add1", "add2_0" },
  253. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2"), new IOSpec("feed3") },
  254. new List<IOSpec> { new IOSpec("add2") },
  255. new List<EdgeSpec>
  256. {
  257. new EdgeSpec("feed1", "add1:0"),
  258. new EdgeSpec("feed2", "add1:1"),
  259. new EdgeSpec("add1:sum:0", "add2_0:0"),
  260. new EdgeSpec("feed3", "add2_0:1"),
  261. new EdgeSpec("add2_0:sum:0", "add2"),
  262. },
  263. new List<EdgeSpec>());
  264. }
  265. [TestMethod]
  266. public void OneOp_TwoInputs_TwoDuplicateOutputs()
  267. {
  268. // Define
  269. var feed1 = Placeholder(func_graph_, s_, "feed1");
  270. var feed2 = Placeholder(func_graph_, s_, "feed2");
  271. var add = Add(feed1, feed2, func_graph_, s_);
  272. Define(-1,
  273. null,
  274. new[] { feed1, feed2 },
  275. new[] { add, add },
  276. null);
  277. // Use, run, and verify
  278. var two = ScalarConst(2, host_graph_, s_);
  279. var func_feed = Placeholder(host_graph_, s_);
  280. var func_op = Use(new[] { two, func_feed });
  281. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  282. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  283. new[] { 5, 5 });
  284. VerifyFDef(new string[] { "add_1" },
  285. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  286. new List<IOSpec> { new IOSpec("add"), new IOSpec("add_0") },
  287. new List<EdgeSpec>
  288. {
  289. new EdgeSpec("feed1", "add_1:0"),
  290. new EdgeSpec("feed2", "add_1:1"),
  291. new EdgeSpec("add_1:sum:0", "add"),
  292. new EdgeSpec("add_1:sum:0", "add_0")
  293. },
  294. new List<EdgeSpec>());
  295. }
  296. [TestMethod]
  297. public void TwoDuplicateOutputs_OutputNames()
  298. {
  299. // Define
  300. var feed1 = Placeholder(func_graph_, s_, "feed1");
  301. var feed2 = Placeholder(func_graph_, s_, "feed2");
  302. var add = Add(feed1, feed2, func_graph_, s_);
  303. Define(-1,
  304. null,
  305. new[] { feed1, feed2 },
  306. new[] { add, add },
  307. new[] { "out1", "out2" });
  308. // Use, run, and verify
  309. var two = ScalarConst(2, host_graph_, s_);
  310. var func_feed = Placeholder(host_graph_, s_);
  311. var func_op = Use(new[] { two, func_feed });
  312. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  313. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  314. new[] { 5, 5 });
  315. VerifyFDef(new string[] { "add" },
  316. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2") },
  317. new List<IOSpec> { new IOSpec("out1"), new IOSpec("out2") },
  318. new List<EdgeSpec>
  319. {
  320. new EdgeSpec("feed1", "add:0"),
  321. new EdgeSpec("feed2", "add:1"),
  322. new EdgeSpec("add:sum:0", "out1"),
  323. new EdgeSpec("add:sum:0", "out2")
  324. },
  325. new List<EdgeSpec>());
  326. }
  327. [TestMethod]
  328. public void TwoOps_ThreeInputs_TwoOutputs()
  329. {
  330. // Define
  331. var feed1 = Placeholder(func_graph_, s_, "feed1");
  332. var feed2 = Placeholder(func_graph_, s_, "feed2");
  333. var feed3 = Placeholder(func_graph_, s_, "feed3");
  334. var add1 = Add(feed1, feed2, func_graph_, s_, "add1");
  335. var add2 = Add(add1, feed3, func_graph_, s_, "add2");
  336. Define(-1,
  337. null,
  338. new[] { feed1, feed2, feed3 },
  339. new[] { add1, add2 },
  340. null);
  341. // Use, run, and verify
  342. var two = ScalarConst(2, host_graph_, s_, "two");
  343. var ten = ScalarConst(10, host_graph_, s_, "ten");
  344. var func_feed = Placeholder(host_graph_, s_);
  345. var func_op = Use(new[] { two, ten, func_feed });
  346. Run(new[] { new KeyValuePair<Operation, Tensor>(func_feed, Int32Tensor(3)) },
  347. new[] { new TF_Output(func_op, 0), new TF_Output(func_op, 1) },
  348. new[] { 12, 15 });
  349. VerifyFDef(new string[] { "add1_0", "add2_0" },
  350. new List<IOSpec> { new IOSpec("feed1"), new IOSpec("feed2"), new IOSpec("feed3") },
  351. new List<IOSpec> { new IOSpec("add1"), new IOSpec("add2") },
  352. new List<EdgeSpec>
  353. {
  354. new EdgeSpec("feed1", "add1_0:0"),
  355. new EdgeSpec("feed2", "add1_0:1"),
  356. new EdgeSpec("add1_0:sum:0", "add2_0:0"),
  357. new EdgeSpec("feed3", "add2_0:1"),
  358. new EdgeSpec("add1_0:sum:0", "add1"),
  359. new EdgeSpec("add2_0:sum:0", "add2")
  360. },
  361. new List<EdgeSpec>());
  362. }
  363. void Define(int num_opers, Operation[] opers,
  364. Operation[] inputs, Operation[] outputs,
  365. string[] output_names, bool expect_failure = false)
  366. => DefineT(num_opers, opers,
  367. inputs.Select(x => new TF_Output(x, 0)).ToArray(),
  368. outputs.Select(x => new TF_Output(x, 0)).ToArray(),
  369. output_names, expect_failure);
  370. void DefineT(int num_opers, Operation[] opers,
  371. TF_Output[] inputs, TF_Output[] outputs,
  372. string[] output_names, bool expect_failure = false)
  373. {
  374. func_ = c_api.TF_GraphToFunction(func_graph_, func_name_, false,
  375. num_opers, num_opers == -1 ? null : opers.Select(x => (IntPtr)x).ToArray(),
  376. inputs.Length, inputs.ToArray(),
  377. outputs.Length, outputs.ToArray(),
  378. output_names == null || output_names.Length == 0 ? null : output_names,
  379. IntPtr.Zero, null, s_.Handle);
  380. if (expect_failure)
  381. {
  382. ASSERT_EQ(IntPtr.Zero, func_);
  383. return;
  384. }
  385. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  386. ASSERT_NE(func_, IntPtr.Zero);
  387. ASSERT_EQ(func_name_, c_api.StringPiece(c_api.TF_FunctionName(func_)));
  388. c_api.TF_GraphCopyFunction(host_graph_, func_, IntPtr.Zero, s_.Handle);
  389. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  390. }
  391. Operation Use(Operation[] inputs)
  392. => UseT(inputs.Select(x => new TF_Output(x, 0)).ToArray());
  393. Operation UseT(TF_Output[] inputs)
  394. => UseHelper(inputs);
  395. Operation UseHelper(TF_Output[] inputs)
  396. {
  397. var desc = TF_NewOperation(host_graph_, func_name_, func_node_name_);
  398. foreach (var input in inputs)
  399. TF_AddInput(desc, input);
  400. c_api.TF_SetDevice(desc, "/cpu:0");
  401. var op = TF_FinishOperation(desc, s_);
  402. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  403. ASSERT_NE(op, IntPtr.Zero);
  404. return op;
  405. }
  406. void Run(KeyValuePair<Operation, Tensor>[] inputs, Operation output, int expected_result)
  407. => Run(inputs, new[] { new TF_Output(output, 0) }, new[] { expected_result });
  408. unsafe void Run(KeyValuePair<Operation, Tensor>[] inputs, TF_Output[] outputs, int[] expected_results)
  409. {
  410. var csession = new CSession(host_graph_, s_);
  411. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  412. csession.SetInputs(inputs);
  413. csession.SetOutputs(outputs);
  414. csession.Run(s_);
  415. ASSERT_EQ(TF_OK, s_.Code, s_.Message);
  416. for (int i = 0; i < expected_results.Length; ++i)
  417. {
  418. var output = csession.output_tensor(i);
  419. ASSERT_TRUE(output != IntPtr.Zero);
  420. EXPECT_EQ(TF_DataType.TF_INT32, c_api.TF_TensorType(output));
  421. EXPECT_EQ(0, c_api.TF_NumDims(output));
  422. ASSERT_EQ(sizeof(int), (int)c_api.TF_TensorByteSize(output));
  423. var output_contents = c_api.TF_TensorData(output);
  424. EXPECT_EQ(expected_results[i], *(int*)output_contents.ToPointer());
  425. }
  426. }
  427. void VerifyFDef(string[] nodes, List<IOSpec> inputs, List<IOSpec> outputs,
  428. List<EdgeSpec> e_edges, List<EdgeSpec> c_edges,
  429. bool is_exact_edges = true)
  430. {
  431. var fdef = GetFunctionDef(func_);
  432. EXPECT_NE(fdef, IntPtr.Zero);
  433. VerifyFDefNodes(fdef, nodes);
  434. VerifyFDefInputs(fdef, inputs);
  435. VerifyFDefOutputs(fdef, outputs);
  436. VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
  437. }
  438. void VerifyFDefNodes(FunctionDef fdef, string[] nodes)
  439. {
  440. ASSERT_EQ(nodes.Length, fdef.NodeDef.Count);
  441. foreach(var node in fdef.NodeDef)
  442. {
  443. ASSERT_TRUE(nodes.Contains(node.Name), $"Got unexpected node: {node.Name} in fdef: {fdef}");
  444. }
  445. }
  446. void VerifyFDefInputs(FunctionDef fdef, List<IOSpec> inputs)
  447. {
  448. var signature = fdef.Signature;
  449. ASSERT_EQ(inputs.Count, signature.InputArg.Count);
  450. for (int i = 0; i < inputs.Count; ++i)
  451. {
  452. var arg = signature.InputArg[i];
  453. var input = inputs[i];
  454. if (input.Value != DataType.DtInvalid)
  455. ASSERT_EQ(arg.Type, input.Value, $"");
  456. ASSERT_EQ(arg.Name, input.Key, $"Got unexpected name for input {i}. fdef: {fdef}");
  457. }
  458. }
  459. void VerifyFDefOutputs(FunctionDef fdef, List<IOSpec> outputs)
  460. {
  461. var signature = fdef.Signature;
  462. ASSERT_EQ(outputs.Count, signature.OutputArg.Count);
  463. for (int i = 0; i < outputs.Count; ++i)
  464. {
  465. var arg = signature.OutputArg[i];
  466. var output = outputs[i];
  467. if (output.Value != DataType.DtInvalid)
  468. ASSERT_EQ(arg.Type, output.Value, $"");
  469. ASSERT_EQ(arg.Name, output.Key, $"Got unexpected name for input {i}. fdef: {fdef}");
  470. }
  471. }
  472. void VerifyFDefEdges(FunctionDef fdef, List<EdgeSpec> e_edges, List<EdgeSpec> c_edges, bool is_exact_edges = true)
  473. {
  474. // Build a set of edges from fdef
  475. var a_edges = new List<EdgeSpec>(); // actual edges
  476. // Get edges from inputs to body nodes and between body nodes
  477. foreach(var node in fdef.NodeDef)
  478. {
  479. for (int i = 0; i < node.Input.Count; ++i)
  480. {
  481. var input = node.Input[i];
  482. a_edges.Add(new EdgeSpec(input, $"{node.Name}:{i}"));
  483. }
  484. }
  485. // Get edges from body nodes to outputs and from inputs to outputs
  486. foreach(var arg in fdef.Signature.OutputArg)
  487. {
  488. var iter = fdef.Ret.FirstOrDefault(x => x.Key == arg.Name);
  489. if(iter.Key != null)
  490. {
  491. a_edges.Add(new EdgeSpec(iter.Value, arg.Name));
  492. }
  493. else
  494. {
  495. a_edges.Add(new EdgeSpec(arg.Name, arg.Name));
  496. }
  497. }
  498. // Verify edges
  499. foreach(var edge in e_edges)
  500. {
  501. ASSERT_TRUE(a_edges.Contains(edge));
  502. }
  503. foreach (var edge in c_edges)
  504. {
  505. ASSERT_TRUE(a_edges.Contains(edge));
  506. }
  507. // If caller specified all edges, check that we have seen all
  508. if (is_exact_edges)
  509. {
  510. ASSERT_EQ(e_edges.Count + c_edges.Count, a_edges.Count,
  511. $"Expected edges: {e_edges}, Expected Control edges: {c_edges}, Actual edges: {a_edges}");
  512. }
  513. }
  514. public void Dispose()
  515. {
  516. }
  517. public struct IOSpec
  518. {
  519. KeyValuePair<string, DataType> pair;
  520. public string Key => pair.Key;
  521. public DataType Value => pair.Value;
  522. public IOSpec(string key, DataType value = DataType.DtInvalid)
  523. {
  524. pair = new KeyValuePair<string, DataType>(key, value);
  525. }
  526. }
  527. public struct EdgeSpec
  528. {
  529. KeyValuePair<string, string> pair;
  530. public string Key => pair.Key;
  531. public string Value => pair.Value;
  532. public EdgeSpec(string key, string value)
  533. {
  534. pair = new KeyValuePair<string, string>(key, value);
  535. }
  536. }
  537. }
  538. }