diff --git a/ge/graph/passes/cond_remove_pass.cc b/ge/graph/passes/cond_remove_pass.cc index bf2e1170..ce5ff7c0 100644 --- a/ge/graph/passes/cond_remove_pass.cc +++ b/ge/graph/passes/cond_remove_pass.cc @@ -234,7 +234,7 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize(); // Create subgraph opdesc & node auto partitioncall_opdesc = - CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size); + CreateSubgraphOpDesc(node, save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size); auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc); // Link node's peerout anchors to new node's inanchors for (const auto &input_anchor : node->GetAllInAnchors()) { @@ -289,7 +289,8 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c /// @param [in] output_num /// @return OpDescPtr /// -OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num) { +OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num, + size_t output_num) { OpDescBuilder op_desc_builder(name, PARTITIONEDCALL); op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num); @@ -299,6 +300,16 @@ OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t i size_t index = op_desc->GetSubgraphInstanceNames().size(); op_desc->AddSubgraphName("f"); op_desc->SetSubgraphInstanceName(static_cast(index), name); + + auto node_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(node_desc, return nullptr); + for (size_t i = 0; i < input_num; ++i) { + (void)op_desc->UpdateInputDesc(i, node_desc->GetInputDesc(i + 1)); + } + for (size_t i = 0; i < output_num; ++i) { + (void)op_desc->UpdateOutputDesc(i, node_desc->GetOutputDesc(i)); + } + return op_desc; } diff --git a/ge/graph/passes/cond_remove_pass.h b/ge/graph/passes/cond_remove_pass.h index 72ca64b8..e466d684 100644 --- a/ge/graph/passes/cond_remove_pass.h +++ b/ge/graph/passes/cond_remove_pass.h @@ -70,7 +70,7 @@ class CondRemovePass : public BaseNodePass { /// Status ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch); - OpDescPtr CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num); + OpDescPtr CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num, size_t output_num); int32_t GetCondIndex(const ConstGeTensorPtr &tensor); };