You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_op_pass_unittest.cc 21 kB

5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "omg/omg_inner_types.h"
  18. #define protected public
  19. #define private public
  20. #include "graph/passes/switch_op_pass.h"
  21. #include "common/debug/log.h"
  22. #include "common/debug/memory_dumper.h"
  23. #include "common/op/attr_value_util.h"
  24. #include "common/types.h"
  25. #include "graph/debug/ge_attr_define.h"
  26. #include "graph/graph.h"
  27. #include "graph/passes/control_op_attr_pass.h"
  28. #include "inc/pass_manager.h"
  29. #undef protected
  30. #undef private
  31. using namespace domi;
  32. using namespace testing;
  33. using namespace ge;
  34. class UTEST_graph_passes_switch_op_pass : public testing::Test {
  35. protected:
  36. void SetUp() {}
  37. void TearDown() {}
  38. public:
  39. void make_graph(ComputeGraphPtr graph, bool match = true) {
  40. GeTensorDesc boolTensorDesc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL);
  41. GeTensorDesc intTensorDesc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
  42. GeTensorDesc scalarTensorDesc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
  43. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  44. xOpDef->AddOutputDesc(scalarTensorDesc);
  45. auto xNode = graph->AddNode(xOpDef);
  46. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  47. yOpDef->AddOutputDesc(scalarTensorDesc);
  48. auto yNode = graph->AddNode(yOpDef);
  49. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  50. zOpDef->AddOutputDesc(scalarTensorDesc);
  51. auto zNode = graph->AddNode(zOpDef);
  52. auto condOpDef = std::make_shared<OpDesc>("Less", "Less");
  53. condOpDef->AddInputDesc(scalarTensorDesc);
  54. condOpDef->AddInputDesc(scalarTensorDesc);
  55. condOpDef->AddOutputDesc(boolTensorDesc);
  56. auto condNode = graph->AddNode(condOpDef);
  57. auto switchOpDef1 = std::make_shared<OpDesc>("Add/Switch", SWITCH);
  58. switchOpDef1->AddInputDesc(scalarTensorDesc);
  59. switchOpDef1->AddInputDesc(boolTensorDesc);
  60. switchOpDef1->AddOutputDesc(scalarTensorDesc);
  61. switchOpDef1->AddOutputDesc(scalarTensorDesc);
  62. auto switchNode1 = graph->AddNode(switchOpDef1);
  63. auto switchOpDef2 = std::make_shared<OpDesc>("Add/Switch_1", SWITCH);
  64. switchOpDef2->AddInputDesc(scalarTensorDesc);
  65. switchOpDef2->AddInputDesc(boolTensorDesc);
  66. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  67. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  68. auto switchNode2 = graph->AddNode(switchOpDef2);
  69. auto switchOpDef3 = std::make_shared<OpDesc>("Square/Switch", SWITCH);
  70. switchOpDef3->AddInputDesc(scalarTensorDesc);
  71. switchOpDef3->AddInputDesc(boolTensorDesc);
  72. switchOpDef3->AddOutputDesc(scalarTensorDesc);
  73. switchOpDef3->AddOutputDesc(scalarTensorDesc);
  74. auto switchNode3 = graph->AddNode(switchOpDef3);
  75. auto addOpDef = std::make_shared<OpDesc>("Add", "ADD");
  76. addOpDef->AddInputDesc(scalarTensorDesc);
  77. addOpDef->AddInputDesc(scalarTensorDesc);
  78. addOpDef->AddOutputDesc(scalarTensorDesc);
  79. auto addNode = graph->AddNode(addOpDef);
  80. auto mergeOpDef = std::make_shared<OpDesc>("Merge", "Merge");
  81. mergeOpDef->AddInputDesc(scalarTensorDesc);
  82. mergeOpDef->AddInputDesc(scalarTensorDesc);
  83. mergeOpDef->AddOutputDesc(scalarTensorDesc);
  84. mergeOpDef->AddOutputDesc(intTensorDesc);
  85. auto mergeNode = graph->AddNode(mergeOpDef);
  86. auto outputOpDef = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  87. outputOpDef->AddInputDesc(scalarTensorDesc);
  88. outputOpDef->AddOutputDesc(scalarTensorDesc);
  89. auto outputNode = graph->AddNode(outputOpDef);
  90. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  91. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  92. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(0));
  93. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(1));
  94. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(0));
  95. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(1));
  96. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switchNode3->GetInDataAnchor(0));
  97. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switchNode3->GetInDataAnchor(1));
  98. (void)GraphUtils::AddEdge(switchNode1->GetOutDataAnchor(1), addNode->GetInDataAnchor(0));
  99. (void)GraphUtils::AddEdge(switchNode2->GetOutDataAnchor(1), addNode->GetInDataAnchor(1));
  100. (void)GraphUtils::AddEdge(addNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(1));
  101. (void)GraphUtils::AddEdge(switchNode3->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(0));
  102. (void)GraphUtils::AddEdge(mergeNode->GetOutDataAnchor(0), outputNode->GetInDataAnchor(0));
  103. }
  104. void make_graph_const(ComputeGraphPtr graph, bool match = true) {
  105. // resnet50 PolynomialDecay
  106. GeTensorDesc scalarTensorDesc(GeShape({1, 1, 1, 1}));
  107. GeTensorDesc boolTensorDesc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  108. GeTensorDesc intTensorDesc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  109. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  110. xOpDef->AddOutputDesc(scalarTensorDesc);
  111. auto xNode = graph->AddNode(xOpDef);
  112. auto yOpDef = std::make_shared<OpDesc>("y", "Const");
  113. yOpDef->AddOutputDesc(scalarTensorDesc);
  114. auto yNode = graph->AddNode(yOpDef);
  115. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  116. zOpDef->AddOutputDesc(scalarTensorDesc);
  117. auto zNode = graph->AddNode(zOpDef);
  118. auto constOpDef = std::make_shared<OpDesc>("Const", "Const");
  119. constOpDef->AddOutputDesc(scalarTensorDesc);
  120. auto constNode = graph->AddNode(constOpDef);
  121. auto condOpDef = std::make_shared<OpDesc>("Equal", "Equal");
  122. condOpDef->AddInputDesc(scalarTensorDesc);
  123. condOpDef->AddInputDesc(scalarTensorDesc);
  124. condOpDef->AddOutputDesc(boolTensorDesc);
  125. auto condNode = graph->AddNode(condOpDef);
  126. auto identityOpDef = std::make_shared<OpDesc>("identity", "Identity");
  127. identityOpDef->AddInputDesc(boolTensorDesc);
  128. identityOpDef->AddOutputDesc(boolTensorDesc);
  129. auto identityNode = graph->AddNode(identityOpDef);
  130. auto switchOpDef1 = std::make_shared<OpDesc>("Switch", SWITCH);
  131. switchOpDef1->AddInputDesc(boolTensorDesc);
  132. switchOpDef1->AddInputDesc(boolTensorDesc);
  133. switchOpDef1->AddOutputDesc(boolTensorDesc);
  134. switchOpDef1->AddOutputDesc(boolTensorDesc);
  135. auto switchNode1 = graph->AddNode(switchOpDef1);
  136. auto tIdentityOpDef = std::make_shared<OpDesc>("switch_t", "Identity");
  137. tIdentityOpDef->AddInputDesc(scalarTensorDesc);
  138. tIdentityOpDef->AddOutputDesc(scalarTensorDesc);
  139. auto tIdentityNode = graph->AddNode(tIdentityOpDef);
  140. auto fIdentityOpDef = std::make_shared<OpDesc>("switch_f", "Identity");
  141. fIdentityOpDef->AddInputDesc(scalarTensorDesc);
  142. fIdentityOpDef->AddOutputDesc(scalarTensorDesc);
  143. auto fIdentityNode = graph->AddNode(fIdentityOpDef);
  144. auto switchOpDef2 = std::make_shared<OpDesc>("Switch_1", SWITCH);
  145. switchOpDef2->AddInputDesc(scalarTensorDesc);
  146. switchOpDef2->AddInputDesc(boolTensorDesc);
  147. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  148. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  149. auto switchNode2 = graph->AddNode(switchOpDef2);
  150. auto mulOpDef = std::make_shared<OpDesc>("truediv", "Mul");
  151. mulOpDef->AddInputDesc(scalarTensorDesc);
  152. mulOpDef->AddInputDesc(scalarTensorDesc);
  153. mulOpDef->AddOutputDesc(scalarTensorDesc);
  154. auto mulNode = graph->AddNode(mulOpDef);
  155. auto ceilOpDef = std::make_shared<OpDesc>("Ceil", "Ceil");
  156. ceilOpDef->AddInputDesc(scalarTensorDesc);
  157. ceilOpDef->AddOutputDesc(scalarTensorDesc);
  158. auto ceilNode = graph->AddNode(ceilOpDef);
  159. auto mergeOpDef = std::make_shared<OpDesc>("Merge", "Merge");
  160. mergeOpDef->AddInputDesc(scalarTensorDesc);
  161. mergeOpDef->AddInputDesc(scalarTensorDesc);
  162. mergeOpDef->AddOutputDesc(scalarTensorDesc);
  163. mergeOpDef->AddOutputDesc(intTensorDesc);
  164. auto mergeNode = graph->AddNode(mergeOpDef);
  165. auto outputOpDef = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  166. outputOpDef->AddInputDesc(scalarTensorDesc);
  167. outputOpDef->AddOutputDesc(scalarTensorDesc);
  168. auto outputNode = graph->AddNode(outputOpDef);
  169. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  170. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  171. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), identityNode->GetInDataAnchor(0));
  172. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(0));
  173. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(1));
  174. (void)GraphUtils::AddEdge(switchNode1->GetOutDataAnchor(0), fIdentityNode->GetInDataAnchor(0));
  175. (void)GraphUtils::AddEdge(switchNode1->GetOutDataAnchor(1), tIdentityNode->GetInDataAnchor(0));
  176. (void)GraphUtils::AddEdge(fIdentityNode->GetOutControlAnchor(), zNode->GetInControlAnchor());
  177. (void)GraphUtils::AddEdge(tIdentityNode->GetOutControlAnchor(), constNode->GetInControlAnchor());
  178. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(0));
  179. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(1));
  180. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), mulNode->GetInDataAnchor(0));
  181. (void)GraphUtils::AddEdge(switchNode2->GetOutDataAnchor(0), mulNode->GetInDataAnchor(1));
  182. (void)GraphUtils::AddEdge(mulNode->GetOutDataAnchor(0), ceilNode->GetInDataAnchor(0));
  183. (void)GraphUtils::AddEdge(constNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(1));
  184. (void)GraphUtils::AddEdge(ceilNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(0));
  185. (void)GraphUtils::AddEdge(mergeNode->GetOutDataAnchor(0), outputNode->GetInDataAnchor(0));
  186. }
  187. void make_graph_cyclic_dependence(ComputeGraphPtr graph, bool match = true) {
  188. GeTensorDesc scalarTensorDesc(GeShape({1, 1, 1, 1}));
  189. GeTensorDesc boolTensorDesc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  190. GeTensorDesc intTensorDesc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  191. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  192. xOpDef->AddOutputDesc(scalarTensorDesc);
  193. auto xNode = graph->AddNode(xOpDef);
  194. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  195. yOpDef->AddOutputDesc(scalarTensorDesc);
  196. auto yNode = graph->AddNode(yOpDef);
  197. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  198. zOpDef->AddOutputDesc(scalarTensorDesc);
  199. auto zNode = graph->AddNode(zOpDef);
  200. auto condOpDef = std::make_shared<OpDesc>("Less", "Less");
  201. condOpDef->AddInputDesc(scalarTensorDesc);
  202. condOpDef->AddInputDesc(scalarTensorDesc);
  203. condOpDef->AddOutputDesc(boolTensorDesc);
  204. auto condNode = graph->AddNode(condOpDef);
  205. auto switchOpDef1 = std::make_shared<OpDesc>("Switch_f_1", SWITCH);
  206. switchOpDef1->AddInputDesc(scalarTensorDesc);
  207. switchOpDef1->AddInputDesc(boolTensorDesc);
  208. switchOpDef1->AddOutputDesc(scalarTensorDesc);
  209. switchOpDef1->AddOutputDesc(scalarTensorDesc);
  210. auto switchNode1 = graph->AddNode(switchOpDef1);
  211. auto switchOpDef2 = std::make_shared<OpDesc>("Switch_t_1", SWITCH);
  212. switchOpDef2->AddInputDesc(scalarTensorDesc);
  213. switchOpDef2->AddInputDesc(boolTensorDesc);
  214. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  215. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  216. auto switchNode2 = graph->AddNode(switchOpDef2);
  217. auto switchOpDef3 = std::make_shared<OpDesc>("Switch_f_2", SWITCH);
  218. switchOpDef3->AddInputDesc(scalarTensorDesc);
  219. switchOpDef3->AddInputDesc(boolTensorDesc);
  220. switchOpDef3->AddOutputDesc(scalarTensorDesc);
  221. switchOpDef3->AddOutputDesc(scalarTensorDesc);
  222. auto switchNode3 = graph->AddNode(switchOpDef3);
  223. auto switchOpDef4 = std::make_shared<OpDesc>("Switch_t_2", SWITCH);
  224. switchOpDef4->AddInputDesc(scalarTensorDesc);
  225. switchOpDef4->AddInputDesc(boolTensorDesc);
  226. switchOpDef4->AddOutputDesc(scalarTensorDesc);
  227. switchOpDef4->AddOutputDesc(scalarTensorDesc);
  228. auto switchNode4 = graph->AddNode(switchOpDef4);
  229. auto squareOpDef1 = std::make_shared<OpDesc>("Square1", "Square");
  230. squareOpDef1->AddInputDesc(scalarTensorDesc);
  231. squareOpDef1->AddOutputDesc(scalarTensorDesc);
  232. auto squareNode1 = graph->AddNode(squareOpDef1);
  233. auto squareOpDef2 = std::make_shared<OpDesc>("Square2", "Square");
  234. squareOpDef2->AddInputDesc(scalarTensorDesc);
  235. squareOpDef2->AddOutputDesc(scalarTensorDesc);
  236. auto squareNode2 = graph->AddNode(squareOpDef2);
  237. auto squareOpDef3 = std::make_shared<OpDesc>("Square3", "Square");
  238. squareOpDef3->AddInputDesc(scalarTensorDesc);
  239. squareOpDef3->AddOutputDesc(scalarTensorDesc);
  240. auto squareNode3 = graph->AddNode(squareOpDef3);
  241. auto squareOpDef4 = std::make_shared<OpDesc>("Square4", "Square");
  242. squareOpDef4->AddInputDesc(scalarTensorDesc);
  243. squareOpDef4->AddOutputDesc(scalarTensorDesc);
  244. auto squareNode4 = graph->AddNode(squareOpDef4);
  245. auto mergeOpDef1 = std::make_shared<OpDesc>("Merge1", "Merge");
  246. mergeOpDef1->AddInputDesc(scalarTensorDesc);
  247. mergeOpDef1->AddInputDesc(scalarTensorDesc);
  248. mergeOpDef1->AddOutputDesc(scalarTensorDesc);
  249. mergeOpDef1->AddOutputDesc(intTensorDesc);
  250. auto mergeNode1 = graph->AddNode(mergeOpDef1);
  251. auto mergeOpDef2 = std::make_shared<OpDesc>("Merge2", "Merge");
  252. mergeOpDef2->AddInputDesc(scalarTensorDesc);
  253. mergeOpDef2->AddInputDesc(scalarTensorDesc);
  254. mergeOpDef2->AddOutputDesc(scalarTensorDesc);
  255. mergeOpDef2->AddOutputDesc(intTensorDesc);
  256. auto mergeNode2 = graph->AddNode(mergeOpDef2);
  257. auto outputOpDef = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  258. outputOpDef->AddInputDesc(scalarTensorDesc);
  259. outputOpDef->AddOutputDesc(scalarTensorDesc);
  260. auto outputNode = graph->AddNode(outputOpDef);
  261. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  262. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  263. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(0));
  264. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(1));
  265. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(0));
  266. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(1));
  267. (void)GraphUtils::AddEdge(switchNode1->GetOutDataAnchor(0), squareNode1->GetInDataAnchor(0));
  268. (void)GraphUtils::AddEdge(switchNode2->GetOutDataAnchor(1), squareNode2->GetInDataAnchor(0));
  269. (void)GraphUtils::AddEdge(squareNode1->GetOutDataAnchor(0), mergeNode1->GetInDataAnchor(0));
  270. (void)GraphUtils::AddEdge(squareNode2->GetOutDataAnchor(0), mergeNode1->GetInDataAnchor(1));
  271. (void)GraphUtils::AddEdge(mergeNode1->GetOutDataAnchor(0), switchNode3->GetInDataAnchor(0));
  272. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switchNode3->GetInDataAnchor(1));
  273. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switchNode4->GetInDataAnchor(0));
  274. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switchNode4->GetInDataAnchor(1));
  275. (void)GraphUtils::AddEdge(switchNode3->GetOutDataAnchor(0), squareNode3->GetInDataAnchor(0));
  276. (void)GraphUtils::AddEdge(switchNode4->GetOutDataAnchor(1), squareNode4->GetInDataAnchor(0));
  277. (void)GraphUtils::AddEdge(squareNode3->GetOutDataAnchor(0), mergeNode2->GetInDataAnchor(0));
  278. (void)GraphUtils::AddEdge(squareNode4->GetOutDataAnchor(0), mergeNode2->GetInDataAnchor(1));
  279. (void)GraphUtils::AddEdge(mergeNode2->GetOutDataAnchor(0), outputNode->GetInDataAnchor(0));
  280. }
  281. void make_graph_case(ComputeGraphPtr graph, bool match = true) {
  282. GeTensorDesc scalarTensorDesc(GeShape({1, 1, 1, 1}));
  283. GeTensorDesc boolTensorDesc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  284. GeTensorDesc intTensorDesc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  285. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  286. xOpDef->AddOutputDesc(scalarTensorDesc);
  287. auto xNode = graph->AddNode(xOpDef);
  288. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  289. yOpDef->AddOutputDesc(scalarTensorDesc);
  290. auto yNode = graph->AddNode(yOpDef);
  291. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  292. zOpDef->AddOutputDesc(scalarTensorDesc);
  293. auto zNode = graph->AddNode(zOpDef);
  294. auto greaterOpDef = std::make_shared<OpDesc>("Greater", "Greater");
  295. greaterOpDef->AddInputDesc(scalarTensorDesc);
  296. greaterOpDef->AddInputDesc(scalarTensorDesc);
  297. greaterOpDef->AddOutputDesc(boolTensorDesc);
  298. auto greaterNode = graph->AddNode(greaterOpDef);
  299. auto lessOpDef = std::make_shared<OpDesc>("Less", "Less");
  300. lessOpDef->AddInputDesc(scalarTensorDesc);
  301. lessOpDef->AddInputDesc(scalarTensorDesc);
  302. lessOpDef->AddOutputDesc(boolTensorDesc);
  303. auto lessNode = graph->AddNode(lessOpDef);
  304. auto switchOpDef1 = std::make_shared<OpDesc>("greater/Switch_t", SWITCH);
  305. switchOpDef1->AddInputDesc(boolTensorDesc);
  306. switchOpDef1->AddInputDesc(boolTensorDesc);
  307. switchOpDef1->AddOutputDesc(boolTensorDesc);
  308. switchOpDef1->AddOutputDesc(boolTensorDesc);
  309. auto switchNode1 = graph->AddNode(switchOpDef1);
  310. auto switchOpDef2 = std::make_shared<OpDesc>("greater/Switch_f", SWITCH);
  311. switchOpDef2->AddInputDesc(scalarTensorDesc);
  312. switchOpDef2->AddInputDesc(boolTensorDesc);
  313. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  314. switchOpDef2->AddOutputDesc(scalarTensorDesc);
  315. auto switchNode2 = graph->AddNode(switchOpDef2);
  316. auto switchOpDef3 = std::make_shared<OpDesc>("less/Switch_t", SWITCH);
  317. switchOpDef3->AddInputDesc(scalarTensorDesc);
  318. switchOpDef3->AddInputDesc(boolTensorDesc);
  319. switchOpDef3->AddOutputDesc(scalarTensorDesc);
  320. switchOpDef3->AddOutputDesc(scalarTensorDesc);
  321. auto switchNode3 = graph->AddNode(switchOpDef3);
  322. auto switchOpDef4 = std::make_shared<OpDesc>("less/Switch_f", SWITCH);
  323. switchOpDef4->AddInputDesc(scalarTensorDesc);
  324. switchOpDef4->AddInputDesc(boolTensorDesc);
  325. switchOpDef4->AddOutputDesc(scalarTensorDesc);
  326. switchOpDef4->AddOutputDesc(scalarTensorDesc);
  327. auto switchNode4 = graph->AddNode(switchOpDef4);
  328. auto mergeOpDef1 = std::make_shared<OpDesc>("Merge1", "Merge");
  329. mergeOpDef1->AddInputDesc(scalarTensorDesc);
  330. mergeOpDef1->AddInputDesc(scalarTensorDesc);
  331. mergeOpDef1->AddOutputDesc(scalarTensorDesc);
  332. mergeOpDef1->AddOutputDesc(intTensorDesc);
  333. auto mergeNode1 = graph->AddNode(mergeOpDef1);
  334. auto mergeOpDef2 = std::make_shared<OpDesc>("Merge2", "Merge");
  335. mergeOpDef2->AddInputDesc(scalarTensorDesc);
  336. mergeOpDef2->AddInputDesc(scalarTensorDesc);
  337. mergeOpDef2->AddOutputDesc(scalarTensorDesc);
  338. mergeOpDef2->AddOutputDesc(intTensorDesc);
  339. auto mergeNode2 = graph->AddNode(mergeOpDef2);
  340. auto outputOpDef = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  341. outputOpDef->AddInputDesc(scalarTensorDesc);
  342. outputOpDef->AddOutputDesc(scalarTensorDesc);
  343. auto outputNode = graph->AddNode(outputOpDef);
  344. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), greaterNode->GetInDataAnchor(0));
  345. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), greaterNode->GetInDataAnchor(1));
  346. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), lessNode->GetInDataAnchor(0));
  347. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), lessNode->GetInDataAnchor(1));
  348. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(0));
  349. (void)GraphUtils::AddEdge(greaterNode->GetOutDataAnchor(0), switchNode1->GetInDataAnchor(1));
  350. (void)GraphUtils::AddEdge(lessNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(0));
  351. (void)GraphUtils::AddEdge(greaterNode->GetOutDataAnchor(0), switchNode2->GetInDataAnchor(1));
  352. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), switchNode3->GetInDataAnchor(0));
  353. (void)GraphUtils::AddEdge(switchNode2->GetOutDataAnchor(0), switchNode3->GetInDataAnchor(1));
  354. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switchNode4->GetInDataAnchor(0));
  355. (void)GraphUtils::AddEdge(switchNode2->GetOutDataAnchor(0), switchNode4->GetInDataAnchor(1));
  356. (void)GraphUtils::AddEdge(switchNode3->GetOutDataAnchor(1), mergeNode1->GetInDataAnchor(0));
  357. (void)GraphUtils::AddEdge(switchNode4->GetOutDataAnchor(0), mergeNode1->GetInDataAnchor(1));
  358. (void)GraphUtils::AddEdge(switchNode1->GetOutDataAnchor(1), mergeNode2->GetInDataAnchor(0));
  359. (void)GraphUtils::AddEdge(mergeNode1->GetOutDataAnchor(0), mergeNode2->GetInDataAnchor(1));
  360. (void)GraphUtils::AddEdge(mergeNode2->GetOutDataAnchor(0), outputNode->GetInDataAnchor(0));
  361. }
  362. };

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示