From 9aeb93b58f955b4a930ca6d21e7737d3140e2165 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=B6=9B?= Date: Fri, 25 Jun 2021 14:37:39 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9B=9E=E9=80=80=20'Pull=20Request=20!1801=20?= =?UTF-8?q?:=20train=5Fgraph=5Fflag'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ge/graph/manager/graph_manager.cc | 23 ++++++++++--------- ge/graph/manager/graph_manager.h | 2 +- ge/graph/passes/global_step_insert_pass.cc | 11 +++++++++ ge/ir_build/ge_ir_build.cc | 1 - .../buffer_pool_mem_assigner_unittest.cc | 5 ---- .../global_step_insert_pass_unittest.cc | 7 +++++- 6 files changed, 30 insertions(+), 19 deletions(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 66026f8d..b862a7d6 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -122,7 +122,6 @@ const char *const kVectorEngine = "VectorEngine"; const char *const kAIcoreEngine = "AIcoreEngine"; const int32_t kDynamicDimsTypeIsGetNext = 0; const int32_t kDynamicDimsTypeIsData = 1; -const int32_t kBase = 10; const char *const kGetNextName = "IteratorV2"; const uint32_t kInitGraphCount = 1; const uint32_t kNotAdded = 0; @@ -1789,7 +1788,7 @@ Status GraphManager::ParseOptions(const std::map &opti return GE_GRAPH_OPTIONS_INVALID); // ge.graphType - ret = ParseTrainGraphFlag(options_.train_graph_flag); + ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); return GE_GRAPH_OPTIONS_INVALID); @@ -1834,17 +1833,19 @@ Status GraphManager::ParseOptions(const std::map &opti return SUCCESS; } -// OPTION_GRAPH_RUN_MODE is supposed to be a session-level option, but it used to be set to global-level in the past. -// If can not parse from session, it can parse from global by GetContext(). -Status GraphManager::ParseTrainGraphFlag(bool &train_flag) { - train_flag = false; - string run_mode; - if (GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == SUCCESS && !run_mode.empty()) { - if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, kBase)) >= TRAIN) { - train_flag = true; +Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { + std::shared_ptr ge_instance_ptr = ge::GELib::GetInstance(); + if (ge_instance_ptr == nullptr) { + GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); + train_flag = false; + } else if (!ge_instance_ptr->isTrainMode()) { + train_flag = false; + } else { // ge_instance_ptr->isTrainMode() is true + train_flag = true; + if (!run_flag) { + GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); } } - GELOGI("Is train flag: %d.", train_flag); return SUCCESS; } diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 3475da6d..93ce354a 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -292,7 +292,7 @@ class GraphManager { static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); - static Status ParseTrainGraphFlag(bool &train_flag); + static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag); static bool IsPerfLevelInvalid(int32_t perf_level); diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc index 297e4ee2..f27641fc 100755 --- a/ge/graph/passes/global_step_insert_pass.cc +++ b/ge/graph/passes/global_step_insert_pass.cc @@ -28,6 +28,10 @@ #include "graph/passes/pass_utils.h" #include "graph/ge_context.h" +namespace { +const char *const kFlagOff = "0"; +} // namespace + namespace ge { NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, const string &node_type, @@ -76,6 +80,13 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, } Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { + // run_flag off means offline, no need insert global step node which type is variable + std::string run_flag; + if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) { + GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), + compute_graph->GetName().c_str()); + return SUCCESS; + } NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); if (output_node == nullptr) { GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 168bcd34..052af2f6 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -601,7 +601,6 @@ graphStatus Impl::Init(const Graph &graph, const std::map(string(ge::RUN_FLAG), to_string(0))); options_.insert(std::pair(string(ge::TRAIN_FLAG), to_string(0))); options_.insert(std::pair(string(ge::SAVE_ORIGINAL_MODEL), to_string(0))); - options_.insert(std::pair(string(ge::OPTION_GRAPH_RUN_MODE), to_string(0))); // print ge option map ge::PrintOptionMap(options_, "ge option"); diff --git a/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc b/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc index 05141785..96283250 100644 --- a/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc +++ b/tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc @@ -29,7 +29,6 @@ #include "graph/build/memory/buffer_pool_mem_assigner.h" #include "graph/build/memory/graph_mem_assigner.h" #include "graph/build/stream_allocator.h" -#include "graph/ge_local_context.h" #undef protected #undef private @@ -261,10 +260,6 @@ TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_serial_graph_assign_success) } TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) { - std::string build_mode; - std::map options_map; - options_map.insert({ge::OPTION_GRAPH_RUN_MODE, "1"}); - ge::GetThreadLocalContext().SetGraphOption(options_map); ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); BufferPoolMemoryPass buffer_pool_mem_pass; diff --git a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc index cc9a4077..9da2565d 100644 --- a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc @@ -34,6 +34,7 @@ #include "graph/tuning_utils.h" #include "graph_builder_utils.h" #include "graph/ge_context.h" +#include "graph/ge_local_context.h" #include "inc/pass_manager.h" #undef protected #undef private @@ -61,9 +62,13 @@ static ComputeGraphPtr BuildGraph1() { TEST_F(UtestGlobalStepInsertPass, skip_insert) { auto graph = BuildGraph1(); + std::string build_mode; + std::map options_map; + options_map.insert({ge::RUN_FLAG, "0"}); + ge::GetThreadLocalContext().SetGraphOption(options_map); GlobalStepInsertPass pass; Status status = pass.Run(graph); EXPECT_EQ(status, SUCCESS); NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP); - EXPECT_NE(found_node, nullptr); + EXPECT_EQ(found_node, nullptr); }