Browse Source

!1665 Bugfix: Insert identity before allreduce

From: @hugo1
Reviewed-by: 
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
f04ec9c790
9 changed files with 445 additions and 229 deletions
  1. +76
    -79
      ge/graph/optimize/mem_rw_conflict_optimize.cc
  2. +24
    -67
      ge/graph/passes/hccl_continuous_memcpy_pass.cc
  3. +1
    -1
      ge/graph/passes/hccl_continuous_memcpy_pass.h
  4. +28
    -80
      ge/graph/passes/hccl_memcpy_pass.cc
  5. +1
    -1
      ge/graph/passes/hccl_memcpy_pass.h
  6. +6
    -1
      tests/ut/ge/CMakeLists.txt
  7. +150
    -0
      tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc
  8. +79
    -0
      tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc
  9. +80
    -0
      tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc

+ 76
- 79
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
@@ -22,6 +22,7 @@
#include "graph/optimize/graph_optimize.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/node_utils.h"
#include "graph/utils/op_desc_utils.h"

namespace {
using namespace ge;
@@ -32,12 +33,14 @@ const int kCaseReadOnly = 0;
const int kCaseScopeWriteable = 2;
const int kCaseWriteable = 3;
const int kCaseInvalidRWType = 5;
// attr _input_mutable = true means node will modify its input in runtime
const char *const kModifyInput = "_input_mutable";

// rw type of input.
enum class InputRWType {
kReadOnly, // Normal op input only read
kWriteable, // Op like Assign/ApplyMomentum
kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput
kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput
kInvalidRWType
};
// rw type of output
@@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) {
return true;
}

NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) {
NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) {
if (src_node.GetOpDesc() == nullptr) {
return nullptr;
}
@@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) {
auto next_num = identity_num.fetch_add(1);
// 1. create new identity op desc
string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num);
auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY);
if (identity_opdesc == nullptr) {
GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str());
return nullptr;
}
OpDescBuilder op_desc_builder(identity_name, IDENTITY);
auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx);
// 2. add input_desc & output_desc for new identity
Status ret = identity_opdesc->AddInputDesc("x", data_desc);
if (ret != SUCCESS) {
GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str());
return nullptr;
}
ret = identity_opdesc->AddOutputDesc("y", data_desc);
if (ret != SUCCESS) {
GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str());
return nullptr;
}
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc)
.AddOutput("y", data_desc)
.Build();

GELOGI("Insert new Identity node %s.", identity_name.c_str());
auto graph = src_node.GetOwnerComputeGraph();
if (graph == nullptr) {
GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str());
return nullptr;
}
return graph->AddNode(identity_opdesc);
return graph->AddNode(identity_op_desc);
}

OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) {
@@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) {
// single node without sub graph
return GetSingleNodeInputRWTypeByIndex(node, index);
} else {
// node with sub graph
std::set<int> node_rw_type_set;
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index);
// get all input data node in subgraph
std::set<int> anchor_rw_type_set;
@@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) {
auto parent_node = sub_graph->GetParentNode();
if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) {
// insert identity
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
if (ret != SUCCESS) {
GELOGE(ret, "Fail to insert identity");
return ret;
if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.",
identity_node->GetName().c_str(),
identity_node->GetType().c_str(),
pre_node->GetName().c_str(),
pre_node->GetType().c_str(),
node->GetName().c_str(),
node->GetType().c_str());
GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.",
identity_node->GetName().c_str(),
identity_node->GetType().c_str(),
pre_node->GetName().c_str(),
pre_node->GetType().c_str(),
node->GetName().c_str(),
node->GetType().c_str());
return FAILED;
}
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
@@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(peer_in_data_node);
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx());
auto ret = out_data_anchor->Unlink(peer_in_data_anchor);
auto old_identity = out_data_anchor->GetOwnerNode();
if (ret != SUCCESS) {
GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(),
peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
return ret;
}
if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) {
auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx());
auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx());
GE_CHECK_NOTNULL(new_identity);
if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS
|| GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s",
pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
return INTERNAL_ERROR;
}

