Browse Source

!1360 bugfix for atomic_addr_clean_pass

From: @yangyongqiang5033
Reviewed-by: @xchu42,@ji_chen
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
9635b0c9b9
3 changed files with 72 additions and 6 deletions
  1. +6
    -6
      ge/graph/passes/atomic_addr_clean_pass.cc
  2. +1
    -0
      tests/ut/ge/CMakeLists.txt
  3. +65
    -0
      tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc

+ 6
- 6
ge/graph/passes/atomic_addr_clean_pass.cc View File

@@ -126,11 +126,11 @@ bool AtomicAddrCleanPass::IsOutputIndexPeerInputAtomic(const NodePtr &node, int6

bool AtomicAddrCleanPass::CheckSkipInsertInLoopGraph(const NodePtr &node) {
OpDescPtr op_desc = node->GetOpDesc();
std::map<string, std::map<int, int>> node_workspace_offset;
std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);
bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX);
node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset);
if (!has_atomic_input && has_atomic_output && node_workspace_offset.empty()) {
atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size);
if (!has_atomic_input && has_atomic_output && atomic_workspace_index_size.empty()) {
std::vector<int64_t> atomic_output_index;
(void) ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_index);
bool is_all_output_peer_also_atomic = true;
@@ -332,11 +332,11 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) {
}

// 2.Check atomic attr in node
std::map<string, std::map<int, int>> node_workspace_offset;
std::map<string, std::map<int64_t, int64_t>> atomic_workspace_index_size;
bool has_atomic_input = op_desc->HasAttr(ATOMIC_ATTR_INPUT_INDEX);
bool has_atomic_output = op_desc->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX);
node_workspace_offset = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_OFFSET, node_workspace_offset);
if (!has_atomic_input && !has_atomic_output && node_workspace_offset.empty()) {
atomic_workspace_index_size = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_index_size);
if (!has_atomic_input && !has_atomic_output && atomic_workspace_index_size.empty()) {
return false;
}



+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -737,6 +737,7 @@ set(KERNEL_TEST_FILES
"graph/passes/folding_kernel/gather_v2_kernel_unittest.cc"
"graph/passes/folding_kernel/slice_kernel_unittest.cc"
"graph/passes/folding_kernel/dynamic_stitch_kernel_unittest.cc"
"graph/passes/atomic_addr_clean_pass_unittest.cc"
)

set(MULTI_PARTS_TEST_FILES


+ 65
- 0
tests/ut/ge/graph/passes/atomic_addr_clean_pass_unittest.cc View File

@@ -0,0 +1,65 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include "graph/passes/atomic_addr_clean_pass.h"
#include "common/op/ge_op_utils.h"
#include "common/types.h"
#include "graph/anchor.h"
#include "graph/attr_value.h"
#include "graph/compute_graph.h"
#include "graph/op_desc.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"
#include "graph/utils/tensor_utils.h"
#include "inc/pass_manager.h"
using namespace testing;

namespace ge {
class UtestGraphPassesAtomicAddrCleanPass : public Test {
public:
UtestGraphPassesAtomicAddrCleanPass() {
graph_ = std::make_shared<ComputeGraph>("test");
}

NodePtr NewNode(const string &name, const string &type, int input_cnt, int output_cnt) {
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
for (int i = 0; i < input_cnt; ++i) {
op_desc->AddInputDesc(GeTensorDesc());
}
for (int i = 0; i < output_cnt; ++i) {
op_desc->AddOutputDesc(GeTensorDesc());
}
NodePtr node = graph_->AddNode(op_desc);
return node;
}

ComputeGraphPtr graph_;
};

// node1 -> node2 -> node3
TEST_F(UtestGraphPassesAtomicAddrCleanPass, pass_run_success) {
auto node1 = NewNode("node1", DATA, 0, 1);
auto node2 = NewNode("node2", RELU, 1, 1);
auto node3 = NewNode("node3", NETOUTPUT, 1, 0);
GraphUtils::AddEdge(node1->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
GraphUtils::AddEdge(node2->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
AtomicAddrCleanPass atomi_addr_clean_pass;
Status ret = atomi_addr_clean_pass.Run(graph_);
EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save