|
|
@@ -764,7 +764,7 @@ Status HybridModelBuilder::LoadGraph() { |
|
|
|
root_graph->GetAllNodesSize()); |
|
|
|
} |
|
|
|
|
|
|
|
root_graph_ = root_graph; |
|
|
|
hybrid_model_.root_graph_ = root_graph; |
|
|
|
// Reset node id by topological order across all subgraphs |
|
|
|
int64_t index = 0; |
|
|
|
for (const auto &node : root_graph->GetAllNodes()) { |
|
|
@@ -2058,7 +2058,7 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { |
|
|
|
GELOGD("[%s] Start to get parallel group from subgraph: %s", |
|
|
|
node_item->NodeName().c_str(), |
|
|
|
subgraph_name.c_str()); |
|
|
|
auto subgraph = root_graph_->GetSubgraph(subgraph_name); |
|
|
|
auto subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_name); |
|
|
|
GE_CHECK_NOTNULL(subgraph); |
|
|
|
for (const auto &sub_node : subgraph->GetAllNodes()) { |
|
|
|
std::string parallel_group; |
|
|
|