|
@@ -234,7 +234,7 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c |
|
|
const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize(); |
|
|
const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize(); |
|
|
// Create subgraph opdesc & node |
|
|
// Create subgraph opdesc & node |
|
|
auto partitioncall_opdesc = |
|
|
auto partitioncall_opdesc = |
|
|
CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size); |
|
|
|
|
|
|
|
|
CreateSubgraphOpDesc(node, save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size); |
|
|
auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc); |
|
|
auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc); |
|
|
// Link node's peerout anchors to new node's inanchors |
|
|
// Link node's peerout anchors to new node's inanchors |
|
|
for (const auto &input_anchor : node->GetAllInAnchors()) { |
|
|
for (const auto &input_anchor : node->GetAllInAnchors()) { |
|
@@ -289,7 +289,8 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c |
|
|
/// @param [in] output_num |
|
|
/// @param [in] output_num |
|
|
/// @return OpDescPtr |
|
|
/// @return OpDescPtr |
|
|
/// |
|
|
/// |
|
|
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num) { |
|
|
|
|
|
|
|
|
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num, |
|
|
|
|
|
size_t output_num) { |
|
|
OpDescBuilder op_desc_builder(name, PARTITIONEDCALL); |
|
|
OpDescBuilder op_desc_builder(name, PARTITIONEDCALL); |
|
|
op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num); |
|
|
op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num); |
|
|
|
|
|
|
|
@@ -299,6 +300,16 @@ OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t i |
|
|
size_t index = op_desc->GetSubgraphInstanceNames().size(); |
|
|
size_t index = op_desc->GetSubgraphInstanceNames().size(); |
|
|
op_desc->AddSubgraphName("f"); |
|
|
op_desc->AddSubgraphName("f"); |
|
|
op_desc->SetSubgraphInstanceName(static_cast<uint32_t>(index), name); |
|
|
op_desc->SetSubgraphInstanceName(static_cast<uint32_t>(index), name); |
|
|
|
|
|
|
|
|
|
|
|
auto node_desc = node->GetOpDesc(); |
|
|
|
|
|
GE_CHECK_NOTNULL_EXEC(node_desc, return nullptr); |
|
|
|
|
|
for (size_t i = 0; i < input_num; ++i) { |
|
|
|
|
|
(void)op_desc->UpdateInputDesc(i, node_desc->GetInputDesc(i + 1)); |
|
|
|
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < output_num; ++i) { |
|
|
|
|
|
(void)op_desc->UpdateOutputDesc(i, node_desc->GetOutputDesc(i)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
return op_desc; |
|
|
return op_desc; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|