You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

constant_folding_pass_unittest.cc 26 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/constant_folding_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include <gtest/gtest.h>
  20. #include "common/types.h"
  21. #include "ge/common/ge/ge_util.h"
  22. #include "graph/passes/base_pass.h"
  23. #include "graph/passes/dimension_compute_pass.h"
  24. #include "graph_builder_utils.h"
  25. #include "inc/kernel.h"
  26. #include "inc/kernel_factory.h"
  27. using namespace domi;
  28. namespace ge {
  29. const char *AddYesDim = "AddYesDim";
  30. const char *AddNYes = "AddNYes";
  31. const char *AddNNo = "AddNNo";
  32. const char *AddYes = "AddYes";
  33. const char *HuberLossYes = "HuberLossYes";
  34. const char *ShapeNo = "ShapeNo";
  35. const char *DataNo = "dataNo";
  36. const char *WrongYes = "WrongYes";
  37. const char *WrongYes1 = "WrongYes1";
  38. const char *WrongYes2 = "WrongYes2";
  39. const char *WrongYes3 = "WrongYes3";
  40. class TestAddNKernel : public Kernel {
  41. public:
  42. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  43. std::vector<ge::GeTensorPtr> &v_output) override {
  44. auto output = std::make_shared<GeTensor>();
  45. std::vector<uint8_t> data{1, 2, 3};
  46. std::vector<int64_t> shape{3};
  47. output->MutableTensorDesc().SetShape(GeShape(shape));
  48. output->SetData(data);
  49. output->MutableTensorDesc().SetDataType(DT_UINT8);
  50. v_output.push_back(output);
  51. return SUCCESS;
  52. }
  53. };
  54. REGISTER_KERNEL(AddNYes, TestAddNKernel);
  55. class TestHuberLossKernel : public Kernel {
  56. public:
  57. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  58. std::vector<ge::GeTensorPtr> &v_output) override {
  59. auto output1 = std::make_shared<GeTensor>();
  60. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  61. std::vector<int64_t> shape{5};
  62. output1->MutableTensorDesc().SetShape(GeShape(shape));
  63. output1->SetData(data);
  64. output1->MutableTensorDesc().SetDataType(DT_UINT8);
  65. v_output.push_back(output1);
  66. auto output2 = std::make_shared<GeTensor>();
  67. std::vector<uint8_t> data2{1, 2, 3, 4, 5, 6};
  68. std::vector<int64_t> shape2{2, 3};
  69. output2->MutableTensorDesc().SetShape(GeShape(shape2));
  70. output2->SetData(data2);
  71. output2->MutableTensorDesc().SetDataType(DT_UINT8);
  72. v_output.push_back(output2);
  73. return SUCCESS;
  74. }
  75. };
  76. REGISTER_KERNEL(HuberLossYes, TestHuberLossKernel);
  77. class TestAddKernel : public Kernel {
  78. public:
  79. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  80. std::vector<ge::GeTensorPtr> &v_output) override {
  81. auto output = std::make_shared<GeTensor>();
  82. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  83. std::vector<int64_t> shape{5};
  84. output->MutableTensorDesc().SetShape(GeShape(shape));
  85. output->SetData(data);
  86. output->MutableTensorDesc().SetDataType(DT_UINT8);
  87. v_output.push_back(output);
  88. return SUCCESS;
  89. }
  90. };
  91. REGISTER_KERNEL(AddYes, TestAddKernel);
  92. class TestAddDimKernel : public Kernel {
  93. public:
  94. Status Compute(const ge::NodePtr &node, std::vector<ge::GeTensorPtr> &v_output) {
  95. auto output = std::make_shared<GeTensor>();
  96. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  97. std::vector<int64_t> shape{5};
  98. output->MutableTensorDesc().SetShape(GeShape(shape));
  99. output->SetData(data);
  100. output->MutableTensorDesc().SetDataType(DT_UINT8);
  101. v_output.push_back(output);
  102. return SUCCESS;
  103. }
  104. };
  105. REGISTER_KERNEL(AddYesDim, TestAddDimKernel);
  106. class TestWrongKernel : public Kernel {
  107. public:
  108. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  109. std::vector<ge::GeTensorPtr> &v_output) override {
  110. // for test: output weights is null
  111. v_output.push_back(nullptr);
  112. return SUCCESS;
  113. }
  114. };
  115. REGISTER_KERNEL(WrongYes, TestWrongKernel);
  116. class TestWrongKernel1 : public Kernel {
  117. public:
  118. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  119. std::vector<ge::GeTensorPtr> &v_output) override {
  120. // for test: no output weights
  121. return SUCCESS;
  122. }
  123. };
  124. REGISTER_KERNEL(WrongYes1, TestWrongKernel1);
  125. class TestWrongKernel2 : public Kernel {
  126. public:
  127. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  128. std::vector<ge::GeTensorPtr> &v_output) override {
  129. auto output1 = std::make_shared<GeTensor>();
  130. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  131. std::vector<int64_t> shape{5};
  132. output1->MutableTensorDesc().SetShape(GeShape(shape));
  133. output1->SetData(data);
  134. output1->MutableTensorDesc().SetDataType(DT_UINT8);
  135. v_output.push_back(output1);
  136. // for test: output weights < output size
  137. return SUCCESS;
  138. }
  139. };
  140. REGISTER_KERNEL(WrongYes2, TestWrongKernel2);
  141. class TestWrongKernel3 : public Kernel {
  142. public:
  143. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  144. std::vector<ge::GeTensorPtr> &v_output) override {
  145. // for test: return NOT_CHANGED
  146. return NOT_CHANGED;
  147. }
  148. };
  149. REGISTER_KERNEL(WrongYes3, TestWrongKernel3);
  150. class UTEST_graph_passes_constant_folding_pass : public testing::Test {
  151. protected:
  152. UTEST_graph_passes_constant_folding_pass() = default;
  153. };
  154. namespace {
  155. /// netoutput1
  156. /// |
  157. /// shapeNo1
  158. /// |
  159. /// addnYes1
  160. /// / \
  161. /// / \
  162. /// const1 const2
  163. ComputeGraphPtr BuildGraph1() {
  164. auto builder = ut::GraphBuilder("test");
  165. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  166. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  167. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  168. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  169. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  170. builder.AddDataEdge(const1, 0, addn1, 0);
  171. builder.AddDataEdge(const2, 0, addn1, 1);
  172. builder.AddDataEdge(addn1, 0, shape1, 0);
  173. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  174. return builder.GetGraph();
  175. }
  176. /// netoutput1
  177. /// |
  178. /// shapeNo1
  179. /// |
  180. /// addnYes1 shapeNo2
  181. /// / \ /
  182. /// / \ /
  183. /// const1 const2
  184. ComputeGraphPtr BuildGraph2() {
  185. auto builder = ut::GraphBuilder("test");
  186. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  187. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  188. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  189. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  190. auto shape2 = builder.AddNode("shape2", ShapeNo, 1, 1);
  191. auto netoutput1 = builder.AddNode("netoutput", DataNo, 1, 0);
  192. builder.AddDataEdge(const1, 0, addn1, 0);
  193. builder.AddDataEdge(const2, 0, addn1, 1);
  194. builder.AddDataEdge(const2, 0, shape2, 0);
  195. builder.AddDataEdge(addn1, 0, shape1, 0);
  196. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  197. return builder.GetGraph();
  198. }
  199. /// netoutput1
  200. /// |
  201. /// shapeNo1
  202. /// | c
  203. /// addnYes1 <----- dataNo1
  204. /// / \
  205. /// / \
  206. /// const1 const2
  207. ComputeGraphPtr BuildGraph3() {
  208. auto builder = ut::GraphBuilder("test");
  209. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  210. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  211. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  212. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  213. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  214. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  215. builder.AddDataEdge(const1, 0, addn1, 0);
  216. builder.AddDataEdge(const2, 0, addn1, 1);
  217. builder.AddControlEdge(data1, addn1);
  218. builder.AddDataEdge(addn1, 0, shape1, 0);
  219. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  220. return builder.GetGraph();
  221. }
  222. /// netoutput1
  223. /// |
  224. /// shapeNo1
  225. /// | c
  226. /// addnYes1 <---------
  227. /// / \ \
  228. /// / \ c \
  229. /// const1 const2 <----- dataNo1
  230. ComputeGraphPtr BuildGraph4() {
  231. auto builder = ut::GraphBuilder("test");
  232. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  233. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  234. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  235. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  236. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  237. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  238. builder.AddDataEdge(const1, 0, addn1, 0);
  239. builder.AddDataEdge(const2, 0, addn1, 1);
  240. builder.AddControlEdge(data1, const2);
  241. builder.AddControlEdge(data1, addn1);
  242. builder.AddDataEdge(addn1, 0, shape1, 0);
  243. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  244. return builder.GetGraph();
  245. }
  246. /// netoutput1
  247. /// |
  248. /// shapeNo1
  249. /// | c
  250. /// addnYes1 <----- dataNo1
  251. /// / \
  252. /// / \ c
  253. /// const1 const2 <----- dataNo2
  254. ComputeGraphPtr BuildGraph5() {
  255. auto builder = ut::GraphBuilder("test");
  256. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  257. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  258. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  259. auto data2 = builder.AddNode("data2", DataNo, 0, 1);
  260. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  261. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  262. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  263. builder.AddDataEdge(const1, 0, addn1, 0);
  264. builder.AddDataEdge(const2, 0, addn1, 1);
  265. builder.AddControlEdge(data2, const2);
  266. builder.AddControlEdge(data1, addn1);
  267. builder.AddDataEdge(addn1, 0, shape1, 0);
  268. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  269. return builder.GetGraph();
  270. }
  271. /// netoutput1
  272. /// |
  273. /// shapeNo1
  274. /// |
  275. /// addYes1 <---- const3
  276. /// |
  277. /// addnYes1 <-
  278. /// / \ \
  279. /// / \ \
  280. /// const1 const2 const4
  281. ComputeGraphPtr BuildGraph6() {
  282. auto builder = ut::GraphBuilder("test");
  283. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  284. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  285. auto const3 = builder.AddNode("const3", CONSTANT, 0, 1);
  286. auto const4 = builder.AddNode("const4", CONSTANT, 0, 1);
  287. auto addn1 = builder.AddNode("addn1", AddNYes, 3, 1);
  288. auto add1 = builder.AddNode("add1", AddYes, 2, 1);
  289. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  290. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  291. builder.AddDataEdge(const1, 0, addn1, 0);
  292. builder.AddDataEdge(const2, 0, addn1, 1);
  293. builder.AddDataEdge(const4, 0, addn1, 2);
  294. builder.AddDataEdge(addn1, 0, add1, 0);
  295. builder.AddDataEdge(const3, 0, add1, 1);
  296. builder.AddDataEdge(add1, 0, shape1, 0);
  297. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  298. return builder.GetGraph();
  299. }
  300. /// netoutput1
  301. /// / \
  302. /// shapeNo1 ShpaeNo2
  303. /// \ /
  304. /// huberLoss1
  305. /// / | \
  306. /// / | \
  307. /// const1 const2 const3
  308. ComputeGraphPtr BuildGraph7() {
  309. auto builder = ut::GraphBuilder("test");
  310. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  311. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  312. auto const3 = builder.AddNode("const3", CONSTANT, 0, 1);
  313. auto huberLoss1 = builder.AddNode("huberLoss1", HuberLossYes, 3, 2);
  314. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  315. auto shape2 = builder.AddNode("shape2", ShapeNo, 1, 1);
  316. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  317. builder.AddDataEdge(const1, 0, huberLoss1, 0);
  318. builder.AddDataEdge(const2, 0, huberLoss1, 1);
  319. builder.AddDataEdge(const3, 0, huberLoss1, 2);
  320. builder.AddDataEdge(huberLoss1, 0, shape1, 0);
  321. builder.AddDataEdge(huberLoss1, 1, shape2, 0);
  322. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  323. builder.AddDataEdge(shape2, 1, netoutput1, 0);
  324. return builder.GetGraph();
  325. }
  326. /// netoutput1
  327. /// |
  328. /// shapeNo1
  329. /// |
  330. /// addnNo1
  331. /// / \
  332. /// / \
  333. /// const1 const2
  334. ComputeGraphPtr BuildGraph8() {
  335. auto builder = ut::GraphBuilder("test");
  336. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  337. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  338. auto addn1 = builder.AddNode("addn1", AddNNo, 2, 1);
  339. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  340. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  341. builder.AddDataEdge(const1, 0, addn1, 0);
  342. builder.AddDataEdge(const2, 0, addn1, 1);
  343. builder.AddDataEdge(addn1, 0, shape1, 0);
  344. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  345. return builder.GetGraph();
  346. }
  347. /// netoutput1
  348. /// |
  349. /// shapeNo1
  350. /// |
  351. /// addnYes1
  352. /// / \
  353. /// / \
  354. /// const1 data1
  355. ComputeGraphPtr BuildGraph9() {
  356. auto builder = ut::GraphBuilder("test");
  357. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  358. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  359. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  360. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  361. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  362. builder.AddDataEdge(const1, 0, addn1, 0);
  363. builder.AddDataEdge(data1, 0, addn1, 1);
  364. builder.AddDataEdge(addn1, 0, shape1, 0);
  365. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  366. return builder.GetGraph();
  367. }
  368. /// netoutput1
  369. /// / \
  370. /// addDim sqrt1
  371. /// \ /
  372. /// switch1
  373. /// / \
  374. /// / \
  375. /// const1 const2
  376. ComputeGraphPtr BuildGraph10() {
  377. auto builder = ut::GraphBuilder("test");
  378. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  379. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  380. auto switchNode1 = builder.AddNode("switch1", SWITCH, 2, 2);
  381. auto sqrt1 = builder.AddNode("sqrt1", RSQRT, 1, 1);
  382. auto add1 = builder.AddNode("addDim", AddYesDim, 1, 1);
  383. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  384. builder.AddDataEdge(const1, 0, switchNode1, 0);
  385. builder.AddDataEdge(const2, 0, switchNode1, 1);
  386. builder.AddDataEdge(switchNode1, 0, add1, 0);
  387. builder.AddDataEdge(switchNode1, 1, sqrt1, 0);
  388. builder.AddDataEdge(add1, 0, netoutput1, 0);
  389. builder.AddDataEdge(sqrt1, 0, netoutput1, 1);
  390. return builder.GetGraph();
  391. }
  392. /// netoutput1
  393. /// |
  394. /// FRAMEWORKOP
  395. /// |
  396. /// const1
  397. ComputeGraphPtr BuildWrongGraph1() {
  398. auto builder = ut::GraphBuilder("test");
  399. auto const_op = builder.AddNode("const1", CONSTANT, 0, 1);
  400. auto op = builder.AddNode("fmk_op", FRAMEWORKOP, 1, 1);
  401. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  402. builder.AddDataEdge(const_op, 0, op, 0);
  403. builder.AddDataEdge(op, 0, netoutput1, 0);
  404. return builder.GetGraph();
  405. }
  406. /// netoutput1
  407. /// |
  408. /// WrongYes
  409. /// |
  410. /// const1
  411. ComputeGraphPtr BuildWrongGraph2() {
  412. auto builder = ut::GraphBuilder("test");
  413. auto const_op = builder.AddNode("const1", CONSTANT, 0, 1);
  414. auto op = builder.AddNode("wrong", WrongYes, 1, 1);
  415. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  416. builder.AddDataEdge(const_op, 0, op, 0);
  417. builder.AddDataEdge(op, 0, netoutput1, 0);
  418. return builder.GetGraph();
  419. }
  420. /// netoutput1
  421. /// |
  422. /// WrongYes1
  423. /// |
  424. /// const1
  425. ComputeGraphPtr BuildWrongGraph3() {
  426. auto builder = ut::GraphBuilder("test");
  427. auto const_op = builder.AddNode("const1", CONSTANT, 0, 1);
  428. auto op = builder.AddNode("wrong1", WrongYes1, 1, 1);
  429. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  430. builder.AddDataEdge(const_op, 0, op, 0);
  431. builder.AddDataEdge(op, 0, netoutput1, 0);
  432. return builder.GetGraph();
  433. }
  434. /// netoutput1 WrongYes1
  435. /// | /
  436. /// WrongYes2
  437. /// /
  438. /// const1
  439. ComputeGraphPtr BuildWrongGraph4() {
  440. auto builder = ut::GraphBuilder("test");
  441. auto const_op_1 = builder.AddNode("const1", CONSTANT, 0, 1);
  442. auto op = builder.AddNode("wrong2", WrongYes2, 1, 2);
  443. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  444. auto wrong_op = builder.AddNode("WrongYes1", WrongYes1, 1, 0);
  445. builder.AddDataEdge(const_op_1, 0, op, 0);
  446. builder.AddDataEdge(op, 0, netoutput1, 0);
  447. builder.AddDataEdge(op, 1, wrong_op, 0);
  448. return builder.GetGraph();
  449. }
  450. /// CONVOLUTION
  451. /// |
  452. /// WrongYes2 WrongYes1
  453. /// /
  454. /// const1
  455. ComputeGraphPtr BuildWrongGraph5() {
  456. auto builder = ut::GraphBuilder("test");
  457. auto const_op_1 = builder.AddNode("const1", CONSTANT, 0, 1);
  458. auto op = builder.AddNode("wrong2", WrongYes2, 1, 1);
  459. auto conv = builder.AddNode("conv", CONVOLUTION, 1, 0);
  460. auto wrong_op = builder.AddNode("WrongYes1", WrongYes1, 1, 0);
  461. builder.AddDataEdge(const_op_1, 0, op, 0);
  462. builder.AddDataEdge(op, 0, conv, 0);
  463. return builder.GetGraph();
  464. }
  465. /// CONVOLUTION
  466. /// |
  467. /// WrongYes3
  468. /// /
  469. /// const1
  470. ComputeGraphPtr BuildWrongGraph6() {
  471. auto builder = ut::GraphBuilder("test");
  472. auto const_op_1 = builder.AddNode("const1", CONSTANT, 0, 1);
  473. auto op = builder.AddNode("wrong3", WrongYes3, 1, 1);
  474. auto conv = builder.AddNode("conv", CONVOLUTION, 1, 0);
  475. builder.AddDataEdge(const_op_1, 0, op, 0);
  476. builder.AddDataEdge(op, 0, conv, 0);
  477. return builder.GetGraph();
  478. }
  479. } // namespace
  480. TEST_F(UTEST_graph_passes_constant_folding_pass, FoldingAddN) {
  481. auto graph = BuildGraph1();
  482. NamesToPass names_to_pass;
  483. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  484. GEPass pass(graph);
  485. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  486. EXPECT_EQ(graph->GetAllNodes().size(), 3);
  487. auto shape1 = graph->FindNode("shape1");
  488. EXPECT_NE(shape1, nullptr);
  489. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  490. auto folded_const = shape1->GetInDataNodes().at(0);
  491. EXPECT_EQ(folded_const->GetType(), CONSTANT);
  492. auto tensor = folded_const->GetOpDesc()->GetOutputDesc(0);
  493. EXPECT_EQ(tensor.GetDataType(), DT_UINT8);
  494. EXPECT_EQ(tensor.GetShape().GetDims(), std::vector<int64_t>({3}));
  495. for (auto &name_to_pass : names_to_pass) {
  496. delete name_to_pass.second;
  497. }
  498. }
  499. TEST_F(UTEST_graph_passes_constant_folding_pass, FoldingWithoutOneConst) {
  500. auto graph = BuildGraph2();
  501. NamesToPass names_to_pass;
  502. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  503. GEPass pass(graph);
  504. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  505. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  506. EXPECT_EQ(graph->FindNode("addn1"), nullptr);
  507. EXPECT_EQ(graph->FindNode("const1"), nullptr);
  508. auto const2 = graph->FindNode("const2");
  509. EXPECT_NE(const2, nullptr);
  510. EXPECT_EQ(const2->GetOutDataNodes().size(), 1);
  511. EXPECT_EQ(const2->GetOutDataNodes().at(0)->GetName(), "shape2");
  512. auto shape1 = graph->FindNode("shape1");
  513. EXPECT_NE(shape1, nullptr);
  514. EXPECT_EQ(shape1->GetInDataNodes().size(), 1);
  515. EXPECT_EQ(shape1->GetInDataNodes().at(0)->GetType(), CONSTANT);
  516. for (auto &name_to_pass : names_to_pass) {
  517. delete name_to_pass.second;
  518. }
  519. }
  520. TEST_F(UTEST_graph_passes_constant_folding_pass, FoldingWithConstControlEdges) {
  521. auto graph = BuildGraph5();
  522. NamesToPass names_to_pass;
  523. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  524. GEPass pass(graph);
  525. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  526. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  527. auto shape1 = graph->FindNode("shape1");
  528. EXPECT_NE(shape1, nullptr);
  529. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  530. EXPECT_EQ(shape1->GetInControlNodes().size(), 0);
  531. EXPECT_EQ(shape1->GetInDataNodes().at(0)->GetType(), CONSTANT);
  532. std::unordered_set<std::string> node_names;
  533. for (auto node : shape1->GetInControlNodes()) {
  534. node_names.insert(node->GetName());
  535. }
  536. EXPECT_EQ(node_names, std::unordered_set<std::string>());
  537. for (auto &name_to_pass : names_to_pass) {
  538. delete name_to_pass.second;
  539. }
  540. }
  541. TEST_F(UTEST_graph_passes_constant_folding_pass, ContinuesFold) {
  542. auto graph = BuildGraph6();
  543. NamesToPass names_to_pass;
  544. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  545. GEPass pass(graph);
  546. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  547. EXPECT_EQ(graph->GetAllNodes().size(), 3);
  548. auto shape1 = graph->FindNode("shape1");
  549. EXPECT_NE(shape1, nullptr);
  550. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  551. auto folded_const = shape1->GetInDataNodes().at(0);
  552. EXPECT_EQ(folded_const->GetType(), CONSTANT);
  553. auto tensor = folded_const->GetOpDesc()->GetOutputDesc(0);
  554. EXPECT_EQ(tensor.GetDataType(), DT_UINT8);
  555. EXPECT_EQ(tensor.GetShape().GetDims(), std::vector<int64_t>({5}));
  556. for (auto &name_to_pass : names_to_pass) {
  557. delete name_to_pass.second;
  558. }
  559. }
  560. TEST_F(UTEST_graph_passes_constant_folding_pass, MultipleOutput) {
  561. auto graph = BuildGraph7();
  562. NamesToPass names_to_pass;
  563. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  564. GEPass pass(graph);
  565. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  566. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  567. auto shape1 = graph->FindNode("shape1");
  568. EXPECT_NE(shape1, nullptr);
  569. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  570. auto folded_const = shape1->GetInDataNodes().at(0);
  571. EXPECT_EQ(folded_const->GetType(), CONSTANT);
  572. auto tensor = folded_const->GetOpDesc()->GetOutputDesc(0);
  573. EXPECT_EQ(tensor.GetDataType(), DT_UINT8);
  574. EXPECT_EQ(tensor.GetShape().GetDims(), std::vector<int64_t>({5}));
  575. auto shape2 = graph->FindNode("shape2");
  576. EXPECT_NE(shape2, nullptr);
  577. EXPECT_EQ(shape2->GetInNodes().size(), 1);
  578. auto folded_const2 = shape2->GetInDataNodes().at(0);
  579. EXPECT_EQ(folded_const2->GetType(), CONSTANT);
  580. auto tensor2 = folded_const2->GetOpDesc()->GetOutputDesc(0);
  581. EXPECT_EQ(tensor2.GetDataType(), DT_UINT8);
  582. EXPECT_EQ(tensor2.GetShape().GetDims(), std::vector<int64_t>({2, 3}));
  583. for (auto &name_to_pass : names_to_pass) {
  584. delete name_to_pass.second;
  585. }
  586. }
  587. TEST_F(UTEST_graph_passes_constant_folding_pass, NotChange1) {
  588. auto graph = BuildGraph8();
  589. NamesToPass names_to_pass;
  590. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  591. GEPass pass(graph);
  592. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  593. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  594. for (auto &name_to_pass : names_to_pass) {
  595. delete name_to_pass.second;
  596. }
  597. }
  598. TEST_F(UTEST_graph_passes_constant_folding_pass, NotChange2) {
  599. auto graph = BuildGraph9();
  600. NamesToPass names_to_pass;
  601. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  602. GEPass pass(graph);
  603. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  604. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  605. for (auto &name_to_pass : names_to_pass) {
  606. delete name_to_pass.second;
  607. }
  608. }
  609. TEST_F(UTEST_graph_passes_constant_folding_pass, FoldingSize) {
  610. auto graph = BuildGraph10();
  611. NamesToPass names_to_pass;
  612. names_to_pass.push_back({"Test", new DimensionComputePass});
  613. GEPass pass(graph);
  614. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  615. EXPECT_EQ(graph->GetAllNodes().size(), 7);
  616. auto switchnode = graph->FindNode("switch1");
  617. EXPECT_NE(switchnode, nullptr);
  618. EXPECT_EQ(switchnode->GetOutDataNodes().size(), 2);
  619. EXPECT_EQ(switchnode->GetOutDataNodes().at(0)->GetName(), "addDim_ctrl_identity_0");
  620. for (auto &name_to_pass : names_to_pass) {
  621. delete name_to_pass.second;
  622. }
  623. }
  624. TEST_F(UTEST_graph_passes_constant_folding_pass, Unlikely1) {
  625. auto graph = BuildWrongGraph1();
  626. NamesToPass names_to_pass;
  627. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  628. GEPass pass(graph);
  629. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  630. for (auto &name_to_pass : names_to_pass) {
  631. delete name_to_pass.second;
  632. }
  633. }
  634. TEST_F(UTEST_graph_passes_constant_folding_pass, Unlikely2) {
  635. auto graph = BuildWrongGraph2();
  636. NamesToPass names_to_pass;
  637. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  638. GEPass pass(graph);
  639. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  640. for (auto &name_to_pass : names_to_pass) {
  641. delete name_to_pass.second;
  642. }
  643. }
  644. TEST_F(UTEST_graph_passes_constant_folding_pass, Unlikely3) {
  645. auto graph = BuildWrongGraph3();
  646. NamesToPass names_to_pass;
  647. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  648. GEPass pass(graph);
  649. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  650. for (auto &name_to_pass : names_to_pass) {
  651. delete name_to_pass.second;
  652. }
  653. }
  654. TEST_F(UTEST_graph_passes_constant_folding_pass, Unlikely4) {
  655. auto graph = BuildWrongGraph4();
  656. NamesToPass names_to_pass;
  657. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  658. GEPass pass(graph);
  659. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  660. for (auto &name_to_pass : names_to_pass) {
  661. delete name_to_pass.second;
  662. }
  663. }
  664. TEST_F(UTEST_graph_passes_constant_folding_pass, Unlikely5) {
  665. auto graph = BuildWrongGraph5();
  666. NamesToPass names_to_pass;
  667. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  668. GEPass pass(graph);
  669. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  670. for (auto &name_to_pass : names_to_pass) {
  671. delete name_to_pass.second;
  672. }
  673. }
  674. TEST_F(UTEST_graph_passes_constant_folding_pass, Unlikely6) {
  675. auto graph = BuildWrongGraph6();
  676. NamesToPass names_to_pass;
  677. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  678. GEPass pass(graph);
  679. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  680. for (auto &name_to_pass : names_to_pass) {
  681. delete name_to_pass.second;
  682. }
  683. }
  684. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)