| @@ -42,7 +42,7 @@ Status MergeInputMemcpyPass::Run(ComputeGraphPtr graph) { | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status MergeInputMemcpyPass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, | Status MergeInputMemcpyPass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, | ||||
| bool multi_batch_flag) { | |||||
| bool multi_batch_flag) { | |||||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
| @@ -74,7 +74,7 @@ Status MergeInputMemcpyPass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, c | |||||
| /// @return ge::NodePtr | /// @return ge::NodePtr | ||||
| /// | /// | ||||
| NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | NodePtr MergeInputMemcpyPass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, | ||||
| const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { | |||||
| const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { | |||||
| OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | ||||
| GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | ||||
| @@ -32,7 +32,7 @@ Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { | |||||
| OpDescPtr merge_op_desc = node->GetOpDesc(); | OpDescPtr merge_op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(merge_op_desc); | GE_CHECK_NOTNULL(merge_op_desc); | ||||
| if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | ||||
| GE_CHK_STATUS_RET(AddActiveNodes(graph, node, true), "Merge add active node failed."); | |||||
| GE_CHK_STATUS_RET(AddActiveNodes(graph, node), "Merge add active node failed."); | |||||
| GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); | GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); | ||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); | GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); | ||||
| @@ -99,18 +99,16 @@ Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, co | |||||
| } | } | ||||
| } | } | ||||
| return AddActiveNodes(graph, stream_merge, false); | |||||
| return AddActiveNodes(graph, stream_merge); | |||||
| } | } | ||||
| /// | /// | ||||
| /// @brief Add StreamActive Op before StreamMerge/Merge | /// @brief Add StreamActive Op before StreamMerge/Merge | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] node | /// @param [in] node | ||||
| /// @param [in] multi_batch_flag | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node, | |||||
| bool multi_batch_flag) { | |||||
| Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | ||||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| @@ -37,10 +37,9 @@ class MergeToStreamMergePass : public GraphPass { | |||||
| /// @brief Add StreamActive Op as StreamMerge in_node | /// @brief Add StreamActive Op as StreamMerge in_node | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] node | /// @param [in] node | ||||
| /// @param [in] multi_batch_flag | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node, bool multi_batch_flag); | |||||
| Status AddActiveNodes(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
| /// | /// | ||||
| /// @brief Create Active Op | /// @brief Create Active Op | ||||