// 2. copy in-control-edge from dst to Identity
if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(),
new_identity->GetName().c_str());
return INTERNAL_ERROR;
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex,
kIdentityAnchorIndex);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to insert Identity %s before %s %dth input.",
new_identity->GetName().c_str(),
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_data_anchor->GetIdx());
return ret;
}
GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(),
InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(),
peer_in_data_anchor->GetOwnerNode()->GetName().c_str());
} else {
(void) out_data_anchor->Unlink(peer_in_data_anchor);
// copy control edge to pre and peer node
if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS
|| GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) {
@@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) {
GELOGD("No need insert Identity.");
continue;
case INSERT_IDENTITY:
auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx());
if (identity_node == nullptr) {
GELOGE(FAILED, "Create identity node failed.");
return FAILED;
}
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node);
if (ret != GRAPH_SUCCESS) {
GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(),
peer_in_node->GetName().c_str());
return INTERNAL_ERROR;
auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex,
kIdentityAnchorIndex);
if (ret != SUCCESS) {
GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(),
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx());
return ret;
}
GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(),
peer_in_node->GetName().c_str());
@@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) {
return SUCCESS;
}
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) {
for (const auto &node : compute_graph->GetDirectNode()) {
if (node->GetType() == HCOMALLREDUCE) {
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(pre_out_anchor);
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) {
pre_out_anchor_set.emplace(pre_out_anchor);
continue;
}
// need insert identity
auto pre_node = pre_out_anchor->GetOwnerNode();
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node);
GE_CHK_STATUS_RET(ret, "Fail to insert identity.");
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
}
}
}
return SUCCESS;
for (const auto &node : compute_graph->GetDirectNode()) {
bool mutable_input_flag = false;
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag);
if (!mutable_input_flag) {
continue;
}
std::set<OutDataAnchorPtr> pre_out_anchor_set;
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) {
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor();
GE_CHECK_NOTNULL(pre_out_anchor);
if (pre_out_anchor_set.insert(pre_out_anchor).second) {
continue;
}
// need insert identity
auto pre_node = pre_out_anchor->GetOwnerNode();
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx());
GE_CHECK_NOTNULL(identity_node);
auto ret =
GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex);
if (ret != SUCCESS) {
GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(),
node->GetName().c_str(), in_data_anchor->GetIdx());
return ret;
}
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(),
pre_node->GetName().c_str(), node->GetName().c_str());
}
}
return SUCCESS;
}
} // namespace



+ 24
- 67
ge/graph/passes/hccl_continuous_memcpy_pass.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,11 +24,12 @@
#include "common/ge/ge_util.h"
#include "framework/common/types.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"

namespace {
const int kAnchorNum = 0;
const int32_t kAnchorAssignRefIndex = 0;
const int32_t kAnchorAssignValueIndex = 1;
const int32_t kAnchorIdentityIndex = 0;
} // namespace
namespace ge {
Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) {
@@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap

std::string node_name = pre_node->GetName() + "_" + IDENTITY;
node_name = CheckDuplicateName(node_name);
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY);
if (op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "New OpDesc failed");
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail.");
return nullptr;
}
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str());

graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail.");
return nullptr;
}

ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail.");
OpDescBuilder op_desc_builder(node_name, IDENTITY);
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx());
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build();
if (identity_op_desc == nullptr) {
return nullptr;
}
// because history reason ,this pass can not do work after constant fold so mark it
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false);
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false);

NodePtr memcpy_node = graph->AddNode(op_desc);
if (memcpy_node == nullptr) {
NodePtr identity_node = graph->AddNode(identity_op_desc);
if (identity_node == nullptr) {
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str());
GELOGE(INTERNAL_ERROR, "Insert Identity node fail.");
return nullptr;
}

return memcpy_node;
return identity_node;
}

///
@@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra
Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph,
const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor);
GE_CHECK_NOTNULL(memcpy_node);
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor);
GE_CHECK_NOTNULL(identity_node);

Status ret1 = src_out_anchor->Unlink(hccl_in_anchor);
if (ret1 != SUCCESS) {
REPORT_CALL_ERROR("E19999",
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetType().c_str(),
hccl_in_anchor->GetIdx());
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return FAILED;
}
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum);
GE_CHECK_NOTNULL(out_data_anchor_0);
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor);
if (ret1 != SUCCESS) {
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999",
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed",
out_data_anchor_0->GetOwnerNode()->GetName().c_str(),
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(),
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.",
identity_node->GetName().c_str(), identity_node->GetType().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetType().c_str(),
hccl_in_anchor->GetIdx());
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return FAILED;
}

Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum));
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999",
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(),
kAnchorNum);
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
memcpy_node->GetName().c_str());
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.",
identity_node->GetName().c_str(), identity_node->GetType().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetType().c_str(),
hccl_in_anchor->GetIdx());
return FAILED;
}
return SUCCESS;


