| @@ -120,6 +120,10 @@ const GraphItem *HybridModel::GetRootGraphItem() const { | |||||
| return root_graph_item_.get(); | return root_graph_item_.get(); | ||||
| } | } | ||||
| const ComputeGraphPtr &HybridModel::GetRootGraph() const { | |||||
| return root_graph_; | |||||
| } | |||||
| const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { | const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { | ||||
| GELOGD("To find subgraph item by name = %s", graph_name.c_str()); | GELOGD("To find subgraph item by name = %s", graph_name.c_str()); | ||||
| auto it = subgraph_items_.find(graph_name); | auto it = subgraph_items_.find(graph_name); | ||||
| @@ -101,6 +101,8 @@ class HybridModel { | |||||
| const GraphItem *GetRootGraphItem() const; | const GraphItem *GetRootGraphItem() const; | ||||
| const ComputeGraphPtr &GetRootGraph() const; | |||||
| const GraphItem *GetSubgraphItem(const std::string &graph_name) const; | const GraphItem *GetSubgraphItem(const std::string &graph_name) const; | ||||
| const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; | const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; | ||||
| @@ -831,13 +831,6 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| "[Invoke][LoadDynamicSubgraph]Failed to load subgraph: [%s]", | "[Invoke][LoadDynamicSubgraph]Failed to load subgraph: [%s]", | ||||
| sub_graph->GetName().c_str()); | sub_graph->GetName().c_str()); | ||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), | |||||
| "[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", | |||||
| parent_node_item->NodeName().c_str()); | |||||
| GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), | |||||
| "[Invoke][IdentifySameInputs][%s] Failed to identify same outputs.", | |||||
| parent_node_item->NodeName().c_str()); | |||||
| // if parent is function control op. need add a virtual partitioned call | // if parent is function control op. need add a virtual partitioned call | ||||
| if (parent_node_item->IsControlOp()) { | if (parent_node_item->IsControlOp()) { | ||||
| GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), | GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), | ||||
| @@ -846,7 +839,16 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| for (auto &it : hybrid_model_.known_shape_sub_models_) { | |||||
| auto node_item = MutableNodeItem(it.first); | |||||
| AscendString graph_name; | |||||
| GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | |||||
| auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| GE_CHK_STATUS_RET(IdentifyVariableOutputs(*node_item, subgraph), | |||||
| "[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", | |||||
| node_item->NodeName().c_str()); | |||||
| } | |||||
| GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), | GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), | ||||
| "[Invoke][ParseDependentByParallelGroup]Failed to establish dependencies for hccl ops," | "[Invoke][ParseDependentByParallelGroup]Failed to establish dependencies for hccl ops," | ||||
| "model_name_:%s.", GetGraphName()); | "model_name_:%s.", GetGraphName()); | ||||
| @@ -1478,50 +1480,8 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { | |||||
| GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); | |||||
| auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | |||||
| if (net_output_node == nullptr) { | |||||
| GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto net_output_desc = net_output_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(net_output_desc); | |||||
| std::map<std::string, int> connected_inputs; | |||||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
| auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto src_node = out_data_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto op_desc = src_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||||
| auto it = connected_inputs.find(input_key); | |||||
| if (it == connected_inputs.end()) { | |||||
| connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); | |||||
| } else { | |||||
| GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||||
| in_data_anchor->GetIdx(), | |||||
| it->second, | |||||
| src_node->GetName().c_str(), | |||||
| out_data_anchor->GetIdx()); | |||||
| node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | |||||
| Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph) { | |||||
| GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | ||||
| auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | ||||
| if (net_output_node == nullptr) { | if (net_output_node == nullptr) { | ||||
| GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str()); | GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str()); | ||||
| @@ -1530,36 +1490,13 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | |||||
| auto net_output_desc = net_output_node->GetOpDesc(); | auto net_output_desc = net_output_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(net_output_desc); | GE_CHECK_NOTNULL(net_output_desc); | ||||
| // constant/variable connected to net output | |||||
| // constants connected to net output | |||||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | ||||
| auto src_node = GetPeerNode(in_data_anchor); | auto src_node = GetPeerNode(in_data_anchor); | ||||
| GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
| auto src_op_type = src_node->GetType(); | auto src_op_type = src_node->GetType(); | ||||
| GELOGD("Node %s, output %d, src node = %s, src node type = %s", | |||||
| node_item.NodeName().c_str(), | |||||
| in_data_anchor->GetIdx(), | |||||
| src_node->GetName().c_str(), | |||||
| src_op_type.c_str()); | |||||
| uint32_t parent_index = 0; | |||||
| if (GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index) != SUCCESS) { | |||||
| continue; | |||||
| } | |||||
| GELOGD("Got parent output index = %u", parent_index); | |||||
| if (src_op_type == DATA) { | |||||
| int ref_i = 0; | |||||
| (void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i); | |||||
| node_item.reuse_inputs.emplace(static_cast<int>(parent_index), ref_i); | |||||
| GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i); | |||||
| } | |||||
| if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) { | |||||
| continue; | |||||
| } | |||||
| GE_CHECK_LE(parent_index, INT32_MAX); | |||||
| node_item.ref_outputs.emplace(static_cast<int>(parent_index), src_node); | |||||
| if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { | if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { | ||||
| known_subgraph_constant_output_refs_[&node_item].emplace(parent_index, src_node); | |||||
| known_subgraph_constant_output_refs_[&node_item].emplace(in_data_anchor->GetIdx(), src_node); | |||||
| } | } | ||||
| } | } | ||||
| @@ -59,8 +59,7 @@ class HybridModelBuilder { | |||||
| Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
| Status LoadTask(NodeItem &node_item); | Status LoadTask(NodeItem &node_item); | ||||
| Status LoadTasks(); | Status LoadTasks(); | ||||
| Status IdentifyVariableOutputs(NodeItem &node_item); | |||||
| Status IdentifySameInputs(NodeItem &node_item); | |||||
| Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); | |||||
| Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
| Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
| Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/load/model_manager/model_utils.h" | #include "graph/load/model_manager/model_utils.h" | ||||
| #include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| @@ -184,13 +185,15 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node | |||||
| GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); | GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| const GeModelPtr ge_model = model.GetGeModel(node); | |||||
| GE_CHECK_NOTNULL(ge_model); | |||||
| AscendString graph_name; | |||||
| GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get graph name"); | |||||
| auto weight_buffer = model.GetModelWeight(graph_name.GetString()); | |||||
| GeModelPtr ge_model; | |||||
| ComputeGraphPtr compute_graph; | |||||
| GE_CHK_STATUS_RET(GetModelAndGraph(model, node, ge_model, compute_graph), | |||||
| "[%s] Failed to get model and graph", | |||||
| node->GetName().c_str()); | |||||
| auto node_item = const_cast<NodeItem *>(model.GetNodeItem(node)); | |||||
| GE_CHECK_NOTNULL(node_item); | |||||
| GE_CHK_STATUS_RET_NOLOG(ParseAttrForAllocatingOutputs(*node_item, *compute_graph)); | |||||
| auto weight_buffer = model.GetModelWeight(compute_graph->GetName()); | |||||
| std::shared_ptr<DavinciModel> davinci_model = MakeShared<DavinciModel>(0, nullptr); | std::shared_ptr<DavinciModel> davinci_model = MakeShared<DavinciModel>(0, nullptr); | ||||
| GE_CHECK_NOTNULL(davinci_model); | GE_CHECK_NOTNULL(davinci_model); | ||||
| @@ -223,5 +226,102 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, | |||||
| RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status KnownNodeExecutor::ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph) { | |||||
| GELOGD("[%s] Start to parse attributes for outputs", node_item.NodeName().c_str()); | |||||
| auto net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); | |||||
| if (net_output_node == nullptr) { | |||||
| GELOGD("[%s] Subgraph do not got net output", graph.GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto net_output_desc = net_output_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(net_output_desc); | |||||
| std::map<std::string, int> connected_inputs; | |||||
| std::map<NodePtr, int> data_indices; | |||||
| GE_CHK_STATUS_RET(GetDataNodes(graph, data_indices), | |||||
| "[%s] Failed to get data node indices", | |||||
| node_item.NodeName().c_str()); | |||||
| for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
| auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if (out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto src_node = out_data_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto op_desc = src_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| auto src_op_type = src_node->GetType(); | |||||
| auto output_index = in_data_anchor->GetIdx(); | |||||
| GELOGD("Node %s, output %d, src node = %s, src node type = %s", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index, | |||||
| src_node->GetName().c_str(), | |||||
| src_op_type.c_str()); | |||||
| // parse reuse outputs | |||||
| std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||||
| auto it = connected_inputs.find(input_key); | |||||
| if (it == connected_inputs.end()) { | |||||
| connected_inputs.emplace(input_key, output_index); | |||||
| } else { | |||||
| GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||||
| output_index, | |||||
| it->second, | |||||
| src_node->GetName().c_str(), | |||||
| out_data_anchor->GetIdx()); | |||||
| node_item.reuse_outputs.emplace(output_index, it->second); | |||||
| } | |||||
| if (src_op_type == DATA) { | |||||
| int data_index = data_indices[src_node]; | |||||
| node_item.reuse_inputs.emplace(output_index, data_index); | |||||
| GELOGD("[%s] output[%u] reuses input[%d]", node_item.NodeName().c_str(), output_index, data_index); | |||||
| } else if (src_op_type == CONSTANTOP || src_op_type == CONSTANT || src_op_type == VARIABLE) { | |||||
| node_item.ref_outputs.emplace(output_index, src_node); | |||||
| GELOGD("[%s] output[%d] ref to node [%s]", | |||||
| node_item.NodeName().c_str(), | |||||
| output_index, | |||||
| src_node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| GELOGD("[%s] Done parsing attributes for outputs successfully", node_item.NodeName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status KnownNodeExecutor::GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices) { | |||||
| std::map<int, NodePtr> ordered_data_nodes; | |||||
| for (const auto &node : graph.GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (node->GetType() == DATA) { | |||||
| int index = -1; | |||||
| (void) AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, index); | |||||
| ordered_data_nodes.emplace(index, node); | |||||
| } | |||||
| } | |||||
| // reindex | |||||
| int data_index = 0; | |||||
| for (const auto &it : ordered_data_nodes) { | |||||
| data_indices.emplace(it.second, data_index++); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status KnownNodeExecutor::GetModelAndGraph(const HybridModel &model, | |||||
| const NodePtr &node, | |||||
| GeModelPtr &ge_model, | |||||
| ComputeGraphPtr &graph) { | |||||
| ge_model = model.GetGeModel(node); | |||||
| GE_CHECK_NOTNULL(ge_model); | |||||
| const auto &root_graph = model.GetRootGraph(); | |||||
| GE_CHECK_NOTNULL(root_graph); | |||||
| AscendString graph_name; | |||||
| GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | |||||
| graph = root_graph->GetSubgraph(graph_name.GetString()); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -51,6 +51,14 @@ class KnownNodeExecutor : public NodeExecutor { | |||||
| Status PrepareTask(NodeTask &task, TaskContext &context) const; | Status PrepareTask(NodeTask &task, TaskContext &context) const; | ||||
| Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | ||||
| ~KnownNodeExecutor() {} | ~KnownNodeExecutor() {} | ||||
| private: | |||||
| static Status ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph); | |||||
| static Status GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices); | |||||
| static Status GetModelAndGraph(const HybridModel &model, | |||||
| const NodePtr &node, | |||||
| GeModelPtr &ge_model, | |||||
| ComputeGraphPtr &graph); | |||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -202,7 +202,7 @@ TEST_F(UtestGeHybrid, data_direct_connect) { | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph); | ||||
| HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | HybridModelBuilder hybrid_model_builder(hybrid_model); | ||||
| auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get()); | |||||
| auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get(), sub_graph); | |||||
| ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ | |||||
| #undef private | #undef private | ||||
| #undef protected | #undef protected | ||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include "../graph/passes/graph_builder_utils.h" | |||||
| using namespace std; | using namespace std; | ||||
| using namespace testing; | using namespace testing; | ||||
| @@ -69,3 +70,22 @@ TEST_F(UnknownNodeExecutorTest, test_init_davinci_model) { | |||||
| model.weight_buffer_map_.emplace("subgraph", TensorBuffer::Create(buffer, sizeof(buffer))); | model.weight_buffer_map_.emplace("subgraph", TensorBuffer::Create(buffer, sizeof(buffer))); | ||||
| ASSERT_EQ(mock.InitDavinciModel(model, model.GetModelWeight("subgraph")), SUCCESS); | ASSERT_EQ(mock.InitDavinciModel(model, model.GetModelWeight("subgraph")), SUCCESS); | ||||
| } | } | ||||
| TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) { | |||||
| ut::GraphBuilder builder("test-graph"); | |||||
| auto data_node = builder.AddNode("Data0", DATA, 1, 1); | |||||
| auto netoutput_node = builder.AddNode("NodeOutput", NETOUTPUT, 2, 2); | |||||
| builder.AddDataEdge(data_node, 0, netoutput_node, 0); | |||||
| auto const_node = builder.AddNode("Const0", CONSTANT, 0, 1); | |||||
| builder.AddDataEdge(const_node, 0, netoutput_node, 1); | |||||
| auto graph = builder.GetGraph(); | |||||
| ut::GraphBuilder builder2("root-graph"); | |||||
| auto partitioned_call = builder2.AddNode("Node0", PARTITIONEDCALL, 1, 2); | |||||
| NodeItem node_item(partitioned_call); | |||||
| ASSERT_EQ(KnownNodeExecutor::ParseAttrForAllocatingOutputs(node_item, *graph), SUCCESS); | |||||
| ASSERT_EQ(node_item.ref_outputs.size(), 1); | |||||
| ASSERT_EQ(node_item.ref_outputs[1], const_node); | |||||
| ASSERT_EQ(node_item.reuse_inputs.size(), 1); | |||||
| ASSERT_EQ(node_item.reuse_inputs[0], 0); | |||||
| } | |||||