modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.cc modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cc modified: ge/graph/optimize/mem_rw_conflict_optimize.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.cc modified: ge/graph/passes/hccl_continuous_memcpy_pass.h modified: ge/graph/passes/hccl_memcpy_pass.cc modified: ge/graph/passes/hccl_memcpy_pass.h modified: tests/ut/ge/CMakeLists.txt new file: tests/ut/ge/graph/optimize/mem_rw_conflict_optimize_unittest.cc new file: tests/ut/ge/graph/passes/hccl_continuous_pass_unittest.cc new file: tests/ut/ge/graph/passes/hccl_memcpy_pass_unittest.cctags/v1.3.0
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2020-2021 Huawei Technologies Co., Ltd | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| * You may obtain a copy of the License at | * You may obtain a copy of the License at | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "graph/optimize/graph_optimize.h" | #include "graph/optimize/graph_optimize.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/op_desc_utils.h" | |||||
| namespace { | namespace { | ||||
| using namespace ge; | using namespace ge; | ||||
| @@ -32,12 +33,14 @@ const int kCaseReadOnly = 0; | |||||
| const int kCaseScopeWriteable = 2; | const int kCaseScopeWriteable = 2; | ||||
| const int kCaseWriteable = 3; | const int kCaseWriteable = 3; | ||||
| const int kCaseInvalidRWType = 5; | const int kCaseInvalidRWType = 5; | ||||
| // attr _input_mutable = true means node will modify its input in runtime | |||||
| const char *const kModifyInput = "_input_mutable"; | |||||
| // rw type of input. | // rw type of input. | ||||
| enum class InputRWType { | enum class InputRWType { | ||||
| kReadOnly, // Normal op input only read | kReadOnly, // Normal op input only read | ||||
| kWriteable, // Op like Assign/ApplyMomentum | kWriteable, // Op like Assign/ApplyMomentum | ||||
| kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput | |||||
| kScopeWriteable, // Op like hcom_allreduce/while, it will modify input ,but not expect take effect on pre ouput | |||||
| kInvalidRWType | kInvalidRWType | ||||
| }; | }; | ||||
| // rw type of output | // rw type of output | ||||
| @@ -154,7 +157,7 @@ bool IsSubgraphOutputNode(const NodePtr &node) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||||
| NodePtr AddIdentityToGraph(const Node &src_node, int out_anchor_idx) { | |||||
| if (src_node.GetOpDesc() == nullptr) { | if (src_node.GetOpDesc() == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -162,30 +165,19 @@ NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { | |||||
| auto next_num = identity_num.fetch_add(1); | auto next_num = identity_num.fetch_add(1); | ||||
| // 1. create new identity op desc | // 1. create new identity op desc | ||||
| string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); | ||||
| auto identity_opdesc = MakeShared<OpDesc>(identity_name, IDENTITY); | |||||
| if (identity_opdesc == nullptr) { | |||||
| GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| OpDescBuilder op_desc_builder(identity_name, IDENTITY); | |||||
| auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); | ||||
| // 2. add input_desc & output_desc for new identity | |||||
| Status ret = identity_opdesc->AddInputDesc("x", data_desc); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| ret = identity_opdesc->AddOutputDesc("y", data_desc); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| auto identity_op_desc = op_desc_builder.AddInput("x", data_desc) | |||||
| .AddOutput("y", data_desc) | |||||
| .Build(); | |||||
| GELOGI("Insert new Identity node %s.", identity_name.c_str()); | GELOGI("Insert new Identity node %s.", identity_name.c_str()); | ||||
| auto graph = src_node.GetOwnerComputeGraph(); | auto graph = src_node.GetOwnerComputeGraph(); | ||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return graph->AddNode(identity_opdesc); | |||||
| return graph->AddNode(identity_op_desc); | |||||
| } | } | ||||
| OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { | ||||
| @@ -274,8 +266,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||||
| // single node without sub graph | // single node without sub graph | ||||
| return GetSingleNodeInputRWTypeByIndex(node, index); | return GetSingleNodeInputRWTypeByIndex(node, index); | ||||
| } else { | } else { | ||||
| // node with sub graph | |||||
| std::set<int> node_rw_type_set; | |||||
| auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | ||||
| // get all input data node in subgraph | // get all input data node in subgraph | ||||
| std::set<int> anchor_rw_type_set; | std::set<int> anchor_rw_type_set; | ||||
| @@ -345,12 +335,24 @@ Status MarkRWTypeForSubgraph(const ComputeGraphPtr &sub_graph) { | |||||
| auto parent_node = sub_graph->GetParentNode(); | auto parent_node = sub_graph->GetParentNode(); | ||||
| if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | if (pre_output_rw_type == OutputRWType::kWriteable && parent_node->GetType() != PARTITIONEDCALL) { | ||||
| // insert identity | // insert identity | ||||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
| auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | GE_CHECK_NOTNULL(identity_node); | ||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Fail to insert identity"); | |||||
| return ret; | |||||
| if (GraphUtils::InsertNodeAfter(pre_out_anchor, {in_data_anchor}, identity_node) != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||||
| identity_node->GetName().c_str(), | |||||
| identity_node->GetType().c_str(), | |||||
| pre_node->GetName().c_str(), | |||||
| pre_node->GetType().c_str(), | |||||
| node->GetName().c_str(), | |||||
| node->GetType().c_str()); | |||||
| GELOGE(FAILED, "Insert Identity node %s(%s) between %s(%s) -> %s(%s) failed.", | |||||
| identity_node->GetName().c_str(), | |||||
| identity_node->GetType().c_str(), | |||||
| pre_node->GetName().c_str(), | |||||
| pre_node->GetType().c_str(), | |||||
| node->GetName().c_str(), | |||||
| node->GetType().c_str()); | |||||
| return FAILED; | |||||
| } | } | ||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | ||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | pre_node->GetName().c_str(), node->GetName().c_str()); | ||||
| @@ -505,34 +507,24 @@ Status SplitIdentityAlongAnchor(const OutDataAnchorPtr &out_data_anchor, const I | |||||
| auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | ||||
| GE_CHECK_NOTNULL(peer_in_data_node); | GE_CHECK_NOTNULL(peer_in_data_node); | ||||
| auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); | ||||
| auto ret = out_data_anchor->Unlink(peer_in_data_anchor); | |||||
| auto old_identity = out_data_anchor->GetOwnerNode(); | auto old_identity = out_data_anchor->GetOwnerNode(); | ||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to unlink from %s %dth out to %s.", old_identity->GetName().c_str(), out_data_anchor->GetIdx(), | |||||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { | ||||
| auto new_identity = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); | |||||
| auto new_identity = AddIdentityToGraph(*pre_node, pre_out_data_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(new_identity); | GE_CHECK_NOTNULL(new_identity); | ||||
| if (GraphUtils::AddEdge(pre_out_data_anchor, new_identity->GetInDataAnchor(kIdentityAnchorIndex)) != SUCCESS | |||||
| || GraphUtils::AddEdge(new_identity->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", | |||||
| pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| // 2. copy in-control-edge from dst to Identity | |||||
| if (GraphUtils::CopyInCtrlEdges(peer_in_data_node, new_identity) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to copy in_control edges from node %s to %s", peer_in_data_node->GetName().c_str(), | |||||
| new_identity->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, new_identity, kIdentityAnchorIndex, | |||||
| kIdentityAnchorIndex); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to insert Identity %s before %s %dth input.", | |||||
| new_identity->GetName().c_str(), | |||||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| peer_in_data_anchor->GetIdx()); | |||||
| return ret; | |||||
| } | } | ||||
| GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), | ||||
| InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), | ||||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); | ||||
| } else { | } else { | ||||
| (void) out_data_anchor->Unlink(peer_in_data_anchor); | |||||
| // copy control edge to pre and peer node | // copy control edge to pre and peer node | ||||
| if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | if (GraphUtils::CopyInCtrlEdges(old_identity, peer_in_data_node) != SUCCESS | ||||
| || GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | || GraphUtils::CopyOutCtrlEdges(old_identity, pre_node) != SUCCESS) { | ||||
| @@ -613,16 +605,14 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
| GELOGD("No need insert Identity."); | GELOGD("No need insert Identity."); | ||||
| continue; | continue; | ||||
| case INSERT_IDENTITY: | case INSERT_IDENTITY: | ||||
| auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); | |||||
| if (identity_node == nullptr) { | |||||
| GELOGE(FAILED, "Create identity node failed."); | |||||
| return FAILED; | |||||
| } | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), | |||||
| peer_in_node->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| auto identity_node = AddIdentityToGraph(*node, out_data_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = GraphUtils::InsertNodeBefore(peer_in_data_anchor, identity_node, kIdentityAnchorIndex, | |||||
| kIdentityAnchorIndex); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Fail to insert %s before %s %dth input.", identity_node->GetName().c_str(), | |||||
| peer_in_data_anchor->GetOwnerNode()->GetName().c_str(), peer_in_data_anchor->GetIdx()); | |||||
| return ret; | |||||
| } | } | ||||
| GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), | ||||
| peer_in_node->GetName().c_str()); | peer_in_node->GetName().c_str()); | ||||
| @@ -633,28 +623,35 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | ||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| if (node->GetType() == HCOMALLREDUCE) { | |||||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||||
| continue; | |||||
| } | |||||
| // need insert identity | |||||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| bool mutable_input_flag = false; | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||||
| if (!mutable_input_flag) { | |||||
| continue; | |||||
| } | |||||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||||
| if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||||
| continue; | |||||
| } | |||||
| // need insert identity | |||||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
| auto identity_node = AddIdentityToGraph(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = | |||||
| GraphUtils::InsertNodeBefore(in_data_anchor, identity_node, kIdentityAnchorIndex, kIdentityAnchorIndex); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to insert node %s before %s %dth input.", identity_node->GetName().c_str(), | |||||
| node->GetName().c_str(), in_data_anchor->GetIdx()); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -24,11 +24,12 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/op_desc_utils.h" | |||||
| namespace { | namespace { | ||||
| const int kAnchorNum = 0; | |||||
| const int32_t kAnchorAssignRefIndex = 0; | const int32_t kAnchorAssignRefIndex = 0; | ||||
| const int32_t kAnchorAssignValueIndex = 1; | const int32_t kAnchorAssignValueIndex = 1; | ||||
| const int32_t kAnchorIdentityIndex = 0; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | Status HcclContinuousMemcpyPass::Run(ge::ComputeGraphPtr graph) { | ||||
| @@ -161,41 +162,23 @@ NodePtr HcclContinuousMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &grap | |||||
| std::string node_name = pre_node->GetName() + "_" + IDENTITY; | std::string node_name = pre_node->GetName() + "_" + IDENTITY; | ||||
| node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||||
| if (op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||||
| return nullptr; | |||||
| } | |||||
| GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||||
| graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||||
| return nullptr; | |||||
| } | |||||
| ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||||
| OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||||
| auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||||
| auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||||
| if (identity_op_desc == nullptr) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // because history reason ,this pass can not do work after constant fold so mark it | // because history reason ,this pass can not do work after constant fold so mark it | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
| (void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
| NodePtr memcpy_node = graph->AddNode(op_desc); | |||||
| if (memcpy_node == nullptr) { | |||||
| NodePtr identity_node = graph->AddNode(identity_op_desc); | |||||
| if (identity_node == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return memcpy_node; | |||||
| return identity_node; | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -256,50 +239,24 @@ Status HcclContinuousMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &gra | |||||
| Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | Status HcclContinuousMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, | ||||
| const OutDataAnchorPtr &src_out_anchor, | const OutDataAnchorPtr &src_out_anchor, | ||||
| const InDataAnchorPtr &hccl_in_anchor) { | const InDataAnchorPtr &hccl_in_anchor) { | ||||
| GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
| NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||||
| GE_CHECK_NOTNULL(memcpy_node); | |||||
| NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||||
| if (ret1 != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", | |||||
| "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
| hccl_in_anchor->GetIdx()); | |||||
| GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||||
| GE_CHECK_NOTNULL(out_data_anchor_0); | |||||
| ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||||
| if (ret1 != SUCCESS) { | |||||
| auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", | REPORT_CALL_ERROR("E19999", | ||||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
| out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||||
| out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||||
| "Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | ||||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | ||||
| hccl_in_anchor->GetIdx()); | hccl_in_anchor->GetIdx()); | ||||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", | |||||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
| memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||||
| kAnchorNum); | |||||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| memcpy_node->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
| hccl_in_anchor->GetIdx()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -24,13 +24,15 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/op_desc_utils.h" | |||||
| namespace { | namespace { | ||||
| const int32_t kAnchorSize = 1; | const int32_t kAnchorSize = 1; | ||||
| const int kAnchorNum = 0; | |||||
| const int32_t kAnchorAssignRefIndex = 0; | const int32_t kAnchorAssignRefIndex = 0; | ||||
| const int32_t kAnchorAssignValueIndex = 1; | const int32_t kAnchorAssignValueIndex = 1; | ||||
| const char *const kInputMutable = "_input_mutable"; | |||||
| const int32_t kAnchorIdentityIndex = 0; | |||||
| // attr _input_mutable = true means hccl node will modify its input in runtime | |||||
| const char *const kModifyInput = "_input_mutable"; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | ||||
| @@ -58,24 +60,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||||
| // need to inset memcpy node between. | // need to inset memcpy node between. | ||||
| // also works on situation that input is variable or const. | // also works on situation that input is variable or const. | ||||
| Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | ||||
| auto op_desc = node->GetOpDesc(); | |||||
| bool node_input_mutable = false; | bool node_input_mutable = false; | ||||
| if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||||
| REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||||
| if (!node_input_mutable) { | if (!node_input_mutable) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||||
| GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||||
| for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | ||||
| if (hccl_in_anchor == nullptr) { | if (hccl_in_anchor == nullptr) { | ||||
| continue; | continue; | ||||
| @@ -127,41 +118,23 @@ NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const O | |||||
| std::string node_name = pre_node->GetName() + "_" + IDENTITY; | std::string node_name = pre_node->GetName() + "_" + IDENTITY; | ||||
| node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||||
| if (op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "New OpDesc failed"); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||||
| return nullptr; | |||||
| } | |||||
| GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||||
| graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed, name:x", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||||
| return nullptr; | |||||
| } | |||||
| ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed, name:y", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||||
| OpDescBuilder op_desc_builder(node_name, IDENTITY); | |||||
| auto data_desc = pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx()); | |||||
| auto identity_op_desc = op_desc_builder.AddInput("x", data_desc).AddOutput("y", data_desc).Build(); | |||||
| if (identity_op_desc == nullptr) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // because history reason ,this pass can not do work after constant fold so mark it | // because history reason ,this pass can not do work after constant fold so mark it | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
| (void)AttrUtils::SetBool(identity_op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | |||||
| NodePtr memcpy_node = graph->AddNode(op_desc); | |||||
| if (memcpy_node == nullptr) { | |||||
| NodePtr identity_node = graph->AddNode(identity_op_desc); | |||||
| if (identity_node == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), graph->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return memcpy_node; | |||||
| return identity_node; | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -220,49 +193,24 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const | |||||
| /// | /// | ||||
| Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | Status HcclMemcpyPass::InsertIdentityBeforeHccl(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, | ||||
| const InDataAnchorPtr &hccl_in_anchor) { | const InDataAnchorPtr &hccl_in_anchor) { | ||||
| GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| GELOGI("Between op %s and op %s need insert identity op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
| NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||||
| GE_CHECK_NOTNULL(memcpy_node); | |||||
| NodePtr identity_node = CreateIdentityNode(graph, src_out_anchor); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | |||||
| if (ret1 != SUCCESS) { | |||||
| auto ret = GraphUtils::InsertNodeBefore(hccl_in_anchor, identity_node, kAnchorIdentityIndex, kAnchorIdentityIndex); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", | REPORT_CALL_ERROR("E19999", | ||||
| "Op:%s(%s) out index:%d unlink from op:%s(%s) in index:%d failed", | |||||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), hccl_in_anchor->GetIdx()); | |||||
| GELOGE(INTERNAL_ERROR, "The op %s Unlink anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto out_data_anchor_0 = memcpy_node->GetOutDataAnchor(kAnchorNum); | |||||
| GE_CHECK_NOTNULL(out_data_anchor_0); | |||||
| ret1 = out_data_anchor_0->LinkTo(hccl_in_anchor); | |||||
| if (ret1 != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", | |||||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%d failed", | |||||
| out_data_anchor_0->GetOwnerNode()->GetName().c_str(), | |||||
| out_data_anchor_0->GetOwnerNode()->GetType().c_str(), out_data_anchor_0->GetIdx(), | |||||
| "Op:Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | ||||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | ||||
| hccl_in_anchor->GetIdx()); | hccl_in_anchor->GetIdx()); | ||||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", memcpy_node->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| Status ret = src_out_anchor->LinkTo(memcpy_node->GetInDataAnchor(kAnchorNum)); | |||||
| if (ret != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", | |||||
| "Op:%s(%s) out index:%d link to op:%s(%s) in index:%u failed", | |||||
| src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| src_out_anchor->GetOwnerNode()->GetType().c_str(), src_out_anchor->GetIdx(), | |||||
| memcpy_node->GetName().c_str(), memcpy_node->GetType().c_str(), | |||||
| kAnchorNum); | |||||
| GELOGE(INTERNAL_ERROR, "The op %s link anchor %s fail.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| memcpy_node->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Fail to insert %s(%s) before %s(%s) on index:%d input anchor.", | |||||
| identity_node->GetName().c_str(), identity_node->GetType().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| hccl_in_anchor->GetOwnerNode()->GetType().c_str(), | |||||
| hccl_in_anchor->GetIdx()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -1,5 +1,5 @@ | |||||
| /** | /** | ||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | * | ||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | * Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| * you may not use this file except in compliance with the License. | * you may not use this file except in compliance with the License. | ||||
| @@ -372,6 +372,7 @@ set(COMMON_FORMAT_SRC_FILES | |||||
| set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | ||||
| "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/graph_optimize.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | "${GE_CODE_DIR}/ge/graph/optimize/summary_optimize.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/optimize/mem_rw_conflict_optimize.cc" | |||||
| ) | ) | ||||
| @@ -715,7 +716,10 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | "graph/passes/mark_node_unknown_shape_pass_unittest.cc" | ||||
| "graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
| "graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||||
| "graph/passes/memcpy_addr_async_unittest.cc" | |||||
| "graph/passes/hccl_continuous_pass_unittest.cc" | |||||
| "graph/passes/hccl_memcpy_pass_unittest.cc" | |||||
| ) | ) | ||||
| set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
| @@ -798,6 +802,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "graph/manager/run_graph_unittest.cc" | "graph/manager/run_graph_unittest.cc" | ||||
| "graph/partition/dynamic_shape_partition_unittest.cc" | "graph/partition/dynamic_shape_partition_unittest.cc" | ||||
| "graph/manager/graph_manager_unittest.cc" | "graph/manager/graph_manager_unittest.cc" | ||||
| "graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||||
| "session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
| "session/ge_api_unittest.cc" | "session/ge_api_unittest.cc" | ||||
| "session/inner_session_unittest.cc" | "session/inner_session_unittest.cc" | ||||
| @@ -0,0 +1,150 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstdint> | |||||
| #include <string> | |||||
| #include <gtest/gtest.h> | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/optimize/graph_optimize.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| #include "../passes/graph_builder_utils.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| namespace ge { | |||||
| class UTest_Graph_Mem_RW_Conflict_Optimize : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| namespace { | |||||
| /* | |||||
| * Data -cast - netoutput | |||||
| */ | |||||
| ComputeGraphPtr BuildGraph_Readonly_Subgraph(const string subraph_name){ | |||||
| auto sub_builder = ut::GraphBuilder(subraph_name); | |||||
| auto data1 = sub_builder.AddNode("data1", DATA, 0,1); | |||||
| auto cast = sub_builder.AddNode("cast", CAST, 1,1); | |||||
| auto netoutput = sub_builder.AddNode("netoutput",NETOUTPUT, 1,1); | |||||
| AttrUtils::SetInt(data1->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX, 1); | |||||
| AttrUtils::SetInt(netoutput->GetOpDesc(),ATTR_NAME_PARENT_NODE_INDEX,0); | |||||
| sub_builder.AddDataEdge(data1,0,cast,0); | |||||
| sub_builder.AddDataEdge(cast,0,netoutput,0); | |||||
| return sub_builder.GetGraph(); | |||||
| } | |||||
| /* | |||||
| * const - allreduce | |||||
| * \ if | |||||
| * insert identity | |||||
| */ | |||||
| ComputeGraphPtr BuildGraph_Readonly_ScopeWrite() { | |||||
| auto builder = ut::GraphBuilder("test"); | |||||
| auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||||
| auto ctrl_const = builder.AddNode("ctrl_const", CONSTANT, 0, 1); | |||||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
| auto if_node = builder.AddNode("if", IF, 1,0); | |||||
| builder.AddDataEdge(const1, 0, allreduce, 0); | |||||
| builder.AddDataEdge(const1, 0, if_node, 0); | |||||
| builder.AddControlEdge(ctrl_const, allreduce); | |||||
| auto root_graph = builder.GetGraph(); | |||||
| string subgraph_name = "then_branch"; | |||||
| ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||||
| then_branch_graph->SetParentNode(if_node); | |||||
| then_branch_graph->SetParentGraph(root_graph); | |||||
| if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||||
| if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||||
| root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||||
| return root_graph; | |||||
| } | |||||
| /* const1---allreduce const1--identity - allreduce | |||||
| * / / | |||||
| * var-identity--cast1 ==> var-----cast1 | |||||
| * \ \ | |||||
| * if if | |||||
| */ | |||||
| ComputeGraphPtr BuildGraph_Identiyt_Split(){ | |||||
| auto builder = ut::GraphBuilder("g1"); | |||||
| auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
| auto identity = builder.AddNode("identity", IDENTITY, 1, 1); | |||||
| auto const1 = builder.AddNode("const1", CONSTANT, 0, 1); | |||||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
| auto cast1 = builder.AddNode("cast1", CAST, 1, 1); | |||||
| auto if_node = builder.AddNode("if", IF, 1,0); | |||||
| builder.AddDataEdge(var, 0 , identity, 0); | |||||
| builder.AddDataEdge(identity, 0 , allreduce, 0); | |||||
| builder.AddDataEdge(identity, 0 , cast1, 0); | |||||
| builder.AddDataEdge(identity, 0 , if_node, 0); | |||||
| builder.AddControlEdge(const1, allreduce); | |||||
| auto root_graph = builder.GetGraph(); | |||||
| string subgraph_name = "then_branch"; | |||||
| ComputeGraphPtr then_branch_graph = BuildGraph_Readonly_Subgraph(subgraph_name); | |||||
| then_branch_graph->SetParentNode(if_node); | |||||
| then_branch_graph->SetParentGraph(root_graph); | |||||
| if_node->GetOpDesc()->AddSubgraphName(subgraph_name); | |||||
| if_node->GetOpDesc()->SetSubgraphInstanceName(0,subgraph_name); | |||||
| root_graph->AddSubgraph(subgraph_name, then_branch_graph); | |||||
| return root_graph; | |||||
| } | |||||
| /* | |||||
| * mul == allreduce | |||||
| * need insert identity | |||||
| */ | |||||
| ComputeGraphPtr BuildGraph_mul_1To2_ScopeWrite() { | |||||
| auto builder = ut::GraphBuilder("test"); | |||||
| auto mul = builder.AddNode("mul", MUL, 2,1); | |||||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 2,0); | |||||
| AttrUtils::SetBool(allreduce->GetOpDesc(), "_input_mutable", true); | |||||
| builder.AddDataEdge(mul,0,allreduce,0); | |||||
| builder.AddDataEdge(mul,0,allreduce,1); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| } // namespace | |||||
| // const -> allreduce | |||||
| // const -> Identity -> allreduce | |||||
| TEST(UtestGraphPassesHcclMemcpyPass, testReadonlyScopeWriteConflict) { | |||||
| ComputeGraphPtr graph = BuildGraph_Readonly_ScopeWrite(); | |||||
| GraphOptimize graph_optimizer; | |||||
| auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| auto allreduce = graph->FindNode("allreduce"); | |||||
| EXPECT_EQ(allreduce->GetInDataNodes().at(0)->GetType(), IDENTITY); | |||||
| } | |||||
| TEST(UtestGraphPassesHcclMemcpyPass, testIdentiytSplit) { | |||||
| ComputeGraphPtr graph = BuildGraph_Identiyt_Split(); | |||||
| GraphOptimize graph_optimizer; | |||||
| auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| auto allreduce = graph->FindNode("allreduce"); | |||||
| auto allreduce_in_node = allreduce->GetInDataNodes().at(0); | |||||
| EXPECT_EQ(allreduce_in_node->GetType(), IDENTITY); | |||||
| EXPECT_EQ(allreduce_in_node->GetInControlNodes().at(0)->GetType(), CONSTANT); | |||||
| } | |||||
| TEST(UtestGraphPassesHcclMemcpyPass, testMul_1To2_ScopeWrite) { | |||||
| ComputeGraphPtr graph = BuildGraph_mul_1To2_ScopeWrite(); | |||||
| EXPECT_EQ(graph->GetDirectNodesSize(), 2); | |||||
| GraphOptimize graph_optimizer; | |||||
| auto ret = graph_optimizer.HandleMemoryRWConflict(graph); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| EXPECT_EQ(graph->GetDirectNodesSize(), 3); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstdint> | |||||
| #include <string> | |||||
| #include <gtest/gtest.h> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/passes/hccl_continuous_memcpy_pass.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| #include "graph_builder_utils.h" | |||||
| namespace ge { | |||||
| class UtestGraphPassesHcclContinuousMemcpyPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| namespace { | |||||
| /* | |||||
| * var var | |||||
| * | \ | \ | |||||
| * | assign | assign | |||||
| * | // =======> | // | |||||
| * allreduce identity | |||||
| * | | | |||||
| * netoutput allreduce | |||||
| * | | |||||
| * netoutput | |||||
| */ | |||||
| ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||||
| auto builder = ut::GraphBuilder("test"); | |||||
| auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
| auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
| auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
| builder.AddDataEdge(var, 0, assign, 0); | |||||
| builder.AddDataEdge(var,0,allreduce,0); | |||||
| builder.AddControlEdge(assign, allreduce); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| } // namespace | |||||
| // const -> allreduce | |||||
| // const -> Identity -> allreduce | |||||
| TEST(UtestGraphPassesHcclContinuousMemcpyPass, testInsertIdentityBeforeHccl) { | |||||
| ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||||
| auto src_node = graph->FindNode("var"); | |||||
| auto dst_node = graph->FindNode("allreduce"); | |||||
| // test InsertIdentityBeforeHccl | |||||
| HcclContinuousMemcpyPass hccl_continuous_memcpy_pass; | |||||
| hccl_continuous_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(0)); | |||||
| // check | |||||
| dst_node = graph->FindNode("allreduce"); | |||||
| auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,80 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <cstdint> | |||||
| #include <string> | |||||
| #include <gtest/gtest.h> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/passes/hccl_memcpy_pass.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| #include "graph_builder_utils.h" | |||||
| namespace ge { | |||||
| class UtestGraphPassesHcclMemcpyPass : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() {} | |||||
| }; | |||||
| namespace { | |||||
| /* | |||||
| * var var | |||||
| * | \ | \ | |||||
| * | assign | assign | |||||
| * | // =======> | // | |||||
| * allreduce identity | |||||
| * | | | |||||
| * netoutput allreduce | |||||
| * | | |||||
| * netoutput | |||||
| */ | |||||
| ComputeGraphPtr BuildGraph_Allreduce_Read_Var_After_Assign(){ | |||||
| auto builder = ut::GraphBuilder("test"); | |||||
| auto var = builder.AddNode("var", VARIABLE, 0, 1); | |||||
| auto assign = builder.AddNode("assign", ASSIGN, 1, 1); | |||||
| auto allreduce = builder.AddNode("allreduce", HCOMALLREDUCE, 1, 1); | |||||
| auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0); | |||||
| builder.AddDataEdge(var, 0, assign, 0); | |||||
| builder.AddDataEdge(var,0,allreduce,0); | |||||
| builder.AddControlEdge(assign, allreduce); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| } // namespace | |||||
| // const -> allreduce | |||||
| // const -> Identity -> allreduce | |||||
| TEST(UtestGraphPassesHcclMemcpyPass, testInsertIdentityBeforeHccl) { | |||||
| ComputeGraphPtr graph = BuildGraph_Allreduce_Read_Var_After_Assign(); | |||||
| auto src_node = graph->FindNode("var"); | |||||
| auto dst_node = graph->FindNode("allreduce"); | |||||
| // test InsertIdentityBeforeHccl | |||||
| HcclMemcpyPass hccl_memcpy_pass; | |||||
| hccl_memcpy_pass.InsertIdentityBeforeHccl(graph, src_node->GetOutDataAnchor(0), | |||||
| dst_node->GetInDataAnchor(0)); | |||||
| // check | |||||
| dst_node = graph->FindNode("allreduce"); | |||||
| auto in_node_before_dst_node = dst_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| EXPECT_EQ(in_node_before_dst_node->GetType(), IDENTITY); | |||||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().size(), 1); | |||||
| EXPECT_EQ(in_node_before_dst_node->GetInControlNodes().at(0)->GetName(), "assign"); | |||||
| } | |||||
| } // namespace ge | |||||