+ 1
- 1
ge/graph/passes/hccl_continuous_memcpy_pass.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 28
- 80
ge/graph/passes/hccl_memcpy_pass.cc View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@@ -24,13 +24,15 @@
#include "common/ge/ge_util.h"
#include "framework/common/types.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/op_desc_utils.h"

namespace {
const int32_t kAnchorSize = 1;
const int kAnchorNum = 0;
const int32_t kAnchorAssignRefIndex = 0;
const int32_t kAnchorAssignValueIndex = 1;
const char *const kInputMutable = "_input_mutable";
const int32_t kAnchorIdentityIndex = 0;
// attr _input_mutable = true means hccl node will modify its input in runtime
const char *const kModifyInput = "_input_mutable";
} // namespace
namespace ge {
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) {
@@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) {
// need to inset memcpy node between.
// also works on situation that input is variable or const.
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) {
auto op_desc = node->GetOpDesc();

bool node_input_mutable = false;
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) {
return SUCCESS;
}

if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) {
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable,
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str());
return FAILED;
}
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable);
if (!node_input_mutable) {
return SUCCESS;
}

GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str());
GELOGI("input mutable hcom op is:%s.", node->GetName().c_str());
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) {
if (hccl_in_anchor == nullptr) {
continue;
@@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O

std::string node_name = pre_node->GetName() + "_" + IDENTITY;
node_name = CheckDuplicateName(node_name);
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY);
if (op_desc == nullptr) {
REPORT_CALL_ERROR("E19999", "New OpDesc failed");
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail.");
return nullptr;
}
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str());

graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail.");
return nullptr;
}

ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()));
if (ret != GRAPH_SUCCESS) {
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y",
op_desc->GetName().c_str(), op_desc->GetType().c_str());
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail.");
OpDescBuilder op_desc_builder(node_name, IDENTITY);
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx());
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build();
if (identity_op_desc == nullptr) {
return nullptr;
}
// because history reason ,this pass can not do work after constant fold so mark it
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false);
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false);

NodePtr memcpy_node = graph->AddNode(op_desc);
if (memcpy_node == nullptr) {
NodePtr identity_node = graph->AddNode(identity_op_desc);
if (identity_node == nullptr) {
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str());
GELOGE(INTERNAL_ERROR, "Insert Identity node fail.");
return nullptr;
}

return memcpy_node;
return identity_node;
}

///
@@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const
///
Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor,
const InDataAnchorPtr &hccl_in_anchor) {
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor);
GE_CHECK_NOTNULL(memcpy_node);
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor);
GE_CHECK_NOTNULL(identity_node);

Status ret1 = src_out_anchor->Unlink(hccl_in_anchor);
if (ret1 != SUCCESS) {
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex);
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999",
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx());
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return FAILED;
}
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum);
GE_CHECK_NOTNULL(out_data_anchor_0);
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor);
if (ret1 != SUCCESS) {
REPORT_CALL_ERROR("E19999",
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed",
out_data_anchor_0->GetOwnerNode()->GetName().c_str(),
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(),
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.",
identity_node->GetName().c_str(), identity_node->GetType().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetType().c_str(),
hccl_in_anchor->GetIdx());
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str());
return FAILED;
}

Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum));
if (ret != SUCCESS) {
REPORT_CALL_ERROR("E19999",
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed",
src_out_anchor->GetOwnerNode()->GetName().c_str(),
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(),
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(),
kAnchorNum);
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(),
memcpy_node->GetName().c_str());
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.",
identity_node->GetName().c_str(), identity_node->GetType().c_str(),
hccl_in_anchor->GetOwnerNode()->GetName().c_str(),
hccl_in_anchor->GetOwnerNode()->GetType().c_str(),
hccl_in_anchor->GetIdx());
return FAILED;
}
return SUCCESS;


