From 08d850e4f92a9f4dd4114195a5fb0449056bcd02 Mon Sep 17 00:00:00 2001 From: "gengchao4@huawei.com" Date: Thu, 22 Apr 2021 14:47:07 +0800 Subject: [PATCH] add support for train_mode tune --- ge/graph/manager/graph_manager.cc | 13 +++++-------- ge/graph/manager/graph_manager.h | 2 +- ge/graph/passes/global_step_insert_pass.cc | 5 ++--- .../passes/global_step_insert_pass_unittest.cc | 4 ++-- 4 files changed, 10 insertions(+), 14 deletions(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index ca1fa9cf..230a29a8 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -1665,7 +1665,7 @@ Status GraphManager::ParseOptions(const std::map &opti // ge.graphType ret = - ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag, options_.build_mode == BUILD_MODE_TUNING); + ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); return GE_GRAPH_OPTIONS_INVALID); @@ -1707,7 +1707,7 @@ Status GraphManager::ParseOptions(const std::map &opti return SUCCESS; } -Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag) { +Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { std::shared_ptr ge_instance_ptr = ge::GELib::GetInstance(); if (ge_instance_ptr == nullptr) { GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); @@ -1715,13 +1715,10 @@ Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, } else if (!ge_instance_ptr->isTrainMode()) { train_flag = false; } else { // ge_instance_ptr->isTrainMode() is true - // tune mode no need check - if (!run_flag && !tune_flag) { - GELOGE(GE_GRAPH_OPTIONS_INVALID, - "Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); - return GE_GRAPH_OPTIONS_INVALID; - } train_flag = true; + if (!run_flag) { + GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); + } } return SUCCESS; } diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index 4c2efb52..ef49cf90 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -277,7 +277,7 @@ class GraphManager { static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); - static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag); + static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag); static bool IsPerfLevelInvalid(int32_t perf_level); diff --git a/ge/graph/passes/global_step_insert_pass.cc b/ge/graph/passes/global_step_insert_pass.cc index 896622ce..6ed7a7ec 100755 --- a/ge/graph/passes/global_step_insert_pass.cc +++ b/ge/graph/passes/global_step_insert_pass.cc @@ -27,7 +27,6 @@ #include "graph/manager/graph_var_manager.h" #include "graph/passes/pass_utils.h" #include "graph/ge_context.h" -#include "graph/tuning_utils.h" namespace ge { NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, @@ -74,8 +73,8 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, } Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { - std::string build_mode; - if (ge::GetContext().GetOption(ge::BUILD_MODE, build_mode) == GRAPH_SUCCESS && build_mode == BUILD_MODE_TUNING) { + std::string run_flag; + if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == "0") { GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), compute_graph->GetName().c_str()); return SUCCESS; diff --git a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc index 98e303c7..9da2565d 100644 --- a/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc +++ b/tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc @@ -60,11 +60,11 @@ static ComputeGraphPtr BuildGraph1() { return builder.GetGraph(); } -TEST_F(UtestGlobalStepInsertPass, skip_tune) { +TEST_F(UtestGlobalStepInsertPass, skip_insert) { auto graph = BuildGraph1(); std::string build_mode; std::map options_map; - options_map.insert({ge::BUILD_MODE, BUILD_MODE_TUNING}); + options_map.insert({ge::RUN_FLAG, "0"}); ge::GetThreadLocalContext().SetGraphOption(options_map); GlobalStepInsertPass pass; Status status = pass.Run(graph);