From: @xchu42 Reviewed-by: @wqtshg,@wqtshg,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -68,7 +68,7 @@ struct GraphExecutionContext { | |||||
DumpProperties dump_properties; | DumpProperties dump_properties; | ||||
bool trace_enabled = false; | bool trace_enabled = false; | ||||
bool dump_enabled = false; | bool dump_enabled = false; | ||||
std::atomic_bool is_eos_; | |||||
std::atomic_bool is_eos_{false}; | |||||
long profiling_level = 0; | long profiling_level = 0; | ||||
long iteration = 0; | long iteration = 0; | ||||
void *global_step = nullptr; | void *global_step = nullptr; | ||||
@@ -323,6 +323,19 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
} | } | ||||
} | } | ||||
for (const auto &src_node : ge_node->GetInControlNodes()) { | |||||
auto src_node_item = MutableNodeItem(src_node); | |||||
GE_CHECK_NOTNULL(src_node_item); | |||||
if (is_hccl_op || src_node_item->IsHcclOp()) { | |||||
GELOGD("[%s](%s) Add input control dependent node [%s](%s)", | |||||
ge_node->GetName().c_str(), | |||||
ge_node->GetType().c_str(), | |||||
src_node->GetName().c_str(), | |||||
src_node->GetType().c_str()); | |||||
dependent_for_execution.emplace(src_node); | |||||
} | |||||
} | |||||
// cond or branch need to be prepared before the execution of IF or CASE | // cond or branch need to be prepared before the execution of IF or CASE | ||||
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) { | ||||
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input | ||||
@@ -589,3 +589,35 @@ TEST_F(UtestGeHybrid, test_key_for_kernel_bin) { | |||||
EXPECT_EQ(atomic_task->GetKeyForTvmMetaData(), ATOMIC_ATTR_TVM_METADATA); | EXPECT_EQ(atomic_task->GetKeyForTvmMetaData(), ATOMIC_ATTR_TVM_METADATA); | ||||
EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname"); | EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname"); | ||||
} | } | ||||
TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) { | |||||
NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl", | |||||
NodeExecutorManager::ExecutorType::HCCL); | |||||
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test"); | |||||
OpDescPtr op_desc = CreateOpDesc("Add", "Add"); | |||||
auto node = compute_graph->AddNode(op_desc); | |||||
std::unique_ptr<NodeItem> node_item; | |||||
NodeItem::Create(node, node_item); | |||||
node_item->node_id = 0; | |||||
OpDescPtr op_desc_1 = CreateOpDesc("AllReduce", "AllReduce"); | |||||
op_desc_1->SetOpKernelLibName("ops_kernel_info_hccl"); | |||||
auto node_1 = compute_graph->AddNode(op_desc_1); | |||||
std::unique_ptr<NodeItem> node_item_1; | |||||
NodeItem::Create(node_1, node_item_1); | |||||
node_item_1->node_id = 1; | |||||
node->GetOutControlAnchor()->LinkTo(node_1->GetInControlAnchor()); | |||||
GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph); | |||||
HybridModel model(root_model); | |||||
model.root_graph_ = compute_graph; | |||||
model.node_items_.emplace(node, std::move(node_item)); | |||||
HybridModelBuilder builder(model); | |||||
std::vector<std::string> deps; | |||||
ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS); | |||||
ASSERT_TRUE(model.GetNodeItem(node)->has_observer); | |||||
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1); | |||||
} |