| @@ -315,6 +315,20 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| } | } | ||||
| } | } | ||||
| if (is_hccl_op) { | |||||
| for (const auto &src_node : ge_node->GetInControlNodes()) { | |||||
| auto src_node_item = MutableNodeItem(src_node); | |||||
| GE_CHECK_NOTNULL(src_node_item); | |||||
| GELOGD("[%s](%s) Add input control dependent node [%s](%s)", | |||||
| ge_node->GetName().c_str(), | |||||
| ge_node->GetType().c_str(), | |||||
| src_node->GetName().c_str(), | |||||
| src_node->GetType().c_str()); | |||||
| src_node_item->has_observer = true; | |||||
| dependent_for_execution.emplace(src_node); | |||||
| } | |||||
| } | |||||
| // cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
| if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | ||||
| auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | ||||
| @@ -2030,8 +2044,9 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||||
| const auto &node = node_item->node; | const auto &node = node_item->node; | ||||
| auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | auto executor_type = NodeExecutorManager::GetInstance().ResolveExecutorType(*node); | ||||
| if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { | if (executor_type == NodeExecutorManager::ExecutorType::HCCL) { | ||||
| std::string parallel_group; | |||||
| if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { | |||||
| int64_t parallel_group_val = -1; | |||||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group_val)) { | |||||
| std::string parallel_group = std::to_string(parallel_group_val); | |||||
| GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str()); | GELOGD("[%s] Got parallel group = [%s]", node_item->NodeName().c_str(), parallel_group.c_str()); | ||||
| parallel_group_to_nodes_[parallel_group].emplace(node_item); | parallel_group_to_nodes_[parallel_group].emplace(node_item); | ||||
| std::set<std::string> group{parallel_group}; | std::set<std::string> group{parallel_group}; | ||||
| @@ -2047,8 +2062,9 @@ Status HybridModelBuilder::CollectParallelGroups(NodeItem *node_item) { | |||||
| auto subgraph = root_graph_->GetSubgraph(subgraph_name); | auto subgraph = root_graph_->GetSubgraph(subgraph_name); | ||||
| GE_CHECK_NOTNULL(subgraph); | GE_CHECK_NOTNULL(subgraph); | ||||
| for (const auto &sub_node : subgraph->GetAllNodes()) { | for (const auto &sub_node : subgraph->GetAllNodes()) { | ||||
| std::string parallel_group; | |||||
| if (AttrUtils::GetStr(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group)) { | |||||
| int64_t parallel_group_val = -1; | |||||
| if (AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, parallel_group_val)) { | |||||
| std::string parallel_group = std::to_string(parallel_group_val); | |||||
| GELOGD("[%s::%s] Got parallel group = %s", | GELOGD("[%s::%s] Got parallel group = %s", | ||||
| subgraph_name.c_str(), | subgraph_name.c_str(), | ||||
| sub_node->GetName().c_str(), | sub_node->GetName().c_str(), | ||||