From: @yangyongqiang5033 Reviewed-by: @xchu42,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
| @@ -126,11 +126,11 @@ bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int6 | |||||
| bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { | bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { | ||||
| OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
| std::map<string, std::map<int, int>> node_workspace_offset; | |||||
| std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size; | |||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | ||||
| bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | ||||
| node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); | |||||
| if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { | |||||
| atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); | |||||
| if (!has_atomic_input && has_atomic_output && atomic_workspace_index_size.empty()) { | |||||
| std::vector<int64_t> atomic_output_index; | std::vector<int64_t> atomic_output_index; | ||||
| (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); | (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); | ||||
| bool is_all_output_peer_also_atomic = true; | bool is_all_output_peer_also_atomic = true; | ||||
| @@ -332,11 +332,11 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { | |||||
| } | } | ||||
| // 2.Check atomic attr in node | // 2.Check atomic attr in node | ||||
| std::map<string, std::map<int, int>> node_workspace_offset; | |||||
| std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size; | |||||
| bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); | ||||
| bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); | ||||
| node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); | |||||
| if (!has_atomic_input && !has_atomic_output && node_workspace_offset.empty()) { | |||||
| atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); | |||||
| if (!has_atomic_input && !has_atomic_output && atomic_workspace_index_size.empty()) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -737,6 +737,7 @@ set(KERNEL_TEST_FILES | |||||
| "graph/passes/folding_kernel/gather_v2_kernel_unittest.cc" | "graph/passes/folding_kernel/gather_v2_kernel_unittest.cc" | ||||
| "graph/passes/folding_kernel/slice_kernel_unittest.cc" | "graph/passes/folding_kernel/slice_kernel_unittest.cc" | ||||
| "graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc" | "graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc" | ||||
| "graph/passes/atomic_addr_clean_pass_unittest.cc" | |||||
| ) | ) | ||||
| set(MULTI_PARTS_TEST_FILES | set(MULTI_PARTS_TEST_FILES | ||||
| @@ -0,0 +1,65 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "graph/passes/atomic_addr_clean_pass.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "common/types.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "inc/pass_manager.h" | |||||
| using namespace testing; | |||||
| namespace ge { | |||||
| class UtestGraphPassesAtomicAddrCleanPass : public Test { | |||||
| public: | |||||
| UtestGraphPassesAtomicAddrCleanPass() { | |||||
| graph_ = std::make_shared<ComputeGraph>("test"); | |||||
| } | |||||
| NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) { | |||||
| OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < input_cnt; ++i) { | |||||
| op_desc->AddInputDesc(GeTensorDesc()); | |||||
| } | |||||
| for (int i = 0; i < output_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(GeTensorDesc()); | |||||
| } | |||||
| NodePtr node = graph_->AddNode(op_desc); | |||||
| return node; | |||||
| } | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| // node1 -> node2 -> node3 | |||||
| TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { | |||||
| auto node1 = NewNode("node1", DATA, 0, 1); | |||||
| auto node2 = NewNode("node2", RELU, 1, 1); | |||||
| auto node3 = NewNode("node3", NETOUTPUT, 1, 0); | |||||
| GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); | |||||
| GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); | |||||
| AtomicAddrCleanPass atomi_addr_clean_pass; | |||||
| Status ret = atomi_addr_clean_pass.Run(graph_); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| } | |||||
| } // namespace ge | |||||