| @@ -93,7 +93,7 @@ Status HcclMemcpyPass::MutableInputProcess(const ComputeGraphPtr &graph, const N | |||||
| int32_t src_out_anchor_size = src_out_anchor->GetPeerInDataAnchors().size(); | int32_t src_out_anchor_size = src_out_anchor->GetPeerInDataAnchors().size(); | ||||
| if (src_out_anchor_size == kAnchorSize) { | if (src_out_anchor_size == kAnchorSize) { | ||||
| // Memcpyasync needs to be inserted between constant (/data) and hcomallreduce to avoid constant being cleared. | |||||
| // Identity needs to be inserted between constant (/data) and hcomallreduce to avoid constant being cleared. | |||||
| if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) { | if (IsDataNode(src_out_anchor->GetOwnerNode()->GetType())) { | ||||
| Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); | Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -193,12 +193,12 @@ bool HcclMemcpyPass::IsDataNode(const std::string& node_type) { | |||||
| } | } | ||||
| /// | /// | ||||
| /// @brief Add MemcpyAsync Node | |||||
| /// @brief Add Identity Node | |||||
| /// @param [in] ge::ComputeGraphPtr graph | /// @param [in] ge::ComputeGraphPtr graph | ||||
| /// @param [in] ge::OutDataAnchorPtr in_node | /// @param [in] ge::OutDataAnchorPtr in_node | ||||
| /// @return ge::NodePtr | /// @return ge::NodePtr | ||||
| /// | /// | ||||
| NodePtr HcclMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { | |||||
| NodePtr HcclMemcpyPass::CreateIdentityNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { | |||||
| GE_IF_BOOL_EXEC(graph == nullptr, return nullptr); | GE_IF_BOOL_EXEC(graph == nullptr, return nullptr); | ||||
| NodePtr pre_node = out_data_anchor->GetOwnerNode(); | NodePtr pre_node = out_data_anchor->GetOwnerNode(); | ||||
| OpDescPtr pre_op_desc = pre_node->GetOpDesc(); | OpDescPtr pre_op_desc = pre_node->GetOpDesc(); | ||||
| @@ -207,24 +207,24 @@ NodePtr HcclMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, cons | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::string node_name = pre_node->GetName() + "_" + MEMCPYASYNC; | |||||
| std::string node_name = pre_node->GetName() + "_" + IDENTITY; | |||||
| node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), MEMCPYASYNC); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name.c_str(), IDENTITY); | |||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: MakeShared op_desc fail."); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: MakeShared op_desc fail."); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| GELOGI("Create MemcpyAsync op:%s.", op_desc->GetName().c_str()); | |||||
| GELOGI("Create Identity op:%s.", op_desc->GetName().c_str()); | |||||
| graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | graphStatus ret = op_desc->AddInputDesc("x", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | ||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add input desc fail."); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add input desc fail."); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | ret = op_desc->AddOutputDesc("y", pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())); | ||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Create MemcpyAsync op: add output desc fail."); | |||||
| GELOGE(INTERNAL_ERROR, "Create Identity op: add output desc fail."); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // because history reason ,this pass can not do work after constant fold so mark it | // because history reason ,this pass can not do work after constant fold so mark it | ||||
| @@ -232,7 +232,7 @@ NodePtr HcclMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, cons | |||||
| NodePtr memcpy_node = graph->AddNode(op_desc); | NodePtr memcpy_node = graph->AddNode(op_desc); | ||||
| if (memcpy_node == nullptr) { | if (memcpy_node == nullptr) { | ||||
| GELOGE(INTERNAL_ERROR, "Insert MemcpyAsync node fail."); | |||||
| GELOGE(INTERNAL_ERROR, "Insert Identity node fail."); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -267,7 +267,7 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const | |||||
| const InDataAnchorPtr &hccl_in_anchor) { | const InDataAnchorPtr &hccl_in_anchor) { | ||||
| GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | GELOGI("Between op %s and op %s need insert memcpy async op.", src_out_anchor->GetOwnerNode()->GetName().c_str(), | ||||
| hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | hccl_in_anchor->GetOwnerNode()->GetName().c_str()); | ||||
| NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, src_out_anchor); | |||||
| NodePtr memcpy_node = CreateIdentityNode(graph, src_out_anchor); | |||||
| GE_CHECK_NOTNULL(memcpy_node); | GE_CHECK_NOTNULL(memcpy_node); | ||||
| Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | Status ret1 = src_out_anchor->Unlink(hccl_in_anchor); | ||||