You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_pass_unittest.cc 14 kB

5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <map>
  18. #include <set>
  19. #include <string>
  20. #include <vector>
  21. #include "gtest/gtest.h"
  22. #define protected public
  23. #include "graph/passes/base_pass.h"
  24. #undef protected
  25. #include "external/graph/ge_error_codes.h"
  26. #include "framework/common/ge_inner_error_codes.h"
  27. #include "framework/common/types.h"
  28. #include "graph/node.h"
  29. #include "graph/utils/graph_utils.h"
  30. #include "graph_builder_utils.h"
  31. template class std::unordered_set<ge::NodePtr>;
  32. using namespace domi;
  33. namespace ge {
  34. class TestPass : public BaseNodePass {
  35. public:
  36. TestPass() = default;
  37. TestPass(bool dead_loop) : dead_loop_(dead_loop), run_times_(0) {}
  38. Status Run(NodePtr &node) override {
  39. ++run_times_;
  40. iter_nodes_.push_back(node);
  41. auto iter = names_to_add_del_.find(node->GetName());
  42. if (iter != names_to_add_del_.end()) {
  43. for (const auto &node_name : iter->second) {
  44. auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name);
  45. GraphUtils::IsolateNode(del_node, {0});
  46. AddNodeDeleted(del_node.get());
  47. }
  48. }
  49. iter = names_to_add_repass_.find(node->GetName());
  50. if (iter != names_to_add_repass_.end()) {
  51. auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
  52. for (const auto &node_name : iter->second) {
  53. for (auto &node_re_pass : all_nodes) {
  54. if (node_re_pass->GetName() == node_name) {
  55. AddRePassNode(node_re_pass);
  56. break;
  57. }
  58. }
  59. }
  60. if (!dead_loop_) {
  61. names_to_add_repass_.erase(iter);
  62. }
  63. }
  64. return SUCCESS;
  65. }
  66. void clear() { iter_nodes_.clear(); }
  67. std::vector<NodePtr> GetIterNodes() { return iter_nodes_; }
  68. void AddRePassNodeName(const std::string &iter_node, const std::string &re_pass_node) {
  69. names_to_add_repass_[iter_node].insert(re_pass_node);
  70. }
  71. void AddDelNodeName(const std::string &iter_node, const std::string &del_node) {
  72. names_to_add_del_[iter_node].insert(del_node);
  73. }
  74. unsigned int GetRunTimes() { return run_times_; }
  75. private:
  76. std::vector<NodePtr> iter_nodes_;
  77. std::map<std::string, std::unordered_set<std::string>> names_to_add_del_;
  78. std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_;
  79. bool dead_loop_;
  80. unsigned int run_times_;
  81. };
  82. class TestDelPass : public BaseNodePass {
  83. public:
  84. Status Run(NodePtr &node) override { return SUCCESS; }
  85. };
  86. class UTEST_graph_passes_base_pass : public testing::Test {
  87. protected:
  88. UTEST_graph_passes_base_pass() {
  89. auto p1 = new TestPass;
  90. names_to_pass_.push_back(std::make_pair("test1", p1));
  91. }
  92. void SetUp() override {
  93. for (auto &name_to_pass : names_to_pass_) {
  94. dynamic_cast<TestPass *>(name_to_pass.second)->clear();
  95. }
  96. }
  97. ~UTEST_graph_passes_base_pass() override {
  98. for (auto &name_to_pass : names_to_pass_) {
  99. delete name_to_pass.second;
  100. }
  101. }
  102. NamesToPass names_to_pass_;
  103. };
  104. using namespace domi;
  105. /// reshape1
  106. /// |
  107. /// add1
  108. /// / \
  109. /// | |
  110. /// data1 const1
  111. ComputeGraphPtr BuildGraph1() {
  112. auto builder = ut::GraphBuilder("g1");
  113. auto data = builder.AddNode("data1", DATA, 0, 1);
  114. auto a1 = builder.AddNode("add1", ADD, 2, 1);
  115. auto c1 = builder.AddNode("const1", CONSTANT, 0, 1);
  116. auto r1 = builder.AddNode("reshape1", RESHAPE, 1, 1);
  117. builder.AddDataEdge(data, 0, a1, 0);
  118. builder.AddDataEdge(c1, 0, a1, 1);
  119. builder.AddDataEdge(a1, 0, r1, 0);
  120. return builder.GetGraph();
  121. }
  122. /// sum1
  123. /// / \
  124. /// / \
  125. /// / \
  126. /// reshape1 addn1
  127. /// | c |
  128. /// add1 <--- shape1
  129. /// / \ |
  130. /// | | |
  131. /// data1 const1 const2
  132. ComputeGraphPtr BuildGraph2() {
  133. auto builder = ut::GraphBuilder("g1");
  134. auto data1 = builder.AddNode("data1", DATA, 0, 1);
  135. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  136. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  137. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  138. auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1);
  139. auto reshape1 = builder.AddNode("reshape1", RESHAPE, 1, 1);
  140. auto addn1 = builder.AddNode("addn1", ADDN, 1, 1);
  141. auto sum1 = builder.AddNode("sum1", SUM, 2, 1);
  142. builder.AddDataEdge(data1, 0, add1, 0);
  143. builder.AddDataEdge(const1, 0, add1, 1);
  144. builder.AddDataEdge(const2, 0, shape1, 0);
  145. builder.AddControlEdge(shape1, add1);
  146. builder.AddDataEdge(add1, 0, reshape1, 0);
  147. builder.AddDataEdge(shape1, 0, addn1, 0);
  148. builder.AddDataEdge(reshape1, 0, sum1, 0);
  149. builder.AddDataEdge(addn1, 0, sum1, 1);
  150. return builder.GetGraph();
  151. }
  152. /// rnextiteration
  153. /// | |
  154. /// merge
  155. /// |
  156. /// data1
  157. ComputeGraphPtr BuildGraph3() {
  158. auto builder = ut::GraphBuilder("g1");
  159. auto data1 = builder.AddNode("data1", DATA, 0, 1);
  160. auto merge1 = builder.AddNode("merge1", MERGE, 2, 1);
  161. auto next1 = builder.AddNode("next1", NEXTITERATION, 1, 1);
  162. builder.AddDataEdge(data1, 0, merge1, 0);
  163. builder.AddDataEdge(merge1, 0, next1, 0);
  164. builder.AddDataEdge(next1, 0, merge1, 1);
  165. builder.AddControlEdge(merge1, next1);
  166. builder.AddControlEdge(next1, merge1);
  167. return builder.GetGraph();
  168. }
  169. void CheckIterOrder(TestPass *pass, std::vector<std::unordered_set<std::string>> &nodes_layers) {
  170. std::unordered_set<std::string> layer_nodes;
  171. size_t layer_index = 0;
  172. for (const auto &node : pass->GetIterNodes()) {
  173. layer_nodes.insert(node->GetName());
  174. EXPECT_LT(layer_index, nodes_layers.size());
  175. if (layer_nodes == nodes_layers[layer_index]) {
  176. layer_index++;
  177. layer_nodes.clear();
  178. }
  179. }
  180. EXPECT_EQ(layer_index, nodes_layers.size());
  181. }
  182. /// Op1
  183. /// |
  184. /// Merge
  185. /// / \
  186. /// Op2 Op3
  187. TEST_F(UTEST_graph_passes_base_pass, DelIsolateFail) {
  188. auto builder = ut::GraphBuilder("g1");
  189. auto merge_node = builder.AddNode("Merge", MERGE, 1, 1);
  190. auto node1 = builder.AddNode("Op1", RELU, 1, 1);
  191. auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1);
  192. auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1);
  193. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
  194. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
  195. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
  196. EXPECT_EQ(node1->GetOutDataNodes().size(), 1);
  197. TestDelPass del_pass;
  198. auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1});
  199. EXPECT_EQ(ret, FAILED);
  200. OpDescPtr op_desc = std::make_shared<OpDesc>("merge", MERGE);
  201. NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op_desc, nullptr));
  202. ret = del_pass.IsolateAndDeleteNode(node, {0, -1});
  203. EXPECT_EQ(ret, FAILED);
  204. }
  205. /// Op1
  206. /// |
  207. /// Merge
  208. /// / \
  209. /// Op2 Op3
  210. TEST_F(UTEST_graph_passes_base_pass, DelIsolateSuccess) {
  211. auto builder = ut::GraphBuilder("g1");
  212. auto merge_node = builder.AddNode("Merge", MERGE, 1, 2);
  213. auto node1 = builder.AddNode("Op1", RELU, 1, 1);
  214. auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1);
  215. auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1);
  216. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
  217. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
  218. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
  219. EXPECT_EQ(node1->GetOutDataNodes().size(), 1);
  220. TestDelPass del_pass;
  221. auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1});
  222. EXPECT_EQ(ret, SUCCESS);
  223. }
  224. TEST_F(UTEST_graph_passes_base_pass, DataGraph) {
  225. auto graph = BuildGraph1();
  226. auto ge_pass = GEPass(graph);
  227. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  228. auto *pass = dynamic_cast<TestPass *>(names_to_pass_[0].second);
  229. EXPECT_EQ(pass->GetIterNodes().size(), 4);
  230. std::vector<std::unordered_set<std::string>> layers;
  231. layers.push_back({"data1", "const1"});
  232. layers.push_back({"add1"});
  233. layers.push_back({"reshape1"});
  234. CheckIterOrder(pass, layers);
  235. }
  236. TEST_F(UTEST_graph_passes_base_pass, GraphWithControlLink) {
  237. auto graph = BuildGraph2();
  238. auto ge_pass = GEPass(graph);
  239. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  240. auto *pass = dynamic_cast<TestPass *>(names_to_pass_[0].second);
  241. EXPECT_EQ(pass->GetIterNodes().size(), 8);
  242. EXPECT_EQ(pass->GetIterNodes().at(3)->GetName(), "shape1");
  243. std::vector<std::unordered_set<std::string>> layers;
  244. layers.push_back({"data1", "const1", "const2"});
  245. layers.push_back({"shape1"});
  246. layers.push_back({"add1", "addn1", "reshape1"});
  247. layers.push_back({"sum1"});
  248. CheckIterOrder(pass, layers);
  249. }
  250. TEST_F(UTEST_graph_passes_base_pass, RePassAfter) {
  251. NamesToPass names_to_pass;
  252. auto test_pass = TestPass();
  253. names_to_pass.push_back(std::make_pair("test", &test_pass));
  254. test_pass.AddRePassNodeName("add1", "sum1");
  255. test_pass.AddRePassNodeName("shape1", "sum1");
  256. test_pass.AddRePassNodeName("shape1", "add1");
  257. test_pass.AddRePassNodeName("data1", "add1");
  258. auto graph = BuildGraph2();
  259. auto ge_pass = GEPass(graph);
  260. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  261. EXPECT_EQ(test_pass.GetIterNodes().size(), 8);
  262. }
  263. TEST_F(UTEST_graph_passes_base_pass, RePassBefore) {
  264. NamesToPass names_to_pass;
  265. auto test_pass = TestPass();
  266. names_to_pass.push_back(std::make_pair("test", &test_pass));
  267. test_pass.AddRePassNodeName("add1", "data1");
  268. auto graph = BuildGraph1();
  269. auto ge_pass = GEPass(graph);
  270. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  271. EXPECT_EQ(test_pass.GetIterNodes().size(), 5);
  272. EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1");
  273. EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1");
  274. EXPECT_EQ(test_pass.GetIterNodes().at(4)->GetName(), "data1");
  275. }
  276. TEST_F(UTEST_graph_passes_base_pass, RePassBeforeMultiTimes) {
  277. NamesToPass names_to_pass;
  278. auto test_pass = TestPass();
  279. names_to_pass.push_back(std::make_pair("test", &test_pass));
  280. test_pass.AddRePassNodeName("add1", "data1");
  281. test_pass.AddRePassNodeName("add1", "const1");
  282. test_pass.AddRePassNodeName("reshape1", "data1");
  283. auto graph = BuildGraph1();
  284. auto ge_pass = GEPass(graph);
  285. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  286. EXPECT_EQ(test_pass.GetIterNodes().size(), 6);
  287. EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1");
  288. EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1");
  289. }
  290. TEST_F(UTEST_graph_passes_base_pass, DelAfter) {
  291. NamesToPass names_to_pass;
  292. auto test_pass = TestPass();
  293. names_to_pass.push_back(std::make_pair("test", &test_pass));
  294. test_pass.AddDelNodeName("add1", "sum1");
  295. auto graph = BuildGraph2();
  296. auto ge_pass = GEPass(graph);
  297. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  298. EXPECT_EQ(test_pass.GetIterNodes().size(), 7);
  299. }
  300. TEST_F(UTEST_graph_passes_base_pass, DelAfterMultiple) {
  301. NamesToPass names_to_pass;
  302. auto test_pass = TestPass();
  303. names_to_pass.push_back(std::make_pair("test", &test_pass));
  304. test_pass.AddDelNodeName("add1", "sum1");
  305. test_pass.AddDelNodeName("add1", "reshape1");
  306. auto graph = BuildGraph2();
  307. auto ge_pass = GEPass(graph);
  308. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  309. EXPECT_EQ(test_pass.GetIterNodes().size(), 6);
  310. }
  311. TEST_F(UTEST_graph_passes_base_pass, DelAfterBreakLink) {
  312. NamesToPass names_to_pass;
  313. auto test_pass = TestPass();
  314. names_to_pass.push_back(std::make_pair("test", &test_pass));
  315. test_pass.AddDelNodeName("shape1", "add1");
  316. test_pass.AddDelNodeName("shape1", "addn1");
  317. test_pass.AddRePassNodeName("shape1", "shape1");
  318. test_pass.AddRePassNodeName("shape1", "reshape1");
  319. test_pass.AddRePassNodeName("shape1", "sum1");
  320. auto graph = BuildGraph2();
  321. auto ge_pass = GEPass(graph);
  322. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  323. EXPECT_EQ(test_pass.GetIterNodes().size(), 7);
  324. }
  325. TEST_F(UTEST_graph_passes_base_pass, DelSelfAndAfter) {
  326. NamesToPass names_to_pass;
  327. auto test_pass = TestPass();
  328. names_to_pass.push_back(std::make_pair("test", &test_pass));
  329. test_pass.AddDelNodeName("shape1", "add1");
  330. test_pass.AddDelNodeName("shape1", "addn1");
  331. auto graph = BuildGraph2();
  332. auto ge_pass = GEPass(graph);
  333. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  334. EXPECT_EQ(test_pass.GetIterNodes().size(), 4);
  335. }
  336. TEST_F(UTEST_graph_passes_base_pass, DelBefore) {
  337. NamesToPass names_to_pass;
  338. auto test_pass = TestPass();
  339. names_to_pass.push_back(std::make_pair("test", &test_pass));
  340. test_pass.AddDelNodeName("reshape1", "add1");
  341. test_pass.AddDelNodeName("sum1", "addn1");
  342. auto graph = BuildGraph2();
  343. auto ge_pass = GEPass(graph);
  344. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  345. EXPECT_EQ(test_pass.GetIterNodes().size(), 8);
  346. }
  347. TEST_F(UTEST_graph_passes_base_pass, RePassAndDel) {
  348. NamesToPass names_to_pass;
  349. auto test_pass = TestPass();
  350. names_to_pass.push_back(std::make_pair("test", &test_pass));
  351. test_pass.AddRePassNodeName("add1", "sum1");
  352. test_pass.AddDelNodeName("reshape1", "sum1");
  353. auto graph = BuildGraph2();
  354. auto ge_pass = GEPass(graph);
  355. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  356. EXPECT_EQ(test_pass.GetIterNodes().size(), 7);
  357. }
  358. TEST_F(UTEST_graph_passes_base_pass, DeadLoop) {
  359. NamesToPass names_to_pass;
  360. auto test_pass = TestPass(true);
  361. names_to_pass.push_back(std::make_pair("test", &test_pass));
  362. test_pass.AddRePassNodeName("add1", "sum1");
  363. test_pass.AddRePassNodeName("sum1", "add1");
  364. auto graph = BuildGraph2();
  365. auto ge_pass = GEPass(graph);
  366. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  367. EXPECT_EQ(test_pass.GetRunTimes(), 1007);
  368. }
  369. TEST_F(UTEST_graph_passes_base_pass, WhileLoop) {
  370. NamesToPass names_to_pass;
  371. auto test_pass = TestPass(true);
  372. names_to_pass.push_back(std::make_pair("test", &test_pass));
  373. auto graph = BuildGraph3();
  374. auto ge_pass = GEPass(graph);
  375. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  376. }
  377. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示

Contributors (1)