Browse Source

CondRemovePass

tags/v1.2.0
lianghao 3 years ago
parent
commit
b598ea75cd
2 changed files with 14 additions and 3 deletions
  1. +13
    -2
      ge/graph/passes/cond_remove_pass.cc
  2. +1
    -1
      ge/graph/passes/cond_remove_pass.h

+ 13
- 2
ge/graph/passes/cond_remove_pass.cc View File

@@ -234,7 +234,7 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c
const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize();
// Create subgraph opdesc & node
auto partitioncall_opdesc =
CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size);
CreateSubgraphOpDesc(node, save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size);
auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc);
// Link node's peerout anchors to new node's inanchors
for (const auto &input_anchor : node->GetAllInAnchors()) {
@@ -289,7 +289,8 @@ Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, c
/// @param [in] output_num
/// @return OpDescPtr
///
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num) {
OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num,
size_t output_num) {
OpDescBuilder op_desc_builder(name, PARTITIONEDCALL);
op_desc_builder.AddDynamicInput("args", input_num).AddDynamicOutput("output", output_num);

@@ -299,6 +300,16 @@ OpDescPtr CondRemovePass::CreateSubgraphOpDesc(const std::string &name, size_t i
size_t index = op_desc->GetSubgraphInstanceNames().size();
op_desc->AddSubgraphName("f");
op_desc->SetSubgraphInstanceName(static_cast<uint32_t>(index), name);

auto node_desc = node->GetOpDesc();
GE_CHECK_NOTNULL_EXEC(node_desc, return nullptr);
for (size_t i = 0; i < input_num; ++i) {
(void)op_desc->UpdateInputDesc(i, node_desc->GetInputDesc(i + 1));
}
for (size_t i = 0; i < output_num; ++i) {
(void)op_desc->UpdateOutputDesc(i, node_desc->GetOutputDesc(i));
}

return op_desc;
}



+ 1
- 1
ge/graph/passes/cond_remove_pass.h View File

@@ -70,7 +70,7 @@ class CondRemovePass : public BaseNodePass {
///
Status ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch);

OpDescPtr CreateSubgraphOpDesc(const std::string &name, size_t input_num, size_t output_num);
OpDescPtr CreateSubgraphOpDesc(const NodePtr &node, const std::string &name, size_t input_num, size_t output_num);

int32_t GetCondIndex(const ConstGeTensorPtr &tensor);
};


Loading…
Cancel
Save