Browse Source

AtomicAddrCleanPass::LinkToAllSecondNodes

tags/v1.3.0
lianghao 3 years ago
parent
commit
5efe0b6932
4 changed files with 76 additions and 4 deletions
  1. +33
    -0
      ge/graph/passes/atomic_addr_clean_pass.cc
  2. +8
    -0
      ge/graph/passes/atomic_addr_clean_pass.h
  3. +2
    -2
      ge/graph/passes/subgraph_const_migration_pass.cc
  4. +33
    -2
      tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc

+ 33
- 0
ge/graph/passes/atomic_addr_clean_pass.cc View File

@@ -222,6 +222,39 @@ Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vect
} }
} }
} }
return LinkToPotentialPrecedenceNode(graph, clean_addr_node);
}

// Add control edges from atomic clean node to all potential precedence nodes which may execute before atomic clean
// node. We hope that atomic clean node can execute with the highest priority in the entire graph. Because of stream
// concurrency mechanism, only placing it at the head can not ensure that priority. Therefore, we need to add control
// edges from atomic clean node to the nodes that may be the first node on each stream. Generally, the first nodes on
// each stream are successors of Data/Variable, and Data/Variable won't generate task or execute, so we link to the
// successors of Data/Variable.
Status AtomicAddrCleanPass::LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node) {
GELOGD("Start to add control edges from %s to all second-nodes behind first-nodes which have no input.",
atomic_clean_node->GetName().c_str());
auto out_ctrl_anchor = atomic_clean_node->GetOutControlAnchor();
GE_CHECK_NOTNULL(out_ctrl_anchor);

for (const auto &node : graph->GetDirectNode()) {
GE_CHECK_NOTNULL(node);
bool need_handle = (node->GetType() == DATA || node->GetType() == VARIABLE) && node->GetInAllNodes().empty();
if (!need_handle) {
continue;
}
auto second_nodes = node->GetOutAllNodes();
for (const auto &second_node : second_nodes) {
GE_CHECK_NOTNULL(second_node);
auto in_ctrl_anchor = second_node->GetInControlAnchor();
GE_CHECK_NOTNULL(in_ctrl_anchor);
if (!out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor)) {
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(in_ctrl_anchor));
GELOGD("Add control edge from %s to %s.", atomic_clean_node->GetName().c_str(), second_node->GetName().c_str());
}
}
}

return SUCCESS; return SUCCESS;
} }




+ 8
- 0
ge/graph/passes/atomic_addr_clean_pass.h View File

@@ -67,6 +67,14 @@ class AtomicAddrCleanPass : public GraphPass {
*/ */
Status LinkToAtomicNode(const NodePtr &atomic_node, NodePtr &atomic_clean_node); Status LinkToAtomicNode(const NodePtr &atomic_node, NodePtr &atomic_clean_node);


/**
* Link atomic clean node to all potential precedence nodes which may execute before atomic clean node
* @param graph
* @param atomic_clean_node
* @return
*/
Status LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node);

/** /**
* Check if this node is atomic op. * Check if this node is atomic op.
* @param node * @param node


+ 2
- 2
ge/graph/passes/subgraph_const_migration_pass.cc View File

@@ -166,8 +166,8 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra
string node_full_name = peer_node->GetName(); string node_full_name = peer_node->GetName();
size_t pos = node_full_name.find(kMbatchNodeNameMark); size_t pos = node_full_name.find(kMbatchNodeNameMark);
if (pos == string::npos) { if (pos == string::npos) {
GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str());
return FAILED;
GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str());
continue;
} }


string fixed_name = node_full_name.substr(0, pos); string fixed_name = node_full_name.substr(0, pos);


+ 33
- 2
tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc View File

@@ -48,18 +48,49 @@ public:
return node; return node;
} }


int CountOfAtomicCleanNode() {
int node_num = 0;
for (NodePtr &node : graph_->GetDirectNode()) {
if (node->GetType() == ATOMICADDRCLEAN) {
++node_num;
}
}
return node_num;
}

ComputeGraphPtr graph_; ComputeGraphPtr graph_;
}; };


// node1 -> node2 -> node3
/*
* Data Data Atomic_clean
* | | / |
* relu relu |
* | ==> | |
* relu(atomic) relu(atomic)
* | |
* netoutput netoutput
*/
TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) {
auto node1 = NewNode("node1", DATA, 0, 1); auto node1 = NewNode("node1", DATA, 0, 1);

auto node2 = NewNode("node2", RELU, 1, 1); auto node2 = NewNode("node2", RELU, 1, 1);
auto node3 = NewNode("node3", NETOUTPUT, 1, 0);
auto node3 = NewNode("node3", RELU, 1, 1);
auto op_desc = node3->GetOpDesc();
vector<int64_t> atomic_input_index = {123, 456};
AttrUtils::SetListInt(op_desc, "atomic_input_index", atomic_input_index);

auto node4 = NewNode("node4", NETOUTPUT, 1, 0);
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
GraphUtils::AddEdge(node3->GetOutDataAnchor(0), node4->GetInDataAnchor(0));
AtomicAddrCleanPass atomi_addr_clean_pass; AtomicAddrCleanPass atomi_addr_clean_pass;
Status ret = atomi_addr_clean_pass.Run(graph_); Status ret = atomi_addr_clean_pass.Run(graph_);
EXPECT_EQ(ret, SUCCESS); EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(1, CountOfAtomicCleanNode());
auto atomic_clean = graph_->FindNode("atomic_addr_clean");
EXPECT_NE(atomic_clean, nullptr);
auto out_ctrl_nodes = atomic_clean->GetOutControlNodes();
EXPECT_EQ(out_ctrl_nodes.size(), 2);
} }
} // namespace ge } // namespace ge

Loading…
Cancel
Save