| @@ -32,6 +32,8 @@ const int kCaseReadOnly = 0; | |||||
| const int kCaseScopeWriteable = 2; | const int kCaseScopeWriteable = 2; | ||||
| const int kCaseWriteable = 3; | const int kCaseWriteable = 3; | ||||
| const int kCaseInvalidRWType = 5; | const int kCaseInvalidRWType = 5; | ||||
| // attr _input_mutable = true means node will modify its input in runtime | |||||
| const char *const kModifyInput = "_input_mutable"; | |||||
| // rw type of input. | // rw type of input. | ||||
| enum class InputRWType { | enum class InputRWType { | ||||
| @@ -274,8 +276,6 @@ InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { | |||||
| // single node without sub graph | // single node without sub graph | ||||
| return GetSingleNodeInputRWTypeByIndex(node, index); | return GetSingleNodeInputRWTypeByIndex(node, index); | ||||
| } else { | } else { | ||||
| // node with sub graph | |||||
| std::set<int> node_rw_type_set; | |||||
| auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); | ||||
| // get all input data node in subgraph | // get all input data node in subgraph | ||||
| std::set<int> anchor_rw_type_set; | std::set<int> anchor_rw_type_set; | ||||
| @@ -633,28 +633,30 @@ Status InsertIdentityAsNeeded(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | Status HandleAllreduceDuplicateInput(ComputeGraphPtr &compute_graph) { | ||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| if (node->GetType() == HCOMALLREDUCE) { | |||||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||||
| if (pre_out_anchor_set.find(pre_out_anchor) == pre_out_anchor_set.end()) { | |||||
| pre_out_anchor_set.emplace(pre_out_anchor); | |||||
| continue; | |||||
| } | |||||
| // need insert identity | |||||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| bool mutable_input_flag = false; | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, mutable_input_flag); | |||||
| if (!mutable_input_flag) { | |||||
| continue; | |||||
| } | |||||
| std::set<OutDataAnchorPtr> pre_out_anchor_set; | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(pre_out_anchor); | |||||
| if (pre_out_anchor_set.insert(pre_out_anchor).second) { | |||||
| continue; | |||||
| } | |||||
| // need insert identity | |||||
| auto pre_node = pre_out_anchor->GetOwnerNode(); | |||||
| auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(identity_node); | |||||
| auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); | |||||
| GE_CHK_STATUS_RET(ret, "Fail to insert identity."); | |||||
| GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), | |||||
| pre_node->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -30,7 +30,8 @@ const int32_t kAnchorSize = 1; | |||||
| const int kAnchorNum = 0; | const int kAnchorNum = 0; | ||||
| const int32_t kAnchorAssignRefIndex = 0; | const int32_t kAnchorAssignRefIndex = 0; | ||||
| const int32_t kAnchorAssignValueIndex = 1; | const int32_t kAnchorAssignValueIndex = 1; | ||||
| const char *const kInputMutable = "_input_mutable"; | |||||
| // attr _input_mutable = true means hccl node will modify its input in runtime | |||||
| const char *const kModifyInput = "_input_mutable"; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | ||||
| @@ -58,24 +59,13 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { | |||||
| // need to inset memcpy node between. | // need to inset memcpy node between. | ||||
| // also works on situation that input is variable or const. | // also works on situation that input is variable or const. | ||||
| Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const NodePtr node) { | ||||
| auto op_desc = node->GetOpDesc(); | |||||
| bool node_input_mutable = false; | bool node_input_mutable = false; | ||||
| if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| if (!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable)) { | |||||
| REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", kInputMutable, | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::GetBool(node->GetOpDesc(), kModifyInput, node_input_mutable); | |||||
| if (!node_input_mutable) { | if (!node_input_mutable) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GELOGI("input mutable hcom op is:%s.", op_desc->GetName().c_str()); | |||||
| GELOGI("input mutable hcom op is:%s.", node->GetName().c_str()); | |||||
| for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | for (auto &hccl_in_anchor : node->GetAllInDataAnchors()) { | ||||
| if (hccl_in_anchor == nullptr) { | if (hccl_in_anchor == nullptr) { | ||||
| continue; | continue; | ||||
| @@ -716,6 +716,7 @@ set(PASS_TEST_FILES | |||||
| "graph/passes/reshape_recovery_pass_unittest.cc" | "graph/passes/reshape_recovery_pass_unittest.cc" | ||||
| "graph/passes/cast_remove_pass_unittest.cc" | "graph/passes/cast_remove_pass_unittest.cc" | ||||
| "graph/passes/memcpy_addr_async_unittest.cc" | "graph/passes/memcpy_addr_async_unittest.cc" | ||||
| "graph/optimize/mem_rw_conflict_optimize_unittest.cc" | |||||
| ) | ) | ||||
| set(KERNEL_TEST_FILES | set(KERNEL_TEST_FILES | ||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/optimize/graph_optimize.h" | |||||
| #include <gtest/gtest.h> | |||||
| #include "graph/passes/graph_builder_utils.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| namespace ge { | |||||
| class MemRwConflictOptimizeTest : public testing::Test { | |||||
| protected: | |||||
| void SetUp() override {} | |||||
| void TearDown() override {} | |||||
| }; | |||||
| namespace { | |||||
| /// | |||||
| /// HcomAllReduce | |||||
| /// \ / | |||||
| /// add | |||||
| /// / \ | |||||
| /// var | |||||
| /// | |||||
| ComputeGraphPtr build_all_reduce_repeat_input_graph() { | |||||
| auto builder = ut::GraphBuilder("build_all_reduce_repeat_input_graph"); | |||||
| auto var = builder.AddNode("var", VARIABLEV2, 0, 1); | |||||
| auto add = builder.AddNode("add", ADD, 2, 1); | |||||
| auto hcom_all_reduce = builder.AddNode("HcomAllReduce", HCOMALLREDUCE, 2, 1); | |||||
| AttrUtils::SetBool(hcom_all_reduce->GetOpDesc(), "_input_mutable", true); | |||||
| builder.AddDataEdge(var, 1, add, 0); | |||||
| builder.AddDataEdge(var, 1, add, 1); | |||||
| builder.AddDataEdge(add, 0, hcom_all_reduce, 0); | |||||
| builder.AddDataEdge(add, 0, hcom_all_reduce, 1); | |||||
| return builder.GetGraph(); | |||||
| } | |||||
| } // namespace | |||||
| TEST_F(MemRwConflictOptimizeTest, test_handle_allreduce_duplicate_input) { | |||||
| auto graph = build_all_reduce_repeat_input_graph(); | |||||
| EXPECT_NE(graph, nullptr); | |||||
| GraphOptimize optimize; | |||||
| EXPECT_EQ(optimize.HandleMemoryRWConflict(graph), SUCCESS); | |||||
| auto all_reduce = graph->FindNode("HcomAllReduce"); | |||||
| EXPECT_NE(all_reduce, nullptr); | |||||
| EXPECT_EQ(all_reduce->GetInDataNodes().size(), 2); | |||||
| EXPECT_EQ(all_reduce->GetInDataNodes().at(0)->GetType(), ADD); | |||||
| EXPECT_EQ(all_reduce->GetInDataNodes().at(1)->GetType(), IDENTITY); | |||||
| } | |||||
| } // namespace ge | |||||