| @@ -1637,6 +1637,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem | |||
| auto temp_graph = MakeShared<ComputeGraph>("temp"); | |||
| GE_CHECK_NOTNULL(temp_graph); | |||
| auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); | |||
| wrapper_op_desc->SetId(parent_node_item->node_id); | |||
| GeModelPtr ge_model = subgraph_models_[subgraph_name]; | |||
| GE_CHECK_NOTNULL(ge_model); | |||
| hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); | |||
| @@ -1916,7 +1917,6 @@ Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root | |||
| NodeItem *node_item = nullptr; | |||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); | |||
| GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); | |||
| GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(node_item)); | |||
| GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task | |||
| node_item->input_start = input_start; | |||
| @@ -2069,22 +2069,17 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||
| } | |||
| Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||
| for (auto &it : hybrid_model_.node_items_) { | |||
| GE_CHK_STATUS_RET_NOLOG(CollectParallelGroups(it.second.get())); | |||
| } | |||
| for (const auto &it : node_to_parallel_groups_) { | |||
| auto node_item = it.first; | |||
| auto dst_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | |||
| auto dst_executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node_item->node); | |||
| for (const auto ¶llel_group : it.second) { | |||
| auto &dependent_nodes = parallel_group_to_nodes_[parallel_group]; | |||
| NodeItem *nearest_dep_node = nullptr; | |||
| int max_id = -1; | |||
| for (auto &dep_node : dependent_nodes) { | |||
| if (node_item == dep_node) { | |||
| continue; | |||
| } | |||
| auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*dep_node->node); | |||
| if (src_engine_type == dst_engine_type) { | |||
| continue; | |||
| } | |||
| if (dep_node->node_id < node_item->node_id && dep_node->node_id > max_id) { | |||
| nearest_dep_node = dep_node; | |||
| max_id = dep_node->node_id; | |||
| @@ -2092,10 +2087,12 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||
| } | |||
| if (nearest_dep_node != nullptr) { | |||
| GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", | |||
| parallel_group.c_str(), | |||
| nearest_dep_node->NodeName().c_str(), | |||
| node_item->NodeName().c_str()); | |||
| GELOGD("[%s] Nearest node = [%s]", node_item->NodeName().c_str(), nearest_dep_node->NodeName().c_str()); | |||
| auto src_engine_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*nearest_dep_node->node); | |||
| if (src_engine_type == dst_executor_type) { | |||
| GELOGD("No need to add dependency for nodes with same executor type"); | |||
| continue; | |||
| } | |||
| auto &deps = node_item->dependents_for_execution; | |||
| if (std::find(deps.begin(), deps.end(), nearest_dep_node->node) != deps.end()) { | |||
| GELOGD("%s->%s Already has dependency, skip it", | |||
| @@ -2105,6 +2102,10 @@ Status HybridModelBuilder::ParseDependentByParallelGroup() { | |||
| } | |||
| nearest_dep_node->has_observer = true; | |||
| deps.emplace_back(nearest_dep_node->node); | |||
| GELOGD("Add dependency for nodes with the same parallel group[%s], src = [%s], dst = [%s]", | |||
| parallel_group.c_str(), | |||
| nearest_dep_node->NodeName().c_str(), | |||
| node_item->NodeName().c_str()); | |||
| } | |||
| } | |||
| } | |||