+ 1
- 1
ge/graph/passes/hccl_memcpy_pass.h View File

@@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.


+ 6
- 1
tests/ut/ge/CMakeLists.txt View File

@@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES
set(GRAPH_OPTIMIZE_COMMON_SRC_FILES
"${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc"
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc"
"${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc"
)


@@ -715,7 +716,10 @@ set(PASS_TEST_FILES
"graph/passes/mark_node_unknown_shape_pass_unittest.cc"
"graph/passes/reshape_recovery_pass_unittest.cc"
"graph/passes/cast_remove_pass_unittest.cc"
"graph/passes/memcpy_addr_async_unittest.cc"
"graph/passes/memcpy_addr_async_unittest.cc"
"graph/passes/hccl_continuous_pass_unittest.cc"
"graph/passes/hccl_memcpy_pass_unittest.cc"
)

set(KERNEL_TEST_FILES
@@ -798,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES
"graph/manager/run_graph_unittest.cc"
"graph/partition/dynamic_shape_partition_unittest.cc"
"graph/manager/graph_manager_unittest.cc"
"graph/optimize/mem_rw_conflict_optimize_unittest.cc"
"session/omg_omg_unittest.cc"
"session/ge_api_unittest.cc"
"session/inner_session_unittest.cc"


+ 150
- 0
tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc View File

@@ -0,0 +1,150 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint>
#include <string>
#include <gtest/gtest.h>

#define protected public
#define private public
#include "graph/optimize/graph_optimize.h"
#undef protected
#undef private
#include "../passes/graph_builder_utils.h"
#include "graph/debug/ge_attr_define.h"

namespace ge {
class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
namespace {
/*
* Data -cast - netoutput
*/
ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){
auto sub_builder = ut::GraphBuilder(subraph_name);
auto data1 = sub_builder.AddNode("data1", DATA, 0,1);
auto cast = sub_builder.AddNode("cast", CAST, 1,1);
auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1);
AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1);
AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0);

sub_builder.AddDataEdge(data1,0,cast,0);
sub_builder.AddDataEdge(cast,0,netoutput,0);
return sub_builder.GetGraph();
}
/*
* const - allreduce
* \ if
* insert identity
*/
ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() {
auto builder = ut::GraphBuilder("test");
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1);
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1);
auto if_node = builder.AddNode("if", IF, 1,0);

builder.AddDataEdge(const1, 0, allreduce, 0);
builder.AddDataEdge(const1, 0, if_node, 0);
builder.AddControlEdge(ctrl_const, allreduce);

auto root_graph = builder.GetGraph();
string subgraph_name = "then_branch";
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name);
then_branch_graph->SetParentNode(if_node);
then_branch_graph->SetParentGraph(root_graph);
if_node->GetOpDesc()->AddSubgraphName(subgraph_name);
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name);
root_graph->AddSubgraph(subgraph_name, then_branch_graph);
return root_graph;
}
/* const1---allreduce const1--identity - allreduce
* / /
* var-identity--cast1 ==> var-----cast1
* \ \
* if if
*/
ComputeGraphPtr BuildGraph_Identiyt_Split(){
auto builder = ut::GraphBuilder("g1");
auto var = builder.AddNode("var", VARIABLE, 0, 1);
auto identity = builder.AddNode("identity", IDENTITY, 1, 1);
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1);
auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
auto if_node = builder.AddNode("if", IF, 1,0);

builder.AddDataEdge(var, 0 , identity, 0);
builder.AddDataEdge(identity, 0 , allreduce, 0);
builder.AddDataEdge(identity, 0 , cast1, 0);
builder.AddDataEdge(identity, 0 , if_node, 0);
builder.AddControlEdge(const1, allreduce);

