You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_group_pass_unittest.cc 14 kB

4 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <cstdint>
  18. #include <string>
  19. #define private public
  20. #include "common/ge_inner_error_codes.h"
  21. #include "inc/pass_manager.h"
  22. #include "utils/graph_utils.h"
  23. #include "graph/passes/parallel_group_pass.h"
  24. #undef private
  25. namespace ge {
  26. namespace {
  27. class UtestGraphPassesParallelGgroupPass : public testing::Test {
  28. protected:
  29. UtestGraphPassesParallelGgroupPass() {
  30. graph_ = std::make_shared<ComputeGraph>("test");
  31. sub_graph_ = std::make_shared<ComputeGraph>("test_subgraph");
  32. vector<int64_t> shape_vec{1, 1, 1, 1};
  33. GeShape shape = GeShape(shape_vec);
  34. default_tensor_desc_ = std::make_shared<GeTensorDesc>();
  35. default_tensor_desc_->SetShape(shape);
  36. default_tensor_desc_->SetFormat(FORMAT_NCHW);
  37. default_tensor_desc_->SetDataType(DT_FLOAT);
  38. }
  39. NodePtr NewNode(const std::string &name, const std::string &type,
  40. int input_cnt, int output_cnt, bool isSubgraph = false) {
  41. OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
  42. for (int i = 0; i < input_cnt; ++i) {
  43. op_desc->AddInputDesc(default_tensor_desc_->Clone());
  44. }
  45. for (int i = 0; i < output_cnt; ++i) {
  46. op_desc->AddOutputDesc(default_tensor_desc_->Clone());
  47. }
  48. NodePtr node = nullptr;
  49. if (isSubgraph) {
  50. node = sub_graph_->AddNode(op_desc);
  51. (void)node->SetOwnerComputeGraph(sub_graph_);
  52. } else {
  53. node = graph_->AddNode(op_desc);
  54. (void)node->SetOwnerComputeGraph(graph_);
  55. }
  56. return node;
  57. }
  58. void BuildDefaultGraph() {
  59. /// input
  60. /// \
  61. /// sqrt pred
  62. /// \ /
  63. /// cast
  64. /// / \
  65. /// switch_t switch_f
  66. /// | |
  67. /// F T
  68. /// | |
  69. /// Merge
  70. /// |
  71. /// relu
  72. /// |
  73. /// sqrt1
  74. input_node_ = NewNode("input", RELU, 0, 1);
  75. sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
  76. pred_node_ = NewNode("pred", GREATER, 2, 1);
  77. cast_node_ = NewNode("cast", CAST, 2, 2);
  78. AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  79. switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
  80. AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  81. switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
  82. AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  83. output_false_node_ = NewNode("false_output", RELU, 1, 1);
  84. AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  85. output_true_node_ = NewNode("true_output", RELU, 1, 1);
  86. AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  87. merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
  88. relu_node_ = NewNode("relu", RELU, 1, 1);
  89. sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1);
  90. AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  91. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
  92. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
  93. GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
  94. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
  95. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
  96. GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  97. GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  98. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  99. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  100. GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0));
  101. GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
  102. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  103. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  104. }
  105. void BuildDefaultGraph1() {
  106. /// input
  107. /// \
  108. /// sqrt pred
  109. /// \ /
  110. /// Switch
  111. /// | |
  112. /// ----F T----
  113. /// \ | / \
  114. /// \ Merge1 Merge2
  115. /// \_________|
  116. input_node_ = NewNode("input", RELU, 0, 1);
  117. AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  118. pred_node_ = NewNode("pred", GREATER, 2, 1);
  119. sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
  120. cast_node_ = NewNode("cast", CAST, 2, 2);
  121. switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
  122. AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  123. switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
  124. AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  125. output_false_node_ = NewNode("false_output", RELU, 1, 2);
  126. AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  127. output_true_node_ = NewNode("true_output", RELU, 1, 2);
  128. AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  129. merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
  130. merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);
  131. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
  132. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
  133. GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
  134. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
  135. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
  136. GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  137. GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  138. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  139. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  140. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
  141. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));
  142. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  143. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  144. }
  145. void BuildDefaultGraph2() {
  146. /// input input1
  147. /// \ \
  148. /// sqrt pred sqrt1 pred1
  149. /// \ / \ /
  150. /// Switch Switch1
  151. /// | | _______|
  152. /// | | /
  153. /// ____F T____
  154. /// \ | / \
  155. /// \ Merge1 Merge2
  156. /// \__________|
  157. input_node_ = NewNode("input", RELU, 0, 2);
  158. input_node1_ = NewNode("input_1", RELU, 0, 2);
  159. sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
  160. pred_node_ = NewNode("pred", GREATER, 2, 1);
  161. sqrt_node1_ = NewNode("sqrt_1", SQRT, 1, 1);
  162. pred_node1_ = NewNode("pred_1", LESS, 2, 1);
  163. cast_node_ = NewNode("cast", CAST, 2, 2);
  164. cast_node1_ = NewNode("cast_1", CAST, 2, 2);
  165. AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  166. AttrUtils::SetStr(input_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2");
  167. switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
  168. AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  169. switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
  170. AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  171. switch_node1_t = NewNode("switch1_t", STREAMSWITCH, 1, 1);
  172. AttrUtils::SetBool(switch_node1_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  173. switch_node1_f = NewNode("switch1_f", STREAMSWITCH, 1, 1);
  174. AttrUtils::SetBool(switch_node1_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  175. output_false_node_ = NewNode("false_output", RELU, 2, 2);
  176. AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  177. output_true_node_ = NewNode("true_output", RELU, 2, 2);
  178. AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2");
  179. merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
  180. merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);
  181. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
  182. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
  183. GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
  184. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
  185. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
  186. GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  187. GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  188. GraphUtils::AddEdge(input_node1_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
  189. GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(0));
  190. GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(1));
  191. GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(0), switch_node1_t->GetInDataAnchor(0));
  192. GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(1), switch_node1_f->GetInDataAnchor(0));
  193. GraphUtils::AddEdge(switch_node1_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1));
  194. GraphUtils::AddEdge(switch_node1_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(1));
  195. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  196. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  197. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
  198. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));
  199. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  200. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  201. }
  202. ComputeGraphPtr graph_;
  203. ComputeGraphPtr sub_graph_;
  204. GeTensorDescPtr default_tensor_desc_;
  205. ParallelGroupPass pass_;
  206. NodePtr pred_node_;
  207. NodePtr pred_node1_;
  208. NodePtr cast_node_;
  209. NodePtr cast_node1_;
  210. NodePtr sqrt_node_;
  211. NodePtr sqrt_node1_;
  212. NodePtr input_node_;
  213. NodePtr input_node1_;
  214. NodePtr switch_node_t;
  215. NodePtr switch_node_f;
  216. NodePtr switch_node1_t;
  217. NodePtr switch_node1_f;
  218. NodePtr output_false_node_;
  219. NodePtr output_true_node_;
  220. NodePtr merge_node_;
  221. NodePtr merge_node1_;
  222. NodePtr relu_node_;
  223. };
  224. TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) {
  225. ComputeGraphPtr graph = nullptr;
  226. auto ret = pass_.Run(graph);
  227. EXPECT_EQ(ret, PARAM_INVALID);
  228. }
  229. TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) {
  230. BuildDefaultGraph();
  231. auto ret = pass_.Run(graph_);
  232. EXPECT_EQ(ret, GRAPH_SUCCESS);
  233. EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
  234. EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor()));
  235. EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor()));
  236. }
  237. TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) {
  238. BuildDefaultGraph1();
  239. auto ret = pass_.Run(graph_);
  240. EXPECT_EQ(ret, GRAPH_SUCCESS);
  241. EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
  242. }
  243. TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) {
  244. BuildDefaultGraph2();
  245. auto ret = pass_.Run(graph_);
  246. EXPECT_EQ(ret, GRAPH_SUCCESS);
  247. EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
  248. EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor()));
  249. }
  250. TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) {
  251. BuildDefaultGraph1();
  252. NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true);
  253. NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true);
  254. NodePtr add = NewNode("add", ADD, 2, 1, true);
  255. AttrUtils::SetStr(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  256. AttrUtils::SetStr(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  257. sub_graph_->SetParentNode(input_node_);
  258. sub_graph_->SetParentGraph(graph_);
  259. auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_);
  260. EXPECT_EQ(ret, GRAPH_SUCCESS);
  261. ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName());
  262. EXPECT_EQ(ret, GRAPH_SUCCESS);
  263. ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName());
  264. EXPECT_EQ(ret, GRAPH_SUCCESS);
  265. ret = pass_.Run(sub_graph_);
  266. EXPECT_EQ(ret, GRAPH_SUCCESS);
  267. ret = pass_.Run(graph_);
  268. EXPECT_EQ(ret, GRAPH_SUCCESS);
  269. }
  270. } // namespace
  271. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示