diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 19d2ef49..7bd9d35c 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -1637,6 +1637,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem auto temp_graph = MakeShared("temp"); GE_CHECK_NOTNULL(temp_graph); auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); + wrapper_op_desc->SetId(parent_node_item->node_id); GeModelPtr ge_model = subgraph_models_[subgraph_name]; GE_CHECK_NOTNULL(ge_model); hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); @@ -1916,7 +1917,6 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root NodeItem *node_item = nullptr; GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); - GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item)); GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task node_item->input_start = input_start; @@ -2069,22 +2069,17 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { } Status HybridModelBuilder::ParseDependentByParallelGroup() { + for (auto &it : hybrid_model_.node_items_) { + GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get())); + } for (const auto &it : node_to_parallel_groups_) { auto node_item = it.first; - auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); + auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); for (const auto ¶llel_group : it.second) { auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; NodeItem *nearest_dep_node = nullptr; int max_id = -1; for (auto &dep_node : dependent_nodes) { - if (node_item == dep_node) { - continue; - } - auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node); - if (src_engine_type == dst_engine_type) { - continue; - } - if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { nearest_dep_node = dep_node; max_id = dep_node->node_id; @@ -2092,10 +2087,12 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() { } if (nearest_dep_node != nullptr) { - GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", - parallel_group.c_str(), - nearest_dep_node->NodeName().c_str(), - node_item->NodeName().c_str()); + GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str()); + auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node); + if (src_engine_type == dst_executor_type) { + GELOGD("No need to add dependency for nodes with same executor type"); + continue; + } auto &deps = node_item->dependents_for_execution; if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { GELOGD("%s->%s Already has dependency, skip it", @@ -2105,6 +2102,10 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() { } nearest_dep_node->has_observer = true; deps.emplace_back(nearest_dep_node->node); + GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", + parallel_group.c_str(), + nearest_dep_node->NodeName().c_str(), + node_item->NodeName().c_str()); } } }