| @@ -1665,7 +1665,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| // ge.graphType | // ge.graphType | ||||
| ret = | ret = | ||||
| ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag, options_.build_mode == BUILD_MODE_TUNING); | |||||
| ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | GE_IF_BOOL_EXEC(ret != SUCCESS, | ||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); | ||||
| return GE_GRAPH_OPTIONS_INVALID); | return GE_GRAPH_OPTIONS_INVALID); | ||||
| @@ -1707,7 +1707,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag) { | |||||
| Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { | |||||
| std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); | std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); | ||||
| if (ge_instance_ptr == nullptr) { | if (ge_instance_ptr == nullptr) { | ||||
| GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); | GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); | ||||
| @@ -1715,13 +1715,10 @@ Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, | |||||
| } else if (!ge_instance_ptr->isTrainMode()) { | } else if (!ge_instance_ptr->isTrainMode()) { | ||||
| train_flag = false; | train_flag = false; | ||||
| } else { // ge_instance_ptr->isTrainMode() is true | } else { // ge_instance_ptr->isTrainMode() is true | ||||
| // tune mode no need check | |||||
| if (!run_flag && !tune_flag) { | |||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, | |||||
| "Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); | |||||
| return GE_GRAPH_OPTIONS_INVALID; | |||||
| } | |||||
| train_flag = true; | train_flag = true; | ||||
| if (!run_flag) { | |||||
| GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -277,7 +277,7 @@ class GraphManager { | |||||
| static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | ||||
| static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag); | |||||
| static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag); | |||||
| static bool IsPerfLevelInvalid(int32_t perf_level); | static bool IsPerfLevelInvalid(int32_t perf_level); | ||||
| @@ -27,7 +27,6 @@ | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/tuning_utils.h" | |||||
| namespace ge { | namespace ge { | ||||
| NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | ||||
| @@ -74,8 +73,8 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | |||||
| } | } | ||||
| Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | ||||
| std::string build_mode; | |||||
| if (ge::GetContext().GetOption(ge::BUILD_MODE, build_mode) == GRAPH_SUCCESS && build_mode == BUILD_MODE_TUNING) { | |||||
| std::string run_flag; | |||||
| if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == "0") { | |||||
| GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), | GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), | ||||
| compute_graph->GetName().c_str()); | compute_graph->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -60,11 +60,11 @@ static ComputeGraphPtr BuildGraph1() { | |||||
| return builder.GetGraph(); | return builder.GetGraph(); | ||||
| } | } | ||||
| TEST_F(UtestGlobalStepInsertPass, skip_tune) { | |||||
| TEST_F(UtestGlobalStepInsertPass, skip_insert) { | |||||
| auto graph = BuildGraph1(); | auto graph = BuildGraph1(); | ||||
| std::string build_mode; | std::string build_mode; | ||||
| std::map<string, string> options_map; | std::map<string, string> options_map; | ||||
| options_map.insert({ge::BUILD_MODE, BUILD_MODE_TUNING}); | |||||
| options_map.insert({ge::RUN_FLAG, "0"}); | |||||
| ge::GetThreadLocalContext().SetGraphOption(options_map); | ge::GetThreadLocalContext().SetGraphOption(options_map); | ||||
| GlobalStepInsertPass pass; | GlobalStepInsertPass pass; | ||||
| Status status = pass.Run(graph); | Status status = pass.Run(graph); | ||||