Browse Source

!1488 Fix hccl control dependencies

From: @xchu42
Reviewed-by: @wqtshg,@wqtshg,@ji_chen
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
0e94677cfd
3 changed files with 46 additions and 1 deletions
  1. +1
    -1
      ge/hybrid/executor/hybrid_execution_context.h
  2. +13
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  3. +32
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 1
ge/hybrid/executor/hybrid_execution_context.h View File

@@ -68,7 +68,7 @@ struct GraphExecutionContext {
DumpProperties dump_properties;
bool trace_enabled = false;
bool dump_enabled = false;
std::atomic_bool is_eos_;
std::atomic_bool is_eos_{false};
long profiling_level = 0;
long iteration = 0;
void *global_step = nullptr;


+ 13
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -323,6 +323,19 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
}
}

for (const auto &src_node : ge_node->GetInControlNodes()) {
auto src_node_item = MutableNodeItem(src_node);
GE_CHECK_NOTNULL(src_node_item);
if (is_hccl_op || src_node_item->IsHcclOp()) {
GELOGD("[%s](%s) Add input control dependent node [%s](%s)",
ge_node->GetName().c_str(),
ge_node->GetType().c_str(),
src_node->GetName().c_str(),
src_node->GetType().c_str());
dependent_for_execution.emplace(src_node);
}
}

// cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
auto src_node = NodeUtils::GetInDataNodeByIndex(*ge_node, 0); // cond input


+ 32
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -589,3 +589,35 @@ TEST_F(UtestGeHybrid, test_key_for_kernel_bin) {
EXPECT_EQ(atomic_task->GetKeyForTvmMetaData(), ATOMIC_ATTR_TVM_METADATA);
EXPECT_EQ(atomic_task->GetKeyForKernelName(op_desc), "Sum_atomic_kernelname");
}

TEST_F(UtestGeHybrid, TestParseDependentInputNodesForHccl) {
NodeExecutorManager::GetInstance().engine_mapping_.emplace("ops_kernel_info_hccl",
NodeExecutorManager::ExecutorType::HCCL);
ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>("test");

OpDescPtr op_desc = CreateOpDesc("Add", "Add");
auto node = compute_graph->AddNode(op_desc);
std::unique_ptr<NodeItem> node_item;
NodeItem::Create(node, node_item);
node_item->node_id = 0;

OpDescPtr op_desc_1 = CreateOpDesc("AllReduce", "AllReduce");
op_desc_1->SetOpKernelLibName("ops_kernel_info_hccl");
auto node_1 = compute_graph->AddNode(op_desc_1);
std::unique_ptr<NodeItem> node_item_1;
NodeItem::Create(node_1, node_item_1);
node_item_1->node_id = 1;

node->GetOutControlAnchor()->LinkTo(node_1->GetInControlAnchor());

GeRootModelPtr root_model = MakeShared<ge::GeRootModel>(compute_graph);
HybridModel model(root_model);
model.root_graph_ = compute_graph;
model.node_items_.emplace(node, std::move(node_item));

HybridModelBuilder builder(model);
std::vector<std::string> deps;
ASSERT_EQ(builder.ParseDependentInputNodes(*node_item_1, deps), SUCCESS);
ASSERT_TRUE(model.GetNodeItem(node)->has_observer);
ASSERT_EQ(node_item_1->dependents_for_execution.size(), 1);
}

Loading…
Cancel
Save