From c457cce7f5c9144ea50147b8b3a90864b91ea916 Mon Sep 17 00:00:00 2001 From: zhaoxinxin Date: Wed, 19 May 2021 11:14:22 +0800 Subject: [PATCH] modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc --- ge/graph/optimize/mem_rw_conflict_optimize.cc | 155 +++++++++--------- .../passes/hccl_continuous_memcpy_pass.cc | 91 +++------- ge/graph/passes/hccl_continuous_memcpy_pass.h | 2 +- ge/graph/passes/hccl_memcpy_pass.cc | 108 ++++-------- ge/graph/passes/hccl_memcpy_pass.h | 2 +- tests/ut/ge/CMakeLists.txt | 7 +- .../mem_rw_conflict_optimize_unittest.cc | 150 +++++++++++++++++ .../passes/hccl_continuous_pass_unittest.cc | 79 +++++++++ .../graph/passes/hccl_memcpy_pass_unittest.cc | 80 +++++++++ 9 files changed, 445 insertions(+), 229 deletions(-) create mode 100644 tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc create mode 100644 tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc create mode 100644 tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc diff --git a/ge/graph/optimize/mem_rw_conflict_optimize.cc b/ge/graph/optimize/mem_rw_conflict_optimize.cc index 077ed110..7c1fc4ab 100644 --- a/ge/graph/optimize/mem_rw_conflict_optimize.cc +++ b/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2020-2021 Huawei Technologies Co., Ltd * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at @@ -22,6 +22,7 @@ #include "graph/optimize/graph_optimize.h" #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" +#include "graph/utils/op_desc_utils.h" namespace { using namespace ge; @@ -32,12 +33,14 @@ const int kCaseReadOnly = 0; const int kCaseScopeWriteable = 2; const int kCaseWriteable = 3; const int kCaseInvalidRWType = 5; +// attr _input_mutable = true means node will modify its input in runtime +const char *const kModifyInput = "_input_mutable"; // rw type of input. enum class InputRWType { kReadOnly, // Normal op input only read kWriteable, // Op like Assign/ApplyMomentum - kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput + kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput kInvalidRWType }; // rw type of output @@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) { return true; } -NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { +NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) { if (src_node.GetOpDesc() == nullptr) { return nullptr; } @@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { auto next_num = identity_num.fetch_add(1); // 1. create new identity op desc string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); - auto identity_opdesc = MakeShared(identity_name, IDENTITY); - if (identity_opdesc == nullptr) { - GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); - return nullptr; - } + OpDescBuilder op_desc_builder(identity_name, IDENTITY); auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); - // 2. add input_desc & output_desc for new identity - Status ret = identity_opdesc->AddInputDesc("x", data_desc); - if (ret != SUCCESS) { - GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); - return nullptr; - } - ret = identity_opdesc->AddOutputDesc("y", data_desc); - if (ret != SUCCESS) { - GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); - return nullptr; - } + auto identity_op_desc = op_desc_builder.AddInput("x", data_desc) + .AddOutput("y", data_desc) + .Build(); + GELOGI("Insert new Identity node %s.", identity_name.c_str()); auto graph = src_node.GetOwnerComputeGraph(); if (graph == nullptr) { GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); return nullptr; } - return graph->AddNode(identity_opdesc); + return graph->AddNode(identity_op_desc); } OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { @@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { // single node without sub graph return GetSingleNodeInputRWTypeByIndex(node, index); } else { - // node with sub graph - std::set node_rw_type_set; auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); // get all input data node in subgraph std::set anchor_rw_type_set; @@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { auto parent_node = sub_graph->GetParentNode(); if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { // insert identity - auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); GE_CHECK_NOTNULL(identity_node); - auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); - if (ret != SUCCESS) { - GELOGE(ret, "Fail to insert identity"); - return ret; + if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) { + REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", + identity_node->GetName().c_str(), + identity_node->GetType().c_str(), + pre_node->GetName().c_str(), + pre_node->GetType().c_str(), + node->GetName().c_str(), + node->GetType().c_str()); + GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", + identity_node->GetName().c_str(), + identity_node->GetType().c_str(), + pre_node->GetName().c_str(), + pre_node->GetType().c_str(), + node->GetName().c_str(), + node->GetType().c_str()); + return FAILED; } GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), pre_node->GetName().c_str(), node->GetName().c_str()); @@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); GE_CHECK_NOTNULL(peer_in_data_node); auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); - auto ret = out_data_anchor->Unlink(peer_in_data_anchor); auto old_identity = out_data_anchor->GetOwnerNode(); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), - peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return ret; - } if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { - auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); + auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx()); GE_CHECK_NOTNULL(new_identity); - if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS - || GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", - pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), - peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); - return INTERNAL_ERROR; - } - - // 2. copy in-control-edge from dst to Identity - if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), - new_identity->GetName().c_str()); - return INTERNAL_ERROR; + auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex, + kIdentityAnchorIndex); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to insert Identity %s before %s %dth input.", + new_identity->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetIdx()); + return ret; } GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); } else { + (void) out_data_anchor->Unlink(peer_in_data_anchor); // copy control edge to pre and peer node if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS || GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { @@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { GELOGD("No need insert Identity."); continue; case INSERT_IDENTITY: - auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); - if (identity_node == nullptr) { - GELOGE(FAILED, "Create identity node failed."); - return FAILED; - } - auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); - if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), - peer_in_node->GetName().c_str()); - return INTERNAL_ERROR; + auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex, + kIdentityAnchorIndex); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx()); + return ret; } GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), peer_in_node->GetName().c_str()); @@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { return SUCCESS; } Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { - for (const auto &node : compute_graph->GetDirectNode()) { - if (node->GetType() == HCOMALLREDUCE) { - std::set pre_out_anchor_set; - for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { - auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(pre_out_anchor); - if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { - pre_out_anchor_set.emplace(pre_out_anchor); - continue; - } - // need insert identity - auto pre_node = pre_out_anchor->GetOwnerNode(); - auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); - GE_CHECK_NOTNULL(identity_node); - auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); - GE_CHK_STATUS_RET(ret, "Fail to insert identity."); - GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), - pre_node->GetName().c_str(), node->GetName().c_str()); - } - } - } - return SUCCESS; + for (const auto &node : compute_graph->GetDirectNode()) { + bool mutable_input_flag = false; + (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); + if (!mutable_input_flag) { + continue; + } + std::set pre_out_anchor_set; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + if (pre_out_anchor_set.insert(pre_out_anchor).second) { + continue; + } + // need insert identity + auto pre_node = pre_out_anchor->GetOwnerNode(); + auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = + GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(), + node->GetName().c_str(), in_data_anchor->GetIdx()); + return ret; + } + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()); + } + } + return SUCCESS; } } // namespace diff --git a/ge/graph/passes/hccl_continuous_memcpy_pass.cc b/ge/graph/passes/hccl_continuous_memcpy_pass.cc index 790661bc..61066d63 100644 --- a/ge/graph/passes/hccl_continuous_memcpy_pass.cc +++ b/ge/graph/passes/hccl_continuous_memcpy_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,11 +24,12 @@ #include "common/ge/ge_util.h" #include "framework/common/types.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" namespace { -const int kAnchorNum = 0; const int32_t kAnchorAssignRefIndex = 0; const int32_t kAnchorAssignValueIndex = 1; +const int32_t kAnchorIdentityIndex = 0; } // namespace namespace ge { Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { @@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap std::string node_name = pre_node->GetName() + "_" + IDENTITY; node_name = CheckDuplicateName(node_name); - OpDescPtr op_desc = MakeShared(node_name.c_str(), IDENTITY); - if (op_desc == nullptr) { - REPORT_CALL_ERROR("E19999", "New OpDesc failed"); - GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); - return nullptr; - } - GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); - - graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); - return nullptr; - } - - ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); + OpDescBuilder op_desc_builder(node_name, IDENTITY); + auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); + auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); + if (identity_op_desc == nullptr) { return nullptr; } // because history reason ,this pass can not do work after constant fold so mark it - (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); + (void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); - NodePtr memcpy_node = graph->AddNode(op_desc); - if (memcpy_node == nullptr) { + NodePtr identity_node = graph->AddNode(identity_op_desc); + if (identity_node == nullptr) { REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); + identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); return nullptr; } - - return memcpy_node; + return identity_node; } /// @@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, const InDataAnchorPtr &hccl_in_anchor) { - GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), + GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), hccl_in_anchor->GetOwnerNode()->GetName().c_str()); - NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); - GE_CHECK_NOTNULL(memcpy_node); + NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); + GE_CHECK_NOTNULL(identity_node); - Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); - if (ret1 != SUCCESS) { - REPORT_CALL_ERROR("E19999", - "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", - src_out_anchor->GetOwnerNode()->GetName().c_str(), - src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), - hccl_in_anchor->GetOwnerNode()->GetName().c_str(), - hccl_in_anchor->GetOwnerNode()->GetType().c_str(), - hccl_in_anchor->GetIdx()); - GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), - hccl_in_anchor->GetOwnerNode()->GetName().c_str()); - return FAILED; - } - auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); - GE_CHECK_NOTNULL(out_data_anchor_0); - ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); - if (ret1 != SUCCESS) { + auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); + if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", - "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", - out_data_anchor_0->GetOwnerNode()->GetName().c_str(), - out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), + "Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", + identity_node->GetName().c_str(), identity_node->GetType().c_str(), hccl_in_anchor->GetOwnerNode()->GetName().c_str(), hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); - GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), - hccl_in_anchor->GetOwnerNode()->GetName().c_str()); - return FAILED; - } - - Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); - if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", - "Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", - src_out_anchor->GetOwnerNode()->GetName().c_str(), - src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), - memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), - kAnchorNum); - GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), - memcpy_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", + identity_node->GetName().c_str(), identity_node->GetType().c_str(), + hccl_in_anchor->GetOwnerNode()->GetName().c_str(), + hccl_in_anchor->GetOwnerNode()->GetType().c_str(), + hccl_in_anchor->GetIdx()); return FAILED; } return SUCCESS; diff --git a/ge/graph/passes/hccl_continuous_memcpy_pass.h b/ge/graph/passes/hccl_continuous_memcpy_pass.h index 538e89e9..5fbb6fd0 100644 --- a/ge/graph/passes/hccl_continuous_memcpy_pass.h +++ b/ge/graph/passes/hccl_continuous_memcpy_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/ge/graph/passes/hccl_memcpy_pass.cc b/ge/graph/passes/hccl_memcpy_pass.cc index 2d2f8220..dd251ea6 100755 --- a/ge/graph/passes/hccl_memcpy_pass.cc +++ b/ge/graph/passes/hccl_memcpy_pass.cc @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,13 +24,15 @@ #include "common/ge/ge_util.h" #include "framework/common/types.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/op_desc_utils.h" namespace { const int32_t kAnchorSize = 1; -const int kAnchorNum = 0; const int32_t kAnchorAssignRefIndex = 0; const int32_t kAnchorAssignValueIndex = 1; -const char *const kInputMutable = "_input_mutable"; +const int32_t kAnchorIdentityIndex = 0; +// attr _input_mutable = true means hccl node will modify its input in runtime +const char *const kModifyInput = "_input_mutable"; } // namespace namespace ge { Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { @@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { // need to inset memcpy node between. // also works on situation that input is variable or const. Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { - auto op_desc = node->GetOpDesc(); - bool node_input_mutable = false; - if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { - return SUCCESS; - } - - if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { - REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); - return FAILED; - } + (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); if (!node_input_mutable) { return SUCCESS; } - GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); + GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { if (hccl_in_anchor == nullptr) { continue; @@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O std::string node_name = pre_node->GetName() + "_" + IDENTITY; node_name = CheckDuplicateName(node_name); - OpDescPtr op_desc = MakeShared(node_name.c_str(), IDENTITY); - if (op_desc == nullptr) { - REPORT_CALL_ERROR("E19999", "New OpDesc failed"); - GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); - return nullptr; - } - GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); - - graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); - return nullptr; - } - - ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); - if (ret != GRAPH_SUCCESS) { - REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y", - op_desc->GetName().c_str(), op_desc->GetType().c_str()); - GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); + OpDescBuilder op_desc_builder(node_name, IDENTITY); + auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); + auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); + if (identity_op_desc == nullptr) { return nullptr; } // because history reason ,this pass can not do work after constant fold so mark it - (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); + (void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); - NodePtr memcpy_node = graph->AddNode(op_desc); - if (memcpy_node == nullptr) { + NodePtr identity_node = graph->AddNode(identity_op_desc); + if (identity_node == nullptr) { REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", - op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); + identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); return nullptr; } - - return memcpy_node; + return identity_node; } /// @@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const /// Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, const InDataAnchorPtr &hccl_in_anchor) { - GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), + GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), hccl_in_anchor->GetOwnerNode()->GetName().c_str()); - NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); - GE_CHECK_NOTNULL(memcpy_node); + NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); + GE_CHECK_NOTNULL(identity_node); - Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); - if (ret1 != SUCCESS) { + auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); + if (ret != SUCCESS) { REPORT_CALL_ERROR("E19999", - "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", - src_out_anchor->GetOwnerNode()->GetName().c_str(), - src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), - hccl_in_anchor->GetOwnerNode()->GetName().c_str(), - hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); - GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), - hccl_in_anchor->GetOwnerNode()->GetName().c_str()); - return FAILED; - } - auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); - GE_CHECK_NOTNULL(out_data_anchor_0); - ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); - if (ret1 != SUCCESS) { - REPORT_CALL_ERROR("E19999", - "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", - out_data_anchor_0->GetOwnerNode()->GetName().c_str(), - out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), + "Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", + identity_node->GetName().c_str(), identity_node->GetType().c_str(), hccl_in_anchor->GetOwnerNode()->GetName().c_str(), hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); - GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), - hccl_in_anchor->GetOwnerNode()->GetName().c_str()); - return FAILED; - } - - Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); - if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", - "Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", - src_out_anchor->GetOwnerNode()->GetName().c_str(), - src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), - memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), - kAnchorNum); - GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), - memcpy_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", + identity_node->GetName().c_str(), identity_node->GetType().c_str(), + hccl_in_anchor->GetOwnerNode()->GetName().c_str(), + hccl_in_anchor->GetOwnerNode()->GetType().c_str(), + hccl_in_anchor->GetIdx()); return FAILED; } return SUCCESS; diff --git a/ge/graph/passes/hccl_memcpy_pass.h b/ge/graph/passes/hccl_memcpy_pass.h index 7ab63c59..b75b27d1 100755 --- a/ge/graph/passes/hccl_memcpy_pass.h +++ b/ge/graph/passes/hccl_memcpy_pass.h @@ -1,5 +1,5 @@ /** - * Copyright 2020 Huawei Technologies Co., Ltd + * Copyright 2021 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 16f3672b..40a94dd2 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES set(GRAPH_OPTIMIZE_COMMON_SRC_FILES "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" + "${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" ) @@ -715,7 +716,10 @@ set(PASS_TEST_FILES "graph/passes/mark_node_unknown_shape_pass_unittest.cc" "graph/passes/reshape_recovery_pass_unittest.cc" "graph/passes/cast_remove_pass_unittest.cc" - "graph/passes/memcpy_addr_async_unittest.cc" + "graph/passes/memcpy_addr_async_unittest.cc" + "graph/passes/hccl_continuous_pass_unittest.cc" + "graph/passes/hccl_memcpy_pass_unittest.cc" + ) set(KERNEL_TEST_FILES @@ -798,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES "graph/manager/run_graph_unittest.cc" "graph/partition/dynamic_shape_partition_unittest.cc" "graph/manager/graph_manager_unittest.cc" + "graph/optimize/mem_rw_conflict_optimize_unittest.cc" "session/omg_omg_unittest.cc" "session/ge_api_unittest.cc" "session/inner_session_unittest.cc" diff --git a/tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc b/tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file mode 100644 index 00000000..22b0b2c0 --- /dev/null +++ b/tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc @@ -0,0 +1,150 @@ +/** + * Copyright 2021 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#define protected public +#define private public +#include "graph/optimize/graph_optimize.h" +#undef protected +#undef private +#include "../passes/graph_builder_utils.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; +namespace { +/* + * Data -cast - netoutput + */ +ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){ + auto sub_builder = ut::GraphBuilder(subraph_name); + auto data1 = sub_builder.AddNode("data1", DATA, 0,1); + auto cast = sub_builder.AddNode("cast", CAST, 1,1); + auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1); + AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1); + AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0); + + sub_builder.AddDataEdge(data1,0,cast,0); + sub_builder.AddDataEdge(cast,0,netoutput,0); + return sub_builder.GetGraph(); +} +/* + * const - allreduce + * \ if + * insert identity + */ +ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() { + auto builder = ut::GraphBuilder("test"); + auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); + auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1); + auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); + auto if_node = builder.AddNode("if", IF, 1,0); + + builder.AddDataEdge(const1, 0, allreduce, 0); + builder.AddDataEdge(const1, 0, if_node, 0); + builder.AddControlEdge(ctrl_const, allreduce); + + auto root_graph = builder.GetGraph(); + string subgraph_name = "then_branch"; + ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); + then_branch_graph->SetParentNode(if_node); + then_branch_graph->SetParentGraph(root_graph); + if_node->GetOpDesc()->AddSubgraphName(subgraph_name); + if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); + root_graph->AddSubgraph(subgraph_name, then_branch_graph); + return root_graph; +} +/* const1---allreduce const1--identity - allreduce + * / / + * var-identity--cast1 ==> var-----cast1 + * \ \ + * if if + */ +ComputeGraphPtr BuildGraph_Identiyt_Split(){ + auto builder = ut::GraphBuilder("g1"); + auto var = builder.AddNode("var", VARIABLE, 0, 1); + auto identity = builder.AddNode("identity", IDENTITY, 1, 1); + auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); + auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); + auto cast1 = builder.AddNode("cast1", CAST, 1, 1); + auto if_node = builder.AddNode("if", IF, 1,0); + + builder.AddDataEdge(var, 0 , identity, 0); + builder.AddDataEdge(identity, 0 , allreduce, 0); + builder.AddDataEdge(identity, 0 , cast1, 0); + builder.AddDataEdge(identity, 0 , if_node, 0); + builder.AddControlEdge(const1, allreduce); + + auto root_graph = builder.GetGraph(); + string subgraph_name = "then_branch"; + ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); + then_branch_graph->SetParentNode(if_node); + then_branch_graph->SetParentGraph(root_graph); + if_node->GetOpDesc()->AddSubgraphName(subgraph_name); + if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); + root_graph->AddSubgraph(subgraph_name, then_branch_graph); + return root_graph; +} +/* + * mul == allreduce + * need insert identity + */ +ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() { + auto builder = ut::GraphBuilder("test"); + auto mul = builder.AddNode("mul", MUL, 2,1); + auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0); + AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true); + builder.AddDataEdge(mul,0,allreduce,0); + builder.AddDataEdge(mul,0,allreduce,1); + return builder.GetGraph(); +} +} // namespace +// const -> allreduce +// const -> Identity -> allreduce +TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) { + ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite(); + GraphOptimize graph_optimizer; + auto ret = graph_optimizer.HandleMemoryRWConflict(graph); + EXPECT_EQ(ret, SUCCESS); + auto allreduce = graph->FindNode("allreduce"); + EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY); +} +TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) { + ComputeGraphPtr graph = BuildGraph_Identiyt_Split(); + GraphOptimize graph_optimizer; + auto ret = graph_optimizer.HandleMemoryRWConflict(graph); + EXPECT_EQ(ret, SUCCESS); + auto allreduce = graph->FindNode("allreduce"); + auto allreduce_in_node = allreduce->GetInDataNodes().at(0); + EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY); + EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT); +} +TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) { + ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite(); + EXPECT_EQ(graph->GetDirectNodesSize(), 2); + GraphOptimize graph_optimizer; + auto ret = graph_optimizer.HandleMemoryRWConflict(graph); + EXPECT_EQ(ret, SUCCESS); + EXPECT_EQ(graph->GetDirectNodesSize(), 3); +} +} // namespace ge diff --git a/tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc b/tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file mode 100644 index 00000000..fb18162b --- /dev/null +++ b/tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc @@ -0,0 +1,79 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "common/ge_inner_error_codes.h" +#define protected public +#define private public +#include "graph/passes/hccl_continuous_memcpy_pass.h" +#undef protected +#undef private +#include "graph_builder_utils.h" + +namespace ge { +class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; +namespace { + +/* + * var var + * | \ | \ + * | assign | assign + * | // =======> | // + * allreduce identity + * | | + * netoutput allreduce + * | + * netoutput + */ +ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ + auto builder = ut::GraphBuilder("test"); + auto var = builder.AddNode("var", VARIABLE, 0, 1); + auto assign = builder.AddNode("assign", ASSIGN, 1, 1); + auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); + auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + builder.AddDataEdge(var, 0, assign, 0); + builder.AddDataEdge(var,0,allreduce,0); + builder.AddControlEdge(assign, allreduce); + return builder.GetGraph(); +} +} // namespace + +// const -> allreduce +// const -> Identity -> allreduce +TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) { + ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); + auto src_node = graph->FindNode("var"); + auto dst_node = graph->FindNode("allreduce"); + // test InsertIdentityBeforeHccl + HcclContinuousMemcpyPass hccl_continuous_memcpy_pass; + hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0)); + + // check + dst_node = graph->FindNode("allreduce"); + auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); + EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); + EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); + EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); +} +} // namespace ge diff --git a/tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc b/tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc new file mode 100644 index 00000000..35edeb47 --- /dev/null +++ b/tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc @@ -0,0 +1,80 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "common/ge_inner_error_codes.h" +#define protected public +#define private public +#include "graph/passes/hccl_memcpy_pass.h" +#undef protected +#undef private +#include "graph_builder_utils.h" + +namespace ge { +class UtestGraphPassesHcclMemcpyPass : public testing::Test { + protected: + void SetUp() {} + void TearDown() {} +}; +namespace { + +/* + * var var + * | \ | \ + * | assign | assign + * | // =======> | // + * allreduce identity + * | | + * netoutput allreduce + * | + * netoutput + */ +ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ + auto builder = ut::GraphBuilder("test"); + auto var = builder.AddNode("var", VARIABLE, 0, 1); + auto assign = builder.AddNode("assign", ASSIGN, 1, 1); + auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); + auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); + + builder.AddDataEdge(var, 0, assign, 0); + builder.AddDataEdge(var,0,allreduce,0); + builder.AddControlEdge(assign, allreduce); + return builder.GetGraph(); +} +} // namespace + +// const -> allreduce +// const -> Identity -> allreduce +TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) { + ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); + auto src_node = graph->FindNode("var"); + auto dst_node = graph->FindNode("allreduce"); + // test InsertIdentityBeforeHccl + HcclMemcpyPass hccl_memcpy_pass; + hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), + dst_node->GetInDataAnchor(0)); + + // check + dst_node = graph->FindNode("allreduce"); + auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); + EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); + EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); + EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); +} +} // namespace ge