From 7a3dba72af047cf0b4fb913e1b4e08975b93d326 Mon Sep 17 00:00:00 2001 From: lichun Date: Wed, 31 Mar 2021 10:20:02 +0800 Subject: [PATCH] Bugfix: fix null owner graph error --- ge/hybrid/model/hybrid_model.h | 1 + ge/hybrid/model/hybrid_model_builder.cc | 4 ++-- ge/hybrid/model/hybrid_model_builder.h | 1 - metadef | 2 +- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index fae53679..62095d42 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -135,6 +135,7 @@ class HybridModel { std::string model_name_; GeRootModelPtr ge_root_model_; std::map input_nodes_; + ComputeGraphPtr root_graph_; std::map device_variable_nodes_; //lint !e148 std::map host_variable_nodes_; //lint !e148 std::map> variable_tensors_; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index f52732c9..1be76331 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -764,7 +764,7 @@ Status HybridModelBuilder::LoadGraph() { root_graph->GetAllNodesSize()); } - root_graph_ = root_graph; + hybrid_model_.root_graph_ = root_graph; // Reset node id by topological order across all subgraphs int64_t index = 0; for (const auto &node : root_graph->GetAllNodes()) { @@ -2058,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { GELOGD("[%s] Start to get parallel group from subgraph: %s", node_item->NodeName().c_str(), subgraph_name.c_str()); - auto subgraph = root_graph_->GetSubgraph(subgraph_name); + auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name); GE_CHECK_NOTNULL(subgraph); for (const auto &sub_node : subgraph->GetAllNodes()) { std::string parallel_group; diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 30241003..430637dc 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -100,7 +100,6 @@ class HybridModelBuilder { NodeItem *MutableNodeItem(const NodePtr &node); GeRootModelPtr ge_root_model_; - ComputeGraphPtr root_graph_; std::map subgraph_models_; std::map constant_op_nodes_; std::map> parallel_group_to_nodes_; diff --git a/metadef b/metadef index 86781b7e..4ff5e398 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 86781b7e8ce21d2b901406cc3619d6bea2aeb18e +Subproject commit 4ff5e3987f2e5d2980019defacaf0891861c84fc diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 57230f30..18bcd7da 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -276,9 +276,9 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) { op_desc->SetOpKernelLibName("ops_kernel_info_hccl"); GeRootModelPtr root_model = MakeShared(compute_graph); HybridModel model(root_model); + model.root_graph_ = compute_graph; HybridModelBuilder builder(model); - builder.root_graph_ = compute_graph; ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS); ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1);