Browse Source

!1801 train_graph_flag

From: @dimitri_rose
Reviewed-by: 
Signed-off-by:
tags/v1.5.1
mindspore-ci-bot Gitee 3 years ago
parent
commit
aba820b44f
6 changed files with 19 additions and 30 deletions
  1. +11
    -12
      ge/graph/manager/graph_manager.cc
  2. +1
    -1
      ge/graph/manager/graph_manager.h
  3. +0
    -11
      ge/graph/passes/global_step_insert_pass.cc
  4. +1
    -0
      ge/ir_build/ge_ir_build.cc
  5. +5
    -0
      tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc
  6. +1
    -6
      tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc

+ 11
- 12
ge/graph/manager/graph_manager.cc View File

@@ -122,6 +122,7 @@ const char *const kVectorEngine = "VectorEngine";
const char *const kAIcoreEngine = "AIcoreEngine";
const int32_t kDynamicDimsTypeIsGetNext = 0;
const int32_t kDynamicDimsTypeIsData = 1;
const int32_t kBase = 10;
const char *const kGetNextName = "IteratorV2";
const uint32_t kInitGraphCount = 1;
const uint32_t kNotAdded = 0;
@@ -1788,7 +1789,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return GE_GRAPH_OPTIONS_INVALID);

// ge.graphType
ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag);
ret = ParseTrainGraphFlag(options_.train_graph_flag);
GE_IF_BOOL_EXEC(ret != SUCCESS,
GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid");
return GE_GRAPH_OPTIONS_INVALID);
@@ -1833,19 +1834,17 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti
return SUCCESS;
}

Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) {
std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance();
if (ge_instance_ptr == nullptr) {
GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized");
train_flag = false;
} else if (!ge_instance_ptr->isTrainMode()) {
train_flag = false;
} else { // ge_instance_ptr->isTrainMode() is true
train_flag = true;
if (!run_flag) {
GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag);
// OPTION_GRAPH_RUN_MODE is supposed to be a session-level option, but it used to be set to global-level in the past.
// If can not parse from session, it can parse from global by GetContext().
Status GraphManager::ParseTrainGraphFlag(bool &train_flag) {
train_flag = false;
string run_mode;
if (GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == SUCCESS && !run_mode.empty()) {
if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, kBase)) >= TRAIN) {
train_flag = true;
}
}
GELOGI("Is train flag: %d.", train_flag);
return SUCCESS;
}



+ 1
- 1
ge/graph/manager/graph_manager.h View File

@@ -292,7 +292,7 @@ class GraphManager {

static Status ParseParallelNum(const std::string &parallel_num, const std::string &key, int &num);

static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag);
static Status ParseTrainGraphFlag(bool &train_flag);

static bool IsPerfLevelInvalid(int32_t perf_level);



+ 0
- 11
ge/graph/passes/global_step_insert_pass.cc View File

@@ -28,10 +28,6 @@
#include "graph/passes/pass_utils.h"
#include "graph/ge_context.h"

namespace {
const char *const kFlagOff = "0";
} // namespace

namespace ge {
NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
const string &node_type,
@@ -80,13 +76,6 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph,
}

Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) {
// run_flag off means offline, no need insert global step node which type is variable
std::string run_flag;
if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) {
GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(),
compute_graph->GetName().c_str());
return SUCCESS;
}
NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT);
if (output_node == nullptr) {
GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID());


+ 1
- 0
ge/ir_build/ge_ir_build.cc View File

@@ -574,6 +574,7 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri
options_.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0)));
options_.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0)));
options_.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), to_string(0)));
options_.insert(std::pair<string, string>(string(ge::OPTION_GRAPH_RUN_MODE), to_string(0)));
// print ge option map
ge::PrintOptionMap(options_, "ge option");



+ 5
- 0
tests/ut/ge/graph/build/buffer_pool_mem_assigner_unittest.cc View File

@@ -29,6 +29,7 @@
#include "graph/build/memory/buffer_pool_mem_assigner.h"
#include "graph/build/memory/graph_mem_assigner.h"
#include "graph/build/stream_allocator.h"
#include "graph/ge_local_context.h"
#undef protected
#undef private

@@ -260,6 +261,10 @@ TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_serial_graph_assign_success)
}

TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) {
std::string build_mode;
std::map<string, string> options_map;
options_map.insert({ge::OPTION_GRAPH_RUN_MODE, "1"});
ge::GetThreadLocalContext().SetGraphOption(options_map);
ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency");
ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency();
BufferPoolMemoryPass buffer_pool_mem_pass;


+ 1
- 6
tests/ut/ge/graph/passes/global_step_insert_pass_unittest.cc View File

@@ -34,7 +34,6 @@
#include "graph/tuning_utils.h"
#include "graph_builder_utils.h"
#include "graph/ge_context.h"
#include "graph/ge_local_context.h"
#include "inc/pass_manager.h"
#undef protected
#undef private
@@ -62,13 +61,9 @@ static ComputeGraphPtr BuildGraph1() {

TEST_F(UtestGlobalStepInsertPass, skip_insert) {
auto graph = BuildGraph1();
std::string build_mode;
std::map<string, string> options_map;
options_map.insert({ge::RUN_FLAG, "0"});
ge::GetThreadLocalContext().SetGraphOption(options_map);
GlobalStepInsertPass pass;
Status status = pass.Run(graph);
EXPECT_EQ(status, SUCCESS);
NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP);
EXPECT_EQ(found_node, nullptr);
EXPECT_NE(found_node, nullptr);
}

Loading…
Cancel
Save