From e0fd62bc23de71aca8a5be5f1eb7ecdffb5f477f Mon Sep 17 00:00:00 2001 From: wuweikang Date: Thu, 29 Apr 2021 17:40:14 +0800 Subject: [PATCH] fix AddGraph and AddGraphWithCopy --- ge/graph/manager/graph_manager.cc | 83 ++++++++----------- ge/graph/manager/graph_manager.h | 2 + .../graph/manager/graph_manager_unittest.cc | 61 ++++++++++++++ 3 files changed, 99 insertions(+), 47 deletions(-) diff --git a/ge/graph/manager/graph_manager.cc b/ge/graph/manager/graph_manager.cc index 47ce4e52..6facd2f9 100755 --- a/ge/graph/manager/graph_manager.cc +++ b/ge/graph/manager/graph_manager.cc @@ -477,10 +477,13 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, } // Do add graph SetAddGraphCondition(graph_id, kStartAdd); + if (CheckGraphAdded(graph_id, graph) != SUCCESS) { + GELOGE(FAILED, "AddGraph failed."); + return FAILED; + } auto compute_graph = GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); - compute_graph->SetGraphID(graph_id); - + (void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); SetSessionGraphId(compute_graph, graph_id); if (CreateGraphNode(graph_id, graph, options) != SUCCESS) { @@ -512,13 +515,7 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph, return SUCCESS; } -Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph, - const std::map &options, - const OmgContext &omg_context) { - if (HasGraphNode(graph_id)) { - GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u.", graph_id); - return GE_GRAPH_GRAPH_ALREADY_EXIST; - } +Status GraphManager::CheckGraphAdded(const GraphId &graph_id, const Graph &graph) { auto compute_graph = GraphUtils::GetComputeGraph(graph); if (compute_graph != nullptr) { compute_graph->SetGraphID(graph_id); @@ -533,54 +530,48 @@ Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &grap GELOGE(FAILED, "compute graph is null"); return FAILED; } - std::vector input_nodes; - std::vector output_nodes; - auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); - std::string session_graph_id; - if (!AttrUtils::GetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id) || - session_graph_id.empty()) { - session_graph_id = "-1_" + to_string(graph_id); - if (!AttrUtils::SetStr(*new_compute_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { - GELOGW("Set attribute of compute graph failed."); - } - for (auto &subgraph : new_compute_graph->GetAllSubgraphs()) { - (void)AttrUtils::SetStr(*subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); - } - GELOGD("Get graph session_graph_id attr failed, set session id to default value: [0]"); - } + return SUCCESS; +} - GraphNodePtr graph_node = MakeShared(graph_id); - if (graph_node == nullptr) { - GELOGE(FAILED, "GraphNode make shared failed"); +Status GraphManager::AddGraphWithCopy(const GraphId &graph_id, const Graph &graph, + const std::map &options, + const OmgContext &omg_context) { + if (HasGraphNode(graph_id)) { + GELOGE(GE_GRAPH_GRAPH_ALREADY_EXIST, "[GraphManager] graph exists, graph_id = %u", graph_id); + return GE_GRAPH_GRAPH_ALREADY_EXIST; + } + if (CheckGraphAdded(graph_id, graph) != SUCCESS) { + GELOGE(FAILED, "AddGraphWithCopy failed."); return FAILED; } - std::shared_ptr graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph); - if (graph_ptr == nullptr) { - GELOGE(FAILED, "GraphPtr make shared failed"); + IncreaseGraphCount(graph_id); + // Do add graph + auto compute_graph = GraphUtils::GetComputeGraph(graph); + std::vector input_nodes; + std::vector output_nodes; + auto new_compute_graph = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); + GE_CHECK_NOTNULL(new_compute_graph); + new_compute_graph->SetGraphID(graph_id); + SetSessionGraphId(new_compute_graph, graph_id); + std::shared_ptr new_graph_ptr = GraphUtils::CreateGraphPtrFromComputeGraph(new_compute_graph); + if (CreateGraphNode(graph_id, *new_graph_ptr, options) != SUCCESS) { + GELOGE(FAILED, "Failed to create graph_node."); return FAILED; } - // update option about tuning graph - ParseOption(options, BUILD_MODE, options_.build_mode); - ParseOption(options, BUILD_STEP, options_.build_step); - ParseOption(options, TUNING_PATH, options_.tuning_path); - - graph_node->SetGraph(graph_ptr); - graph_node->SetOptions(options); - AddGraphNode(graph_id, graph_node); AddLocalOmgContext(graph_id, omg_context); if (!options_.output_datatype.empty()) { GetLocalOmgContext().output_type = options_.output_datatype; } + if (InitDynamicParams(new_compute_graph) != SUCCESS) { + GELOGE(GRAPH_PARAM_INVALID, "Failed to init params when online infer is dynamic."); + return GRAPH_PARAM_INVALID; + } - CompilerStages &stages = GetCompilerStages(graph_id); - stages.preparer.SetOptions(options_); - Status status = stages.optimizer.SetOptions(options_); - if (status != SUCCESS) { - GELOGE(status, "Graph optimizer set options failed."); - return status; + if (SetStagesOptions(graph_id, options_) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Set stage options failed."); + return INTERNAL_ERROR; } - stages.builder.SetOptions(options_); var_acc_ctrl_.AddGraph(graph_id, new_compute_graph); return SUCCESS; @@ -1019,7 +1010,6 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: if (!graph_node->IsAsync()) { ret = LoadGraph(ge_root_model, graph_node); } else { - GE_CHECK_NOTNULL(ge_root_model); ret = LoadGraphAsync(ge_root_model, graph_node); } if (ret != SUCCESS) { @@ -1034,7 +1024,6 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: if (!graph_node->IsAsync()) { ret = LoadGraph(ge_root_model_ptr, graph_node); } else { - GE_CHECK_NOTNULL(ge_root_model); ret = LoadGraphAsync(ge_root_model_ptr, graph_node); } if (ret != SUCCESS) { diff --git a/ge/graph/manager/graph_manager.h b/ge/graph/manager/graph_manager.h index c9a15126..a9199c06 100644 --- a/ge/graph/manager/graph_manager.h +++ b/ge/graph/manager/graph_manager.h @@ -412,6 +412,8 @@ class GraphManager { void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id); + static Status CheckGraphAdded(const GraphId &graph_id, const Graph &graph); + std::atomic_bool thread_run_flag_; BlockingQueue prerun_args_q_{}; BlockingQueue run_args_q_{}; diff --git a/tests/ut/ge/graph/manager/graph_manager_unittest.cc b/tests/ut/ge/graph/manager/graph_manager_unittest.cc index dad55f3d..d54a0133 100644 --- a/tests/ut/ge/graph/manager/graph_manager_unittest.cc +++ b/tests/ut/ge/graph/manager/graph_manager_unittest.cc @@ -206,6 +206,37 @@ TEST_F(UtestGraphManagerTest, test_add_graph_3) { EXPECT_EQ(status2, ge::SUCCESS); } +TEST_F(UtestGraphManagerTest, test_add_graph_4) { + GraphId graph_id = 1; + GraphManager graph_manager; + // create graph + Graph graph("test_graph"); + CreateGraph(graph); + auto compute_graph = GraphUtils::GetComputeGraph(graph); + (void)AttrUtils::SetBool(*compute_graph, ATTR_NAME_GRAPH_HAS_BEEN_ADDED, true); + + std::map options; + OmgContext context; + Status status = graph_manager.AddGraph(graph_id, graph, options, context); + EXPECT_NE(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_1) { + GraphId graph_id = 1; + GraphManager graph_manager; + + // create graph + Graph graph("test_graph"); + CreateGraph(graph); + GraphNodePtr graph_node = MakeShared(graph_id); + graph_manager.graph_map_.insert({1, graph_node}); + + std::map options; + OmgContext context; + Status status = graph_manager.AddGraphWithCopy(graph_id, graph, options, context); + EXPECT_NE(status, ge::SUCCESS); +} + TEST_F(UtestGraphManagerTest, test_remove_graph_1) { GraphId graph_id = 1; GraphManager graph_manager; @@ -373,3 +404,33 @@ TEST_F(UtestGraphManagerTest, test_check_incre_build_and_pre_run_3) { Status status = graph_manager.CheckIncreBuildAndPreRun(&graph_manager, arg, graph_node, ge_root_model); EXPECT_NE(status, ge::SUCCESS); } + +TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_success) { + GraphId graph_id = 1; + GraphManager graph_manager; + // create graph + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + std::map options; + OmgContext context; + Status status = graph_manager.AddGraphWithCopy(graph_id, graph, options, context); + EXPECT_EQ(status, ge::SUCCESS); +} + +TEST_F(UtestGraphManagerTest, test_add_graph_with_copy_fail) { + GraphId graph_id = 1; + GraphManager graph_manager; + // create graph + ComputeGraphPtr compute_graph = MakeShared("test_graph"); + Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); + + std::map options; + OmgContext context; + Status status = graph_manager.AddGraph(graph_id, graph, options, context); + EXPECT_EQ(status, ge::SUCCESS); + status = graph_manager.RemoveGraph(graph_id); + EXPECT_EQ(status, ge::SUCCESS); + status = graph_manager.AddGraphWithCopy(graph_id, graph, options, context); + EXPECT_NE(status, ge::SUCCESS); +} \ No newline at end of file