Browse Source

add support for train_mode tune

tags/v1.3.0
gengchao4@huawei.com 3 years ago
parent
commit
9a5e6d7e03
1 changed files with 6 additions and 1 deletions
  1. +6
    -1
      ge/graph/passes/global_step_insert_pass.cc

+ 6
- 1
ge/graph/passes/global_step_insert_pass.cc View File

@@ -28,6 +28,10 @@
#include "graph/passes/pass_utils.h" #include "graph/passes/pass_utils.h"
#include "graph/ge_context.h" #include "graph/ge_context.h"


namespace {
const char *const kFlagOff = "0";
} // namespace

namespace ge { namespace ge {
NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
const string &node_type, const string &node_type,
@@ -73,8 +77,9 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
} }


Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) {
// run_flag off means offline, no need insert global step node which type is variable
std::string run_flag; std::string run_flag;
if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == "0") {
if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) {
GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(),
compute_graph->GetName().c_str()); compute_graph->GetName().c_str());
return SUCCESS; return SUCCESS;


Loading…
Cancel
Save