| @@ -222,6 +222,39 @@ Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vect | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return LinkToPotentialPrecedenceNode(graph, clean_addr_node); | |||||
| } | |||||
| // Add control edges from atomic clean node to all potential precedence nodes which may execute before atomic clean | |||||
| // node. We hope that atomic clean node can execute with the highest priority in the entire graph. Because of stream | |||||
| // concurrency mechanism, only placing it at the head can not ensure that priority. Therefore, we need to add control | |||||
| // edges from atomic clean node to the nodes that may be the first node on each stream. Generally, the first nodes on | |||||
| // each stream are successors of Data/Variable, and Data/Variable won't generate task or execute, so we link to the | |||||
| // successors of Data/Variable. | |||||
| Status AtomicAddrCleanPass::LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node) { | |||||
| GELOGD("Start to add control edges from %s to all second-nodes behind first-nodes which have no input.", | |||||
| atomic_clean_node->GetName().c_str()); | |||||
| auto out_ctrl_anchor = atomic_clean_node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(out_ctrl_anchor); | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| bool need_handle = (node->GetType() == DATA || node->GetType() == VARIABLE) && node->GetInAllNodes().empty(); | |||||
| if (!need_handle) { | |||||
| continue; | |||||
| } | |||||
| auto second_nodes = node->GetOutAllNodes(); | |||||
| for (const auto &second_node : second_nodes) { | |||||
| GE_CHECK_NOTNULL(second_node); | |||||
| auto in_ctrl_anchor = second_node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||||
| if (!out_ctrl_anchor->IsLinkedWith(in_ctrl_anchor)) { | |||||
| GE_CHK_STATUS_RET(out_ctrl_anchor->LinkTo(in_ctrl_anchor)); | |||||
| GELOGD("Add control edge from %s to %s.", atomic_clean_node->GetName().c_str(), second_node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -67,6 +67,14 @@ class AtomicAddrCleanPass : public GraphPass { | |||||
| */ | */ | ||||
| Status LinkToAtomicNode(const NodePtr &atomic_node, NodePtr &atomic_clean_node); | Status LinkToAtomicNode(const NodePtr &atomic_node, NodePtr &atomic_clean_node); | ||||
| /** | |||||
| * Link atomic clean node to all potential precedence nodes which may execute before atomic clean node | |||||
| * @param graph | |||||
| * @param atomic_clean_node | |||||
| * @return | |||||
| */ | |||||
| Status LinkToPotentialPrecedenceNode(ComputeGraphPtr &graph, NodePtr &atomic_clean_node); | |||||
| /** | /** | ||||
| * Check if this node is atomic op. | * Check if this node is atomic op. | ||||
| * @param node | * @param node | ||||
| @@ -606,6 +606,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/variable_prepare_pass_unittest.cc" | "graph/passes/variable_prepare_pass_unittest.cc" | ||||
| "graph/passes/variable_ref_delete_pass_unittest.cc" | "graph/passes/variable_ref_delete_pass_unittest.cc" | ||||
| "graph/passes/dimension_adjust_pass_unittest.cc" | "graph/passes/dimension_adjust_pass_unittest.cc" | ||||
| "graph/passes/atomic_addr_clean_pass_unittest.cc" | |||||
| "graph/passes/pass_utils_unittest.cc" | "graph/passes/pass_utils_unittest.cc" | ||||
| "graph/passes/net_output_pass_unittest.cc" | "graph/passes/net_output_pass_unittest.cc" | ||||
| "graph/passes/no_use_reshape_remove_pass_unittest.cc" | "graph/passes/no_use_reshape_remove_pass_unittest.cc" | ||||
| @@ -0,0 +1,96 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "graph/passes/atomic_addr_clean_pass.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "common/types.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "inc/pass_manager.h" | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| class UtestGraphPassesAtomicAddrCleanPass : public Test { | |||||
| public: | |||||
| UtestGraphPassesAtomicAddrCleanPass() { | |||||
| graph_ = std::make_shared<ComputeGraph>("test"); | |||||
| } | |||||
| NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < input_cnt; ++i) { | |||||
| op_desc->AddInputDesc(GeTensorDesc()); | |||||
| } | |||||
| for (int i = 0; i < output_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(GeTensorDesc()); | |||||
| } | |||||
| NodePtr node = graph_->AddNode(op_desc); | |||||
| return node; | |||||
| } | |||||
| int CountOfAtomicCleanNode() { | |||||
| int node_num = 0; | |||||
| for (NodePtr &node : graph_->GetDirectNode()) { | |||||
| if (node->GetType() == ATOMICADDRCLEAN) { | |||||
| ++node_num; | |||||
| } | |||||
| } | |||||
| return node_num; | |||||
| } | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| /* | |||||
| * Data Data Atomic_clean | |||||
| * | | / | | |||||
| * relu relu | | |||||
| * | ==> | | | |||||
| * relu(atomic) relu(atomic) | |||||
| * | | | |||||
| * netoutput netoutput | |||||
| */ | |||||
| TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { | |||||
| auto node1 = NewNode("node1", DATA, 0, 1); | |||||
| auto node2 = NewNode("node2", RELU, 1, 1); | |||||
| auto node3 = NewNode("node3", RELU, 1, 1); | |||||
| auto op_desc = node3->GetOpDesc(); | |||||
| vector<int64_t> atomic_input_index = {123, 456}; | |||||
| AttrUtils::SetListInt(op_desc, "atomic_input_index", atomic_input_index); | |||||
| auto node4 = NewNode("node4", NETOUTPUT, 1, 0); | |||||
| GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(node3->GetOutDataAnchor(0), node4->GetInDataAnchor(0)); | |||||
| AtomicAddrCleanPass atomi_addr_clean_pass; | |||||
| Status ret = atomi_addr_clean_pass.Run(graph_); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| EXPECT_EQ(1, CountOfAtomicCleanNode()); | |||||
| auto atomic_clean = graph_->FindNode("atomic_addr_clean"); | |||||
| EXPECT_NE(atomic_clean, nullptr); | |||||
| auto out_ctrl_nodes = atomic_clean->GetOutControlNodes(); | |||||
| EXPECT_EQ(out_ctrl_nodes.size(), 2); | |||||
| } | |||||
| } // namespace ge | |||||