| @@ -185,13 +185,17 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { | |||||
| const auto &parent_graph = compute_graph->GetParentGraph(); | const auto &parent_graph = compute_graph->GetParentGraph(); | ||||
| GE_CHECK_NOTNULL(parent_graph); | GE_CHECK_NOTNULL(parent_graph); | ||||
| for (const NodePtr &node : compute_graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { | |||||
| continue; | |||||
| } | |||||
| bool flag = false; | |||||
| (void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag); | |||||
| if (!flag) { | |||||
| for (const NodePtr &node : compute_graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { | |||||
| continue; | |||||
| } | |||||
| node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); | |||||
| node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); | |||||
| } | |||||
| } | } | ||||
| return PostParseSubgraph(compute_graph, subgraph_name, parent_node); | return PostParseSubgraph(compute_graph, subgraph_name, parent_node); | ||||
| @@ -578,6 +578,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
| subgraph->SetName("Batch_" + std::to_string(i)); | subgraph->SetName("Batch_" + std::to_string(i)); | ||||
| subgraph->SetParentNode(case_node_); | subgraph->SetParentNode(case_node_); | ||||
| subgraph->SetParentGraph(graph); | subgraph->SetParentGraph(graph); | ||||
| (void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); | |||||
| graph->AddSubgraph(subgraph->GetName(), subgraph); | graph->AddSubgraph(subgraph->GetName(), subgraph); | ||||
| all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | ||||
| @@ -599,55 +600,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
| } | } | ||||
| } | } | ||||
| return PostProcSubgraph(graph); | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Assign parent index for branches. | |||||
| /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | |||||
| /// @return 0: SUCCESS / others: FAILED | |||||
| /// | |||||
| Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | |||||
| auto func_desc = case_node_->GetOpDesc(); | |||||
| domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; | |||||
| auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); | |||||
| if (post_func == nullptr) { | |||||
| GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), | |||||
| case_node_->GetType().c_str()); | |||||
| if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS || | |||||
| parse_func_v2 == nullptr) { | |||||
| GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(), | |||||
| case_node_->GetType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | |||||
| const auto &subgraph = graph->GetSubgraph(name); | |||||
| if (subgraph == nullptr) { | |||||
| GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| std::string subgraph_name; | |||||
| GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name), | |||||
| "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); | |||||
| auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); | |||||
| Status ret = FAILED; | |||||
| if (post_func != nullptr) { | |||||
| ret = post_func(subgraph_name, graph); | |||||
| } else if (parse_func_v2 != nullptr) { | |||||
| ret = parse_func_v2(subgraph_name.c_str(), graph); | |||||
| } | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), | |||||
| case_node_->GetName().c_str(), case_node_->GetType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -131,14 +131,6 @@ class MultiBatchClonePass : public GraphPass { | |||||
| /// | /// | ||||
| Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch); | Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch); | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Assign parent index for branches. | |||||
| /// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | |||||
| /// @return 0: SUCCESS / others: FAILED | |||||
| /// | |||||
| Status PostProcSubgraph(const ComputeGraphPtr &graph); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Remove subgraph supend output anchor. | /// @brief Remove subgraph supend output anchor. | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/passes/data_pass.h" | |||||
| #include "graph/passes/multi_batch_clone_pass.h" | #include "graph/passes/multi_batch_clone_pass.h" | ||||
| #include "graph/passes/prune_pass.h" | #include "graph/passes/prune_pass.h" | ||||
| #include "graph/preprocess/multi_batch_options.h" | #include "graph/preprocess/multi_batch_options.h" | ||||
| @@ -1697,6 +1698,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { | |||||
| if (multi_batch_with_switchn == nullptr) { | if (multi_batch_with_switchn == nullptr) { | ||||
| PassManager pass_manager; | PassManager pass_manager; | ||||
| GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); | GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); | ||||
| GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); | |||||
| return pass_manager.Run(graph); | return pass_manager.Run(graph); | ||||
| } | } | ||||
| } | } | ||||