/** * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include "init/gelib.h" #include "opskernel_manager/ops_kernel_builder_manager.h" #include "external/ge/ge_api.h" #include "ge_running_env/ge_running_env_faker.h" #include "ge_graph_dsl/graph_dsl.h" #include "ge_running_env/fake_compound_engine.h" #include "ge_running_env/fake_op.h" #include "easy_graph/layout/graph_layout.h" #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" #include "ge_graph_dsl/assert/graph_assert.h" using namespace std; using namespace ge; namespace { bool IfNodeExist(const ComputeGraphPtr &graph, std::function filter, bool direct_node_flag = true) { for (const auto &node : graph->GetNodes(direct_node_flag)) { if (filter(node)) { return true; } } return false; } void GetSubgraphsWithFilter(const ComputeGraphPtr &graph, std::function filter, std::vector &subgraphs) { for (const auto &subgraph : graph->GetAllSubgraphs()) { if (filter(subgraph)) { subgraphs.emplace_back(subgraph); } } } bool IsAllNodeMatch(const ComputeGraphPtr &graph, std::function filter) { for (const auto &node : graph->GetAllNodes()) { if (!filter(node)) { return false; } } return true; } } class TestFftsPlus : public testing::Test { protected: GeRunningEnvFaker ge_env; EG_NS::GraphEasyExecutor executor; void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); ge_env.InstallDefault() .Install(FakeCompoundEngine("FFTS+", {"AIcoreEngine", "DNN_VM_AICPU"}).KernelInfoStore("FFTS+")) .Install(FakeOp(GETNEXT).InfoStoreAndBuilder("AicpuLib")) .Install(FakeOp(HCOMREDUCE).InfoStoreAndBuilder("HcclLib")); } void TearDown() {} }; /* * g1 * * ┌──────────┐ (0,1) ┌────────┐ (0,0) ┌────────┐ * │ const │ ───────> │ less │ ───────> │ reduce │ * └──────────┘ └────────┘ └────────┘ * ∧ * │ (0,0) * │ * ┌──────────┐ (0,0) ┌────────┐ (0,1) ┌────────┐ * │ get_next │ ───────> │ add │ <─────── │ data1 │ * └──────────┘ └────────┘ └────────┘ * */ TEST_F(TestFftsPlus, test_ffts_plus) { auto tensor = std::make_shared(); uint32_t value = 0; tensor->SetData((uint8_t *)&value, sizeof(uint32_t)); DEF_GRAPH(g1) { CHAIN(NODE("get_next", GETNEXT)->NODE("add", ADD)); CHAIN(NODE("data1", DATA)->NODE("add")->NODE("less", LESS)->NODE("reduce", HCOMREDUCE)); CHAIN(NODE("const", OP_CFG(CONSTANTOP).Attr("value", tensor))->Node("less")); }; auto graph = ToGeGraph(g1); // new session & add graph map options; Session session(options); auto ret = session.AddGraph(1, graph, options); EXPECT_EQ(ret, SUCCESS); // build input tensor std::vector inputs; // build_graph through session ret = session.BuildGraph(1, inputs); EXPECT_EQ(ret, SUCCESS); CHECK_GRAPH(PreRunAfterBuild) { // node exist ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "get_next"; })); ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "add"; })); ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "less"; })); ASSERT_TRUE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetType() == PARTITIONEDCALL; })); // subgraph exit ASSERT_EQ(graph->GetAllSubgraphs().size(), 1); std::vector subgraphs; GetSubgraphsWithFilter(graph, [](const ComputeGraphPtr &graph) { const auto &parent_node = graph->GetParentNode(); if ((parent_node == nullptr) || (parent_node->GetOpDesc() == nullptr)) { return false; } return parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); }, subgraphs); ASSERT_EQ(subgraphs.size(), 1); // subgraph node check const auto &subgraph = subgraphs[0]; ASSERT_TRUE(subgraph != nullptr); ASSERT_TRUE(IsAllNodeMatch(subgraph, [](const NodePtr &node) { return node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID); })); const auto &parent_node = subgraph->GetParentNode(); ASSERT_TRUE(parent_node != nullptr); ASSERT_TRUE(parent_node->GetOpDesc() != nullptr); int64_t stream_id = parent_node->GetOpDesc()->GetStreamId(); ASSERT_TRUE(IsAllNodeMatch(subgraph, [stream_id](const NodePtr &node) { return node->GetOpDesc()->GetStreamId() == stream_id; })); }; }