You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_prepare_pass_unittest.cc 9.0 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/variable_prepare_op_pass.h"
  17. #include <gtest/gtest.h>
  18. #include <string>
  19. using namespace domi;
  20. using namespace ge;
  21. class UTEST_graph_passes_variable_prepare_pass : public testing::Test {
  22. protected:
  23. void SetUp() {}
  24. void TearDown() {}
  25. };
  26. class NodeBuilder {
  27. public:
  28. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  29. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  30. ge::DataType data_type = DT_FLOAT) {
  31. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  32. return *this;
  33. }
  34. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  35. ge::DataType data_type = DT_FLOAT) {
  36. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  37. return *this;
  38. }
  39. ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
  40. private:
  41. ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  42. ge::DataType data_type = DT_FLOAT) {
  43. GeShape ge_shape{std::vector<int64_t>(shape)};
  44. ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
  45. tensor_desc->SetShape(ge_shape);
  46. tensor_desc->SetFormat(format);
  47. tensor_desc->SetDataType(data_type);
  48. return tensor_desc;
  49. }
  50. ge::OpDescPtr op_desc_;
  51. };
  52. /// variable -- const
  53. /// \ /
  54. /// \ /
  55. /// assign
  56. TEST_F(UTEST_graph_passes_variable_prepare_pass, variable_prepare_pass_succ1) {
  57. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  58. ge::NodePtr variable_node = NodeBuilder("variable", VARIABLE)
  59. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  60. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  61. .Build(graph);
  62. ge::NodePtr const_node = NodeBuilder("const", CONSTANT)
  63. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  64. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  65. .Build(graph);
  66. ge::NodePtr apply_assign_node = NodeBuilder("assign", ASSIGN)
  67. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  68. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  69. .Build(graph);
  70. ge::GraphUtils::AddEdge(variable_node->GetOutDataAnchor(0), apply_assign_node->GetInDataAnchor(0));
  71. ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), apply_assign_node->GetInDataAnchor(1));
  72. ge::VariablePrepareOpPass pass_;
  73. ge::Status status = pass_.Run(graph);
  74. EXPECT_EQ(apply_assign_node->GetOutDataNodes().size(), 0);
  75. EXPECT_EQ(domi::SUCCESS, status);
  76. }
  77. /// variable -- applyMoment
  78. TEST_F(UTEST_graph_passes_variable_prepare_pass, variable_prepare_pass_succ2) {
  79. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  80. ge::NodePtr variable_node = NodeBuilder("variable", VARIABLE)
  81. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  82. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  83. .Build(graph);
  84. ge::NodePtr apply_monetum_node = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  85. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  86. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  87. .Build(graph);
  88. ge::NodePtr sinh_node = NodeBuilder("sinh", SINH)
  89. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  90. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  91. .Build(graph);
  92. ge::GraphUtils::AddEdge(variable_node->GetOutDataAnchor(0), apply_monetum_node->GetInDataAnchor(0));
  93. ge::GraphUtils::AddEdge(apply_monetum_node->GetOutControlAnchor(), sinh_node->GetInControlAnchor());
  94. ge::VariablePrepareOpPass pass_;
  95. ge::Status status = pass_.Run(graph);
  96. EXPECT_EQ(apply_monetum_node->GetOutDataNodes().size(), 0);
  97. EXPECT_EQ(domi::SUCCESS, status);
  98. }
  99. /// variable -- const1
  100. /// \ /
  101. /// \ /
  102. /// assign_add1 -- const2
  103. /// \ /
  104. /// \ /
  105. /// assign_sub -- const3
  106. /// \ /
  107. /// \ /
  108. /// assign_add2 -- const4
  109. /// \ /
  110. /// \ /
  111. /// assign_add3
  112. TEST_F(UTEST_graph_passes_variable_prepare_pass, variable_prepare_pass_succ3) {
  113. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  114. ge::NodePtr variable_node = NodeBuilder("variable", VARIABLE)
  115. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  116. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  117. .Build(graph);
  118. ge::NodePtr const_node1 = NodeBuilder("const1", CONSTANT)
  119. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  120. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  121. .Build(graph);
  122. ge::NodePtr const_node2 = NodeBuilder("const2", CONSTANT)
  123. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  124. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  125. .Build(graph);
  126. ge::NodePtr const_node3 = NodeBuilder("const3", CONSTANT)
  127. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  128. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  129. .Build(graph);
  130. ge::NodePtr const_node4 = NodeBuilder("const4", CONSTANT)
  131. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  132. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  133. .Build(graph);
  134. ge::NodePtr assign_add1 = NodeBuilder("assign_add1", ASSIGNADD)
  135. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  136. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  137. .Build(graph);
  138. ge::NodePtr assign_sub = NodeBuilder("assign_sub", ASSIGNSUB)
  139. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  140. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  141. .Build(graph);
  142. ge::NodePtr assign_add2 = NodeBuilder("assign_add2", ASSIGNADD)
  143. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  144. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  145. .Build(graph);
  146. ge::NodePtr assign_add3 = NodeBuilder("assign_add3", ASSIGNADD)
  147. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  148. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  149. .Build(graph);
  150. ge::GraphUtils::AddEdge(variable_node->GetOutDataAnchor(0), assign_add1->GetInDataAnchor(0));
  151. ge::GraphUtils::AddEdge(const_node1->GetOutDataAnchor(0), assign_add1->GetInDataAnchor(1));
  152. ge::GraphUtils::AddEdge(assign_add1->GetOutDataAnchor(0), assign_sub->GetInDataAnchor(0));
  153. ge::GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), assign_sub->GetInDataAnchor(1));
  154. ge::GraphUtils::AddEdge(assign_sub->GetOutDataAnchor(0), assign_add2->GetInDataAnchor(0));
  155. ge::GraphUtils::AddEdge(const_node3->GetOutDataAnchor(0), assign_add2->GetInDataAnchor(1));
  156. ge::GraphUtils::AddEdge(assign_add2->GetOutDataAnchor(0), assign_add3->GetInDataAnchor(0));
  157. ge::GraphUtils::AddEdge(const_node4->GetOutDataAnchor(0), assign_add3->GetInDataAnchor(1));
  158. ge::VariablePrepareOpPass pass_;
  159. ge::Status status = pass_.Run(graph);
  160. EXPECT_EQ(assign_add3->GetOutDataNodes().size(), 0);
  161. EXPECT_EQ(domi::SUCCESS, status);
  162. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示