diff --git a/ge/graph/passes/atomic_addr_clean_pass.cc b/ge/graph/passes/atomic_addr_clean_pass.cc index 7c6ed8ce..16d3c129 100755 --- a/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/ge/graph/passes/atomic_addr_clean_pass.cc @@ -126,11 +126,11 @@ bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int6 bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) { OpDescPtr op_desc = node->GetOpDesc(); - std::map> node_workspace_offset; + std::map> atomic_workspace_index_size; bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); - node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); - if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) { + atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); + if (!has_atomic_input && has_atomic_output && atomic_workspace_index_size.empty()) { std::vector atomic_output_index; (void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index); bool is_all_output_peer_also_atomic = true; @@ -332,11 +332,11 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { } // 2.Check atomic attr in node - std::map> node_workspace_offset; + std::map> atomic_workspace_index_size; bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX); bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX); - node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset); - if (!has_atomic_input && !has_atomic_output && node_workspace_offset.empty()) { + atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size); + if (!has_atomic_input && !has_atomic_output && atomic_workspace_index_size.empty()) { return false; } diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index eb1c1340..141d75a7 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -737,6 +737,7 @@ set(KERNEL_TEST_FILES "graph/passes/folding_kernel/gather_v2_kernel_unittest.cc" "graph/passes/folding_kernel/slice_kernel_unittest.cc" "graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc" + "graph/passes/atomic_addr_clean_pass_unittest.cc" ) set(MULTI_PARTS_TEST_FILES diff --git a/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc b/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc new file mode 100644 index 00000000..59636511 --- /dev/null +++ b/tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "graph/passes/atomic_addr_clean_pass.h" +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "graph/anchor.h" +#include "graph/attr_value.h" +#include "graph/compute_graph.h" +#include "graph/op_desc.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" +#include "inc/pass_manager.h" +using namespace testing; + +namespace ge { +class UtestGraphPassesAtomicAddrCleanPass : public Test { +public: + UtestGraphPassesAtomicAddrCleanPass() { + graph_ = std::make_shared("test"); + } + + NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) { + OpDescPtr op_desc = std::make_shared(name, type); + for (int i = 0; i < input_cnt; ++i) { + op_desc->AddInputDesc(GeTensorDesc()); + } + for (int i = 0; i < output_cnt; ++i) { + op_desc->AddOutputDesc(GeTensorDesc()); + } + NodePtr node = graph_->AddNode(op_desc); + return node; + } + + ComputeGraphPtr graph_; +}; + +// node1 -> node2 -> node3 +TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) { + auto node1 = NewNode("node1", DATA, 0, 1); + auto node2 = NewNode("node2", RELU, 1, 1); + auto node3 = NewNode("node3", NETOUTPUT, 1, 0); + GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0)); + GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0)); + AtomicAddrCleanPass atomi_addr_clean_pass; + Status ret = atomi_addr_clean_pass.Run(graph_); + EXPECT_EQ(ret, SUCCESS); +} +} // namespace ge