From d694f1dc21261a6f0880e1688f6e88bbeef148d8 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Mon, 19 Apr 2021 15:44:44 +0800 Subject: [PATCH 1/5] add support for train_mode tune --- ge/graph/build/model_builder.cc | 17 +++++++++++++++++ ge/graph/manager/graph_manager.cc | 16 +++++++++------- ge/graph/manager/graph_manager.h | 2 +- ge/graph/passes/global_step_insert_pass.cc | 8 ++++++++ 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 9ae6e6be..08ead6d9 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -647,6 +647,14 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) { std::vector data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); tbe_kernel = MakeShared(kernel_name, std::move(data)); GE_CHECK_NOTNULL(tbe_kernel); + GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(), + atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); + if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { + std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra attr" + + FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; + GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); + return ge::FAILED; + } } } if (tbe_kernel == nullptr) { @@ -695,6 +703,15 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { GE_CHECK_NOTNULL(kernel_buffer.GetData()); std::vector data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); tbe_kernel = std::make_shared(kernel_name, std::move(data)); + GE_CHECK_NOTNULL(tbe_kernel); + GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), + node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); + if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { + std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra attr" + + FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; + GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); + return ge::FAILED; + } } } GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index f2b4211d..896bdac9 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -1686,7 +1686,8 @@ Status GraphManager::ParseOptions(const std::map &opti return GE_GRAPH_OPTIONS_INVALID); // ge.graphType - ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); + ret = + ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag, options_.build_mode == BUILD_MODE_TUNING); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); return GE_GRAPH_OPTIONS_INVALID); @@ -1728,20 +1729,21 @@ Status GraphManager::ParseOptions(const std::map &opti return SUCCESS; } -Status GraphManager::ParseTrainGraphFlag(bool &options, bool &option) { +Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag) { std::shared_ptr ge_instance_ptr = ge::GELib::GetInstance(); if (ge_instance_ptr == nullptr) { GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); - option = false; + train_flag = false; } else if (!ge_instance_ptr->isTrainMode()) { - option = false; + train_flag = false; } else { // ge_instance_ptr->isTrainMode() is true - if (!options) { + // tune mode no need check + if (!run_flag && !tune_flag) { GELOGE(GE_GRAPH_OPTIONS_INVALID, - "Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", options); + "Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); return GE_GRAPH_OPTIONS_INVALID; } - option = true; + train_flag = true; } return SUCCESS; } diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 0533a0b6..1a8fa8b7 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -277,7 +277,7 @@ class GraphManager { static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); - static Status ParseTrainGraphFlag(bool &options, bool &option); + static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag); static bool IsPerfLevelInvalid(int32_t perf_level); diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc index 9fc1d066..896622ce 100755 --- a/ge/graph/passes/global_step_insert_pass.cc +++ b/ge/graph/passes/global_step_insert_pass.cc @@ -26,6 +26,8 @@ #include "common/ge/ge_util.h" #include "graph/manager/graph_var_manager.h" #include "graph/passes/pass_utils.h" +#include "graph/ge_context.h" +#include "graph/tuning_utils.h" namespace ge { NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, @@ -72,6 +74,12 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, } Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { + std::string build_mode; + if (ge::GetContext().GetOption(ge::BUILD_MODE, build_mode) == GRAPH_SUCCESS && build_mode == BUILD_MODE_TUNING) { + GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), + compute_graph->GetName().c_str()); + return SUCCESS; + } NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); if (output_node == nullptr) { GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); From 18d9518739982e0f97b74265514813b9053efc64 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Mon, 19 Apr 2021 20:14:16 +0800 Subject: [PATCH 2/5] add support for train_mode tune --- ge/graph/build/model_builder.cc | 6 ++---- .../ge/graph/build/model_builder_unittest.cc | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/ge/graph/build/model_builder.cc b/ge/graph/build/model_builder.cc index 08ead6d9..99abd73a 100755 --- a/ge/graph/build/model_builder.cc +++ b/ge/graph/build/model_builder.cc @@ -650,8 +650,7 @@ Status ModelBuilder::SaveAtomicTBEKernel(const OpDescPtr &op_desc) { GELOGI("Node [%s][%s] start recovery extra attr %s from %s", atomic_op_desc->GetName().c_str(), atomic_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); if (!(atomic_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { - std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra attr" + - FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; + std::string error = "Node" + FmtToStr(atomic_op_desc->GetName()) + "set extra tbeKernel attr failed"; GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); return ge::FAILED; } @@ -707,8 +706,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { GELOGI("Node [%s][%s] start recovery extra attr %s from %s", node_op_desc->GetName().c_str(), node_op_desc->GetType().c_str(), ge::OP_EXTATTR_NAME_TBE_KERNEL, ATTR_NAME_TBE_KERNEL_NAME.c_str()); if (!(node_op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel))) { - std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra attr" + - FmtToStr(ge::OP_EXTATTR_NAME_TBE_KERNEL) + "failed"; + std::string error = "Node" + FmtToStr(node_op_desc->GetName()) + "set extra tbeKernel attr failed"; GE_ERRORLOG_AND_ERRORMSG(ge::FAILED, error.c_str()); return ge::FAILED; } diff --git a/tests/ut/ge/graph/build/model_builder_unittest.cc b/tests/ut/ge/graph/build/model_builder_unittest.cc index b9204dbc..d5efc9bb 100644 --- a/tests/ut/ge/graph/build/model_builder_unittest.cc +++ b/tests/ut/ge/graph/build/model_builder_unittest.cc @@ -161,3 +161,22 @@ TEST_F(UtestModelBuilderTest, test_save_atomic_bin) { op_desc->SetExtAttr("atomic_clean_node_ptr", atomic_node); EXPECT_EQ(builder.SaveAtomicTBEKernel(op_desc), SUCCESS); } + +TEST_F(UtestModelBuilderTest, test_model_save) { + Graph2SubGraphInfoList subgraphs; + std::map stream_max_parallel_num; + ge::ComputeGraphPtr graph = make_shared(""); + ge::ModelBuilder builder(0, graph, subgraphs, stream_max_parallel_num, false); + + auto op_desc = make_shared("Conv2d", "Conv2d"); + auto kernel_buffer = static_cast(Buffer(10)); + AttrUtils::SetStr(op_desc, ATTR_NAME_TBE_KERNEL_NAME, "Conv2d"); + AttrUtils::SetBytes(op_desc, ATTR_NAME_TBE_KERNEL_BUFFER, kernel_buffer); + + ge::NodePtr node = graph->AddNode(op_desc); + ge::Model ge_model; + ge::GeModel ge_gemodel; + builder.SaveDataToModel(ge_model, ge_gemodel); + auto tbe_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); + EXPECT_NE(tbe_kernel, nullptr); +} From bd2173a9f3bf2e9efe82b0f9465f06a5d97a64e2 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Wed, 21 Apr 2021 10:24:52 +0800 Subject: [PATCH 3/5] add support for train_mode tune --- tests/ut/ge/CMakeLists.txt | 1 + .../global_step_insert_pass_unittest.cc | 74 +++++++++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index f2f08106..0c11c9d2 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -691,6 +691,7 @@ set(PASS_TEST_FILES "graph/passes/stop_gradient_pass_unittest.cc" "graph/passes/prevent_gradient_pass_unittest.cc" "graph/passes/identity_pass_unittest.cc" + "graph/passes/global_step_insert_pass_unittest.cc" "graph/passes/placeholder_with_default_pass_unittest.cc" "graph/passes/snapshot_pass_unittest.cc" "graph/passes/guarantee_const_pass_unittest.cc" diff --git a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc new file mode 100644 index 00000000..98e303c7 --- /dev/null +++ b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define protected public +#define private public +#include "graph/passes/global_step_insert_pass.h" + +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/op_desc.h" +#include "graph/passes/base_pass.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/tuning_utils.h" +#include "graph_builder_utils.h" +#include "graph/ge_context.h" +#include "graph/ge_local_context.h" +#include "inc/pass_manager.h" +#undef protected +#undef private + +using namespace std; +using namespace testing; +using namespace ge; + +class UtestGlobalStepInsertPass : public Test { + protected: +}; + +static ComputeGraphPtr BuildGraph1() { + ge::ut::GraphBuilder builder("g1"); + auto var1 = builder.AddNode("var1", "Variable", 0, 1); + auto var2 = builder.AddNode("var2", "Variable", 0, 1); + auto identity1 = builder.AddNode("identity1", "Identity", 1, 1); + auto out = builder.AddNode("out", "NetOutput", 1, 1); + + builder.AddDataEdge(var1, 0, identity1, 0); + builder.AddControlEdge(var2, identity1); + builder.AddDataEdge(identity1, 0, out, 0); + return builder.GetGraph(); +} + +TEST_F(UtestGlobalStepInsertPass, skip_tune) { + auto graph = BuildGraph1(); + std::string build_mode; + std::map options_map; + options_map.insert({ge::BUILD_MODE, BUILD_MODE_TUNING}); + ge::GetThreadLocalContext().SetGraphOption(options_map); + GlobalStepInsertPass pass; + Status status = pass.Run(graph); + EXPECT_EQ(status, SUCCESS); + NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP); + EXPECT_EQ(found_node, nullptr); +} From 08d850e4f92a9f4dd4114195a5fb0449056bcd02 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Thu, 22 Apr 2021 14:47:07 +0800 Subject: [PATCH 4/5] add support for train_mode tune --- ge/graph/manager/graph_manager.cc | 13 +++++-------- ge/graph/manager/graph_manager.h | 2 +- ge/graph/passes/global_step_insert_pass.cc | 5 ++--- .../passes/global_step_insert_pass_unittest.cc | 4 ++-- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index ca1fa9cf..230a29a8 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -1665,7 +1665,7 @@ Status GraphManager::ParseOptions(const std::map &opti // ge.graphType ret = - ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag, options_.build_mode == BUILD_MODE_TUNING); + ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); return GE_GRAPH_OPTIONS_INVALID); @@ -1707,7 +1707,7 @@ Status GraphManager::ParseOptions(const std::map &opti return SUCCESS; } -Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag) { +Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { std::shared_ptr ge_instance_ptr = ge::GELib::GetInstance(); if (ge_instance_ptr == nullptr) { GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); @@ -1715,13 +1715,10 @@ Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, } else if (!ge_instance_ptr->isTrainMode()) { train_flag = false; } else { // ge_instance_ptr->isTrainMode() is true - // tune mode no need check - if (!run_flag && !tune_flag) { - GELOGE(GE_GRAPH_OPTIONS_INVALID, - "Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); - return GE_GRAPH_OPTIONS_INVALID; - } train_flag = true; + if (!run_flag) { + GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); + } } return SUCCESS; } diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 4c2efb52..ef49cf90 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -277,7 +277,7 @@ class GraphManager { static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); - static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag); + static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag); static bool IsPerfLevelInvalid(int32_t perf_level); diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc index 896622ce..6ed7a7ec 100755 --- a/ge/graph/passes/global_step_insert_pass.cc +++ b/ge/graph/passes/global_step_insert_pass.cc @@ -27,7 +27,6 @@ #include "graph/manager/graph_var_manager.h" #include "graph/passes/pass_utils.h" #include "graph/ge_context.h" -#include "graph/tuning_utils.h" namespace ge { NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, @@ -74,8 +73,8 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, } Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { - std::string build_mode; - if (ge::GetContext().GetOption(ge::BUILD_MODE, build_mode) == GRAPH_SUCCESS && build_mode == BUILD_MODE_TUNING) { + std::string run_flag; + if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == "0") { GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), compute_graph->GetName().c_str()); return SUCCESS; diff --git a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc index 98e303c7..9da2565d 100644 --- a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc @@ -60,11 +60,11 @@ static ComputeGraphPtr BuildGraph1() { return builder.GetGraph(); } -TEST_F(UtestGlobalStepInsertPass, skip_tune) { +TEST_F(UtestGlobalStepInsertPass, skip_insert) { auto graph = BuildGraph1(); std::string build_mode; std::map options_map; - options_map.insert({ge::BUILD_MODE, BUILD_MODE_TUNING}); + options_map.insert({ge::RUN_FLAG, "0"}); ge::GetThreadLocalContext().SetGraphOption(options_map); GlobalStepInsertPass pass; Status status = pass.Run(graph); From 9a5e6d7e038c7029f5a4b4da581ea5e6ae311869 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Fri, 7 May 2021 10:55:21 +0800 Subject: [PATCH 5/5] add support for train_mode tune --- ge/graph/passes/global_step_insert_pass.cc | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc index 6ed7a7ec..d702e758 100755 --- a/ge/graph/passes/global_step_insert_pass.cc +++ b/ge/graph/passes/global_step_insert_pass.cc @@ -28,6 +28,10 @@ #include "graph/passes/pass_utils.h" #include "graph/ge_context.h" +namespace { +const char *const kFlagOff = "0"; +} // namespace + namespace ge { NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, @@ -73,8 +77,9 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, } Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { + // run_flag off means offline, no need insert global step node which type is variable std::string run_flag; - if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == "0") { + if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) { GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), compute_graph->GetName().c_str()); return SUCCESS;