@@ -185,13 +185,17 @@ Status DataPass::Run(ComputeGraphPtr compute_graph) { | |||||
const auto &parent_graph = compute_graph->GetParentGraph(); | const auto &parent_graph = compute_graph->GetParentGraph(); | ||||
GE_CHECK_NOTNULL(parent_graph); | GE_CHECK_NOTNULL(parent_graph); | ||||
for (const NodePtr &node : compute_graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { | |||||
continue; | |||||
} | |||||
bool flag = false; | |||||
(void)AttrUtils::GetBool(compute_graph, "_no_reset_name", flag); | |||||
if (!flag) { | |||||
for (const NodePtr &node : compute_graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
if ((node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2) || (node->GetType() == NETOUTPUT)) { | |||||
continue; | |||||
} | |||||
node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); | |||||
node->GetOpDesc()->SetName(parent_node->GetName() + "_" + compute_graph->GetName() + "/" + node->GetName()); | |||||
} | |||||
} | } | ||||
return PostParseSubgraph(compute_graph, subgraph_name, parent_node); | return PostParseSubgraph(compute_graph, subgraph_name, parent_node); | ||||
@@ -578,6 +578,7 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
subgraph->SetName("Batch_" + std::to_string(i)); | subgraph->SetName("Batch_" + std::to_string(i)); | ||||
subgraph->SetParentNode(case_node_); | subgraph->SetParentNode(case_node_); | ||||
subgraph->SetParentGraph(graph); | subgraph->SetParentGraph(graph); | ||||
(void)AttrUtils::SetBool(subgraph, "_no_reset_name", true); | |||||
graph->AddSubgraph(subgraph->GetName(), subgraph); | graph->AddSubgraph(subgraph->GetName(), subgraph); | ||||
all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | ||||
@@ -599,55 +600,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
} | } | ||||
} | } | ||||
return PostProcSubgraph(graph); | |||||
} | |||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Assign parent index for branches. | |||||
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | |||||
/// @return 0: SUCCESS / others: FAILED | |||||
/// | |||||
Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) { | |||||
auto func_desc = case_node_->GetOpDesc(); | |||||
domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr; | |||||
auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType()); | |||||
if (post_func == nullptr) { | |||||
GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(), | |||||
case_node_->GetType().c_str()); | |||||
if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS || | |||||
parse_func_v2 == nullptr) { | |||||
GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(), | |||||
case_node_->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
for (const auto &name : func_desc->GetSubgraphInstanceNames()) { | |||||
const auto &subgraph = graph->GetSubgraph(name); | |||||
if (subgraph == nullptr) { | |||||
GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str()); | |||||
return FAILED; | |||||
} | |||||
std::string subgraph_name; | |||||
GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name), | |||||
"Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str()); | |||||
auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph); | |||||
Status ret = FAILED; | |||||
if (post_func != nullptr) { | |||||
ret = post_func(subgraph_name, graph); | |||||
} else if (parse_func_v2 != nullptr) { | |||||
ret = parse_func_v2(subgraph_name.c_str(), graph); | |||||
} | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(), | |||||
case_node_->GetName().c_str(), case_node_->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -131,14 +131,6 @@ class MultiBatchClonePass : public GraphPass { | |||||
/// | /// | ||||
Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch); | Status CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch); | ||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Assign parent index for branches. | |||||
/// @param [in] const ComputeGraphPtr &graph: Root/Case graph. | |||||
/// @return 0: SUCCESS / others: FAILED | |||||
/// | |||||
Status PostProcSubgraph(const ComputeGraphPtr &graph); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Remove subgraph supend output anchor. | /// @brief Remove subgraph supend output anchor. | ||||
@@ -29,6 +29,7 @@ | |||||
#include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "graph/passes/data_pass.h" | |||||
#include "graph/passes/multi_batch_clone_pass.h" | #include "graph/passes/multi_batch_clone_pass.h" | ||||
#include "graph/passes/prune_pass.h" | #include "graph/passes/prune_pass.h" | ||||
#include "graph/preprocess/multi_batch_options.h" | #include "graph/preprocess/multi_batch_options.h" | ||||
@@ -1697,6 +1698,7 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { | |||||
if (multi_batch_with_switchn == nullptr) { | if (multi_batch_with_switchn == nullptr) { | ||||
PassManager pass_manager; | PassManager pass_manager; | ||||
GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); | GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass)); | ||||
GE_CHK_STATUS_RET(pass_manager.AddPass("DataPass", new (std::nothrow) DataPass)); | |||||
return pass_manager.Run(graph); | return pass_manager.Run(graph); | ||||
} | } | ||||
} | } | ||||