@@ -222,6 +222,39 @@ Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vect | |||
} | |||
} | |||
} | |||
return LinkToPotentialPrecedenceNode(graph, clean_addr_node); | |||
} | |||
// Add control edges from atomic clean node to all potential precedence nodes which may execute before atomic clean | |||
// node. We hope that atomic clean node can execute with the highest priority in the entire graph. Because of stream | |||
// concurrency mechanism, only placing it at the head can not ensure that priority. Therefore, we need to add control | |||
// edges from atomic clean node to the nodes that may be the first node on each stream. Generally, the first nodes on | |||
// each stream are successors of Data/Variable, and Data/Variable won't generate task or execute, so we link to the | |||
// successors of Data/Variable. | |||
Status AtomicAddrCleanPass::LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node) { | |||
GELOGD("Start to add control edges from %s to all second-nodes behind first-nodes which have no input.", | |||
atomic_clean_node->GetName().c_str()); | |||
auto out_ctrl_anchor = atomic_clean_node->GetOutControlAnchor(); | |||
GE_CHECK_NOTNULL(out_ctrl_anchor); | |||
for (const auto &node : graph->GetDirectNode()) { | |||
GE_CHECK_NOTNULL(node); | |||
bool need_handle = (node->GetType() == DATA || node->GetType() == VARIABLE) && node->GetInAllNodes().empty(); | |||
if (!need_handle) { | |||
continue; | |||
} | |||
auto second_nodes = node->GetOutAllNodes(); | |||
for (const auto &second_node : second_nodes) { | |||
GE_CHECK_NOTNULL(second_node); | |||
auto in_ctrl_anchor = second_node->GetInControlAnchor(); | |||
GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
if (!out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor)) { | |||
GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(in_ctrl_anchor)); | |||
GELOGD("Add control edge from %s to %s.", atomic_clean_node->GetName().c_str(), second_node->GetName().c_str()); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -67,6 +67,14 @@ class AtomicAddrCleanPass : public GraphPass { | |||
*/ | |||
Status LinkToAtomicNode(const NodePtr &atomic_node, NodePtr &atomic_clean_node); | |||
/** | |||
* Link atomic clean node to all potential precedence nodes which may execute before atomic clean node | |||
* @param graph | |||
* @param atomic_clean_node | |||
* @return | |||
*/ | |||
Status LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node); | |||
/** | |||
* Check if this node is atomic op. | |||
* @param node | |||
@@ -166,8 +166,8 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra | |||
string node_full_name = peer_node->GetName(); | |||
size_t pos = node_full_name.find(kMbatchNodeNameMark); | |||
if (pos == string::npos) { | |||
GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); | |||
return FAILED; | |||
GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); | |||
continue; | |||
} | |||
string fixed_name = node_full_name.substr(0, pos); | |||
@@ -48,18 +48,49 @@ public: | |||
return node; | |||
} | |||
int CountOfAtomicCleanNode() { | |||
int node_num = 0; | |||
for (NodePtr &node : graph_->GetDirectNode()) { | |||
if (node->GetType() == ATOMICADDRCLEAN) { | |||
++node_num; | |||
} | |||
} | |||
return node_num; | |||
} | |||
ComputeGraphPtr graph_; | |||
}; | |||
// node1 -> node2 -> node3 | |||
/* | |||
* Data Data Atomic_clean | |||
* | | / | | |||
* relu relu | | |||
* | ==> | | | |||
* relu(atomic) relu(atomic) | |||
* | | | |||
* netoutput netoutput | |||
*/ | |||
TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { | |||
auto node1 = NewNode("node1", DATA, 0, 1); | |||
auto node2 = NewNode("node2", RELU, 1, 1); | |||
auto node3 = NewNode("node3", NETOUTPUT, 1, 0); | |||
auto node3 = NewNode("node3", RELU, 1, 1); | |||
auto op_desc = node3->GetOpDesc(); | |||
vector<int64_t> atomic_input_index = {123, 456}; | |||
AttrUtils::SetListInt(op_desc, "atomic_input_index", atomic_input_index); | |||
auto node4 = NewNode("node4", NETOUTPUT, 1, 0); | |||
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); | |||
GraphUtils::AddEdge(node3->GetOutDataAnchor(0), node4->GetInDataAnchor(0)); | |||
AtomicAddrCleanPass atomi_addr_clean_pass; | |||
Status ret = atomi_addr_clean_pass.Run(graph_); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(1, CountOfAtomicCleanNode()); | |||
auto atomic_clean = graph_->FindNode("atomic_addr_clean"); | |||
EXPECT_NE(atomic_clean, nullptr); | |||
auto out_ctrl_nodes = atomic_clean->GetOutControlNodes(); | |||
EXPECT_EQ(out_ctrl_nodes.size(), 2); | |||
} | |||
} // namespace ge |