You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

addn_pass_unittest.cc 8.2 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cstdint>
  17. #include <string>
  18. #include <gtest/gtest.h>
  19. #include "common/ge_inner_error_codes.h"
  20. #include "graph/passes/addn_pass.h"
  21. using namespace domi;
  22. namespace ge {
  23. using namespace domi;
  24. namespace {
  25. GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW,
  26. DataType data_type = DT_FLOAT) {
  27. GeShape ge_shape{vector<int64_t>(shape)};
  28. GeTensorDescPtr tensor_desc = std::make_shared<GeTensorDesc>();
  29. tensor_desc->SetShape(ge_shape);
  30. tensor_desc->SetFormat(format);
  31. tensor_desc->SetDataType(data_type);
  32. return tensor_desc;
  33. }
  34. class NodeBuilder {
  35. public:
  36. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  37. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
  38. DataType data_type = DT_FLOAT) {
  39. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  40. return *this;
  41. }
  42. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
  43. DataType data_type = DT_FLOAT) {
  44. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  45. return *this;
  46. }
  47. NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) {
  48. op_desc_->AddOutputDesc(tensor_desc->Clone());
  49. return *this;
  50. }
  51. NodePtr Build(const ComputeGraphPtr &graph) {
  52. NodePtr node = graph->AddNode(op_desc_);
  53. return node;
  54. }
  55. private:
  56. OpDescPtr op_desc_;
  57. };
  58. } // namespace
  59. TEST(UTEST_graph_passes_addn_pass, NullPass) {
  60. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  61. GEPass pass(graph);
  62. AddNPass *addn_pass = nullptr;
  63. NamesToPass names_to_pass;
  64. names_to_pass.emplace_back("Test", addn_pass);
  65. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  66. }
  67. TEST(UTEST_graph_passes_addn_pass, NullGraph) {
  68. ComputeGraphPtr graph = nullptr;
  69. GEPass pass(graph);
  70. AddNPass addn_pass;
  71. NamesToPass names_to_pass;
  72. names_to_pass.emplace_back("Test", nullptr);
  73. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  74. }
  75. TEST(UTEST_graph_passes_addn_pass, EmptyPass) {
  76. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  77. GEPass pass(graph);
  78. AddNPass addn_pass;
  79. NamesToPass names_to_pass;
  80. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  81. }
  82. /// |
  83. /// AddN
  84. /// |
  85. TEST(UTEST_graph_passes_addn_pass, SingleAddNNode) {
  86. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  87. GeTensorDescPtr generalGeTensorDesc = std::make_shared<GeTensorDesc>();
  88. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN).Build(graph);
  89. GEPass pass(graph);
  90. AddNPass addn_pass;
  91. NamesToPass names_to_pass;
  92. names_to_pass.emplace_back("Test", &addn_pass);
  93. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  94. EXPECT_EQ(graph->GetDirectNodesSize(), 1);
  95. EXPECT_TRUE(add_n_node->GetInDataNodes().empty());
  96. EXPECT_TRUE(add_n_node->GetOutDataNodes().empty());
  97. }
  98. /// Op1
  99. /// |
  100. /// AddN
  101. /// |
  102. TEST(UTEST_graph_passes_addn_pass, NoOuput) {
  103. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  104. GeTensorDescPtr generalGeTensorDesc = std::make_shared<GeTensorDesc>();
  105. NodePtr node = NodeBuilder("node", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  106. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN).AddInputDesc({1, 1, 224, 224}).Build(graph);
  107. GraphUtils::AddEdge(node->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(0));
  108. GEPass pass(graph);
  109. AddNPass addn_pass;
  110. NamesToPass names_to_pass;
  111. names_to_pass.emplace_back("Test", &addn_pass);
  112. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  113. EXPECT_FALSE(add_n_node->GetInDataNodes().empty());
  114. EXPECT_TRUE(add_n_node->GetOutDataNodes().empty());
  115. EXPECT_FALSE(node->GetOutDataNodes().empty());
  116. }
  117. /// |
  118. /// AddN
  119. /// |
  120. /// Op
  121. TEST(UTEST_graph_passes_addn_pass, NoInput) {
  122. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  123. GeTensorDescPtr generalGeTensorDesc = std::make_shared<GeTensorDesc>();
  124. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  125. NodePtr node = NodeBuilder("node2", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  126. GraphUtils::AddEdge(add_n_node->GetOutDataAnchor(0), node->GetInDataAnchor(0));
  127. GEPass pass(graph);
  128. AddNPass addn_pass;
  129. NamesToPass names_to_pass;
  130. names_to_pass.emplace_back("Test", &addn_pass);
  131. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  132. EXPECT_EQ(graph->GetDirectNodesSize(), 2);
  133. EXPECT_TRUE(add_n_node->GetInDataNodes().empty());
  134. EXPECT_EQ(node->GetInDataNodes().at(0)->GetName(), add_n_node->GetName());
  135. }
  136. /// Op1
  137. /// |
  138. /// AddN
  139. /// |
  140. /// Op2
  141. TEST(UTEST_graph_passes_addn_pass, SingleInputRemoveAddnSuccess) {
  142. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  143. GeTensorDescPtr generalGeTensorDesc = std::make_shared<GeTensorDesc>();
  144. NodePtr node1 =
  145. NodeBuilder("node1", CONSTANTOP).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  146. NodePtr add_n_node =
  147. NodeBuilder("add_n_node", ADDN).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  148. NodePtr node2 =
  149. NodeBuilder("node2", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  150. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(0));
  151. GraphUtils::AddEdge(add_n_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
  152. EXPECT_EQ(graph->GetDirectNodesSize(), 3);
  153. GEPass pass(graph);
  154. AddNPass addn_pass;
  155. NamesToPass names_to_pass;
  156. names_to_pass.emplace_back("Test", &addn_pass);
  157. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  158. EXPECT_EQ(node1->GetOutDataNodes().at(0)->GetName(), node2->GetName());
  159. EXPECT_EQ(node2->GetInDataNodes().at(0)->GetName(), node1->GetName());
  160. EXPECT_TRUE(add_n_node->GetOutDataNodes().empty());
  161. EXPECT_TRUE(add_n_node->GetInDataNodes().empty());
  162. }
  163. /// Op1 Op2
  164. /// \ /
  165. /// AddN
  166. /// |
  167. /// Op3
  168. TEST(UTEST_graph_passes_addn_pass, MultipleInputsDoNotRemove) {
  169. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  170. GeTensorDescPtr generalGeTensorDesc = std::make_shared<GeTensorDesc>();
  171. NodePtr node1 =
  172. NodeBuilder("node1", CONSTANTOP).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  173. NodePtr node2 =
  174. NodeBuilder("node2", CONSTANTOP).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  175. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN)
  176. .AddInputDesc({1, 1, 224, 224})
  177. .AddInputDesc({1, 1, 224, 224})
  178. .AddOutputDesc({1, 1, 224, 224})
  179. .Build(graph);
  180. NodePtr node3 =
  181. NodeBuilder("node3", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  182. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(0));
  183. GraphUtils::AddEdge(node2->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(1));
  184. GraphUtils::AddEdge(add_n_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
  185. EXPECT_EQ(graph->GetDirectNodesSize(), 4);
  186. GEPass pass(graph);
  187. AddNPass addn_pass;
  188. NamesToPass names_to_pass;
  189. names_to_pass.emplace_back("Test", &addn_pass);
  190. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  191. EXPECT_EQ(graph->GetDirectNodesSize(), 4);
  192. }
  193. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示