| @@ -222,6 +222,39 @@ Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vect | |||
| } | |||
| } | |||
| } | |||
| return LinkToPotentialPrecedenceNode(graph, clean_addr_node); | |||
| } | |||
| // Add control edges from atomic clean node to all potential precedence nodes which may execute before atomic clean | |||
| // node. We hope that atomic clean node can execute with the highest priority in the entire graph. Because of stream | |||
| // concurrency mechanism, only placing it at the head can not ensure that priority. Therefore, we need to add control | |||
| // edges from atomic clean node to the nodes that may be the first node on each stream. Generally, the first nodes on | |||
| // each stream are successors of Data/Variable, and Data/Variable won't generate task or execute, so we link to the | |||
| // successors of Data/Variable. | |||
| Status AtomicAddrCleanPass::LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node) { | |||
| GELOGD("Start to add control edges from %s to all second-nodes behind first-nodes which have no input.", | |||
| atomic_clean_node->GetName().c_str()); | |||
| auto out_ctrl_anchor = atomic_clean_node->GetOutControlAnchor(); | |||
| GE_CHECK_NOTNULL(out_ctrl_anchor); | |||
| for (const auto &node : graph->GetDirectNode()) { | |||
| GE_CHECK_NOTNULL(node); | |||
| bool need_handle = (node->GetType() == DATA || node->GetType() == VARIABLE) && node->GetInAllNodes().empty(); | |||
| if (!need_handle) { | |||
| continue; | |||
| } | |||
| auto second_nodes = node->GetOutAllNodes(); | |||
| for (const auto &second_node : second_nodes) { | |||
| GE_CHECK_NOTNULL(second_node); | |||
| auto in_ctrl_anchor = second_node->GetInControlAnchor(); | |||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
| if (!out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor)) { | |||
| GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(in_ctrl_anchor)); | |||
| GELOGD("Add control edge from %s to %s.", atomic_clean_node->GetName().c_str(), second_node->GetName().c_str()); | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -67,6 +67,14 @@ class AtomicAddrCleanPass : public GraphPass { | |||
| */ | |||
| Status LinkToAtomicNode(const NodePtr &atomic_node, NodePtr &atomic_clean_node); | |||
| /** | |||
| * Link atomic clean node to all potential precedence nodes which may execute before atomic clean node | |||
| * @param graph | |||
| * @param atomic_clean_node | |||
| * @return | |||
| */ | |||
| Status LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node); | |||
| /** | |||
| * Check if this node is atomic op. | |||
| * @param node | |||
| @@ -166,8 +166,8 @@ Status SubgraphConstMigrationPass::ClassifyGraphNodes(const ComputeGraphPtr &gra | |||
| string node_full_name = peer_node->GetName(); | |||
| size_t pos = node_full_name.find(kMbatchNodeNameMark); | |||
| if (pos == string::npos) { | |||
| GELOGE(FAILED, "find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); | |||
| return FAILED; | |||
| GELOGI("Can not find: %s of multi-batch in node: %s", kMbatchNodeNameMark.c_str(), node_full_name.c_str()); | |||
| continue; | |||
| } | |||
| string fixed_name = node_full_name.substr(0, pos); | |||
| @@ -48,18 +48,49 @@ public: | |||
| return node; | |||
| } | |||
| int CountOfAtomicCleanNode() { | |||
| int node_num = 0; | |||
| for (NodePtr &node : graph_->GetDirectNode()) { | |||
| if (node->GetType() == ATOMICADDRCLEAN) { | |||
| ++node_num; | |||
| } | |||
| } | |||
| return node_num; | |||
| } | |||
| ComputeGraphPtr graph_; | |||
| }; | |||
| // node1 -> node2 -> node3 | |||
| /* | |||
| * Data Data Atomic_clean | |||
| * | | / | | |||
| * relu relu | | |||
| * | ==> | | | |||
| * relu(atomic) relu(atomic) | |||
| * | | | |||
| * netoutput netoutput | |||
| */ | |||
| TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { | |||
| auto node1 = NewNode("node1", DATA, 0, 1); | |||
| auto node2 = NewNode("node2", RELU, 1, 1); | |||
| auto node3 = NewNode("node3", NETOUTPUT, 1, 0); | |||
| auto node3 = NewNode("node3", RELU, 1, 1); | |||
| auto op_desc = node3->GetOpDesc(); | |||
| vector<int64_t> atomic_input_index = {123, 456}; | |||
| AttrUtils::SetListInt(op_desc, "atomic_input_index", atomic_input_index); | |||
| auto node4 = NewNode("node4", NETOUTPUT, 1, 0); | |||
| GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); | |||
| GraphUtils::AddEdge(node3->GetOutDataAnchor(0), node4->GetInDataAnchor(0)); | |||
| AtomicAddrCleanPass atomi_addr_clean_pass; | |||
| Status ret = atomi_addr_clean_pass.Run(graph_); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| EXPECT_EQ(1, CountOfAtomicCleanNode()); | |||
| auto atomic_clean = graph_->FindNode("atomic_addr_clean"); | |||
| EXPECT_NE(atomic_clean, nullptr); | |||
| auto out_ctrl_nodes = atomic_clean->GetOutControlNodes(); | |||
| EXPECT_EQ(out_ctrl_nodes.size(), 2); | |||
| } | |||
| } // namespace ge | |||