You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

subgraph_pass_unittest.cc 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cstdint>
  17. #include <string>
  18. #include <gtest/gtest.h>
  19. #include "graph/passes/subgraph_pass.h"
  20. #include "inc/pass_manager.h"
  21. namespace ge {
  22. namespace {
  23. class UtestGraphPassesSubgraphPass : public testing::Test {
  24. protected:
  25. void SetUp() {}
  26. void TearDown() {}
  27. };
  28. OpDescPtr CreateOpDesc(const std::string name, const std::string type, uint32_t input_num, uint32_t output_num) {
  29. OpDescPtr op_desc = std::shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, type));
  30. if (op_desc == nullptr) {
  31. return nullptr;
  32. }
  33. for (uint32_t i = 0; i < input_num; i++) {
  34. op_desc->AddInputDesc(GeTensorDesc());
  35. }
  36. for (uint32_t i = 0; i < output_num; i++) {
  37. op_desc->AddOutputDesc(GeTensorDesc());
  38. }
  39. return op_desc;
  40. }
  41. bool CheckMemcpyExist(const ComputeGraphPtr &graph) {
  42. for (const auto &node : graph->GetDirectNode()) {
  43. if (node->GetType() == IDENTITY) {
  44. return true;
  45. }
  46. }
  47. return false;
  48. }
  49. uint32_t CheckMemcpyNum(const ComputeGraphPtr &graph) {
  50. uint32_t num = 0;
  51. for (const auto &node : graph->GetDirectNode()) {
  52. if (node->GetType() == IDENTITY) {
  53. num++;
  54. }
  55. }
  56. return num;
  57. }
  58. } // namespace
  59. ///
  60. /// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 *****
  61. /// * *
  62. /// Case * Const * Data
  63. /// / \ * | * |
  64. /// data_0 data_1 * NetOutput * NetOutput
  65. /// * *
  66. /// ****** root_graph ****** * ****** subgraph branch1 ***** * ****** subgraph branch2 *****
  67. ///
  68. TEST(UtestGraphPassesSubgraphPass, add_memcpy_success) {
  69. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("add_memcpy_success");
  70. NodePtr func_node = graph->AddNode(CreateOpDesc("Case", CASE, 2, 1));
  71. NodePtr data_node_0 = graph->AddNode(CreateOpDesc("data_0", DATA, 1, 1));
  72. NodePtr data_node_1 = graph->AddNode(CreateOpDesc("data_1", DATA, 1, 1));
  73. EXPECT_EQ(GraphUtils::AddEdge(data_node_0->GetOutDataAnchor(0), func_node->GetInDataAnchor(0)), GRAPH_SUCCESS);
  74. EXPECT_EQ(GraphUtils::AddEdge(data_node_1->GetOutDataAnchor(0), func_node->GetInDataAnchor(1)), GRAPH_SUCCESS);
  75. std::string subgraph_name_1 = "instance_branch_1";
  76. ComputeGraphPtr subgraph_1 = std::make_shared<ComputeGraph>(subgraph_name_1);
  77. subgraph_1->SetParentNode(func_node);
  78. subgraph_1->SetParentGraph(graph);
  79. size_t index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  80. EXPECT_EQ(index, 0);
  81. func_node->GetOpDesc()->AddSubgraphName("branch1");
  82. EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1);
  83. func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_1);
  84. EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 1);
  85. std::string subgraph_name_2 = "instance_branch_2";
  86. ComputeGraphPtr subgraph_2 = std::make_shared<ComputeGraph>(subgraph_name_2);
  87. subgraph_2->SetParentNode(func_node);
  88. subgraph_2->SetParentGraph(graph);
  89. index = func_node->GetOpDesc()->GetSubgraphInstanceNames().size();
  90. EXPECT_EQ(index, 1);
  91. func_node->GetOpDesc()->AddSubgraphName("branch2");
  92. EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2);
  93. func_node->GetOpDesc()->SetSubgraphInstanceName(index, subgraph_name_2);
  94. EXPECT_EQ(func_node->GetOpDesc()->GetSubgraphInstanceNames().size(), 2);
  95. {
  96. // Const->NetOutput in subgraph
  97. NodePtr const_node = subgraph_1->AddNode(CreateOpDesc("const", CONSTANTOP, 0, 1));
  98. NodePtr output_node = subgraph_1->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1));
  99. EXPECT_EQ(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS);
  100. }
  101. {
  102. // Data->NetOutput in subgraph but not while body
  103. NodePtr data_node = subgraph_2->AddNode(CreateOpDesc("sata", DATA, 1, 1));
  104. NodePtr output_node = subgraph_2->AddNode(CreateOpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT, 1, 1));
  105. EXPECT_EQ(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)), SUCCESS);
  106. EXPECT_TRUE(AttrUtils::SetInt(data_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1));
  107. }
  108. PassManager pass_manager;
  109. pass_manager.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass);
  110. EXPECT_EQ(pass_manager.Run(graph), SUCCESS);
  111. EXPECT_FALSE(CheckMemcpyExist(graph));
  112. EXPECT_EQ(pass_manager.Run(subgraph_1), SUCCESS);
  113. EXPECT_EQ(CheckMemcpyNum(subgraph_1), 1);
  114. EXPECT_EQ(pass_manager.Run(subgraph_2), SUCCESS);
  115. EXPECT_EQ(CheckMemcpyNum(subgraph_2), 1);
  116. }
  117. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示