auto root_graph = builder.GetGraph();
string subgraph_name = "then_branch";
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name);
then_branch_graph->SetParentNode(if_node);
then_branch_graph->SetParentGraph(root_graph);
if_node->GetOpDesc()->AddSubgraphName(subgraph_name);
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name);
root_graph->AddSubgraph(subgraph_name, then_branch_graph);
return root_graph;
}
/*
* mul == allreduce
* need insert identity
*/
ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() {
auto builder = ut::GraphBuilder("test");
auto mul = builder.AddNode("mul", MUL, 2,1);
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0);
AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true);
builder.AddDataEdge(mul,0,allreduce,0);
builder.AddDataEdge(mul,0,allreduce,1);
return builder.GetGraph();
}
} // namespace
// const -> allreduce
// const -> Identity -> allreduce
TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) {
ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite();
GraphOptimize graph_optimizer;
auto ret = graph_optimizer.HandleMemoryRWConflict(graph);
EXPECT_EQ(ret, SUCCESS);
auto allreduce = graph->FindNode("allreduce");
EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY);
}
TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) {
ComputeGraphPtr graph = BuildGraph_Identiyt_Split();
GraphOptimize graph_optimizer;
auto ret = graph_optimizer.HandleMemoryRWConflict(graph);
EXPECT_EQ(ret, SUCCESS);
auto allreduce = graph->FindNode("allreduce");
auto allreduce_in_node = allreduce->GetInDataNodes().at(0);
EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY);
EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT);
}
TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) {
ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite();
EXPECT_EQ(graph->GetDirectNodesSize(), 2);
GraphOptimize graph_optimizer;
auto ret = graph_optimizer.HandleMemoryRWConflict(graph);
EXPECT_EQ(ret, SUCCESS);
EXPECT_EQ(graph->GetDirectNodesSize(), 3);
}
} // namespace ge

+ 79
- 0
tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc View File

@@ -0,0 +1,79 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint>
#include <string>
#include <gtest/gtest.h>

#include "common/ge_inner_error_codes.h"
#define protected public
#define private public
#include "graph/passes/hccl_continuous_memcpy_pass.h"
#undef protected
#undef private
#include "graph_builder_utils.h"

namespace ge {
class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
namespace {

/*
* var var
* | \ | \
* | assign | assign
* | // =======> | //
* allreduce identity
* | |
* netoutput allreduce
* |
* netoutput
*/
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){
auto builder = ut::GraphBuilder("test");
auto var = builder.AddNode("var", VARIABLE, 0, 1);
auto assign = builder.AddNode("assign", ASSIGN, 1, 1);
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1);
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);

builder.AddDataEdge(var, 0, assign, 0);
builder.AddDataEdge(var,0,allreduce,0);
builder.AddControlEdge(assign, allreduce);
return builder.GetGraph();
}
} // namespace

// const -> allreduce
// const -> Identity -> allreduce
TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) {
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign();
auto src_node = graph->FindNode("var");
auto dst_node = graph->FindNode("allreduce");
// test InsertIdentityBeforeHccl
HcclContinuousMemcpyPass hccl_continuous_memcpy_pass;
hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0));

// check
dst_node = graph->FindNode("allreduce");
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode();
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY);
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1);
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign");
}
} // namespace ge

+ 80
- 0
tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc View File

@@ -0,0 +1,80 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint>
#include <string>
#include <gtest/gtest.h>

#include "common/ge_inner_error_codes.h"
#define protected public
#define private public
#include "graph/passes/hccl_memcpy_pass.h"
#undef protected
#undef private
#include "graph_builder_utils.h"

namespace ge {
class UtestGraphPassesHcclMemcpyPass : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
};
namespace {

/*
* var var
* | \ | \
* | assign | assign
* | // =======> | //
* allreduce identity
* | |
* netoutput allreduce
* |
* netoutput
*/
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){
auto builder = ut::GraphBuilder("test");
auto var = builder.AddNode("var", VARIABLE, 0, 1);
auto assign = builder.AddNode("assign", ASSIGN, 1, 1);
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1);
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);

builder.AddDataEdge(var, 0, assign, 0);
builder.AddDataEdge(var,0,allreduce,0);
builder.AddControlEdge(assign, allreduce);
return builder.GetGraph();
}
} // namespace

// const -> allreduce
// const -> Identity -> allreduce
TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) {
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign();
auto src_node = graph->FindNode("var");
auto dst_node = graph->FindNode("allreduce");
// test InsertIdentityBeforeHccl
HcclMemcpyPass hccl_memcpy_pass;
hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0),
dst_node->GetInDataAnchor(0));

// check
dst_node = graph->FindNode("allreduce");
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode();
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY);
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1);
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign");
}
} // namespace ge

Loading…
Cancel
Save