|
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #include <gtest/gtest.h>
- #include "init/gelib.h"
- #include "opskernel_manager/ops_kernel_builder_manager.h"
- #include "external/ge/ge_api.h"
- #include "ge_running_env/ge_running_env_faker.h"
- #include "ge_graph_dsl/graph_dsl.h"
- #include "ge_running_env/fake_compound_engine.h"
- #include "ge_running_env/fake_op.h"
-
- #include "easy_graph/layout/graph_layout.h"
- #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h"
- #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h"
-
- #include "ge_graph_dsl/assert/graph_assert.h"
-
- using namespace std;
- using namespace ge;
-
- namespace {
- bool IfNodeExist(const ComputeGraphPtr &graph, std::function<bool(const NodePtr &)> filter,
- bool direct_node_flag = true) {
- for (const auto &node : graph->GetNodes(direct_node_flag)) {
- if (filter(node)) {
- return true;
- }
- }
- return false;
- }
-
- void GetSubgraphsWithFilter(const ComputeGraphPtr &graph, std::function<bool(const ComputeGraphPtr &)> filter,
- std::vector<ComputeGraphPtr> &subgraphs) {
- for (const auto &subgraph : graph->GetAllSubgraphs()) {
- if (filter(subgraph)) {
- subgraphs.emplace_back(subgraph);
- }
- }
- }
-
- bool IsAllNodeMatch(const ComputeGraphPtr &graph, std::function<bool(const NodePtr &)> filter) {
- for (const auto &node : graph->GetAllNodes()) {
- if (!filter(node)) {
- return false;
- }
- }
- return true;
- }
- }
-
- class TestFftsPlus : public testing::Test {
- protected:
- GeRunningEnvFaker ge_env;
- EG_NS::GraphEasyExecutor executor;
- void SetUp() {
- EG_NS::GraphLayout::GetInstance().Config(executor, nullptr);
- ge_env.InstallDefault()
- .Install(FakeCompoundEngine("ffts_plus", {"AIcoreEngine", "DNN_VM_AICPU"}).KernelInfoStore("ffts_plus"))
- .Install(FakeOp(GETNEXT).InfoStoreAndBuilder("AicpuLib"))
- .Install(FakeOp(HCOMREDUCE).InfoStoreAndBuilder("HcclLib"));
- }
- void TearDown() {}
- };
-
- /*
- * g1
- *
- * ┌──────────┐ (0,1) ┌────────┐ (0,0) ┌────────┐
- * │ const │ ───────> │ less │ ───────> │ reduce │
- * └──────────┘ └────────┘ └────────┘
- * ∧
- * │ (0,0)
- * │
- * ┌──────────┐ (0,0) ┌────────┐ (0,1) ┌────────┐
- * │ get_next │ ───────> │ add │ <─────── │ data1 │
- * └──────────┘ └────────┘ └────────┘
- *
- */
- TEST_F(TestFftsPlus, test_ffts_plus) {
- auto tensor = std::make_shared<GeTensor>();
- uint32_t value = 0;
- tensor->SetData((uint8_t *)&value, sizeof(uint32_t));
- DEF_GRAPH(g1) {
- CHAIN(NODE("get_next", GETNEXT)->NODE("add", ADD));
- CHAIN(NODE("data1", DATA)->NODE("add")->NODE("less", LESS)->NODE("reduce", HCOMREDUCE));
- CHAIN(NODE("const", OP_CFG(CONSTANTOP).Attr("value", tensor))->Node("less"));
- };
-
- auto graph = ToGeGraph(g1);
- // new session & add graph
- map<AscendString, AscendString> options;
- Session session(options);
- auto ret = session.AddGraph(1, graph, options);
- EXPECT_EQ(ret, SUCCESS);
-
- // build input tensor
- std::vector<InputTensorInfo> inputs;
- // build_graph through session
- ret = session.BuildGraph(1, inputs);
- EXPECT_EQ(ret, SUCCESS);
-
- CHECK_GRAPH(PreRunAfterBuild) {
- // node exist
- ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "get_next"; }));
- ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "add"; }));
- ASSERT_FALSE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetName() == "less"; }));
- ASSERT_TRUE(IfNodeExist(graph, [](const NodePtr &node) { return node->GetType() == PARTITIONEDCALL; }));
-
- // subgraph exit
- ASSERT_EQ(graph->GetAllSubgraphs().size(), 1);
- std::vector<ComputeGraphPtr> subgraphs;
- GetSubgraphsWithFilter(graph,
- [](const ComputeGraphPtr &graph) {
- const auto &parent_node = graph->GetParentNode();
- if ((parent_node == nullptr) || (parent_node->GetOpDesc() == nullptr)) {
- return false;
- }
- return parent_node->GetOpDesc()->HasAttr(ATTR_NAME_FFTS_PLUS_SUB_GRAPH); },
- subgraphs);
- ASSERT_EQ(subgraphs.size(), 1);
-
- // subgraph node check
- const auto &subgraph = subgraphs[0];
- ASSERT_TRUE(subgraph != nullptr);
- ASSERT_TRUE(IsAllNodeMatch(subgraph,
- [](const NodePtr &node) {
- return node->GetOpDesc()->HasAttr(ATTR_NAME_THREAD_SCOPE_ID);
- }));
- const auto &parent_node = subgraph->GetParentNode();
- ASSERT_TRUE(parent_node != nullptr);
- ASSERT_TRUE(parent_node->GetOpDesc() != nullptr);
- int64_t stream_id = parent_node->GetOpDesc()->GetStreamId();
- ASSERT_TRUE(IsAllNodeMatch(subgraph,
- [stream_id](const NodePtr &node) {
- return node->GetOpDesc()->GetStreamId() == stream_id;
- }));
- };
- }
|