Browse Source

add support for train_mode tune

tags/v1.3.0
gengchao4@huawei.com 3 years ago
parent
commit
08d850e4f9
4 changed files with 10 additions and 14 deletions
  1. +5
    -8
      ge/graph/manager/graph_manager.cc
  2. +1
    -1
      ge/graph/manager/graph_manager.h
  3. +2
    -3
      ge/graph/passes/global_step_insert_pass.cc
  4. +2
    -2
      tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc

+ 5
- 8
ge/graph/manager/graph_manager.cc View File

@@ -1665,7 +1665,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti


// ge.graphType // ge.graphType
ret = ret =
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag, options_.build_mode == BUILD_MODE_TUNING);
ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
GE_IF_BOOL_EXEC(ret != SUCCESS, GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid"); GELOGE(GE_GRAPH_OPTIONS_INVALID, "Key:ge.runFlag value is invalid");
return GE_GRAPH_OPTIONS_INVALID); return GE_GRAPH_OPTIONS_INVALID);
@@ -1707,7 +1707,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return SUCCESS; return SUCCESS;
} }


Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag) {
Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) {
std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance();
if (ge_instance_ptr == nullptr) { if (ge_instance_ptr == nullptr) {
GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized");
@@ -1715,13 +1715,10 @@ Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag,
} else if (!ge_instance_ptr->isTrainMode()) { } else if (!ge_instance_ptr->isTrainMode()) {
train_flag = false; train_flag = false;
} else { // ge_instance_ptr->isTrainMode() is true } else { // ge_instance_ptr->isTrainMode() is true
// tune mode no need check
if (!run_flag && !tune_flag) {
GELOGE(GE_GRAPH_OPTIONS_INVALID,
"Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag);
return GE_GRAPH_OPTIONS_INVALID;
}
train_flag = true; train_flag = true;
if (!run_flag) {
GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag);
}
} }
return SUCCESS; return SUCCESS;
} }


+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -277,7 +277,7 @@ class GraphManager {


static Status ParseParallelNum(const std::string &parallel_num, const std::string &key, int &num); static Status ParseParallelNum(const std::string &parallel_num, const std::string &key, int &num);


static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag, const bool &tune_flag);
static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag);


static bool IsPerfLevelInvalid(int32_t perf_level); static bool IsPerfLevelInvalid(int32_t perf_level);




+ 2
- 3
ge/graph/passes/global_step_insert_pass.cc View File

@@ -27,7 +27,6 @@
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"
#include "graph/passes/pass_utils.h" #include "graph/passes/pass_utils.h"
#include "graph/ge_context.h" #include "graph/ge_context.h"
#include "graph/tuning_utils.h"


namespace ge { namespace ge {
NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
@@ -74,8 +73,8 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
} }


Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) {
std::string build_mode;
if (ge::GetContext().GetOption(ge::BUILD_MODE, build_mode) == GRAPH_SUCCESS && build_mode == BUILD_MODE_TUNING) {
std::string run_flag;
if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == "0") {
GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(),
compute_graph->GetName().c_str()); compute_graph->GetName().c_str());
return SUCCESS; return SUCCESS;


+ 2
- 2
tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc View File

@@ -60,11 +60,11 @@ static ComputeGraphPtr BuildGraph1() {
return builder.GetGraph(); return builder.GetGraph();
} }


TEST_F(UtestGlobalStepInsertPass, skip_tune) {
TEST_F(UtestGlobalStepInsertPass, skip_insert) {
auto graph = BuildGraph1(); auto graph = BuildGraph1();
std::string build_mode; std::string build_mode;
std::map<string, string> options_map; std::map<string, string> options_map;
options_map.insert({ge::BUILD_MODE, BUILD_MODE_TUNING});
options_map.insert({ge::RUN_FLAG, "0"});
ge::GetThreadLocalContext().SetGraphOption(options_map); ge::GetThreadLocalContext().SetGraphOption(options_map);
GlobalStepInsertPass pass; GlobalStepInsertPass pass;
Status status = pass.Run(graph); Status status = pass.Run(graph);


Loading…
Cancel
Save