You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_base_pass_unittest.cc 8.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #define protected public
  18. #define private public
  19. #include "graph/passes/infer_base_pass.h"
  20. #include "graph/utils/tensor_utils.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph_builder_utils.h"
  23. #include "inc/external/graph/operator_reg.h"
  24. #include "inc/external/graph/operator.h"
  25. #include "inc/external/graph/operator_factory.h"
  26. #include "inc/graph/operator_factory_impl.h"
  27. using namespace std;
  28. using namespace testing;
  29. namespace ge {
  30. class InferBasePassStub : public InferBasePass {
  31. public:
  32. graphStatus Infer(NodePtr &node) override{
  33. auto op_desc = node->GetOpDesc();
  34. auto input_desc = op_desc->MutableInputDesc(0);
  35. auto output_desc = op_desc->MutableOutputDesc(0);
  36. if (input_desc->GetShape().GetDims() != output_desc->GetShape().GetDims()) {
  37. input_desc->SetShape(output_desc->GetShape());
  38. return GRAPH_NODE_NEED_REPASS;
  39. }
  40. return GRAPH_SUCCESS;
  41. };
  42. private:
  43. std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override { return "test SerialTensorInfo"; };
  44. graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override {
  45. if (src->GetShape().GetDims() != dst->GetShape().GetDims()) {
  46. changed = true;
  47. } else {
  48. changed = false;
  49. }
  50. dst->SetShape(src->GetShape());
  51. return GRAPH_SUCCESS;
  52. };
  53. graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override {
  54. dst->SetShape(src[0]->GetShape());
  55. return GRAPH_SUCCESS;
  56. };
  57. graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
  58. GeTensorDescPtr &dst) override {
  59. dst->SetShape(src[0]->GetShape());
  60. return GRAPH_SUCCESS;
  61. };
  62. };
  63. class UtestGraphInferBasePassStub : public testing::Test {
  64. protected:
  65. void SetUp() {}
  66. void TearDown() {}
  67. };
  68. /*
  69. * data1 data2
  70. * \ /
  71. * merge
  72. * |
  73. * netoutput
  74. */
  75. ut::GraphBuilder TestSubgraphBuilder() {
  76. ut::GraphBuilder builder = ut::GraphBuilder("branch_graph");
  77. std::vector<int64_t> shape1 = {1,1};
  78. auto data1 = builder.AddNode("data1_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1);
  79. auto data1_desc = data1->GetOpDesc();
  80. EXPECT_NE(data1_desc, nullptr);
  81. AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);
  82. std::vector<int64_t> shape2 = {2,2};
  83. auto data2 = builder.AddNode("data2_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2);
  84. auto data2_desc = data2->GetOpDesc();
  85. EXPECT_NE(data2_desc, nullptr);
  86. AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);
  87. auto merge = builder.AddNode("merge", "Merge", 2, 1);
  88. std::vector<int64_t> shape7 = {8,8};
  89. auto netoutput = builder.AddNode("output", NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7);
  90. auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0);
  91. EXPECT_NE(input0_desc, nullptr);
  92. AttrUtils::SetInt(input0_desc, "_parent_node_index", 0);
  93. builder.AddDataEdge(data1, 0, merge, 0);
  94. builder.AddDataEdge(data2, 0, merge, 1);
  95. builder.AddDataEdge(merge, 0, netoutput, 0);
  96. return builder;
  97. }
  98. /*
  99. * data1 data2
  100. * \ /
  101. * case1
  102. * |
  103. * netoutput
  104. */
  105. ut::GraphBuilder RootGraphBuilder() {
  106. ut::GraphBuilder builder = ut::GraphBuilder("root_graph");
  107. auto data1 = builder.AddNode("data1", "Data", 0, 1);
  108. auto data2 = builder.AddNode("data2", "Data", 0, 1);
  109. auto case1 = builder.AddNode("case1", CASE, 2, 1);
  110. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  111. builder.AddDataEdge(data1, 0, case1, 0);
  112. builder.AddDataEdge(data2, 0, case1, 1);
  113. builder.AddDataEdge(case1, 0, netoutput, 0);
  114. auto parent_graph = builder.GetGraph();
  115. auto subgraph_builder = TestSubgraphBuilder();
  116. auto subgraph = subgraph_builder.GetGraph();
  117. case1->GetOpDesc()->AddSubgraphName(subgraph->GetName());
  118. case1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName());
  119. subgraph->SetParentNode(case1);
  120. subgraph->SetParentGraph(parent_graph);
  121. EXPECT_EQ(parent_graph->AddSubgraph(subgraph->GetName(), subgraph), GRAPH_SUCCESS);
  122. return builder;
  123. }
  124. TEST_F(UtestGraphInferBasePassStub, infer_base_before_subgraph) {
  125. auto builder = RootGraphBuilder();
  126. auto parent_graph = builder.GetGraph();
  127. auto subgraphs = parent_graph->GetAllSubgraphs();
  128. EXPECT_EQ(subgraphs.size(), 1);
  129. // check base pass run
  130. auto case_node = parent_graph->FindNode("case1");
  131. EXPECT_NE(case_node, nullptr);
  132. InferBasePassStub base_pass;
  133. EXPECT_EQ(base_pass.Run(case_node), SUCCESS);
  134. // check subgraph data update
  135. auto data_node = subgraphs[0]->FindNode("data1_1");
  136. auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0);
  137. auto data_out_0_dims = data_out_0_desc->GetShape().GetDims();
  138. EXPECT_EQ(data_out_0_dims.size(), 4);
  139. std::vector<int64_t> data_target_dims = {1, 1, 224, 224};
  140. EXPECT_EQ(data_out_0_dims, data_target_dims);
  141. // check peer input update
  142. auto netoutput_node = parent_graph->FindNode("netoutput");
  143. EXPECT_NE(netoutput_node, nullptr);
  144. auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0);
  145. auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims();
  146. EXPECT_EQ(netoutput_in_0_dims.size(), 4);
  147. std::vector<int64_t> target_dims = {1, 1, 224, 224};
  148. EXPECT_EQ(netoutput_in_0_dims, target_dims);
  149. }
  150. TEST_F(UtestGraphInferBasePassStub, infer_base_after_subgraph_need_repass) {
  151. auto builder = RootGraphBuilder();
  152. auto parent_graph = builder.GetGraph();
  153. auto subgraphs = parent_graph->GetAllSubgraphs();
  154. EXPECT_EQ(subgraphs.size(), 1);
  155. // check base pass run
  156. auto case_node = parent_graph->FindNode("case1");
  157. EXPECT_NE(case_node, nullptr);
  158. InferBasePassStub base_pass;
  159. base_pass.options_[kOptimizeAfterSubGraph] = "yes";
  160. EXPECT_EQ(base_pass.Run(case_node), SUCCESS);
  161. // check subgraph data update
  162. auto data_node = subgraphs[0]->FindNode("data1_1");
  163. auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0);
  164. auto data_out_0_dims = data_out_0_desc->GetShape().GetDims();
  165. EXPECT_EQ(data_out_0_dims.size(), 2);
  166. std::vector<int64_t> data_target_dims = {1, 1};
  167. EXPECT_EQ(data_out_0_dims, data_target_dims);
  168. // check peer input update
  169. auto netoutput_node = parent_graph->FindNode("netoutput");
  170. EXPECT_NE(netoutput_node, nullptr);
  171. auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0);
  172. auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims();
  173. EXPECT_EQ(netoutput_in_0_dims.size(), 4);
  174. std::vector<int64_t> target_dims = {1, 1, 224, 224};
  175. EXPECT_EQ(netoutput_in_0_dims, target_dims);
  176. }
  177. TEST_F(UtestGraphInferBasePassStub, infer_base_after_subgraph_no_repass) {
  178. auto builder = RootGraphBuilder();
  179. auto parent_graph = builder.GetGraph();
  180. auto subgraphs = parent_graph->GetAllSubgraphs();
  181. EXPECT_EQ(subgraphs.size(), 1);
  182. // check base pass run
  183. auto case_node = parent_graph->FindNode("case1");
  184. EXPECT_NE(case_node, nullptr);
  185. // update case in shape, do not re_pass
  186. auto case_in_shape_no_repass = GeShape({8,8});
  187. case_node->GetOpDesc()->MutableInputDesc(0)->SetShape(case_in_shape_no_repass);
  188. InferBasePassStub base_pass;
  189. base_pass.options_[kOptimizeAfterSubGraph] = "yes";
  190. EXPECT_EQ(base_pass.Run(case_node), SUCCESS);
  191. // check subgraph data update
  192. auto data_node = subgraphs[0]->FindNode("data1_1");
  193. auto data_out_0_desc = data_node->GetOpDesc()->MutableOutputDesc(0);
  194. auto data_out_0_dims = data_out_0_desc->GetShape().GetDims();
  195. EXPECT_EQ(data_out_0_dims.size(), 2);
  196. std::vector<int64_t> data_target_dims = {1, 1};
  197. EXPECT_EQ(data_out_0_dims, data_target_dims);
  198. // check peer input update
  199. auto netoutput_node = parent_graph->FindNode("netoutput");
  200. EXPECT_NE(netoutput_node, nullptr);
  201. auto netoutput_in_0_desc = netoutput_node->GetOpDesc()->MutableInputDesc(0);
  202. auto netoutput_in_0_dims = netoutput_in_0_desc->GetShape().GetDims();
  203. EXPECT_EQ(netoutput_in_0_dims.size(), 2);
  204. std::vector<int64_t> target_dims = {8,8};
  205. EXPECT_EQ(netoutput_in_0_dims, target_dims);
  206. }
  207. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示