From 8dac43028b175eb9358c5b3f9f2908a1e6e58925 Mon Sep 17 00:00:00 2001 From: chuxing Date: Wed, 14 Apr 2021 15:29:35 +0800 Subject: [PATCH] fix reuse input --- ge/hybrid/model/hybrid_model.cc | 4 + ge/hybrid/model/hybrid_model.h | 2 + ge/hybrid/model/hybrid_model_builder.cc | 89 ++------------ ge/hybrid/model/hybrid_model_builder.h | 3 +- .../compiledsubgraph/known_node_executor.cc | 114 ++++++++++++++++-- .../compiledsubgraph/known_node_executor.h | 8 ++ tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 2 +- .../ge/hybrid/known_node_executor_unittest.cc | 20 +++ 8 files changed, 156 insertions(+), 86 deletions(-) diff --git a/ge/hybrid/model/hybrid_model.cc b/ge/hybrid/model/hybrid_model.cc index a669c06f..07268f56 100644 --- a/ge/hybrid/model/hybrid_model.cc +++ b/ge/hybrid/model/hybrid_model.cc @@ -120,6 +120,10 @@ const GraphItem *HybridModel::GetRootGraphItem() const { return root_graph_item_.get(); } +const ComputeGraphPtr &HybridModel::GetRootGraph() const { + return root_graph_; +} + const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { GELOGD("To find subgraph item by name = %s", graph_name.c_str()); auto it = subgraph_items_.find(graph_name); diff --git a/ge/hybrid/model/hybrid_model.h b/ge/hybrid/model/hybrid_model.h index 18daed4f..012571b8 100644 --- a/ge/hybrid/model/hybrid_model.h +++ b/ge/hybrid/model/hybrid_model.h @@ -101,6 +101,8 @@ class HybridModel { const GraphItem *GetRootGraphItem() const; + const ComputeGraphPtr &GetRootGraph() const; + const GraphItem *GetSubgraphItem(const std::string &graph_name) const; const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index a047a05b..4a603e14 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -831,13 +831,6 @@ Status HybridModelBuilder::LoadGraph() { "[Invoke][LoadDynamicSubgraph]Failed to load subgraph: [%s]", sub_graph->GetName().c_str()); } else { - GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), - "[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", - parent_node_item->NodeName().c_str()); - GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), - "[Invoke][IdentifySameInputs][%s] Failed to identify same outputs.", - parent_node_item->NodeName().c_str()); - // if parent is function control op. need add a virtual partitioned call if (parent_node_item->IsControlOp()) { GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), @@ -846,7 +839,16 @@ Status HybridModelBuilder::LoadGraph() { } } } - + for (auto &it : hybrid_model_.known_shape_sub_models_) { + auto node_item = MutableNodeItem(it.first); + AscendString graph_name; + GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); + auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); + GE_CHECK_NOTNULL(subgraph); + GE_CHK_STATUS_RET(IdentifyVariableOutputs(*node_item, subgraph), + "[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", + node_item->NodeName().c_str()); + } GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), "[Invoke][ParseDependentByParallelGroup]Failed to establish dependencies for hccl ops," "model_name_:%s.", GetGraphName()); @@ -1478,50 +1480,8 @@ Status HybridModelBuilder::InitRuntimeParams() { return SUCCESS; } -Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { - GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); - auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); - GE_CHECK_NOTNULL(subgraph); - auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); - if (net_output_node == nullptr) { - GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); - return SUCCESS; - } - - auto net_output_desc = net_output_node->GetOpDesc(); - GE_CHECK_NOTNULL(net_output_desc); - - std::map connected_inputs; - for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { - auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (out_data_anchor == nullptr) { - continue; - } - auto src_node = out_data_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(src_node); - auto op_desc = src_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - - std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); - auto it = connected_inputs.find(input_key); - if (it == connected_inputs.end()) { - connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); - } else { - GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), - in_data_anchor->GetIdx(), - it->second, - src_node->GetName().c_str(), - out_data_anchor->GetIdx()); - node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); - } - } - return SUCCESS; -} - -Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { +Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph) { GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); - auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); - GE_CHECK_NOTNULL(subgraph); auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); if (net_output_node == nullptr) { GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str()); @@ -1530,36 +1490,13 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { auto net_output_desc = net_output_node->GetOpDesc(); GE_CHECK_NOTNULL(net_output_desc); - // constant/variable connected to net output + // constants connected to net output for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { auto src_node = GetPeerNode(in_data_anchor); GE_CHECK_NOTNULL(src_node); auto src_op_type = src_node->GetType(); - GELOGD("Node %s, output %d, src node = %s, src node type = %s", - node_item.NodeName().c_str(), - in_data_anchor->GetIdx(), - src_node->GetName().c_str(), - src_op_type.c_str()); - uint32_t parent_index = 0; - if (GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index) != SUCCESS) { - continue; - } - GELOGD("Got parent output index = %u", parent_index); - if (src_op_type == DATA) { - int ref_i = 0; - (void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i); - node_item.reuse_inputs.emplace(static_cast(parent_index), ref_i); - GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i); - } - - if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) { - continue; - } - - GE_CHECK_LE(parent_index, INT32_MAX); - node_item.ref_outputs.emplace(static_cast(parent_index), src_node); if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { - known_subgraph_constant_output_refs_[&node_item].emplace(parent_index, src_node); + known_subgraph_constant_output_refs_[&node_item].emplace(in_data_anchor->GetIdx(), src_node); } } diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 3e467dc8..041e9dbb 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -59,8 +59,7 @@ class HybridModelBuilder { Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadTask(NodeItem &node_item); Status LoadTasks(); - Status IdentifyVariableOutputs(NodeItem &node_item); - Status IdentifySameInputs(NodeItem &node_item); + Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index 48c3ab9e..b88932a3 100755 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -22,6 +22,7 @@ #include "common/ge/ge_util.h" #include "graph/attr_value.h" #include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" #include "graph/load/model_manager/model_utils.h" #include "graph/load/model_manager/model_manager.h" #include "hybrid/executor/hybrid_execution_context.h" @@ -184,13 +185,15 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); GE_CHECK_NOTNULL(node); - const GeModelPtr ge_model = model.GetGeModel(node); - GE_CHECK_NOTNULL(ge_model); - - AscendString graph_name; - GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get graph name"); - auto weight_buffer = model.GetModelWeight(graph_name.GetString()); - + GeModelPtr ge_model; + ComputeGraphPtr compute_graph; + GE_CHK_STATUS_RET(GetModelAndGraph(model, node, ge_model, compute_graph), + "[%s] Failed to get model and graph", + node->GetName().c_str()); + auto node_item = const_cast(model.GetNodeItem(node)); + GE_CHECK_NOTNULL(node_item); + GE_CHK_STATUS_RET_NOLOG(ParseAttrForAllocatingOutputs(*node_item, *compute_graph)); + auto weight_buffer = model.GetModelWeight(compute_graph->GetName()); std::shared_ptr davinci_model = MakeShared(0, nullptr); GE_CHECK_NOTNULL(davinci_model); @@ -223,5 +226,102 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); return SUCCESS; } + +Status KnownNodeExecutor::ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph) { + GELOGD("[%s] Start to parse attributes for outputs", node_item.NodeName().c_str()); + auto net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); + if (net_output_node == nullptr) { + GELOGD("[%s] Subgraph do not got net output", graph.GetName().c_str()); + return SUCCESS; + } + + auto net_output_desc = net_output_node->GetOpDesc(); + GE_CHECK_NOTNULL(net_output_desc); + std::map connected_inputs; + std::map data_indices; + GE_CHK_STATUS_RET(GetDataNodes(graph, data_indices), + "[%s] Failed to get data node indices", + node_item.NodeName().c_str()); + for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { + auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); + if (out_data_anchor == nullptr) { + continue; + } + auto src_node = out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + auto op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto src_op_type = src_node->GetType(); + auto output_index = in_data_anchor->GetIdx(); + GELOGD("Node %s, output %d, src node = %s, src node type = %s", + node_item.NodeName().c_str(), + output_index, + src_node->GetName().c_str(), + src_op_type.c_str()); + // parse reuse outputs + std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); + auto it = connected_inputs.find(input_key); + if (it == connected_inputs.end()) { + connected_inputs.emplace(input_key, output_index); + } else { + GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), + output_index, + it->second, + src_node->GetName().c_str(), + out_data_anchor->GetIdx()); + node_item.reuse_outputs.emplace(output_index, it->second); + } + + if (src_op_type == DATA) { + int data_index = data_indices[src_node]; + node_item.reuse_inputs.emplace(output_index, data_index); + GELOGD("[%s] output[%u] reuses input[%d]", node_item.NodeName().c_str(), output_index, data_index); + } else if (src_op_type == CONSTANTOP || src_op_type == CONSTANT || src_op_type == VARIABLE) { + node_item.ref_outputs.emplace(output_index, src_node); + GELOGD("[%s] output[%d] ref to node [%s]", + node_item.NodeName().c_str(), + output_index, + src_node->GetName().c_str()); + } + } + + GELOGD("[%s] Done parsing attributes for outputs successfully", node_item.NodeName().c_str()); + return SUCCESS; +} + +Status KnownNodeExecutor::GetDataNodes(ComputeGraph &graph, std::map &data_indices) { + std::map ordered_data_nodes; + for (const auto &node : graph.GetDirectNode()) { + GE_CHECK_NOTNULL(node); + if (node->GetType() == DATA) { + int index = -1; + (void) AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, index); + ordered_data_nodes.emplace(index, node); + } + } + + // reindex + int data_index = 0; + for (const auto &it : ordered_data_nodes) { + data_indices.emplace(it.second, data_index++); + } + + return SUCCESS; +} + +Status KnownNodeExecutor::GetModelAndGraph(const HybridModel &model, + const NodePtr &node, + GeModelPtr &ge_model, + ComputeGraphPtr &graph) { + ge_model = model.GetGeModel(node); + GE_CHECK_NOTNULL(ge_model); + const auto &root_graph = model.GetRootGraph(); + GE_CHECK_NOTNULL(root_graph); + AscendString graph_name; + GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get subgraph name"); + graph = root_graph->GetSubgraph(graph_name.GetString()); + GE_CHECK_NOTNULL(graph); + return SUCCESS; +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h index 629cb543..11cda846 100644 --- a/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h +++ b/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h @@ -51,6 +51,14 @@ class KnownNodeExecutor : public NodeExecutor { Status PrepareTask(NodeTask &task, TaskContext &context) const; Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; ~KnownNodeExecutor() {} + + private: + static Status ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph); + static Status GetDataNodes(ComputeGraph &graph, std::map &data_indices); + static Status GetModelAndGraph(const HybridModel &model, + const NodePtr &node, + GeModelPtr &ge_model, + ComputeGraphPtr &graph); }; } // namespace hybrid } // namespace ge diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index b5aac527..fbb2aa40 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -202,7 +202,7 @@ TEST_F(UtestGeHybrid, data_direct_connect) { GeRootModelPtr ge_root_model = make_shared(root_graph); HybridModel hybrid_model(ge_root_model); HybridModelBuilder hybrid_model_builder(hybrid_model); - auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get()); + auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get(), sub_graph); ASSERT_EQ(ret, SUCCESS); } diff --git a/tests/ut/ge/hybrid/known_node_executor_unittest.cc b/tests/ut/ge/hybrid/known_node_executor_unittest.cc index 16bbe3a0..98e985f7 100644 --- a/tests/ut/ge/hybrid/known_node_executor_unittest.cc +++ b/tests/ut/ge/hybrid/known_node_executor_unittest.cc @@ -26,6 +26,7 @@ #undef private #undef protected #include "graph/manager/graph_mem_allocator.h" +#include "../graph/passes/graph_builder_utils.h" using namespace std; using namespace testing; @@ -69,3 +70,22 @@ TEST_F(UnknownNodeExecutorTest, test_init_davinci_model) { model.weight_buffer_map_.emplace("subgraph", TensorBuffer::Create(buffer, sizeof(buffer))); ASSERT_EQ(mock.InitDavinciModel(model, model.GetModelWeight("subgraph")), SUCCESS); } + +TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) { + ut::GraphBuilder builder("test-graph"); + auto data_node = builder.AddNode("Data0", DATA, 1, 1); + auto netoutput_node = builder.AddNode("NodeOutput", NETOUTPUT, 2, 2); + builder.AddDataEdge(data_node, 0, netoutput_node, 0); + auto const_node = builder.AddNode("Const0", CONSTANT, 0, 1); + builder.AddDataEdge(const_node, 0, netoutput_node, 1); + auto graph = builder.GetGraph(); + + ut::GraphBuilder builder2("root-graph"); + auto partitioned_call = builder2.AddNode("Node0", PARTITIONEDCALL, 1, 2); + NodeItem node_item(partitioned_call); + ASSERT_EQ(KnownNodeExecutor::ParseAttrForAllocatingOutputs(node_item, *graph), SUCCESS); + ASSERT_EQ(node_item.ref_outputs.size(), 1); + ASSERT_EQ(node_item.ref_outputs[1], const_node); + ASSERT_EQ(node_item.reuse_inputs.size(), 1); + ASSERT_EQ(node_item.reuse_inputs[0], 0); +} \ No newline at end of file