You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pass_utils_unittest.cc 7.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/pass_utils.h"
  17. #include <gtest/gtest.h>
  18. #include <vector>
  19. #include "common/types.h"
  20. #include "graph/types.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "graph_builder_utils.h"
  24. #include "inc/kernel.h"
  25. #include "inc/kernel_factory.h"
  26. using namespace domi;
  27. using namespace ge;
  28. class UTEST_graph_passes_pass_utils : public testing::Test {
  29. protected:
  30. void SetUp() {}
  31. void TearDown() {}
  32. };
  33. class NodeBuilder {
  34. public:
  35. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  36. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  37. ge::DataType data_type = DT_FLOAT) {
  38. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  39. return *this;
  40. }
  41. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  42. ge::DataType data_type = DT_FLOAT) {
  43. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  44. return *this;
  45. }
  46. ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
  47. private:
  48. ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  49. ge::DataType data_type = DT_FLOAT) {
  50. GeShape ge_shape{std::vector<int64_t>(shape)};
  51. ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
  52. tensor_desc->SetShape(ge_shape);
  53. tensor_desc->SetFormat(format);
  54. tensor_desc->SetDataType(data_type);
  55. return tensor_desc;
  56. }
  57. ge::OpDescPtr op_desc_;
  58. };
  59. TEST_F(UTEST_graph_passes_pass_utils, set_out_node_weight) {
  60. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  61. // data
  62. ge::NodePtr node_data = NodeBuilder("data", DATA).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
  63. // const
  64. ge::NodePtr node_const =
  65. NodeBuilder("const", CONSTANT).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
  66. // relu
  67. ge::NodePtr node_relu = NodeBuilder("node_relu1", RELU)
  68. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  69. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  70. .Build(graph);
  71. // sinh
  72. ge::NodePtr node_sinh = NodeBuilder("node_sinh", SINH)
  73. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  74. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  75. .Build(graph);
  76. // relu
  77. ge::NodePtr node_relu2 = NodeBuilder("node_relu2", RELU)
  78. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  79. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  80. .Build(graph);
  81. // sinh
  82. ge::NodePtr node_sinh2 = NodeBuilder("node_sinh2", SINH)
  83. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  84. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  85. .Build(graph);
  86. // add edge
  87. ge::GraphUtils::AddEdge(node_data->GetOutControlAnchor(), node_const->GetInControlAnchor());
  88. ge::GraphUtils::AddEdge(node_const->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0));
  89. ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_sinh->GetInDataAnchor(0));
  90. ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_relu2->GetInControlAnchor());
  91. ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_sinh2->GetInDataAnchor(0));
  92. for (auto node : graph->GetDirectNode()) {
  93. if (node->GetType() == CONSTANT) {
  94. int32_t weight[] = {1};
  95. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  96. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  97. vector<GeTensorPtr> tensor_vec = {tensor};
  98. OpDescUtils::SetWeights(node, tensor_vec);
  99. }
  100. if (!node->GetOutDataNodes().empty()) {
  101. auto out_data_anchor = node->GetOutDataNodes().at(0)->GetOutDataAnchor(0);
  102. Status status = PassUtils::SetOutNodeWeight(out_data_anchor, node);
  103. EXPECT_EQ(domi::SUCCESS, status);
  104. }
  105. }
  106. }
  107. // only some failure castes for coverage check
  108. TEST_F(UTEST_graph_passes_pass_utils, is_constant_null) {
  109. ge::NodePtr node = nullptr;
  110. bool ret = PassUtils::IsConstant(node);
  111. EXPECT_EQ(false, ret);
  112. }
  113. TEST_F(UTEST_graph_passes_pass_utils, get_in_data_node_fail) {
  114. ge::NodePtr node = nullptr;
  115. NodePtr in_data_node = PassUtils::GetInDataNode(node, 0);
  116. EXPECT_EQ(nullptr, in_data_node);
  117. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  118. // relu
  119. ge::NodePtr node_relu = NodeBuilder("relu", RELU)
  120. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  121. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  122. .Build(graph);
  123. NodePtr data_node = PassUtils::GetInDataNode(node_relu, 1);
  124. EXPECT_EQ(nullptr, data_node);
  125. }
  126. TEST_F(UTEST_graph_passes_pass_utils, get_unique_in_data_anchor_index_failed) {
  127. int invalid_index = -1;
  128. ge::NodePtr node = nullptr;
  129. int status = PassUtils::GetUniqueInDataAnchorIndex(node);
  130. EXPECT_EQ(invalid_index, status);
  131. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  132. // relu
  133. ge::NodePtr node_relu = NodeBuilder("relu", RELU)
  134. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  135. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  136. .Build(graph);
  137. int ret = PassUtils::GetUniqueInDataAnchorIndex(node_relu);
  138. EXPECT_EQ(invalid_index, ret);
  139. }
  140. TEST_F(UTEST_graph_passes_pass_utils, unlink_node_with_ctrl_copy_fail) {
  141. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  142. // relu
  143. ge::NodePtr node_relu = NodeBuilder("relu", RELU)
  144. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  145. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  146. .Build(graph);
  147. Status status = PassUtils::UnlinkNodeWithControlCopy(node_relu, 1);
  148. EXPECT_EQ(ge::SUCCESS, status);
  149. Status ret = PassUtils::UnlinkNodeWithControlCopy(node_relu, 0);
  150. EXPECT_EQ(ge::FAILED, ret);
  151. }
  152. TEST_F(UTEST_graph_passes_pass_utils, null_input) {
  153. std::vector<NodePtr> deleted_nodes;
  154. std::vector<NodePtr> end_nodes;
  155. EXPECT_NE(PassUtils::RemoveInactiveBranchToMerge(nullptr, deleted_nodes, end_nodes), 0);
  156. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示