@@ -135,6 +135,7 @@ class HybridModel { | |||||
std::string model_name_; | std::string model_name_; | ||||
GeRootModelPtr ge_root_model_; | GeRootModelPtr ge_root_model_; | ||||
std::map<uint32_t, NodeItem *> input_nodes_; | std::map<uint32_t, NodeItem *> input_nodes_; | ||||
ComputeGraphPtr root_graph_; | |||||
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148 | ||||
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148 | ||||
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_; | ||||
@@ -764,7 +764,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
root_graph->GetAllNodesSize()); | root_graph->GetAllNodesSize()); | ||||
} | } | ||||
root_graph_ = root_graph; | |||||
hybrid_model_.root_graph_ = root_graph; | |||||
// Reset node id by topological order across all subgraphs | // Reset node id by topological order across all subgraphs | ||||
int64_t index = 0; | int64_t index = 0; | ||||
for (const auto &node : root_graph->GetAllNodes()) { | for (const auto &node : root_graph->GetAllNodes()) { | ||||
@@ -2058,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||||
GELOGD("[%s] Start to get parallel group from subgraph: %s", | GELOGD("[%s] Start to get parallel group from subgraph: %s", | ||||
node_item->NodeName().c_str(), | node_item->NodeName().c_str(), | ||||
subgraph_name.c_str()); | subgraph_name.c_str()); | ||||
auto subgraph = root_graph_->GetSubgraph(subgraph_name); | |||||
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name); | |||||
GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
for (const auto &sub_node : subgraph->GetAllNodes()) { | for (const auto &sub_node : subgraph->GetAllNodes()) { | ||||
std::string parallel_group; | std::string parallel_group; | ||||
@@ -100,7 +100,6 @@ class HybridModelBuilder { | |||||
NodeItem *MutableNodeItem(const NodePtr &node); | NodeItem *MutableNodeItem(const NodePtr &node); | ||||
GeRootModelPtr ge_root_model_; | GeRootModelPtr ge_root_model_; | ||||
ComputeGraphPtr root_graph_; | |||||
std::map<std::string, GeModelPtr> subgraph_models_; | std::map<std::string, GeModelPtr> subgraph_models_; | ||||
std::map<std::string, NodePtr> constant_op_nodes_; | std::map<std::string, NodePtr> constant_op_nodes_; | ||||
std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | std::map<std::string, std::set<NodeItem *>> parallel_group_to_nodes_; | ||||
@@ -1 +1 @@ | |||||
Subproject commit 86781b7e8ce21d2b901406cc3619d6bea2aeb18e | |||||
Subproject commit 4ff5e3987f2e5d2980019defacaf0891861c84fc |
@@ -276,9 +276,9 @@ TEST_F(UtestGeHybrid, test_parse_parallel_group) { | |||||
op_desc->SetOpKernelLibName("ops_kernel_info_hccl"); | op_desc->SetOpKernelLibName("ops_kernel_info_hccl"); | ||||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | ||||
HybridModel model(root_model); | HybridModel model(root_model); | ||||
model.root_graph_ = compute_graph; | |||||
HybridModelBuilder builder(model); | HybridModelBuilder builder(model); | ||||
builder.root_graph_ = compute_graph; | |||||
ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS); | ASSERT_EQ(builder.CollectParallelGroups(node_item.get()), SUCCESS); | ||||
ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); | ASSERT_EQ(builder.node_to_parallel_groups_.size(), 1); | ||||