diff --git a/ge/hybrid/executor/hybrid_execution_context.h b/ge/hybrid/executor/hybrid_execution_context.h index 003e8010..54840c6a 100644 --- a/ge/hybrid/executor/hybrid_execution_context.h +++ b/ge/hybrid/executor/hybrid_execution_context.h @@ -68,7 +68,7 @@ struct GraphExecutionContext { DumpProperties dump_properties; bool trace_enabled = false; bool dump_enabled = false; - std::atomic_bool is_eos_; + std::atomic_bool is_eos_{false}; long profiling_level = 0; long iteration = 0; void *global_step = nullptr; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 60fdf55a..0716068b 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -323,6 +323,19 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s } } + for (const auto &src_node : ge_node->GetInControlNodes()) { + auto src_node_item = MutableNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + if (is_hccl_op || src_node_item->IsHcclOp()) { + GELOGD("[%s](%s) Add input control dependent node [%s](%s)", + ge_node->GetName().c_str(), + ge_node->GetType().c_str(), + src_node->GetName().c_str(), + src_node->GetType().c_str()); + dependent_for_execution.emplace(src_node); + } + } + // cond or branch need to be prepared before the execution of IF or CASE if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 9746585d..9b151550 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -589,3 +589,35 @@ TEST_F(UtestGeHybrid, test_key_for_kernel_bin) { EXPECT_EQ(atomic_task->GetKeyForTvmMetaData(), ATOMIC_ATTR_TVM_METADATA); EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname"); } + +TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { + NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl", + NodeExecutorManager::ExecutorType::HCCL); + ComputeGraphPtr compute_graph = MakeShared("test"); + + OpDescPtr op_desc = CreateOpDesc("Add", "Add"); + auto node = compute_graph->AddNode(op_desc); + std::unique_ptr node_item; + NodeItem::Create(node, node_item); + node_item->node_id = 0; + + OpDescPtr op_desc_1 = CreateOpDesc("AllReduce", "AllReduce"); + op_desc_1->SetOpKernelLibName("ops_kernel_info_hccl"); + auto node_1 = compute_graph->AddNode(op_desc_1); + std::unique_ptr node_item_1; + NodeItem::Create(node_1, node_item_1); + node_item_1->node_id = 1; + + node->GetOutControlAnchor()->LinkTo(node_1->GetInControlAnchor()); + + GeRootModelPtr root_model = MakeShared(compute_graph); + HybridModel model(root_model); + model.root_graph_ = compute_graph; + model.node_items_.emplace(node, std::move(node_item)); + + HybridModelBuilder builder(model); + std::vector deps; + ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); + ASSERT_TRUE(model.GetNodeItem(node)->has_observer); + ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); +} \ No newline at end of file