| @@ -1,125 +1,125 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "graph/common/local_context.h" | |||||
| #include "graph/passes/subgraph_const_migration_pass.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| class UtestSubgraphConstMigrationPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| public: | |||||
| NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
| GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (auto i = 0; i < in_num; ++i) { | |||||
| op_desc->AddInputDesc(test_desc); | |||||
| } | |||||
| for (auto i = 0; i < out_num; ++i) { | |||||
| op_desc->AddOutputDesc(test_desc); | |||||
| } | |||||
| if (type == "Const") { | |||||
| uint64_t const_value = 101; | |||||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
| AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); | |||||
| } | |||||
| return graph->AddNode(op_desc); | |||||
| } | |||||
| void make_original_graph(const ComputeGraphPtr &graph) { | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| { | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
| } | |||||
| auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); | |||||
| { | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||||
| GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); | |||||
| } | |||||
| auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); | |||||
| { | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); | |||||
| GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); | |||||
| } | |||||
| auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
| } | |||||
| void make_multibatch_graph(const ComputeGraphPtr &graph) { | |||||
| auto index = MakeNode(graph, 1, 1, "index", "Data"); | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| auto case1 = MakeNode(graph, 4, 1, "case", "Case"); | |||||
| GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); | |||||
| auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); | |||||
| GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
| AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||||
| case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); | |||||
| ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch"); | |||||
| make_original_graph(branch); | |||||
| for (int i = 0; i < 2; ++i) { | |||||
| std::string name("_ascend_mbatch_batch_" + std::to_string(i)); | |||||
| std::vector<NodePtr> input_nodes; | |||||
| std::vector<NodePtr> output_nodes; | |||||
| ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); | |||||
| subgraph->SetName(name); | |||||
| subgraph->SetParentNode(case1); | |||||
| subgraph->SetParentGraph(graph); | |||||
| graph->AddSubgraph(subgraph->GetName(), subgraph); | |||||
| case1->GetOpDesc()->AddSubgraphName(name); | |||||
| case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); | |||||
| } | |||||
| } | |||||
| }; | |||||
| TEST_F(UtestSubgraphConstMigrationPass, graph_nullptr) { | |||||
| PassManager pass_manager; | |||||
| pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| make_multibatch_graph(graph); | |||||
| pass_manager.Run(graph); | |||||
| } | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "framework/omg/omg_inner_types.h" | |||||
| #include "graph/common/local_context.h" | |||||
| #include "graph/passes/subgraph_const_migration_pass.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #include "register/op_registry.h" | |||||
| namespace ge { | |||||
| class UtestSubgraphConstMigrationPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| public: | |||||
| NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { | |||||
| GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (auto i = 0; i < in_num; ++i) { | |||||
| op_desc->AddInputDesc(test_desc); | |||||
| } | |||||
| for (auto i = 0; i < out_num; ++i) { | |||||
| op_desc->AddOutputDesc(test_desc); | |||||
| } | |||||
| if (type == "Const") { | |||||
| uint64_t const_value = 101; | |||||
| auto weight = make_shared<GeTensor>(op_desc->GetOutputDesc(0), (uint8_t *)&const_value, sizeof(uint64_t)); | |||||
| AttrUtils::SetTensor(op_desc, ge::ATTR_NAME_WEIGHTS, weight); | |||||
| } | |||||
| return graph->AddNode(op_desc); | |||||
| } | |||||
| void make_original_graph(const ComputeGraphPtr &graph) { | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| { | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
| } | |||||
| auto const1 = MakeNode(graph, 0, 1, "const1", "Const"); | |||||
| { | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 2); | |||||
| GraphUtils::AddEdge(data1->GetOutControlAnchor(), const1->GetInControlAnchor()); | |||||
| } | |||||
| auto const2 = MakeNode(graph, 0, 1, "const2", "Const"); | |||||
| { | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, 3); | |||||
| GraphUtils::AddEdge(data2->GetOutControlAnchor(), const2->GetInControlAnchor()); | |||||
| } | |||||
| auto conv2d_node = MakeNode(graph, 3, 1, "conv1", "Conv2D"); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(const1->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(const2->GetOutDataAnchor(0), conv2d_node->GetInDataAnchor(2)); | |||||
| } | |||||
| void make_multibatch_graph(const ComputeGraphPtr &graph) { | |||||
| auto index = MakeNode(graph, 1, 1, "index", "Data"); | |||||
| auto data = MakeNode(graph, 1, 1, "data", "Data"); | |||||
| auto data1 = MakeNode(graph, 1, 1, "data1", "Data"); | |||||
| auto data2 = MakeNode(graph, 1, 1, "data2", "Data"); | |||||
| AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_INDEX, 0); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_INDEX, 1); | |||||
| AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_INDEX, 2); | |||||
| auto case1 = MakeNode(graph, 4, 1, "case", "Case"); | |||||
| GraphUtils::AddEdge(index->GetOutDataAnchor(0), case1->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(data->GetOutDataAnchor(0), case1->GetInDataAnchor(1)); | |||||
| GraphUtils::AddEdge(data1->GetOutDataAnchor(0), case1->GetInDataAnchor(2)); | |||||
| GraphUtils::AddEdge(data2->GetOutDataAnchor(0), case1->GetInDataAnchor(3)); | |||||
| auto output_node = MakeNode(graph, 1, 0, "output", "NetOutput"); | |||||
| GraphUtils::AddEdge(case1->GetOutDataAnchor(0), output_node->GetInDataAnchor(0)); | |||||
| AttrUtils::SetInt(case1->GetOpDesc(), ATTR_NAME_BATCH_NUM, 2); | |||||
| case1->GetOpDesc()->RegisterSubgraphIrName("branches", kDynamic); | |||||
| ComputeGraphPtr branch = std::make_shared<ComputeGraph>("test_branch"); | |||||
| make_original_graph(branch); | |||||
| for (int i = 0; i < 2; ++i) { | |||||
| std::string name("_ascend_mbatch_batch_" + std::to_string(i)); | |||||
| std::vector<NodePtr> input_nodes; | |||||
| std::vector<NodePtr> output_nodes; | |||||
| ComputeGraphPtr subgraph = GraphUtils::CloneGraph(branch, name, input_nodes, output_nodes); | |||||
| subgraph->SetName(name); | |||||
| subgraph->SetParentNode(case1); | |||||
| subgraph->SetParentGraph(graph); | |||||
| graph->AddSubgraph(subgraph->GetName(), subgraph); | |||||
| case1->GetOpDesc()->AddSubgraphName(name); | |||||
| case1->GetOpDesc()->SetSubgraphInstanceName(i, subgraph->GetName()); | |||||
| } | |||||
| } | |||||
| }; | |||||
| TEST_F(UtestSubgraphConstMigrationPass, subgraph_const_migration) { | |||||
| PassManager pass_manager; | |||||
| pass_manager.AddPass("SubgraphConstMigrationPass", new (std::nothrow) SubgraphConstMigrationPass); | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test_graph"); | |||||
| make_multibatch_graph(graph); | |||||
| EXPECT_EQ(pass_manager.Run(graph), SUCCESS); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||