From 1847efba531955b9757e424b5b382e115ac7d0ca Mon Sep 17 00:00:00 2001 From: lianghao Date: Fri, 11 Jun 2021 14:25:01 +0800 Subject: [PATCH] run_flag --- ge/common/profiling/ge_profiling.cc | 2 +- ge/graph/manager/graph_manager.cc | 35 ++++++------------- ge/graph/manager/graph_manager.h | 2 +- inc/framework/common/profiling/ge_profiling.h | 2 +- .../graph/manager/graph_manager_unittest.cc | 21 +++-------- 5 files changed, 18 insertions(+), 44 deletions(-) diff --git a/ge/common/profiling/ge_profiling.cc b/ge/common/profiling/ge_profiling.cc index feedf572..d0343326 100644 --- a/ge/common/profiling/ge_profiling.cc +++ b/ge/common/profiling/ge_profiling.cc @@ -216,6 +216,6 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le return ge::SUCCESS; } -GE_FUNC_VISIBILITY ge::Status ProSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { +GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream) { return ge::SUCCESS; } diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 4e53d950..bf04ed58 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -120,7 +120,6 @@ const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar"; const char *const kCheckPointGraph = "checkpoint_graph"; const char *const kVectorEngine = "VectorEngine"; const char *const kAIcoreEngine = "AIcoreEngine"; -const char *const kRunFlagOffline = "0"; const int32_t kDynamicDimsTypeIsGetNext = 0; const int32_t kDynamicDimsTypeIsData = 1; const char *const kGetNextName = "IteratorV2"; @@ -1789,8 +1788,7 @@ Status GraphManager::ParseOptions(const std::map &opti return GE_GRAPH_OPTIONS_INVALID); // ge.graphType - ret = - ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); + ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); return GE_GRAPH_OPTIONS_INVALID); @@ -2436,6 +2434,8 @@ Status GraphManager::RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute continue; } if (n->GetOpDesc()->GetType() == CONSTANT || n->GetOpDesc()->GetType() == CONSTANTOP) { + // reset const type depend on train_flag + options_.train_graph_flag ? n->GetOpDesc()->SetType(CONSTANTOP) : n->GetOpDesc()->SetType(CONSTANT); if (n->GetOutAllNodes().empty() && n->GetInAllNodes().empty()) { // it is an isolated constant, just remove it if (GraphUtils::RemoveJustNode(compute_graph, n) != GRAPH_SUCCESS) { @@ -2762,35 +2762,22 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { "Please pay attention to it."); } - GE_CHK_STATUS_RET(ChangeConstType(compute_graph)); + ChangeConstTypeWhenTraining(compute_graph); GELOGI("End optimize after merge sub graph."); return SUCCESS; } -Status GraphManager::ChangeConstType(const ComputeGraphPtr &compute_graph) { - // run_flag off means offline, on means online - string run_flag; - (void)ge::GetContext().GetOption(ge::RUN_FLAG, run_flag); - // The constant for online is CONSTANTOP, and is CONSTANT for offline. They will be unified in future. - if (run_flag == kRunFlagOffline) { - GELOGI("Offline mode, change all Constant to Const."); - } else { - GELOGI("Online mode, change all Const to Constant."); - } - for (NodePtr &n : compute_graph->GetAllNodes()) { - GE_CHECK_NOTNULL(n); - if (n->GetType() == CONSTANT || n->GetType() == CONSTANTOP) { - auto op_desc = n->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (run_flag == kRunFlagOffline) { - op_desc->SetType(CONSTANT); - } else { - op_desc->SetType(CONSTANTOP); +void GraphManager::ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph) { + // The constant for train is CONSTANTOP, and is CONSTANT for inference. They will be unified in future. + if (options_.train_graph_flag) { + for (NodePtr &n : compute_graph->GetAllNodes()) { + // This can ensure that n is not a null pointer + if (n->GetOpDesc()->GetType() == CONSTANT) { + n->GetOpDesc()->SetType(CONSTANTOP); } } } - return SUCCESS; } Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index c8459b16..945a5e5d 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -375,7 +375,7 @@ class GraphManager { static void ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_node, RunAsyncCallback callback, Status ret, const string &log); - Status ChangeConstType(const ComputeGraphPtr &compute_graph); + void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector &inputs, ge::ComputeGraphPtr &compute_graph, uint64_t session_id); diff --git a/inc/framework/common/profiling/ge_profiling.h b/inc/framework/common/profiling/ge_profiling.h index 5b3c75c5..a8de56a8 100644 --- a/inc/framework/common/profiling/ge_profiling.h +++ b/inc/framework/common/profiling/ge_profiling.h @@ -43,6 +43,6 @@ GE_FUNC_VISIBILITY ge::Status RegProfCtrlCallback(MsprofCtrlCallback func); GE_FUNC_VISIBILITY ge::Status RegProfSetDeviceCallback(MsprofSetDeviceCallback func); GE_FUNC_VISIBILITY ge::Status RegProfReporterCallback(MsprofReporterCallback func); GE_FUNC_VISIBILITY ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t len); -GE_FUNC_VISIBILITY ge::Status ProSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream); +GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream); #endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_ diff --git a/tests/ut/ge/graph/manager/graph_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_manager_unittest.cc index 5cc2a7f6..9bae10eb 100644 --- a/tests/ut/ge/graph/manager/graph_manager_unittest.cc +++ b/tests/ut/ge/graph/manager/graph_manager_unittest.cc @@ -579,29 +579,16 @@ TEST_F(UtestGraphManagerTest, test_prerunthread_failed_2) { // } TEST_F(UtestGraphManagerTest, ChangeAndDeleteConst_success) { - std::map options_map; - options_map.insert({ge::RUN_FLAG, "0"}); - ge::GetThreadLocalContext().SetGraphOption(options_map); - GraphId graph_id = 1; GraphManager graph_manager; graph_manager.options_.train_graph_flag = true; auto graph = CreateGraphWithIsolatedConst(); - Status status = graph_manager.ChangeConstType(graph); - EXPECT_EQ(status, ge::SUCCESS); - auto constant1 = graph->FindFirstNodeMatchType("Constant"); - EXPECT_EQ(constant1, nullptr); - - options_map.clear(); - options_map.insert({ge::RUN_FLAG, "1"}); - ge::GetThreadLocalContext().SetGraphOption(options_map); - status = graph_manager.ChangeConstType(graph); - EXPECT_EQ(status, ge::SUCCESS); - constant1 = graph->FindFirstNodeMatchType("Constant"); - EXPECT_NE(constant1, nullptr); + graph_manager.ChangeConstTypeWhenTraining(graph); + auto const1 = graph->FindFirstNodeMatchType("Const"); + EXPECT_EQ(const1, nullptr); - status = graph_manager.RemoveIsolatedConstInThisGraph(graph); + Status status = graph_manager.RemoveIsolatedConstInThisGraph(graph); EXPECT_EQ(status, ge::SUCCESS); auto all_nodes = graph->GetDirectNode(); EXPECT_EQ(all_nodes.size(), 3);