You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ffts_plus.cc 6.4 kB

4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. /**
  2. * Copyright 2021 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "init/gelib.h"
  18. #include "opskernel_manager/ops_kernel_builder_manager.h"
  19. #include "external/ge/ge_api.h"
  20. #include "ge_running_env/ge_running_env_faker.h"
  21. #include "ge_graph_dsl/graph_dsl.h"
  22. #include "ge_running_env/fake_compound_engine.h"
  23. #include "ge_running_env/fake_op.h"
  24. #include "easy_graph/layout/graph_layout.h"
  25. #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h"
  26. #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
  27. #include "ge_graph_dsl/assert/graph_assert.h"
  28. using namespace std;
  29. using namespace ge;
  30. namespace {
  31. bool IfNodeExistWithType(const ComputeGraphPtr &graph, const std::string &type, bool direct_node_flag) {
  32. for (const auto &node : graph->GetNodes(direct_node_flag)) {
  33. if (node->GetType() == type) {
  34. return true;
  35. }
  36. }
  37. return false;
  38. }
  39. bool IfNodeExistWithName(const ComputeGraphPtr &graph, const std::string &name, bool direct_node_flag) {
  40. for (const auto &node : graph->GetNodes(direct_node_flag)) {
  41. if (node->GetName() == name) {
  42. return true;
  43. }
  44. }
  45. return false;
  46. }
  47. void GetSubgraphsWithFilter(const ComputeGraphPtr &graph, std::function<bool(const ComputeGraphPtr &)> filter,
  48. std::vector<ComputeGraphPtr> &subgraphs) {
  49. for (const auto &subgraph : graph->GetAllSubgraphs()) {
  50. if (filter(subgraph)) {
  51. subgraphs.emplace_back(subgraph);
  52. }
  53. }
  54. }
  55. bool IsAllNodeMatch(const ComputeGraphPtr &graph, std::function<bool(const NodePtr &)> filter) {
  56. for (const auto &node : graph->GetAllNodes()) {
  57. if (!filter(node)) {
  58. return false;
  59. }
  60. }
  61. return true;
  62. }
  63. }
  64. class TestFftsPlus : public testing::Test {
  65. protected:
  66. GeRunningEnvFaker ge_env;
  67. EG_NS::GraphEasyExecutor executor;
  68. void SetUp() {
  69. EG_NS::GraphLayout::GetInstance().Config(executor, nullptr);
  70. ge_env.InstallDefault()
  71. .Install(FakeCompoundEngine("FFTS+", {"AIcoreEngine", "DNN_VM_AICPU"}).KernelInfoStore("FFTS+"))
  72. .Install(FakeOp(GETNEXT).InfoStoreAndBuilder("AicpuLib"))
  73. .Install(FakeOp(HCOMREDUCE).InfoStoreAndBuilder("HcclLib"));
  74. }
  75. void TearDown() {}
  76. };
  77. /*
  78. * g1
  79. *
  80. * ┌──────────┐ (0,1) ┌────────┐ (0,0) ┌────────┐
  81. * │ const │ ───────> │ less │ ───────> │ reduce │
  82. * └──────────┘ └────────┘ └────────┘
  83. * ∧
  84. * │ (0,0)
  85. * │
  86. * ┌──────────┐ (0,0) ┌────────┐ (0,1) ┌────────┐
  87. * │ get_next │ ───────> │ add │ <─────── │ data1 │
  88. * └──────────┘ └────────┘ └────────┘
  89. *
  90. */
  91. TEST_F(TestFftsPlus, test_ffts_plus) {
  92. auto tensor = std::make_shared<GeTensor>();
  93. uint32_t value = 0;
  94. tensor->SetData((uint8_t *)&value, sizeof(uint32_t));
  95. DEF_GRAPH(g1) {
  96. CHAIN(NODE("get_next", GETNEXT)->NODE("add", ADD));
  97. CHAIN(NODE("data1", DATA)->NODE("add")->NODE("less", LESS)->NODE("reduce", HCOMREDUCE));
  98. CHAIN(NODE("const", OP_CFG(CONSTANTOP).Attr("value", tensor))->Node("less"));
  99. };
  100. auto graph = ToGeGraph(g1);
  101. // new session & add graph
  102. map<AscendString, AscendString> options;
  103. Session session(options);
  104. auto ret = session.AddGraph(1, graph, options);
  105. EXPECT_EQ(ret, SUCCESS);
  106. // build input tensor
  107. std::vector<InputTensorInfo> inputs;
  108. // build_graph through session
  109. ret = session.BuildGraph(1, inputs);
  110. EXPECT_EQ(ret, SUCCESS);
  111. CHECK_GRAPH(PreRunAfterBuild) {
  112. // node exist
  113. ASSERT_FALSE(IfNodeExistWithName(graph, "get_next", true));
  114. ASSERT_FALSE(IfNodeExistWithName(graph, "add", true));
  115. ASSERT_FALSE(IfNodeExistWithName(graph, "less", true));
  116. ASSERT_TRUE(IfNodeExistWithType(graph, PARTITIONEDCALL, true));
  117. // subgraph exit
  118. ASSERT_EQ(graph->GetAllSubgraphs().size(), 1);
  119. std::vector<ComputeGraphPtr> subgraphs;
  120. GetSubgraphsWithFilter(graph,
  121. [](const ComputeGraphPtr &graph) {
  122. const auto &parent_node = graph->GetParentNode();
  123. if ((parent_node == nullptr) || (parent_node->GetOpDesc() == nullptr)) {
  124. return false;
  125. }
  126. return parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); },
  127. subgraphs);
  128. ASSERT_EQ(subgraphs.size(), 1);
  129. // subgraph node check
  130. const auto &subgraph = subgraphs[0];
  131. ASSERT_TRUE(subgraph != nullptr);
  132. ASSERT_TRUE(IsAllNodeMatch(subgraph,
  133. [](const NodePtr &node) {
  134. if (node->GetOpDesc() == nullptr) {
  135. return false;
  136. }
  137. return node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID);
  138. }));
  139. const auto &parent_node = subgraph->GetParentNode();
  140. ASSERT_TRUE(parent_node != nullptr);
  141. ASSERT_TRUE(parent_node->GetOpDesc() != nullptr);
  142. int64_t stream_id = parent_node->GetOpDesc()->GetStreamId();
  143. ASSERT_TRUE(IsAllNodeMatch(subgraph,
  144. [stream_id](const NodePtr &node) {
  145. if (node->GetOpDesc() == nullptr) {
  146. return false;
  147. }
  148. return node->GetOpDesc()->GetStreamId() == stream_id;
  149. }));
  150. };
  151. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示