From: @xchu42 Reviewed-by: @wan_xuelei,@wqtshg,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
| @@ -121,6 +121,10 @@ const GraphItem *HybridModel::GetRootGraphItem() const { | |||
| return root_graph_item_.get(); | |||
| } | |||
| const ComputeGraphPtr &HybridModel::GetRootGraph() const { | |||
| return root_graph_; | |||
| } | |||
| const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { | |||
| GELOGD("To find subgraph item by name = %s", graph_name.c_str()); | |||
| auto it = subgraph_items_.find(graph_name); | |||
| @@ -101,6 +101,8 @@ class HybridModel { | |||
| const GraphItem *GetRootGraphItem() const; | |||
| const ComputeGraphPtr &GetRootGraph() const; | |||
| const GraphItem *GetSubgraphItem(const std::string &graph_name) const; | |||
| const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; | |||
| @@ -842,13 +842,6 @@ Status HybridModelBuilder::LoadGraph() { | |||
| "[Invoke][LoadDynamicSubgraph]Failed to load subgraph: [%s]", | |||
| sub_graph->GetName().c_str()); | |||
| } else { | |||
| GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), | |||
| "[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", | |||
| parent_node_item->NodeName().c_str()); | |||
| GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), | |||
| "[Invoke][IdentifySameInputs][%s] Failed to identify same outputs.", | |||
| parent_node_item->NodeName().c_str()); | |||
| // if parent is function control op. need add a virtual partitioned call | |||
| if (parent_node_item->IsControlFlowV2Op()) { | |||
| GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), | |||
| @@ -857,7 +850,16 @@ Status HybridModelBuilder::LoadGraph() { | |||
| } | |||
| } | |||
| } | |||
| for (auto &it : hybrid_model_.known_shape_sub_models_) { | |||
| auto node_item = MutableNodeItem(it.first); | |||
| AscendString graph_name; | |||
| GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | |||
| auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); | |||
| GE_CHECK_NOTNULL(subgraph); | |||
| GE_CHK_STATUS_RET(IdentifyVariableOutputs(*node_item, subgraph), | |||
| "[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", | |||
| node_item->NodeName().c_str()); | |||
| } | |||
| GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), | |||
| "[Invoke][ParseDependentByParallelGroup]Failed to establish dependencies for hccl ops," | |||
| "model_name_:%s.", GetGraphName()); | |||
| @@ -1493,50 +1495,8 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||
| return SUCCESS; | |||
| } | |||
| Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { | |||
| GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); | |||
| auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||
| GE_CHECK_NOTNULL(subgraph); | |||
| auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | |||
| if (net_output_node == nullptr) { | |||
| GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| auto net_output_desc = net_output_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(net_output_desc); | |||
| std::map<std::string, int> connected_inputs; | |||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||
| auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (out_data_anchor == nullptr) { | |||
| continue; | |||
| } | |||
| auto src_node = out_data_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(src_node); | |||
| auto op_desc = src_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||
| auto it = connected_inputs.find(input_key); | |||
| if (it == connected_inputs.end()) { | |||
| connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); | |||
| } else { | |||
| GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||
| in_data_anchor->GetIdx(), | |||
| it->second, | |||
| src_node->GetName().c_str(), | |||
| out_data_anchor->GetIdx()); | |||
| node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | |||
| Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph) { | |||
| GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | |||
| auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||
| GE_CHECK_NOTNULL(subgraph); | |||
| auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | |||
| if (net_output_node == nullptr) { | |||
| GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str()); | |||
| @@ -1545,36 +1505,13 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | |||
| auto net_output_desc = net_output_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(net_output_desc); | |||
| // constant/variable connected to net output | |||
| // constants connected to net output | |||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||
| auto src_node = GetPeerNode(in_data_anchor); | |||
| GE_CHECK_NOTNULL(src_node); | |||
| auto src_op_type = src_node->GetType(); | |||
| GELOGD("Node %s, output %d, src node = %s, src node type = %s", | |||
| node_item.NodeName().c_str(), | |||
| in_data_anchor->GetIdx(), | |||
| src_node->GetName().c_str(), | |||
| src_op_type.c_str()); | |||
| uint32_t parent_index = 0; | |||
| if (GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index) != SUCCESS) { | |||
| continue; | |||
| } | |||
| GELOGD("Got parent output index = %u", parent_index); | |||
| if (src_op_type == DATA) { | |||
| int ref_i = 0; | |||
| (void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i); | |||
| node_item.reuse_inputs.emplace(static_cast<int>(parent_index), ref_i); | |||
| GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i); | |||
| } | |||
| if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) { | |||
| continue; | |||
| } | |||
| GE_CHECK_LE(parent_index, INT32_MAX); | |||
| node_item.ref_outputs.emplace(static_cast<int>(parent_index), src_node); | |||
| if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { | |||
| known_subgraph_constant_output_refs_[&node_item].emplace(parent_index, src_node); | |||
| known_subgraph_constant_output_refs_[&node_item].emplace(in_data_anchor->GetIdx(), src_node); | |||
| } | |||
| } | |||
| @@ -59,8 +59,7 @@ class HybridModelBuilder { | |||
| Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | |||
| Status LoadTask(NodeItem &node_item); | |||
| Status LoadTasks(); | |||
| Status IdentifyVariableOutputs(NodeItem &node_item); | |||
| Status IdentifySameInputs(NodeItem &node_item); | |||
| Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); | |||
| Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | |||
| Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | |||
| Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | |||
| @@ -22,6 +22,7 @@ | |||
| #include "common/ge/ge_util.h" | |||
| #include "graph/attr_value.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/load/model_manager/model_utils.h" | |||
| #include "graph/load/model_manager/model_manager.h" | |||
| #include "hybrid/executor/hybrid_execution_context.h" | |||
| @@ -184,13 +185,15 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node | |||
| GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); | |||
| GE_CHECK_NOTNULL(node); | |||
| const GeModelPtr ge_model = model.GetGeModel(node); | |||
| GE_CHECK_NOTNULL(ge_model); | |||
| AscendString graph_name; | |||
| GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get graph name"); | |||
| auto weight_buffer = model.GetModelWeight(graph_name.GetString()); | |||
| GeModelPtr ge_model; | |||
| ComputeGraphPtr compute_graph; | |||
| GE_CHK_STATUS_RET(GetModelAndGraph(model, node, ge_model, compute_graph), | |||
| "[%s] Failed to get model and graph", | |||
| node->GetName().c_str()); | |||
| auto node_item = const_cast<NodeItem *>(model.GetNodeItem(node)); | |||
| GE_CHECK_NOTNULL(node_item); | |||
| GE_CHK_STATUS_RET_NOLOG(ParseAttrForAllocatingOutputs(*node_item, *compute_graph)); | |||
| auto weight_buffer = model.GetModelWeight(compute_graph->GetName()); | |||
| std::shared_ptr<DavinciModel> davinci_model = MakeShared<DavinciModel>(0, nullptr); | |||
| GE_CHECK_NOTNULL(davinci_model); | |||
| @@ -223,5 +226,102 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, | |||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); | |||
| return SUCCESS; | |||
| } | |||
| Status KnownNodeExecutor::ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph) { | |||
| GELOGD("[%s] Start to parse attributes for outputs", node_item.NodeName().c_str()); | |||
| auto net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); | |||
| if (net_output_node == nullptr) { | |||
| GELOGD("[%s] Subgraph do not got net output", graph.GetName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| auto net_output_desc = net_output_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(net_output_desc); | |||
| std::map<std::string, int> connected_inputs; | |||
| std::map<NodePtr, int> data_indices; | |||
| GE_CHK_STATUS_RET(GetDataNodes(graph, data_indices), | |||
| "[%s] Failed to get data node indices", | |||
| node_item.NodeName().c_str()); | |||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||
| auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (out_data_anchor == nullptr) { | |||
| continue; | |||
| } | |||
| auto src_node = out_data_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(src_node); | |||
| auto op_desc = src_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| auto src_op_type = src_node->GetType(); | |||
| auto output_index = in_data_anchor->GetIdx(); | |||
| GELOGD("Node %s, output %d, src node = %s, src node type = %s", | |||
| node_item.NodeName().c_str(), | |||
| output_index, | |||
| src_node->GetName().c_str(), | |||
| src_op_type.c_str()); | |||
| // parse reuse outputs | |||
| std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||
| auto it = connected_inputs.find(input_key); | |||
| if (it == connected_inputs.end()) { | |||
| connected_inputs.emplace(input_key, output_index); | |||
| } else { | |||
| GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||
| output_index, | |||
| it->second, | |||
| src_node->GetName().c_str(), | |||
| out_data_anchor->GetIdx()); | |||
| node_item.reuse_outputs.emplace(output_index, it->second); | |||
| } | |||
| if (src_op_type == DATA) { | |||
| int data_index = data_indices[src_node]; | |||
| node_item.reuse_inputs.emplace(output_index, data_index); | |||
| GELOGD("[%s] output[%u] reuses input[%d]", node_item.NodeName().c_str(), output_index, data_index); | |||
| } else if (src_op_type == CONSTANTOP || src_op_type == CONSTANT || src_op_type == VARIABLE) { | |||
| node_item.ref_outputs.emplace(output_index, src_node); | |||
| GELOGD("[%s] output[%d] ref to node [%s]", | |||
| node_item.NodeName().c_str(), | |||
| output_index, | |||
| src_node->GetName().c_str()); | |||
| } | |||
| } | |||
| GELOGD("[%s] Done parsing attributes for outputs successfully", node_item.NodeName().c_str()); | |||
| return SUCCESS; | |||
| } | |||
| Status KnownNodeExecutor::GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices) { | |||
| std::map<int, NodePtr> ordered_data_nodes; | |||
| for (const auto &node : graph.GetDirectNode()) { | |||
| GE_CHECK_NOTNULL(node); | |||
| if (node->GetType() == DATA) { | |||
| int index = -1; | |||
| (void) AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, index); | |||
| ordered_data_nodes.emplace(index, node); | |||
| } | |||
| } | |||
| // reindex | |||
| int data_index = 0; | |||
| for (const auto &it : ordered_data_nodes) { | |||
| data_indices.emplace(it.second, data_index++); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status KnownNodeExecutor::GetModelAndGraph(const HybridModel &model, | |||
| const NodePtr &node, | |||
| GeModelPtr &ge_model, | |||
| ComputeGraphPtr &graph) { | |||
| ge_model = model.GetGeModel(node); | |||
| GE_CHECK_NOTNULL(ge_model); | |||
| const auto &root_graph = model.GetRootGraph(); | |||
| GE_CHECK_NOTNULL(root_graph); | |||
| AscendString graph_name; | |||
| GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | |||
| graph = root_graph->GetSubgraph(graph_name.GetString()); | |||
| GE_CHECK_NOTNULL(graph); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace hybrid | |||
| } // namespace ge | |||
| @@ -51,6 +51,14 @@ class KnownNodeExecutor : public NodeExecutor { | |||
| Status PrepareTask(NodeTask &task, TaskContext &context) const; | |||
| Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | |||
| ~KnownNodeExecutor() {} | |||
| private: | |||
| static Status ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph); | |||
| static Status GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices); | |||
| static Status GetModelAndGraph(const HybridModel &model, | |||
| const NodePtr &node, | |||
| GeModelPtr &ge_model, | |||
| ComputeGraphPtr &graph); | |||
| }; | |||
| } // namespace hybrid | |||
| } // namespace ge | |||
| @@ -202,7 +202,7 @@ TEST_F(UtestGeHybrid, data_direct_connect) { | |||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph); | |||
| HybridModel hybrid_model(ge_root_model); | |||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||
| auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get()); | |||
| auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get(), sub_graph); | |||
| ASSERT_EQ(ret, SUCCESS); | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #undef private | |||
| #undef protected | |||
| #include "graph/manager/graph_mem_allocator.h" | |||
| #include "../graph/passes/graph_builder_utils.h" | |||
| using namespace std; | |||
| using namespace testing; | |||
| @@ -69,3 +70,22 @@ TEST_F(UnknownNodeExecutorTest, test_init_davinci_model) { | |||
| model.weight_buffer_map_.emplace("subgraph", TensorBuffer::Create(buffer, sizeof(buffer))); | |||
| ASSERT_EQ(mock.InitDavinciModel(model, model.GetModelWeight("subgraph")), SUCCESS); | |||
| } | |||
| TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) { | |||
| ut::GraphBuilder builder("test-graph"); | |||
| auto data_node = builder.AddNode("Data0", DATA, 1, 1); | |||
| auto netoutput_node = builder.AddNode("NodeOutput", NETOUTPUT, 2, 2); | |||
| builder.AddDataEdge(data_node, 0, netoutput_node, 0); | |||
| auto const_node = builder.AddNode("Const0", CONSTANT, 0, 1); | |||
| builder.AddDataEdge(const_node, 0, netoutput_node, 1); | |||
| auto graph = builder.GetGraph(); | |||
| ut::GraphBuilder builder2("root-graph"); | |||
| auto partitioned_call = builder2.AddNode("Node0", PARTITIONEDCALL, 1, 2); | |||
| NodeItem node_item(partitioned_call); | |||
| ASSERT_EQ(KnownNodeExecutor::ParseAttrForAllocatingOutputs(node_item, *graph), SUCCESS); | |||
| ASSERT_EQ(node_item.ref_outputs.size(), 1); | |||
| ASSERT_EQ(node_item.ref_outputs[1], const_node); | |||
| ASSERT_EQ(node_item.reuse_inputs.size(), 1); | |||
| ASSERT_EQ(node_item.reuse_inputs[0], 0); | |||
| } | |||