| @@ -122,6 +122,7 @@ const char *const kVectorEngine = "VectorEngine"; | |||||
| const char *const kAIcoreEngine = "AIcoreEngine"; | const char *const kAIcoreEngine = "AIcoreEngine"; | ||||
| const int32_t kDynamicDimsTypeIsGetNext = 0; | const int32_t kDynamicDimsTypeIsGetNext = 0; | ||||
| const int32_t kDynamicDimsTypeIsData = 1; | const int32_t kDynamicDimsTypeIsData = 1; | ||||
| const int32_t kBase = 10; | |||||
| const char *const kGetNextName = "IteratorV2"; | const char *const kGetNextName = "IteratorV2"; | ||||
| const uint32_t kInitGraphCount = 1; | const uint32_t kInitGraphCount = 1; | ||||
| const uint32_t kNotAdded = 0; | const uint32_t kNotAdded = 0; | ||||
| @@ -1788,7 +1789,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| return GE_GRAPH_OPTIONS_INVALID); | return GE_GRAPH_OPTIONS_INVALID); | ||||
| // ge.graphType | // ge.graphType | ||||
| ret = ParseTrainGraphFlag(options_.run_graph_flag, options_.train_graph_flag); | |||||
| ret = ParseTrainGraphFlag(options_.train_graph_flag); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | GE_IF_BOOL_EXEC(ret != SUCCESS, | ||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); | GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Parse][TrainGraphFlag] Key:ge.runFlag value is invalid"); | ||||
| return GE_GRAPH_OPTIONS_INVALID); | return GE_GRAPH_OPTIONS_INVALID); | ||||
| @@ -1833,19 +1834,17 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::ParseTrainGraphFlag(const bool &run_flag, bool &train_flag) { | |||||
| std::shared_ptr<GELib> ge_instance_ptr = ge::GELib::GetInstance(); | |||||
| if (ge_instance_ptr == nullptr) { | |||||
| GELOGW("[Initialize] set train_graph_flag to 0 when GE is not initialized or finalized"); | |||||
| train_flag = false; | |||||
| } else if (!ge_instance_ptr->isTrainMode()) { | |||||
| train_flag = false; | |||||
| } else { // ge_instance_ptr->isTrainMode() is true | |||||
| train_flag = true; | |||||
| if (!run_flag) { | |||||
| GELOGW("Key:ge.runFlag, its value %d is invalid, it must be 1 when GElib::is_train_mode_ flag is 1", run_flag); | |||||
| // OPTION_GRAPH_RUN_MODE is supposed to be a session-level option, but it used to be set to global-level in the past. | |||||
| // If can not parse from session, it can parse from global by GetContext(). | |||||
| Status GraphManager::ParseTrainGraphFlag(bool &train_flag) { | |||||
| train_flag = false; | |||||
| string run_mode; | |||||
| if (GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == SUCCESS && !run_mode.empty()) { | |||||
| if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, kBase)) >= TRAIN) { | |||||
| train_flag = true; | |||||
| } | } | ||||
| } | } | ||||
| GELOGI("Is train flag: %d.", train_flag); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -292,7 +292,7 @@ class GraphManager { | |||||
| static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | static Status ParseParallelNum(const std::string ¶llel_num, const std::string &key, int &num); | ||||
| static Status ParseTrainGraphFlag(const bool &run_flag, bool &train_flag); | |||||
| static Status ParseTrainGraphFlag(bool &train_flag); | |||||
| static bool IsPerfLevelInvalid(int32_t perf_level); | static bool IsPerfLevelInvalid(int32_t perf_level); | ||||
| @@ -28,10 +28,6 @@ | |||||
| #include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| namespace { | |||||
| const char *const kFlagOff = "0"; | |||||
| } // namespace | |||||
| namespace ge { | namespace ge { | ||||
| NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | ||||
| const string &node_type, | const string &node_type, | ||||
| @@ -80,13 +76,6 @@ NodePtr GlobalStepInsertPass::InsertOp(ComputeGraphPtr &compute_graph, | |||||
| } | } | ||||
| Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | Status GlobalStepInsertPass::Run(ComputeGraphPtr compute_graph) { | ||||
| // run_flag off means offline, no need insert global step node which type is variable | |||||
| std::string run_flag; | |||||
| if (ge::GetContext().GetOption(ge::RUN_FLAG, run_flag) == GRAPH_SUCCESS && run_flag == kFlagOff) { | |||||
| GELOGI("compute_graph [%u] [%s] skip insert global step", compute_graph->GetGraphID(), | |||||
| compute_graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); | NodePtr output_node = compute_graph->FindFirstNodeMatchType(NETOUTPUT); | ||||
| if (output_node == nullptr) { | if (output_node == nullptr) { | ||||
| GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); | GELOGD("Node type %s can't be found in graph %u", NETOUTPUT, compute_graph->GetGraphID()); | ||||
| @@ -574,6 +574,7 @@ graphStatus Impl::Init(const Graph &graph, const std::map<std::string, std::stri | |||||
| options_.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0))); | options_.insert(std::pair<string, string>(string(ge::RUN_FLAG), to_string(0))); | ||||
| options_.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0))); | options_.insert(std::pair<string, string>(string(ge::TRAIN_FLAG), to_string(0))); | ||||
| options_.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), to_string(0))); | options_.insert(std::pair<string, string>(string(ge::SAVE_ORIGINAL_MODEL), to_string(0))); | ||||
| options_.insert(std::pair<string, string>(string(ge::OPTION_GRAPH_RUN_MODE), to_string(0))); | |||||
| // print ge option map | // print ge option map | ||||
| ge::PrintOptionMap(options_, "ge option"); | ge::PrintOptionMap(options_, "ge option"); | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "graph/build/memory/buffer_pool_mem_assigner.h" | #include "graph/build/memory/buffer_pool_mem_assigner.h" | ||||
| #include "graph/build/memory/graph_mem_assigner.h" | #include "graph/build/memory/graph_mem_assigner.h" | ||||
| #include "graph/build/stream_allocator.h" | #include "graph/build/stream_allocator.h" | ||||
| #include "graph/ge_local_context.h" | |||||
| #undef protected | #undef protected | ||||
| #undef private | #undef private | ||||
| @@ -260,6 +261,10 @@ TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_serial_graph_assign_success) | |||||
| } | } | ||||
| TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) { | TEST_F(UtestBufferPoolMemAssignerTest, buffer_pool_subgraph_with_inner_dependency_assign_success) { | ||||
| std::string build_mode; | |||||
| std::map<string, string> options_map; | |||||
| options_map.insert({ge::OPTION_GRAPH_RUN_MODE, "1"}); | |||||
| ge::GetThreadLocalContext().SetGraphOption(options_map); | |||||
| ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); | ut::BufferPoolGraphBuilder builder("SubgraphWithInnerDependency"); | ||||
| ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); | ge::ComputeGraphPtr graph = builder.BuildSubgraphWithInnerDependency(); | ||||
| BufferPoolMemoryPass buffer_pool_mem_pass; | BufferPoolMemoryPass buffer_pool_mem_pass; | ||||
| @@ -34,7 +34,6 @@ | |||||
| #include "graph/tuning_utils.h" | #include "graph/tuning_utils.h" | ||||
| #include "graph_builder_utils.h" | #include "graph_builder_utils.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/ge_local_context.h" | |||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| #undef protected | #undef protected | ||||
| #undef private | #undef private | ||||
| @@ -62,13 +61,9 @@ static ComputeGraphPtr BuildGraph1() { | |||||
| TEST_F(UtestGlobalStepInsertPass, skip_insert) { | TEST_F(UtestGlobalStepInsertPass, skip_insert) { | ||||
| auto graph = BuildGraph1(); | auto graph = BuildGraph1(); | ||||
| std::string build_mode; | |||||
| std::map<string, string> options_map; | |||||
| options_map.insert({ge::RUN_FLAG, "0"}); | |||||
| ge::GetThreadLocalContext().SetGraphOption(options_map); | |||||
| GlobalStepInsertPass pass; | GlobalStepInsertPass pass; | ||||
| Status status = pass.Run(graph); | Status status = pass.Run(graph); | ||||
| EXPECT_EQ(status, SUCCESS); | EXPECT_EQ(status, SUCCESS); | ||||
| NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP); | NodePtr found_node = graph->FindNode(NODE_NAME_GLOBAL_STEP); | ||||
| EXPECT_EQ(found_node, nullptr); | |||||
| EXPECT_NE(found_node, nullptr); | |||||
| } | } | ||||