From: @hugo1 Reviewed-by: Signed-off-by:tags/v1.3.0
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
@@ -22,6 +22,7 @@ | |||
#include "graph/optimize/graph_optimize.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/node_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
namespace { | |||
using namespace ge; | |||
@@ -32,12 +33,14 @@ const int kCaseReadOnly = 0; | |||
const int kCaseScopeWriteable = 2; | |||
const int kCaseWriteable = 3; | |||
const int kCaseInvalidRWType = 5; | |||
// attr _input_mutable = true means node will modify its input in runtime | |||
const char *const kModifyInput = "_input_mutable"; | |||
// rw type of input. | |||
enum class InputRWType { | |||
kReadOnly, // Normal op input only read | |||
kWriteable, // Op like Assign/ApplyMomentum | |||
kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput | |||
kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput | |||
kInvalidRWType | |||
}; | |||
// rw type of output | |||
@@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) { | |||
return true; | |||
} | |||
NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||
NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) { | |||
if (src_node.GetOpDesc() == nullptr) { | |||
return nullptr; | |||
} | |||
@@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||
auto next_num = identity_num.fetch_add(1); | |||
// 1. create new identity op desc | |||
string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | |||
auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY); | |||
if (identity_opdesc == nullptr) { | |||
GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); | |||
return nullptr; | |||
} | |||
OpDescBuilder op_desc_builder(identity_name, IDENTITY); | |||
auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | |||
// 2. add input_desc & output_desc for new identity | |||
Status ret = identity_opdesc->AddInputDesc("x", data_desc); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); | |||
return nullptr; | |||
} | |||
ret = identity_opdesc->AddOutputDesc("y", data_desc); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); | |||
return nullptr; | |||
} | |||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc) | |||
.AddOutput("y", data_desc) | |||
.Build(); | |||
GELOGI("Insert new Identity node %s.", identity_name.c_str()); | |||
auto graph = src_node.GetOwnerComputeGraph(); | |||
if (graph == nullptr) { | |||
GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | |||
return nullptr; | |||
} | |||
return graph->AddNode(identity_opdesc); | |||
return graph->AddNode(identity_op_desc); | |||
} | |||
OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | |||
@@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||
// single node without sub graph | |||
return GetSingleNodeInputRWTypeByIndex(node, index); | |||
} else { | |||
// node with sub graph | |||
std::set<int> node_rw_type_set; | |||
auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | |||
// get all input data node in subgraph | |||
std::set<int> anchor_rw_type_set; | |||
@@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { | |||
auto parent_node = sub_graph->GetParentNode(); | |||
if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | |||
// insert identity | |||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Fail to insert identity"); | |||
return ret; | |||
if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||
identity_node->GetName().c_str(), | |||
identity_node->GetType().c_str(), | |||
pre_node->GetName().c_str(), | |||
pre_node->GetType().c_str(), | |||
node->GetName().c_str(), | |||
node->GetType().c_str()); | |||
GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||
identity_node->GetName().c_str(), | |||
identity_node->GetType().c_str(), | |||
pre_node->GetName().c_str(), | |||
pre_node->GetType().c_str(), | |||
node->GetName().c_str(), | |||
node->GetType().c_str()); | |||
return FAILED; | |||
} | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
@@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I | |||
auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||
GE_CHECK_NOTNULL(peer_in_data_node); | |||
auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | |||
auto ret = out_data_anchor->Unlink(peer_in_data_anchor); | |||
auto old_identity = out_data_anchor->GetOwnerNode(); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
return ret; | |||
} | |||
if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | |||
auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); | |||
auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(new_identity); | |||
if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS | |||
|| GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", | |||
pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
} | |||
// 2. copy in-control-edge from dst to Identity | |||
if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), | |||
new_identity->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex, | |||
kIdentityAnchorIndex); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Failed to insert Identity %s before %s %dth input.", | |||
new_identity->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_data_anchor->GetIdx()); | |||
return ret; | |||
} | |||
GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | |||
InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
} else { | |||
(void) out_data_anchor->Unlink(peer_in_data_anchor); | |||
// copy control edge to pre and peer node | |||
if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | |||
|| GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | |||
@@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
GELOGD("No need insert Identity."); | |||
continue; | |||
case INSERT_IDENTITY: | |||
auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); | |||
if (identity_node == nullptr) { | |||
GELOGE(FAILED, "Create identity node failed."); | |||
return FAILED; | |||
} | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); | |||
if (ret != GRAPH_SUCCESS) { | |||
GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), | |||
peer_in_node->GetName().c_str()); | |||
return INTERNAL_ERROR; | |||
auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex, | |||
kIdentityAnchorIndex); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(), | |||
peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx()); | |||
return ret; | |||
} | |||
GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | |||
peer_in_node->GetName().c_str()); | |||
@@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
return SUCCESS; | |||
} | |||
Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | |||
for (const auto &node : compute_graph->GetDirectNode()) { | |||
if (node->GetType() == HCOMALLREDUCE) { | |||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(pre_out_anchor); | |||
if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||
pre_out_anchor_set.emplace(pre_out_anchor); | |||
continue; | |||
} | |||
// need insert identity | |||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
} | |||
} | |||
} | |||
return SUCCESS; | |||
for (const auto &node : compute_graph->GetDirectNode()) { | |||
bool mutable_input_flag = false; | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||
if (!mutable_input_flag) { | |||
continue; | |||
} | |||
std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
GE_CHECK_NOTNULL(pre_out_anchor); | |||
if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||
continue; | |||
} | |||
// need insert identity | |||
auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||
GE_CHECK_NOTNULL(identity_node); | |||
auto ret = | |||
GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex); | |||
if (ret != SUCCESS) { | |||
GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(), | |||
node->GetName().c_str(), in_data_anchor->GetIdx()); | |||
return ret; | |||
} | |||
GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
pre_node->GetName().c_str(), node->GetName().c_str()); | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
} // namespace | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -24,11 +24,12 @@ | |||
#include "common/ge/ge_util.h" | |||
#include "framework/common/types.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
namespace { | |||
const int kAnchorNum = 0; | |||
const int32_t kAnchorAssignRefIndex = 0; | |||
const int32_t kAnchorAssignValueIndex = 1; | |||
const int32_t kAnchorIdentityIndex = 0; | |||
} // namespace | |||
namespace ge { | |||
Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
@@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap | |||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | |||
node_name = CheckDuplicateName(node_name); | |||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||
if (op_desc == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||
return nullptr; | |||
} | |||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||
return nullptr; | |||
} | |||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||
if (identity_op_desc == nullptr) { | |||
return nullptr; | |||
} | |||
// because history reason ,this pass can not do work after constant fold so mark it | |||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||
if (memcpy_node == nullptr) { | |||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||
if (identity_node == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | |||
return nullptr; | |||
} | |||
return memcpy_node; | |||
return identity_node; | |||
} | |||
/// | |||
@@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra | |||
Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | |||
const OutDataAnchorPtr &src_out_anchor, | |||
const InDataAnchorPtr &hccl_in_anchor) { | |||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(memcpy_node); | |||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(identity_node); | |||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||
kAnchorNum); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
memcpy_node->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -24,13 +24,15 @@ | |||
#include "common/ge/ge_util.h" | |||
#include "framework/common/types.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/utils/op_desc_utils.h" | |||
namespace { | |||
const int32_t kAnchorSize = 1; | |||
const int kAnchorNum = 0; | |||
const int32_t kAnchorAssignRefIndex = 0; | |||
const int32_t kAnchorAssignValueIndex = 1; | |||
const char *const kInputMutable = "_input_mutable"; | |||
const int32_t kAnchorIdentityIndex = 0; | |||
// attr _input_mutable = true means hccl node will modify its input in runtime | |||
const char *const kModifyInput = "_input_mutable"; | |||
} // namespace | |||
namespace ge { | |||
Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
@@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
// need to inset memcpy node between. | |||
// also works on situation that input is variable or const. | |||
Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | |||
auto op_desc = node->GetOpDesc(); | |||
bool node_input_mutable = false; | |||
if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||
return SUCCESS; | |||
} | |||
if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||
REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||
return FAILED; | |||
} | |||
(void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||
if (!node_input_mutable) { | |||
return SUCCESS; | |||
} | |||
GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||
GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||
for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | |||
if (hccl_in_anchor == nullptr) { | |||
continue; | |||
@@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O | |||
std::string node_name = pre_node->GetName() + "_" + IDENTITY; | |||
node_name = CheckDuplicateName(node_name); | |||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||
if (op_desc == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||
return nullptr; | |||
} | |||
GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||
graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||
return nullptr; | |||
} | |||
ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
if (ret != GRAPH_SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||
OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||
auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||
auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||
if (identity_op_desc == nullptr) { | |||
return nullptr; | |||
} | |||
// because history reason ,this pass can not do work after constant fold so mark it | |||
(void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
(void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
NodePtr memcpy_node = graph->AddNode(op_desc); | |||
if (memcpy_node == nullptr) { | |||
NodePtr identity_node = graph->AddNode(identity_op_desc); | |||
if (identity_node == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | |||
return nullptr; | |||
} | |||
return memcpy_node; | |||
return identity_node; | |||
} | |||
/// | |||
@@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const | |||
/// | |||
Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | |||
const InDataAnchorPtr &hccl_in_anchor) { | |||
GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(memcpy_node); | |||
NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||
GE_CHECK_NOTNULL(identity_node); | |||
Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||
GE_CHECK_NOTNULL(out_data_anchor_0); | |||
ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||
if (ret1 != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||
out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||
"Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
return FAILED; | |||
} | |||
Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", | |||
"Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||
src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||
kAnchorNum); | |||
GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
memcpy_node->GetName().c_str()); | |||
GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
hccl_in_anchor->GetIdx()); | |||
return FAILED; | |||
} | |||
return SUCCESS; | |||
@@ -1,5 +1,5 @@ | |||
/** | |||
* Copyright 2020 Huawei Technologies Co., Ltd | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
@@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES | |||
set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | |||
"${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | |||
"${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | |||
"${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" | |||
) | |||
@@ -715,7 +716,10 @@ set(PASS_TEST_FILES | |||
"graph/passes/mark_node_unknown_shape_pass_unittest.cc" | |||
"graph/passes/reshape_recovery_pass_unittest.cc" | |||
"graph/passes/cast_remove_pass_unittest.cc" | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/memcpy_addr_async_unittest.cc" | |||
"graph/passes/hccl_continuous_pass_unittest.cc" | |||
"graph/passes/hccl_memcpy_pass_unittest.cc" | |||
) | |||
set(KERNEL_TEST_FILES | |||
@@ -798,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES | |||
"graph/manager/run_graph_unittest.cc" | |||
"graph/partition/dynamic_shape_partition_unittest.cc" | |||
"graph/manager/graph_manager_unittest.cc" | |||
"graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||
"session/omg_omg_unittest.cc" | |||
"session/ge_api_unittest.cc" | |||
"session/inner_session_unittest.cc" | |||
@@ -0,0 +1,150 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <cstdint> | |||
#include <string> | |||
#include <gtest/gtest.h> | |||
#define protected public | |||
#define private public | |||
#include "graph/optimize/graph_optimize.h" | |||
#undef protected | |||
#undef private | |||
#include "../passes/graph_builder_utils.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
namespace ge { | |||
class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
namespace { | |||
/* | |||
* Data -cast - netoutput | |||
*/ | |||
ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){ | |||
auto sub_builder = ut::GraphBuilder(subraph_name); | |||
auto data1 = sub_builder.AddNode("data1", DATA, 0,1); | |||
auto cast = sub_builder.AddNode("cast", CAST, 1,1); | |||
auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1); | |||
AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0); | |||
sub_builder.AddDataEdge(data1,0,cast,0); | |||
sub_builder.AddDataEdge(cast,0,netoutput,0); | |||
return sub_builder.GetGraph(); | |||
} | |||
/* | |||
* const - allreduce | |||
* \ if | |||
* insert identity | |||
*/ | |||
ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() { | |||
auto builder = ut::GraphBuilder("test"); | |||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||
auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto if_node = builder.AddNode("if", IF, 1,0); | |||
builder.AddDataEdge(const1, 0, allreduce, 0); | |||
builder.AddDataEdge(const1, 0, if_node, 0); | |||
builder.AddControlEdge(ctrl_const, allreduce); | |||
auto root_graph = builder.GetGraph(); | |||
string subgraph_name = "then_branch"; | |||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||
then_branch_graph->SetParentNode(if_node); | |||
then_branch_graph->SetParentGraph(root_graph); | |||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||
return root_graph; | |||
} | |||
/* const1---allreduce const1--identity - allreduce | |||
* / / | |||
* var-identity--cast1 ==> var-----cast1 | |||
* \ \ | |||
* if if | |||
*/ | |||
ComputeGraphPtr BuildGraph_Identiyt_Split(){ | |||
auto builder = ut::GraphBuilder("g1"); | |||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
auto identity = builder.AddNode("identity", IDENTITY, 1, 1); | |||
auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto cast1 = builder.AddNode("cast1", CAST, 1, 1); | |||
auto if_node = builder.AddNode("if", IF, 1,0); | |||
builder.AddDataEdge(var, 0 , identity, 0); | |||
builder.AddDataEdge(identity, 0 , allreduce, 0); | |||
builder.AddDataEdge(identity, 0 , cast1, 0); | |||
builder.AddDataEdge(identity, 0 , if_node, 0); | |||
builder.AddControlEdge(const1, allreduce); | |||
auto root_graph = builder.GetGraph(); | |||
string subgraph_name = "then_branch"; | |||
ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||
then_branch_graph->SetParentNode(if_node); | |||
then_branch_graph->SetParentGraph(root_graph); | |||
if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||
if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||
root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||
return root_graph; | |||
} | |||
/* | |||
* mul == allreduce | |||
* need insert identity | |||
*/ | |||
ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() { | |||
auto builder = ut::GraphBuilder("test"); | |||
auto mul = builder.AddNode("mul", MUL, 2,1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0); | |||
AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true); | |||
builder.AddDataEdge(mul,0,allreduce,0); | |||
builder.AddDataEdge(mul,0,allreduce,1); | |||
return builder.GetGraph(); | |||
} | |||
} // namespace | |||
// const -> allreduce | |||
// const -> Identity -> allreduce | |||
TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) { | |||
ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite(); | |||
GraphOptimize graph_optimizer; | |||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
auto allreduce = graph->FindNode("allreduce"); | |||
EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY); | |||
} | |||
TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) { | |||
ComputeGraphPtr graph = BuildGraph_Identiyt_Split(); | |||
GraphOptimize graph_optimizer; | |||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
auto allreduce = graph->FindNode("allreduce"); | |||
auto allreduce_in_node = allreduce->GetInDataNodes().at(0); | |||
EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY); | |||
EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT); | |||
} | |||
TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) { | |||
ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite(); | |||
EXPECT_EQ(graph->GetDirectNodesSize(), 2); | |||
GraphOptimize graph_optimizer; | |||
auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
EXPECT_EQ(ret, SUCCESS); | |||
EXPECT_EQ(graph->GetDirectNodesSize(), 3); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,79 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <cstdint> | |||
#include <string> | |||
#include <gtest/gtest.h> | |||
#include "common/ge_inner_error_codes.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/hccl_continuous_memcpy_pass.h" | |||
#undef protected | |||
#undef private | |||
#include "graph_builder_utils.h" | |||
namespace ge { | |||
class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
namespace { | |||
/* | |||
* var var | |||
* | \ | \ | |||
* | assign | assign | |||
* | // =======> | // | |||
* allreduce identity | |||
* | | | |||
* netoutput allreduce | |||
* | | |||
* netoutput | |||
*/ | |||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||
auto builder = ut::GraphBuilder("test"); | |||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(var, 0, assign, 0); | |||
builder.AddDataEdge(var,0,allreduce,0); | |||
builder.AddControlEdge(assign, allreduce); | |||
return builder.GetGraph(); | |||
} | |||
} // namespace | |||
// const -> allreduce | |||
// const -> Identity -> allreduce | |||
TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) { | |||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||
auto src_node = graph->FindNode("var"); | |||
auto dst_node = graph->FindNode("allreduce"); | |||
// test InsertIdentityBeforeHccl | |||
HcclContinuousMemcpyPass hccl_continuous_memcpy_pass; | |||
hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0)); | |||
// check | |||
dst_node = graph->FindNode("allreduce"); | |||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,80 @@ | |||
/** | |||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <cstdint> | |||
#include <string> | |||
#include <gtest/gtest.h> | |||
#include "common/ge_inner_error_codes.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/passes/hccl_memcpy_pass.h" | |||
#undef protected | |||
#undef private | |||
#include "graph_builder_utils.h" | |||
namespace ge { | |||
class UtestGraphPassesHcclMemcpyPass : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() {} | |||
}; | |||
namespace { | |||
/* | |||
* var var | |||
* | \ | \ | |||
* | assign | assign | |||
* | // =======> | // | |||
* allreduce identity | |||
* | | | |||
* netoutput allreduce | |||
* | | |||
* netoutput | |||
*/ | |||
ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||
auto builder = ut::GraphBuilder("test"); | |||
auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||
auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||
builder.AddDataEdge(var, 0, assign, 0); | |||
builder.AddDataEdge(var,0,allreduce,0); | |||
builder.AddControlEdge(assign, allreduce); | |||
return builder.GetGraph(); | |||
} | |||
} // namespace | |||
// const -> allreduce | |||
// const -> Identity -> allreduce | |||
TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) { | |||
ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||
auto src_node = graph->FindNode("var"); | |||
auto dst_node = graph->FindNode("allreduce"); | |||
// test InsertIdentityBeforeHccl | |||
HcclMemcpyPass hccl_memcpy_pass; | |||
hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), | |||
dst_node->GetInDataAnchor(0)); | |||
// check | |||
dst_node = graph->FindNode("allreduce"); | |||
auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||
EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||
} | |||
} // namespace ge |