You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_op_pass_unittest.cc 22 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "omg/omg_inner_types.h"
  18. #define protected public
  19. #define private public
  20. #include "graph/passes/switch_op_pass.h"
  21. #include "common/debug/log.h"
  22. #include "common/debug/memory_dumper.h"
  23. #include "common/op/attr_value_util.h"
  24. #include "common/types.h"
  25. #include "graph/debug/ge_attr_define.h"
  26. #include "graph/graph.h"
  27. #include "graph/passes/control_op_attr_pass.h"
  28. #include "inc/pass_manager.h"
  29. #undef protected
  30. #undef private
  31. using namespace testing;
  32. using namespace ge;
  33. class UtestGraphPassesSwitchOpPass : public testing::Test {
  34. protected:
  35. void SetUp() {}
  36. void TearDown() {}
  37. public:
  38. void make_graph(ComputeGraphPtr graph, bool match = true) {
  39. GeTensorDesc bool_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL);
  40. GeTensorDesc int_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
  41. GeTensorDesc scalar_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
  42. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  43. xOpDef->AddOutputDesc(scalar_tensor_desc);
  44. auto xNode = graph->AddNode(xOpDef);
  45. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  46. yOpDef->AddOutputDesc(scalar_tensor_desc);
  47. auto yNode = graph->AddNode(yOpDef);
  48. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  49. zOpDef->AddOutputDesc(scalar_tensor_desc);
  50. auto zNode = graph->AddNode(zOpDef);
  51. auto condOpDef = std::make_shared<OpDesc>("Less", "Less");
  52. condOpDef->AddInputDesc(scalar_tensor_desc);
  53. condOpDef->AddInputDesc(scalar_tensor_desc);
  54. condOpDef->AddOutputDesc(bool_tensor_desc);
  55. auto condNode = graph->AddNode(condOpDef);
  56. auto switch_op_def1 = std::make_shared<OpDesc>("Add/Switch", SWITCH);
  57. switch_op_def1->AddInputDesc(scalar_tensor_desc);
  58. switch_op_def1->AddInputDesc(bool_tensor_desc);
  59. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  60. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  61. auto switch_node1 = graph->AddNode(switch_op_def1);
  62. auto switch_op_def2 = std::make_shared<OpDesc>("Add/Switch_1", SWITCH);
  63. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  64. switch_op_def2->AddInputDesc(bool_tensor_desc);
  65. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  66. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  67. auto switch_node2 = graph->AddNode(switch_op_def2);
  68. auto switch_op_def3 = std::make_shared<OpDesc>("Square/Switch", SWITCH);
  69. switch_op_def3->AddInputDesc(scalar_tensor_desc);
  70. switch_op_def3->AddInputDesc(bool_tensor_desc);
  71. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  72. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  73. auto switch_node3 = graph->AddNode(switch_op_def3);
  74. auto addOpDef = std::make_shared<OpDesc>("Add", "ADD");
  75. addOpDef->AddInputDesc(scalar_tensor_desc);
  76. addOpDef->AddInputDesc(scalar_tensor_desc);
  77. addOpDef->AddOutputDesc(scalar_tensor_desc);
  78. auto addNode = graph->AddNode(addOpDef);
  79. auto mergeOpDef = std::make_shared<OpDesc>("Merge", "Merge");
  80. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  81. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  82. mergeOpDef->AddOutputDesc(scalar_tensor_desc);
  83. mergeOpDef->AddOutputDesc(int_tensor_desc);
  84. auto mergeNode = graph->AddNode(mergeOpDef);
  85. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  86. output_op_def->AddInputDesc(scalar_tensor_desc);
  87. output_op_def->AddOutputDesc(scalar_tensor_desc);
  88. auto output_node = graph->AddNode(output_op_def);
  89. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  90. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  91. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  92. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  93. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  94. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  95. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(0));
  96. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(1));
  97. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(1), addNode->GetInDataAnchor(0));
  98. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(1), addNode->GetInDataAnchor(1));
  99. (void)GraphUtils::AddEdge(addNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(1));
  100. (void)GraphUtils::AddEdge(switch_node3->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(0));
  101. (void)GraphUtils::AddEdge(mergeNode->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  102. }
  103. void make_graph_const(ComputeGraphPtr graph, bool match = true) {
  104. // resnet50 PolynomialDecay
  105. GeTensorDesc scalar_tensor_desc(GeShape({1, 1, 1, 1}));
  106. GeTensorDesc bool_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  107. GeTensorDesc int_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  108. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  109. xOpDef->AddOutputDesc(scalar_tensor_desc);
  110. auto xNode = graph->AddNode(xOpDef);
  111. auto yOpDef = std::make_shared<OpDesc>("y", "Const");
  112. yOpDef->AddOutputDesc(scalar_tensor_desc);
  113. auto yNode = graph->AddNode(yOpDef);
  114. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  115. zOpDef->AddOutputDesc(scalar_tensor_desc);
  116. auto zNode = graph->AddNode(zOpDef);
  117. auto constOpDef = std::make_shared<OpDesc>("Const", "Const");
  118. constOpDef->AddOutputDesc(scalar_tensor_desc);
  119. auto constNode = graph->AddNode(constOpDef);
  120. auto condOpDef = std::make_shared<OpDesc>("Equal", "Equal");
  121. condOpDef->AddInputDesc(scalar_tensor_desc);
  122. condOpDef->AddInputDesc(scalar_tensor_desc);
  123. condOpDef->AddOutputDesc(bool_tensor_desc);
  124. auto condNode = graph->AddNode(condOpDef);
  125. auto identityOpDef = std::make_shared<OpDesc>("identity", "Identity");
  126. identityOpDef->AddInputDesc(bool_tensor_desc);
  127. identityOpDef->AddOutputDesc(bool_tensor_desc);
  128. auto identityNode = graph->AddNode(identityOpDef);
  129. auto switch_op_def1 = std::make_shared<OpDesc>("Switch", SWITCH);
  130. switch_op_def1->AddInputDesc(bool_tensor_desc);
  131. switch_op_def1->AddInputDesc(bool_tensor_desc);
  132. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  133. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  134. auto switch_node1 = graph->AddNode(switch_op_def1);
  135. auto tIdentityOpDef = std::make_shared<OpDesc>("switch_t", "Identity");
  136. tIdentityOpDef->AddInputDesc(scalar_tensor_desc);
  137. tIdentityOpDef->AddOutputDesc(scalar_tensor_desc);
  138. auto tIdentityNode = graph->AddNode(tIdentityOpDef);
  139. auto fIdentityOpDef = std::make_shared<OpDesc>("switch_f", "Identity");
  140. fIdentityOpDef->AddInputDesc(scalar_tensor_desc);
  141. fIdentityOpDef->AddOutputDesc(scalar_tensor_desc);
  142. auto fIdentityNode = graph->AddNode(fIdentityOpDef);
  143. auto switch_op_def2 = std::make_shared<OpDesc>("Switch_1", SWITCH);
  144. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  145. switch_op_def2->AddInputDesc(bool_tensor_desc);
  146. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  147. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  148. auto switch_node2 = graph->AddNode(switch_op_def2);
  149. auto mulOpDef = std::make_shared<OpDesc>("truediv", "Mul");
  150. mulOpDef->AddInputDesc(scalar_tensor_desc);
  151. mulOpDef->AddInputDesc(scalar_tensor_desc);
  152. mulOpDef->AddOutputDesc(scalar_tensor_desc);
  153. auto mulNode = graph->AddNode(mulOpDef);
  154. auto ceilOpDef = std::make_shared<OpDesc>("Ceil", "Ceil");
  155. ceilOpDef->AddInputDesc(scalar_tensor_desc);
  156. ceilOpDef->AddOutputDesc(scalar_tensor_desc);
  157. auto ceilNode = graph->AddNode(ceilOpDef);
  158. auto mergeOpDef = std::make_shared<OpDesc>("Merge", "Merge");
  159. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  160. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  161. mergeOpDef->AddOutputDesc(scalar_tensor_desc);
  162. mergeOpDef->AddOutputDesc(int_tensor_desc);
  163. auto mergeNode = graph->AddNode(mergeOpDef);
  164. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  165. output_op_def->AddInputDesc(scalar_tensor_desc);
  166. output_op_def->AddOutputDesc(scalar_tensor_desc);
  167. auto output_node = graph->AddNode(output_op_def);
  168. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  169. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  170. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), identityNode->GetInDataAnchor(0));
  171. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  172. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  173. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(0), fIdentityNode->GetInDataAnchor(0));
  174. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(1), tIdentityNode->GetInDataAnchor(0));
  175. (void)GraphUtils::AddEdge(fIdentityNode->GetOutControlAnchor(), zNode->GetInControlAnchor());
  176. (void)GraphUtils::AddEdge(tIdentityNode->GetOutControlAnchor(), constNode->GetInControlAnchor());
  177. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  178. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  179. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), mulNode->GetInDataAnchor(0));
  180. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(0), mulNode->GetInDataAnchor(1));
  181. (void)GraphUtils::AddEdge(mulNode->GetOutDataAnchor(0), ceilNode->GetInDataAnchor(0));
  182. (void)GraphUtils::AddEdge(constNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(1));
  183. (void)GraphUtils::AddEdge(ceilNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(0));
  184. (void)GraphUtils::AddEdge(mergeNode->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  185. }
  186. void make_graph_cyclic_dependence(ComputeGraphPtr graph, bool match = true) {
  187. GeTensorDesc scalar_tensor_desc(GeShape({1, 1, 1, 1}));
  188. GeTensorDesc bool_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  189. GeTensorDesc int_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  190. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  191. xOpDef->AddOutputDesc(scalar_tensor_desc);
  192. auto xNode = graph->AddNode(xOpDef);
  193. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  194. yOpDef->AddOutputDesc(scalar_tensor_desc);
  195. auto yNode = graph->AddNode(yOpDef);
  196. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  197. zOpDef->AddOutputDesc(scalar_tensor_desc);
  198. auto zNode = graph->AddNode(zOpDef);
  199. auto condOpDef = std::make_shared<OpDesc>("Less", "Less");
  200. condOpDef->AddInputDesc(scalar_tensor_desc);
  201. condOpDef->AddInputDesc(scalar_tensor_desc);
  202. condOpDef->AddOutputDesc(bool_tensor_desc);
  203. auto condNode = graph->AddNode(condOpDef);
  204. auto switch_op_def1 = std::make_shared<OpDesc>("Switch_f_1", SWITCH);
  205. switch_op_def1->AddInputDesc(scalar_tensor_desc);
  206. switch_op_def1->AddInputDesc(bool_tensor_desc);
  207. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  208. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  209. auto switch_node1 = graph->AddNode(switch_op_def1);
  210. auto switch_op_def2 = std::make_shared<OpDesc>("Switch_t_1", SWITCH);
  211. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  212. switch_op_def2->AddInputDesc(bool_tensor_desc);
  213. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  214. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  215. auto switch_node2 = graph->AddNode(switch_op_def2);
  216. auto switch_op_def3 = std::make_shared<OpDesc>("Switch_f_2", SWITCH);
  217. switch_op_def3->AddInputDesc(scalar_tensor_desc);
  218. switch_op_def3->AddInputDesc(bool_tensor_desc);
  219. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  220. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  221. auto switch_node3 = graph->AddNode(switch_op_def3);
  222. auto switch_op_def4 = std::make_shared<OpDesc>("Switch_t_2", SWITCH);
  223. switch_op_def4->AddInputDesc(scalar_tensor_desc);
  224. switch_op_def4->AddInputDesc(bool_tensor_desc);
  225. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  226. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  227. auto switch_node4 = graph->AddNode(switch_op_def4);
  228. auto squareOpDef1 = std::make_shared<OpDesc>("Square1", "Square");
  229. squareOpDef1->AddInputDesc(scalar_tensor_desc);
  230. squareOpDef1->AddOutputDesc(scalar_tensor_desc);
  231. auto squareNode1 = graph->AddNode(squareOpDef1);
  232. auto squareOpDef2 = std::make_shared<OpDesc>("Square2", "Square");
  233. squareOpDef2->AddInputDesc(scalar_tensor_desc);
  234. squareOpDef2->AddOutputDesc(scalar_tensor_desc);
  235. auto squareNode2 = graph->AddNode(squareOpDef2);
  236. auto squareOpDef3 = std::make_shared<OpDesc>("Square3", "Square");
  237. squareOpDef3->AddInputDesc(scalar_tensor_desc);
  238. squareOpDef3->AddOutputDesc(scalar_tensor_desc);
  239. auto squareNode3 = graph->AddNode(squareOpDef3);
  240. auto squareOpDef4 = std::make_shared<OpDesc>("Square4", "Square");
  241. squareOpDef4->AddInputDesc(scalar_tensor_desc);
  242. squareOpDef4->AddOutputDesc(scalar_tensor_desc);
  243. auto squareNode4 = graph->AddNode(squareOpDef4);
  244. auto merge_op_def1 = std::make_shared<OpDesc>("Merge1", "Merge");
  245. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  246. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  247. merge_op_def1->AddOutputDesc(scalar_tensor_desc);
  248. merge_op_def1->AddOutputDesc(int_tensor_desc);
  249. auto merge_node1 = graph->AddNode(merge_op_def1);
  250. auto merge_op_def2 = std::make_shared<OpDesc>("Merge2", "Merge");
  251. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  252. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  253. merge_op_def2->AddOutputDesc(scalar_tensor_desc);
  254. merge_op_def2->AddOutputDesc(int_tensor_desc);
  255. auto merge_node2 = graph->AddNode(merge_op_def2);
  256. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  257. output_op_def->AddInputDesc(scalar_tensor_desc);
  258. output_op_def->AddOutputDesc(scalar_tensor_desc);
  259. auto output_node = graph->AddNode(output_op_def);
  260. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  261. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  262. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  263. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  264. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  265. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  266. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(0), squareNode1->GetInDataAnchor(0));
  267. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(1), squareNode2->GetInDataAnchor(0));
  268. (void)GraphUtils::AddEdge(squareNode1->GetOutDataAnchor(0), merge_node1->GetInDataAnchor(0));
  269. (void)GraphUtils::AddEdge(squareNode2->GetOutDataAnchor(0), merge_node1->GetInDataAnchor(1));
  270. (void)GraphUtils::AddEdge(merge_node1->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(0));
  271. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(1));
  272. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(0));
  273. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(1));
  274. (void)GraphUtils::AddEdge(switch_node3->GetOutDataAnchor(0), squareNode3->GetInDataAnchor(0));
  275. (void)GraphUtils::AddEdge(switch_node4->GetOutDataAnchor(1), squareNode4->GetInDataAnchor(0));
  276. (void)GraphUtils::AddEdge(squareNode3->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(0));
  277. (void)GraphUtils::AddEdge(squareNode4->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(1));
  278. (void)GraphUtils::AddEdge(merge_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  279. }
  280. void make_graph_case(ComputeGraphPtr graph, bool match = true) {
  281. GeTensorDesc scalar_tensor_desc(GeShape({1, 1, 1, 1}));
  282. GeTensorDesc bool_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  283. GeTensorDesc int_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  284. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  285. xOpDef->AddOutputDesc(scalar_tensor_desc);
  286. auto xNode = graph->AddNode(xOpDef);
  287. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  288. yOpDef->AddOutputDesc(scalar_tensor_desc);
  289. auto yNode = graph->AddNode(yOpDef);
  290. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  291. zOpDef->AddOutputDesc(scalar_tensor_desc);
  292. auto zNode = graph->AddNode(zOpDef);
  293. auto greater_op_def = std::make_shared<OpDesc>("Greater", "Greater");
  294. greater_op_def->AddInputDesc(scalar_tensor_desc);
  295. greater_op_def->AddInputDesc(scalar_tensor_desc);
  296. greater_op_def->AddOutputDesc(bool_tensor_desc);
  297. auto greaterNode = graph->AddNode(greater_op_def);
  298. auto less_op_def = std::make_shared<OpDesc>("Less", "Less");
  299. less_op_def->AddInputDesc(scalar_tensor_desc);
  300. less_op_def->AddInputDesc(scalar_tensor_desc);
  301. less_op_def->AddOutputDesc(bool_tensor_desc);
  302. auto less_node = graph->AddNode(less_op_def);
  303. auto switch_op_def1 = std::make_shared<OpDesc>("greater/Switch_t", SWITCH);
  304. switch_op_def1->AddInputDesc(bool_tensor_desc);
  305. switch_op_def1->AddInputDesc(bool_tensor_desc);
  306. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  307. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  308. auto switch_node1 = graph->AddNode(switch_op_def1);
  309. auto switch_op_def2 = std::make_shared<OpDesc>("greater/Switch_f", SWITCH);
  310. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  311. switch_op_def2->AddInputDesc(bool_tensor_desc);
  312. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  313. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  314. auto switch_node2 = graph->AddNode(switch_op_def2);
  315. auto switch_op_def3 = std::make_shared<OpDesc>("less/Switch_t", SWITCH);
  316. switch_op_def3->AddInputDesc(scalar_tensor_desc);
  317. switch_op_def3->AddInputDesc(bool_tensor_desc);
  318. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  319. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  320. auto switch_node3 = graph->AddNode(switch_op_def3);
  321. auto switch_op_def4 = std::make_shared<OpDesc>("less/Switch_f", SWITCH);
  322. switch_op_def4->AddInputDesc(scalar_tensor_desc);
  323. switch_op_def4->AddInputDesc(bool_tensor_desc);
  324. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  325. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  326. auto switch_node4 = graph->AddNode(switch_op_def4);
  327. auto merge_op_def1 = std::make_shared<OpDesc>("Merge1", "Merge");
  328. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  329. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  330. merge_op_def1->AddOutputDesc(scalar_tensor_desc);
  331. merge_op_def1->AddOutputDesc(int_tensor_desc);
  332. auto merge_node1 = graph->AddNode(merge_op_def1);
  333. auto merge_op_def2 = std::make_shared<OpDesc>("Merge2", "Merge");
  334. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  335. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  336. merge_op_def2->AddOutputDesc(scalar_tensor_desc);
  337. merge_op_def2->AddOutputDesc(int_tensor_desc);
  338. auto merge_node2 = graph->AddNode(merge_op_def2);
  339. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  340. output_op_def->AddInputDesc(scalar_tensor_desc);
  341. output_op_def->AddOutputDesc(scalar_tensor_desc);
  342. auto output_node = graph->AddNode(output_op_def);
  343. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), greaterNode->GetInDataAnchor(0));
  344. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), greaterNode->GetInDataAnchor(1));
  345. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), less_node->GetInDataAnchor(0));
  346. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), less_node->GetInDataAnchor(1));
  347. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  348. (void)GraphUtils::AddEdge(greaterNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  349. (void)GraphUtils::AddEdge(less_node->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  350. (void)GraphUtils::AddEdge(greaterNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  351. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(0));
  352. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(1));
  353. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(0));
  354. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(1));
  355. (void)GraphUtils::AddEdge(switch_node3->GetOutDataAnchor(1), merge_node1->GetInDataAnchor(0));
  356. (void)GraphUtils::AddEdge(switch_node4->GetOutDataAnchor(0), merge_node1->GetInDataAnchor(1));
  357. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(1), merge_node2->GetInDataAnchor(0));
  358. (void)GraphUtils::AddEdge(merge_node1->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(1));
  359. (void)GraphUtils::AddEdge(merge_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  360. }
  361. };

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示