modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cctags/v1.3.0
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| @@ -22,6 +22,7 @@ | |||
| #include "graph/optimize/graph_optimize.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| namespace { | |||
| using namespace ge; | |||
| @@ -32,12 +33,14 @@ const int kCaseReadOnly = 0; | |||
| const int kCaseScopeWriteable = 2; | |||
| const int kCaseWriteable = 3; | |||
| const int kCaseInvalidRWType = 5; | |||
| // attr _input_mutable = true means node will modify its input in runtime | |||
| const char *const kModifyInput = "_input_mutable"; | |||
| // rw type of input. | |||
| enum class InputRWType { | |||
| kReadOnly, // Normal op input only read | |||
| kWriteable, // Op like Assign/ApplyMomentum | |||
| kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput | |||
| kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput | |||
| kInvalidRWType | |||
| }; | |||
| // rw type of output | |||
| @@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) { | |||
| return true; | |||
| } | |||
| NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||
| NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) { | |||
| if (src_node.GetOpDesc() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| @@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||
| auto next_num = identity_num.fetch_add(1); | |||
| // 1. create new identity op desc | |||
| string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | |||
| auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY); | |||
| if (identity_opdesc == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); | |||
| return nullptr; | |||
| } | |||
| OpDescBuilder op_desc_builder(identity_name, IDENTITY); | |||
| auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | |||
| // 2. add input_desc & output_desc for new identity | |||
| Status ret = identity_opdesc->AddInputDesc("x", data_desc); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); | |||
| return nullptr; | |||
| } | |||
| ret = identity_opdesc->AddOutputDesc("y", data_desc); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); | |||
| return nullptr; | |||
| } | |||
| auto identity_op_desc = op_desc_builder.AddInput("x", data_desc) | |||
| .AddOutput("y", data_desc) | |||
| .Build(); | |||
| GELOGI("Insert new Identity node %s.", identity_name.c_str()); | |||
| auto graph = src_node.GetOwnerComputeGraph(); | |||
| if (graph == nullptr) { | |||
| GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | |||
| return nullptr; | |||
| } | |||
| return graph->AddNode(identity_opdesc); | |||
| return graph->AddNode(identity_op_desc); | |||
| } | |||
| OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | |||
| @@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||
| // single node without sub graph | |||
| return GetSingleNodeInputRWTypeByIndex(node, index); | |||
| } else { | |||
| // node with sub graph | |||
| std::set<int> node_rw_type_set; | |||
| auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | |||
| // get all input data node in subgraph | |||
| std::set<int> anchor_rw_type_set; | |||
| @@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { | |||
| auto parent_node = sub_graph->GetParentNode(); | |||
| if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | |||
| // insert identity | |||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
| auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Fail to insert identity"); | |||
| return ret; | |||
| if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||
| identity_node->GetName().c_str(), | |||
| identity_node->GetType().c_str(), | |||
| pre_node->GetName().c_str(), | |||
| pre_node->GetType().c_str(), | |||
| node->GetName().c_str(), | |||
| node->GetType().c_str()); | |||
| GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||
| identity_node->GetName().c_str(), | |||
| identity_node->GetType().c_str(), | |||
| pre_node->GetName().c_str(), | |||
| pre_node->GetType().c_str(), | |||
| node->GetName().c_str(), | |||
| node->GetType().c_str()); | |||
| return FAILED; | |||
| } | |||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||
| @@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I | |||
| auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(peer_in_data_node); | |||
| auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | |||
| auto ret = out_data_anchor->Unlink(peer_in_data_anchor); | |||
| auto old_identity = out_data_anchor->GetOwnerNode(); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), | |||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | |||
| auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); | |||
| auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(new_identity); | |||
| if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS | |||
| || GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", | |||
| pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| // 2. copy in-control-edge from dst to Identity | |||
| if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), | |||
| new_identity->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex, | |||
| kIdentityAnchorIndex); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to insert Identity %s before %s %dth input.", | |||
| new_identity->GetName().c_str(), | |||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
| peer_in_data_anchor->GetIdx()); | |||
| return ret; | |||
| } | |||
| GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | |||
| InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||
| } else { | |||
| (void) out_data_anchor->Unlink(peer_in_data_anchor); | |||
| // copy control edge to pre and peer node | |||
| if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | |||
| || GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | |||
| @@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
| GELOGD("No need insert Identity."); | |||
| continue; | |||
| case INSERT_IDENTITY: | |||
| auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); | |||
| if (identity_node == nullptr) { | |||
| GELOGE(FAILED, "Create identity node failed."); | |||
| return FAILED; | |||
| } | |||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), | |||
| peer_in_node->GetName().c_str()); | |||
| return INTERNAL_ERROR; | |||
| auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex, | |||
| kIdentityAnchorIndex); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(), | |||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx()); | |||
| return ret; | |||
| } | |||
| GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | |||
| peer_in_node->GetName().c_str()); | |||
| @@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||
| return SUCCESS; | |||
| } | |||
| Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | |||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||
| if (node->GetType() == HCOMALLREDUCE) { | |||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||
| continue; | |||
| } | |||
| // need insert identity | |||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||
| } | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||
| bool mutable_input_flag = false; | |||
| (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||
| if (!mutable_input_flag) { | |||
| continue; | |||
| } | |||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||
| if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||
| continue; | |||
| } | |||
| // need insert identity | |||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||
| auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| auto ret = | |||
| GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(), | |||
| node->GetName().c_str(), in_data_anchor->GetIdx()); | |||
| return ret; | |||
| } | |||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -24,11 +24,12 @@ | |||
| #include "common/ge/ge_util.h" | |||
| #include "framework/common/types.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| namespace { | |||
| const int kAnchorNum = 0; | |||
| const int32_t kAnchorAssignRefIndex = 0; | |||
| const int32_t kAnchorAssignValueIndex = 1; | |||
| const int32_t kAnchorIdentityIndex = 0; | |||
| } // namespace | |||
| namespace ge { | |||
| Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
| @@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap | |||
| std::string node_name = pre_node->GetName() + "_" + IDENTITY; | |||
| node_name = CheckDuplicateName(node_name); | |||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||
| if (op_desc == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||
| GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||
| return nullptr; | |||
| } | |||
| GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||
| graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||
| return nullptr; | |||
| } | |||
| ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||
| OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||
| auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||
| auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||
| if (identity_op_desc == nullptr) { | |||
| return nullptr; | |||
| } | |||
| // because history reason ,this pass can not do work after constant fold so mark it | |||
| (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
| (void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
| NodePtr memcpy_node = graph->AddNode(op_desc); | |||
| if (memcpy_node == nullptr) { | |||
| NodePtr identity_node = graph->AddNode(identity_op_desc); | |||
| if (identity_node == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | |||
| return nullptr; | |||
| } | |||
| return memcpy_node; | |||
| return identity_node; | |||
| } | |||
| /// | |||
| @@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra | |||
| Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | |||
| const OutDataAnchorPtr &src_out_anchor, | |||
| const InDataAnchorPtr &hccl_in_anchor) { | |||
| GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
| NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||
| GE_CHECK_NOTNULL(memcpy_node); | |||
| NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||
| if (ret1 != SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", | |||
| "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
| hccl_in_anchor->GetIdx()); | |||
| GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||
| GE_CHECK_NOTNULL(out_data_anchor_0); | |||
| ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||
| if (ret1 != SUCCESS) { | |||
| auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||
| if (ret != SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", | |||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
| out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||
| out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||
| "Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
| hccl_in_anchor->GetIdx()); | |||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||
| if (ret != SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", | |||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
| memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||
| kAnchorNum); | |||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| memcpy_node->GetName().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
| hccl_in_anchor->GetIdx()); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -24,13 +24,15 @@ | |||
| #include "common/ge/ge_util.h" | |||
| #include "framework/common/types.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| namespace { | |||
| const int32_t kAnchorSize = 1; | |||
| const int kAnchorNum = 0; | |||
| const int32_t kAnchorAssignRefIndex = 0; | |||
| const int32_t kAnchorAssignValueIndex = 1; | |||
| const char *const kInputMutable = "_input_mutable"; | |||
| const int32_t kAnchorIdentityIndex = 0; | |||
| // attr _input_mutable = true means hccl node will modify its input in runtime | |||
| const char *const kModifyInput = "_input_mutable"; | |||
| } // namespace | |||
| namespace ge { | |||
| Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
| @@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||
| // need to inset memcpy node between. | |||
| // also works on situation that input is variable or const. | |||
| Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| bool node_input_mutable = false; | |||
| if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||
| return SUCCESS; | |||
| } | |||
| if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||
| REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||
| if (!node_input_mutable) { | |||
| return SUCCESS; | |||
| } | |||
| GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||
| GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||
| for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | |||
| if (hccl_in_anchor == nullptr) { | |||
| continue; | |||
| @@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O | |||
| std::string node_name = pre_node->GetName() + "_" + IDENTITY; | |||
| node_name = CheckDuplicateName(node_name); | |||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||
| if (op_desc == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||
| GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||
| return nullptr; | |||
| } | |||
| GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||
| graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||
| return nullptr; | |||
| } | |||
| ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||
| OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||
| auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||
| auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||
| if (identity_op_desc == nullptr) { | |||
| return nullptr; | |||
| } | |||
| // because history reason ,this pass can not do work after constant fold so mark it | |||
| (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
| (void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||
| NodePtr memcpy_node = graph->AddNode(op_desc); | |||
| if (memcpy_node == nullptr) { | |||
| NodePtr identity_node = graph->AddNode(identity_op_desc); | |||
| if (identity_node == nullptr) { | |||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | |||
| return nullptr; | |||
| } | |||
| return memcpy_node; | |||
| return identity_node; | |||
| } | |||
| /// | |||
| @@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const | |||
| /// | |||
| Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | |||
| const InDataAnchorPtr &hccl_in_anchor) { | |||
| GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
| NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||
| GE_CHECK_NOTNULL(memcpy_node); | |||
| NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||
| GE_CHECK_NOTNULL(identity_node); | |||
| Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||
| if (ret1 != SUCCESS) { | |||
| auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||
| if (ret != SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", | |||
| "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); | |||
| GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||
| GE_CHECK_NOTNULL(out_data_anchor_0); | |||
| ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||
| if (ret1 != SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", | |||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||
| out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||
| out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||
| "Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
| hccl_in_anchor->GetIdx()); | |||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return FAILED; | |||
| } | |||
| Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||
| if (ret != SUCCESS) { | |||
| REPORT_CALL_ERROR("E19999", | |||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||
| memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||
| kAnchorNum); | |||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||
| memcpy_node->GetName().c_str()); | |||
| GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||
| hccl_in_anchor->GetIdx()); | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| @@ -1,5 +1,5 @@ | |||
| /** | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| @@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES | |||
| set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | |||
| "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | |||
| "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | |||
| "${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" | |||
| ) | |||
| @@ -715,7 +716,10 @@ set(PASS_TEST_FILES | |||
| "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | |||
| "graph/passes/reshape_recovery_pass_unittest.cc" | |||
| "graph/passes/cast_remove_pass_unittest.cc" | |||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||
| "graph/passes/hccl_continuous_pass_unittest.cc" | |||
| "graph/passes/hccl_memcpy_pass_unittest.cc" | |||
| ) | |||
| set(KERNEL_TEST_FILES | |||
| @@ -798,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES | |||
| "graph/manager/run_graph_unittest.cc" | |||
| "graph/partition/dynamic_shape_partition_unittest.cc" | |||
| "graph/manager/graph_manager_unittest.cc" | |||
| "graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||
| "session/omg_omg_unittest.cc" | |||
| "session/ge_api_unittest.cc" | |||
| "session/inner_session_unittest.cc" | |||
| @@ -0,0 +1,150 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cstdint> | |||
| #include <string> | |||
| #include <gtest/gtest.h> | |||
| #define protected public | |||
| #define private public | |||
| #include "graph/optimize/graph_optimize.h" | |||
| #undef protected | |||
| #undef private | |||
| #include "../passes/graph_builder_utils.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| namespace ge { | |||
| class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| namespace { | |||
| /* | |||
| * Data -cast - netoutput | |||
| */ | |||
| ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){ | |||
| auto sub_builder = ut::GraphBuilder(subraph_name); | |||
| auto data1 = sub_builder.AddNode("data1", DATA, 0,1); | |||
| auto cast = sub_builder.AddNode("cast", CAST, 1,1); | |||
| auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1); | |||
| AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1); | |||
| AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0); | |||
| sub_builder.AddDataEdge(data1,0,cast,0); | |||
| sub_builder.AddDataEdge(cast,0,netoutput,0); | |||
| return sub_builder.GetGraph(); | |||
| } | |||
| /* | |||
| * const - allreduce | |||
| * \ if | |||
| * insert identity | |||
| */ | |||
| ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||
| auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1); | |||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
| auto if_node = builder.AddNode("if", IF, 1,0); | |||
| builder.AddDataEdge(const1, 0, allreduce, 0); | |||
| builder.AddDataEdge(const1, 0, if_node, 0); | |||
| builder.AddControlEdge(ctrl_const, allreduce); | |||
| auto root_graph = builder.GetGraph(); | |||
| string subgraph_name = "then_branch"; | |||
| ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||
| then_branch_graph->SetParentNode(if_node); | |||
| then_branch_graph->SetParentGraph(root_graph); | |||
| if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||
| if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||
| root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||
| return root_graph; | |||
| } | |||
| /* const1---allreduce const1--identity - allreduce | |||
| * / / | |||
| * var-identity--cast1 ==> var-----cast1 | |||
| * \ \ | |||
| * if if | |||
| */ | |||
| ComputeGraphPtr BuildGraph_Identiyt_Split(){ | |||
| auto builder = ut::GraphBuilder("g1"); | |||
| auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
| auto identity = builder.AddNode("identity", IDENTITY, 1, 1); | |||
| auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
| auto cast1 = builder.AddNode("cast1", CAST, 1, 1); | |||
| auto if_node = builder.AddNode("if", IF, 1,0); | |||
| builder.AddDataEdge(var, 0 , identity, 0); | |||
| builder.AddDataEdge(identity, 0 , allreduce, 0); | |||
| builder.AddDataEdge(identity, 0 , cast1, 0); | |||
| builder.AddDataEdge(identity, 0 , if_node, 0); | |||
| builder.AddControlEdge(const1, allreduce); | |||
| auto root_graph = builder.GetGraph(); | |||
| string subgraph_name = "then_branch"; | |||
| ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||
| then_branch_graph->SetParentNode(if_node); | |||
| then_branch_graph->SetParentGraph(root_graph); | |||
| if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||
| if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||
| root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||
| return root_graph; | |||
| } | |||
| /* | |||
| * mul == allreduce | |||
| * need insert identity | |||
| */ | |||
| ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() { | |||
| auto builder = ut::GraphBuilder("test"); | |||
| auto mul = builder.AddNode("mul", MUL, 2,1); | |||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0); | |||
| AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true); | |||
| builder.AddDataEdge(mul,0,allreduce,0); | |||
| builder.AddDataEdge(mul,0,allreduce,1); | |||
| return builder.GetGraph(); | |||
| } | |||
| } // namespace | |||
| // const -> allreduce | |||
| // const -> Identity -> allreduce | |||
| TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) { | |||
| ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite(); | |||
| GraphOptimize graph_optimizer; | |||
| auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| auto allreduce = graph->FindNode("allreduce"); | |||
| EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY); | |||
| } | |||
| TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) { | |||
| ComputeGraphPtr graph = BuildGraph_Identiyt_Split(); | |||
| GraphOptimize graph_optimizer; | |||
| auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| auto allreduce = graph->FindNode("allreduce"); | |||
| auto allreduce_in_node = allreduce->GetInDataNodes().at(0); | |||
| EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY); | |||
| EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT); | |||
| } | |||
| TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) { | |||
| ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite(); | |||
| EXPECT_EQ(graph->GetDirectNodesSize(), 2); | |||
| GraphOptimize graph_optimizer; | |||
| auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| EXPECT_EQ(graph->GetDirectNodesSize(), 3); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cstdint> | |||
| #include <string> | |||
| #include <gtest/gtest.h> | |||
| #include "common/ge_inner_error_codes.h" | |||
| #define protected public | |||
| #define private public | |||
| #include "graph/passes/hccl_continuous_memcpy_pass.h" | |||
| #undef protected | |||
| #undef private | |||
| #include "graph_builder_utils.h" | |||
| namespace ge { | |||
| class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| namespace { | |||
| /* | |||
| * var var | |||
| * | \ | \ | |||
| * | assign | assign | |||
| * | // =======> | // | |||
| * allreduce identity | |||
| * | | | |||
| * netoutput allreduce | |||
| * | | |||
| * netoutput | |||
| */ | |||
| ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||
| auto builder = ut::GraphBuilder("test"); | |||
| auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
| auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
| auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||
| builder.AddDataEdge(var, 0, assign, 0); | |||
| builder.AddDataEdge(var,0,allreduce,0); | |||
| builder.AddControlEdge(assign, allreduce); | |||
| return builder.GetGraph(); | |||
| } | |||
| } // namespace | |||
| // const -> allreduce | |||
| // const -> Identity -> allreduce | |||
| TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) { | |||
| ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||
| auto src_node = graph->FindNode("var"); | |||
| auto dst_node = graph->FindNode("allreduce"); | |||
| // test InsertIdentityBeforeHccl | |||
| HcclContinuousMemcpyPass hccl_continuous_memcpy_pass; | |||
| hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0)); | |||
| // check | |||
| dst_node = graph->FindNode("allreduce"); | |||
| auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
| EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,80 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <cstdint> | |||
| #include <string> | |||
| #include <gtest/gtest.h> | |||
| #include "common/ge_inner_error_codes.h" | |||
| #define protected public | |||
| #define private public | |||
| #include "graph/passes/hccl_memcpy_pass.h" | |||
| #undef protected | |||
| #undef private | |||
| #include "graph_builder_utils.h" | |||
| namespace ge { | |||
| class UtestGraphPassesHcclMemcpyPass : public testing::Test { | |||
| protected: | |||
| void SetUp() {} | |||
| void TearDown() {} | |||
| }; | |||
| namespace { | |||
| /* | |||
| * var var | |||
| * | \ | \ | |||
| * | assign | assign | |||
| * | // =======> | // | |||
| * allreduce identity | |||
| * | | | |||
| * netoutput allreduce | |||
| * | | |||
| * netoutput | |||
| */ | |||
| ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||
| auto builder = ut::GraphBuilder("test"); | |||
| auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||
| auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||
| auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||
| builder.AddDataEdge(var, 0, assign, 0); | |||
| builder.AddDataEdge(var,0,allreduce,0); | |||
| builder.AddControlEdge(assign, allreduce); | |||
| return builder.GetGraph(); | |||
| } | |||
| } // namespace | |||
| // const -> allreduce | |||
| // const -> Identity -> allreduce | |||
| TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) { | |||
| ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||
| auto src_node = graph->FindNode("var"); | |||
| auto dst_node = graph->FindNode("allreduce"); | |||
| // test InsertIdentityBeforeHccl | |||
| HcclMemcpyPass hccl_memcpy_pass; | |||
| hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), | |||
| dst_node->GetInDataAnchor(0)); | |||
| // check | |||
| dst_node = graph->FindNode("allreduce"); | |||
| auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||
| EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||
| } | |||
| } // namespace ge | |||