You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_base_pass_unittest.cc 14 kB


  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "graph/passes/infer_base_pass.h"
  18. #include "graph/debug/ge_attr_define.h"
  19. #include "graph/utils/tensor_utils.h"
  20. #include "graph/utils/graph_utils.h"
  21. #include "graph_builder_utils.h"
  22. using namespace std;
  23. using namespace testing;
  24. namespace ge {
  25. class ChildPassBuilder;
  26. static const char *kInferTimes = "infer_times";
  27. class InferBasePassStub : public InferBasePass {
  28. public:
  29. friend class ChildPassBuilder;
  30. graphStatus Infer(NodePtr &node) override{
  31. call_infer_times++;
  32. for (size_t i = 0; i < node->GetOutDataNodesSize(); ++i) {
  33. auto output_td = node->GetOpDesc()->MutableOutputDesc(i);
  34. int times = 0;
  35. AttrUtils::GetInt(output_td, kInferTimes, times);
  36. AttrUtils::SetInt(output_td, kInferTimes, times + 1);
  37. }
  38. return infer_result_;
  39. };
  40. int32_t call_infer_times = 0;
  41. int32_t call_update_tensor_desc_times = 0;
  42. int32_t call_update_from_subgraph_times = 0;
  43. int32_t call_update_from_subgraph_multi_dims_times = 0;
  44. std::vector<std::pair<GeTensorDescPtr, GeTensorDescPtr>> update_td_pairs;
  45. private:
  46. bool NeedInfer(const NodePtr &node) const override {
  47. return need_infer_;
  48. };
  49. std::string SerialTensorInfo(const GeTensorDescPtr &tensor_desc) const override { return "test SerialTensorInfo"; };
  50. graphStatus UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) override {
  51. call_update_tensor_desc_times++;
  52. changed = td_changed_;
  53. int times = 0;
  54. if (AttrUtils::GetInt(src, kInferTimes, times)) {
  55. AttrUtils::SetInt(dst, kInferTimes, times);
  56. }
  57. update_td_pairs.emplace_back(src, dst);
  58. return GRAPH_SUCCESS;
  59. };
  60. graphStatus UpdateOutputFromSubgraphs(const std::vector<GeTensorDescPtr> &src, GeTensorDescPtr &dst) override {
  61. call_update_from_subgraph_times++;
  62. return GRAPH_SUCCESS;
  63. };
  64. graphStatus UpdateOutputFromSubgraphsForMultiDims(const std::vector<GeTensorDescPtr> &src,
  65. GeTensorDescPtr &dst) override {
  66. call_update_from_subgraph_multi_dims_times++;
  67. return GRAPH_SUCCESS;
  68. };
  69. bool td_changed_;
  70. bool need_infer_;
  71. graphStatus infer_result_;
  72. };
  73. class ChildPassBuilder {
  74. public:
  75. ChildPassBuilder &SetNeedInferFlag(bool flag) {
  76. need_infer_ = flag;
  77. return *this;
  78. }
  79. ChildPassBuilder &SetInferResult(graphStatus ret) {
  80. infer_result_ = ret;
  81. return *this;
  82. }
  83. ChildPassBuilder &SetTdChangedFlag(bool changed_flag) {
  84. td_changed_ = changed_flag;
  85. return *this;
  86. }
  87. InferBasePassStub Build() {
  88. InferBasePassStub ib;
  89. ib.td_changed_ = td_changed_;
  90. ib.need_infer_ = need_infer_;
  91. ib.infer_result_ = infer_result_;
  92. return ib;
  93. }
  94. private:
  95. bool td_changed_ = false;
  96. bool need_infer_ = true;
  97. graphStatus infer_result_ = GRAPH_SUCCESS;
  98. };
  99. class UtestGraphInferBasePassStub : public testing::Test {
  100. protected:
  101. void SetUp() {}
  102. void TearDown() {}
  103. };
  104. /*
  105. * data1 data2
  106. * \ /
  107. * sub1
  108. * |
  109. * netoutput
  110. */
  111. ut::GraphBuilder TestSubgraphBuilder() {
  112. ut::GraphBuilder builder = ut::GraphBuilder("branch_graph");
  113. std::vector<int64_t> shape1 = {1,1};
  114. auto data1 = builder.AddNode("data1_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1);
  115. auto data1_desc = data1->GetOpDesc();
  116. EXPECT_NE(data1_desc, nullptr);
  117. AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);
  118. std::vector<int64_t> shape2 = {2,2};
  119. auto data2 = builder.AddNode("data2_1", "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2);
  120. auto data2_desc = data2->GetOpDesc();
  121. EXPECT_NE(data2_desc, nullptr);
  122. AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);
  123. auto sub1 = builder.AddNode("Sub", "Sub", 2, 1);
  124. std::vector<int64_t> shape7 = {8,8};
  125. auto netoutput = builder.AddNode("output", NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7);
  126. auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0);
  127. EXPECT_NE(input0_desc, nullptr);
  128. AttrUtils::SetInt(input0_desc, "_parent_node_index", 0);
  129. builder.AddDataEdge(data1, 0, sub1, 0);
  130. builder.AddDataEdge(data2, 0, sub1, 1);
  131. builder.AddDataEdge(sub1, 0, netoutput, 0);
  132. return builder;
  133. }
  134. /*
  135. * data1 data2
  136. * \ /
  137. * case1
  138. * |
  139. * netoutput
  140. */
  141. ut::GraphBuilder RootGraphBuilder() {
  142. ut::GraphBuilder builder = ut::GraphBuilder("root_graph");
  143. auto data1 = builder.AddNode("data1", "Data", 0, 1);
  144. auto data2 = builder.AddNode("data2", "Data", 0, 1);
  145. auto case1 = builder.AddNode("case1", CASE, 2, 1);
  146. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  147. builder.AddDataEdge(data1, 0, case1, 0);
  148. builder.AddDataEdge(data2, 0, case1, 1);
  149. builder.AddDataEdge(case1, 0, netoutput, 0);
  150. auto parent_graph = builder.GetGraph();
  151. auto subgraph_builder = TestSubgraphBuilder();
  152. auto subgraph = subgraph_builder.GetGraph();
  153. case1->GetOpDesc()->AddSubgraphName(subgraph->GetName());
  154. case1->GetOpDesc()->SetSubgraphInstanceName(0, subgraph->GetName());
  155. subgraph->SetParentNode(case1);
  156. subgraph->SetParentGraph(parent_graph);
  157. EXPECT_EQ(parent_graph->AddSubgraph(subgraph->GetName(), subgraph), GRAPH_SUCCESS);
  158. return builder;
  159. }
  160. /*
  161. * data1 data2
  162. * \ /
  163. * add1
  164. * |
  165. * netoutput
  166. */
  167. ut::GraphBuilder NoSubgraphBuilder() {
  168. ut::GraphBuilder builder = ut::GraphBuilder("no_subgraph");
  169. auto data1 = builder.AddNode("data1", "Data", 0, 1);
  170. auto data2 = builder.AddNode("data2", "Data", 0, 1);
  171. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  172. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  173. builder.AddDataEdge(data1, 0, add1, 0);
  174. builder.AddDataEdge(data2, 0, add1, 1);
  175. builder.AddDataEdge(add1, 0, netoutput, 0);
  176. return builder;
  177. }
  178. TEST_F(UtestGraphInferBasePassStub, CallInfer_WhenNeedInferReturnTrue) {
  179. auto builder = NoSubgraphBuilder();
  180. auto test_graph = builder.GetGraph();
  181. auto add_node = test_graph->FindNode("add1");
  182. EXPECT_NE(add_node, nullptr);
  183. ChildPassBuilder pass_builder;
  184. auto stub_base_pass = pass_builder.Build();
  185. // NeedInfer return true
  186. EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS);
  187. EXPECT_EQ(stub_base_pass.call_infer_times, 1);
  188. int times = -1;
  189. EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times));
  190. EXPECT_EQ(times, 1);
  191. }
  192. TEST_F(UtestGraphInferBasePassStub, NotCallInfer_WhenNeedInferReturnFalse) {
  193. auto builder = NoSubgraphBuilder();
  194. auto test_graph = builder.GetGraph();
  195. auto add_node = test_graph->FindNode("add1");
  196. EXPECT_NE(add_node, nullptr);
  197. ChildPassBuilder pass_builder;
  198. auto stub_base_pass = pass_builder.SetNeedInferFlag(false).Build();
  199. // NeedInfer return false
  200. EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS);
  201. EXPECT_EQ(stub_base_pass.call_infer_times, 0);
  202. int times = -1;
  203. EXPECT_FALSE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times));
  204. }
  205. TEST_F(UtestGraphInferBasePassStub, NotAddCurNodeRepass_CallUpdatePeerNode_WhenInferReturnSuccess) {
  206. auto builder = NoSubgraphBuilder();
  207. auto test_graph = builder.GetGraph();
  208. auto add_node = test_graph->FindNode("add1");
  209. auto netoutput = test_graph->FindNode("netoutput");
  210. EXPECT_NE(add_node, nullptr);
  211. EXPECT_NE(netoutput, nullptr);
  212. ChildPassBuilder pass_builder;
  213. auto stub_base_pass = pass_builder.Build();
  214. EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS);
  215. EXPECT_EQ(stub_base_pass.call_infer_times, 1);
  216. EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1);
  217. std::vector<std::pair<GeTensorDescPtr, GeTensorDescPtr>> expected_updated_tensor_desc_pairs = {
  218. {add_node->GetOpDesc()->MutableOutputDesc(0), netoutput->GetOpDesc()->MutableInputDesc(0)}};
  219. EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs);
  220. EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({}));
  221. }
  222. TEST_F(UtestGraphInferBasePassStub, AddCurNodeRepass_NotCallUpdatePeerNode_WhenInferReturnNeedRepass) {
  223. auto builder = NoSubgraphBuilder();
  224. auto test_graph = builder.GetGraph();
  225. auto add_node = test_graph->FindNode("add1");
  226. EXPECT_NE(add_node, nullptr);
  227. ChildPassBuilder pass_builder;
  228. auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build();
  229. // do re_pass
  230. EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS);
  231. EXPECT_EQ(stub_base_pass.call_infer_times, 1);
  232. EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 0);
  233. // EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({add_node}));
  234. }
  235. TEST_F(UtestGraphInferBasePassStub, NotAddPeerNodeRepass_AfterUpdatePeerNode_WhenUnchanged) {
  236. auto builder = NoSubgraphBuilder();
  237. auto test_graph = builder.GetGraph();
  238. auto add_node = test_graph->FindNode("add1");
  239. auto netoutput = test_graph->FindNode("netoutput");
  240. EXPECT_NE(add_node, nullptr);
  241. EXPECT_NE(netoutput, nullptr);
  242. ChildPassBuilder pass_builder;
  243. auto stub_base_pass = pass_builder.Build();
  244. EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS);
  245. EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1);
  246. EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({}));
  247. int times = -1;
  248. EXPECT_TRUE(AttrUtils::GetInt(add_node->GetOpDesc()->GetOutputDescPtr(0), kInferTimes, times));
  249. EXPECT_EQ(times, 1);
  250. times = -1;
  251. EXPECT_TRUE(AttrUtils::GetInt(netoutput->GetOpDesc()->GetInputDescPtr(0), kInferTimes, times));
  252. EXPECT_EQ(times, 1);
  253. }
  254. TEST_F(UtestGraphInferBasePassStub, AddPeerNodeRepass_AfterUpdatePeerNode_WhenChanged) {
  255. auto builder = NoSubgraphBuilder();
  256. auto test_graph = builder.GetGraph();
  257. auto add_node = test_graph->FindNode("add1");
  258. auto netoutput = test_graph->FindNode("netoutput");
  259. EXPECT_NE(add_node, nullptr);
  260. EXPECT_NE(netoutput, nullptr);
  261. ChildPassBuilder pass_builder;
  262. auto stub_base_pass = pass_builder.SetTdChangedFlag(true).Build();
  263. EXPECT_EQ(stub_base_pass.Run(add_node), SUCCESS);
  264. EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 1);
  265. // EXPECT_EQ(stub_base_pass.GetNodesNeedRePassImmediately(), std::unordered_set<NodePtr>({netoutput}));
  266. }
  267. TEST_F(UtestGraphInferBasePassStub, TestUpdateSubgraphData_WhenBeforeSubgraph) {
  268. auto builder = RootGraphBuilder();
  269. auto parent_graph = builder.GetGraph();
  270. auto subgraphs = parent_graph->GetAllSubgraphs();
  271. EXPECT_EQ(subgraphs.size(), 1);
  272. auto case_node = parent_graph->FindNode("case1");
  273. auto data1 = subgraphs[0]->FindNode("data1_1");
  274. auto data2 = subgraphs[0]->FindNode("data2_1");
  275. EXPECT_NE(case_node, nullptr);
  276. EXPECT_NE(data1, nullptr);
  277. EXPECT_NE(data2, nullptr);
  278. ChildPassBuilder pass_builder;
  279. auto stub_base_pass = pass_builder.SetInferResult(GRAPH_NODE_NEED_REPASS).Build();
  280. EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS);
  281. // when GRAPH_NODE_NEED_REPASS, not update peer node, only update two data, update input and output, 2*2
  282. EXPECT_EQ(stub_base_pass.call_update_tensor_desc_times, 4);
  283. std::vector<std::pair<GeTensorDescPtr, GeTensorDescPtr>> expected_updated_tensor_desc_pairs = {
  284. {case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableInputDesc(0)},
  285. {case_node->GetOpDesc()->MutableInputDesc(0), data1->GetOpDesc()->MutableOutputDesc(0)},
  286. {case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableInputDesc(0)},
  287. {case_node->GetOpDesc()->MutableInputDesc(1), data2->GetOpDesc()->MutableOutputDesc(0)},
  288. };
  289. EXPECT_EQ(stub_base_pass.update_td_pairs, expected_updated_tensor_desc_pairs);
  290. }
  291. TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutput_WhenAfterSubgraph) {
  292. auto builder = RootGraphBuilder();
  293. auto parent_graph = builder.GetGraph();
  294. auto subgraphs = parent_graph->GetAllSubgraphs();
  295. EXPECT_EQ(subgraphs.size(), 1);
  296. auto case_node = parent_graph->FindNode("case1");
  297. EXPECT_NE(case_node, nullptr);
  298. ChildPassBuilder pass_builder;
  299. auto stub_base_pass = pass_builder.Build();
  300. stub_base_pass.SetOption(kOptimizeAfterSubGraph, "");
  301. EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS);
  302. EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 1);
  303. EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 0);
  304. }
  305. TEST_F(UtestGraphInferBasePassStub, TestUpdateParentNodeOutputForMultiDims_WhenAfterSubgraph) {
  306. auto builder = RootGraphBuilder();
  307. auto parent_graph = builder.GetGraph();
  308. auto subgraphs = parent_graph->GetAllSubgraphs();
  309. EXPECT_EQ(subgraphs.size(), 1);
  310. auto case_node = parent_graph->FindNode("case1");
  311. auto set_ret = AttrUtils::SetInt(case_node->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2);
  312. EXPECT_EQ(set_ret, true);
  313. EXPECT_NE(case_node, nullptr);
  314. ChildPassBuilder pass_builder;
  315. auto stub_base_pass = pass_builder.Build();
  316. stub_base_pass.SetOption(kOptimizeAfterSubGraph, "");
  317. EXPECT_EQ(stub_base_pass.Run(case_node), SUCCESS);
  318. EXPECT_EQ(stub_base_pass.call_update_from_subgraph_times, 0);
  319. EXPECT_EQ(stub_base_pass.call_update_from_subgraph_multi_dims_times, 1);
  320. }
  321. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示