From: @xchu42 Reviewed-by: @wan_xuelei,@wqtshg,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -121,6 +121,10 @@ const GraphItem *HybridModel::GetRootGraphItem() const { | |||||
return root_graph_item_.get(); | return root_graph_item_.get(); | ||||
} | } | ||||
const ComputeGraphPtr &HybridModel::GetRootGraph() const { | |||||
return root_graph_; | |||||
} | |||||
const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { | const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { | ||||
GELOGD("To find subgraph item by name = %s", graph_name.c_str()); | GELOGD("To find subgraph item by name = %s", graph_name.c_str()); | ||||
auto it = subgraph_items_.find(graph_name); | auto it = subgraph_items_.find(graph_name); | ||||
@@ -101,6 +101,8 @@ class HybridModel { | |||||
const GraphItem *GetRootGraphItem() const; | const GraphItem *GetRootGraphItem() const; | ||||
const ComputeGraphPtr &GetRootGraph() const; | |||||
const GraphItem *GetSubgraphItem(const std::string &graph_name) const; | const GraphItem *GetSubgraphItem(const std::string &graph_name) const; | ||||
const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; | const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; | ||||
@@ -842,13 +842,6 @@ Status HybridModelBuilder::LoadGraph() { | |||||
"[Invoke][LoadDynamicSubgraph]Failed to load subgraph: [%s]", | "[Invoke][LoadDynamicSubgraph]Failed to load subgraph: [%s]", | ||||
sub_graph->GetName().c_str()); | sub_graph->GetName().c_str()); | ||||
} else { | } else { | ||||
GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), | |||||
"[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", | |||||
parent_node_item->NodeName().c_str()); | |||||
GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item), | |||||
"[Invoke][IdentifySameInputs][%s] Failed to identify same outputs.", | |||||
parent_node_item->NodeName().c_str()); | |||||
// if parent is function control op. need add a virtual partitioned call | // if parent is function control op. need add a virtual partitioned call | ||||
if (parent_node_item->IsControlFlowV2Op()) { | if (parent_node_item->IsControlFlowV2Op()) { | ||||
GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), | GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), | ||||
@@ -857,7 +850,16 @@ Status HybridModelBuilder::LoadGraph() { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
for (auto &it : hybrid_model_.known_shape_sub_models_) { | |||||
auto node_item = MutableNodeItem(it.first); | |||||
AscendString graph_name; | |||||
GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | |||||
auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString()); | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
GE_CHK_STATUS_RET(IdentifyVariableOutputs(*node_item, subgraph), | |||||
"[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.", | |||||
node_item->NodeName().c_str()); | |||||
} | |||||
GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), | GE_CHK_STATUS_RET(ParseDependentByParallelGroup(), | ||||
"[Invoke][ParseDependentByParallelGroup]Failed to establish dependencies for hccl ops," | "[Invoke][ParseDependentByParallelGroup]Failed to establish dependencies for hccl ops," | ||||
"model_name_:%s.", GetGraphName()); | "model_name_:%s.", GetGraphName()); | ||||
@@ -1493,50 +1495,8 @@ Status HybridModelBuilder::InitRuntimeParams() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) { | |||||
GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str()); | |||||
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | |||||
if (net_output_node == nullptr) { | |||||
GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
auto net_output_desc = net_output_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(net_output_desc); | |||||
std::map<std::string, int> connected_inputs; | |||||
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if (out_data_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto src_node = out_data_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(src_node); | |||||
auto op_desc = src_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||||
auto it = connected_inputs.find(input_key); | |||||
if (it == connected_inputs.end()) { | |||||
connected_inputs.emplace(input_key, in_data_anchor->GetIdx()); | |||||
} else { | |||||
GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||||
in_data_anchor->GetIdx(), | |||||
it->second, | |||||
src_node->GetName().c_str(), | |||||
out_data_anchor->GetIdx()); | |||||
node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | |||||
Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph) { | |||||
GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); | ||||
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); | ||||
if (net_output_node == nullptr) { | if (net_output_node == nullptr) { | ||||
GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str()); | GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str()); | ||||
@@ -1545,36 +1505,13 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { | |||||
auto net_output_desc = net_output_node->GetOpDesc(); | auto net_output_desc = net_output_node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(net_output_desc); | GE_CHECK_NOTNULL(net_output_desc); | ||||
// constant/variable connected to net output | |||||
// constants connected to net output | |||||
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | ||||
auto src_node = GetPeerNode(in_data_anchor); | auto src_node = GetPeerNode(in_data_anchor); | ||||
GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
auto src_op_type = src_node->GetType(); | auto src_op_type = src_node->GetType(); | ||||
GELOGD("Node %s, output %d, src node = %s, src node type = %s", | |||||
node_item.NodeName().c_str(), | |||||
in_data_anchor->GetIdx(), | |||||
src_node->GetName().c_str(), | |||||
src_op_type.c_str()); | |||||
uint32_t parent_index = 0; | |||||
if (GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index) != SUCCESS) { | |||||
continue; | |||||
} | |||||
GELOGD("Got parent output index = %u", parent_index); | |||||
if (src_op_type == DATA) { | |||||
int ref_i = 0; | |||||
(void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i); | |||||
node_item.reuse_inputs.emplace(static_cast<int>(parent_index), ref_i); | |||||
GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i); | |||||
} | |||||
if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) { | |||||
continue; | |||||
} | |||||
GE_CHECK_LE(parent_index, INT32_MAX); | |||||
node_item.ref_outputs.emplace(static_cast<int>(parent_index), src_node); | |||||
if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { | if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) { | ||||
known_subgraph_constant_output_refs_[&node_item].emplace(parent_index, src_node); | |||||
known_subgraph_constant_output_refs_[&node_item].emplace(in_data_anchor->GetIdx(), src_node); | |||||
} | } | ||||
} | } | ||||
@@ -59,8 +59,7 @@ class HybridModelBuilder { | |||||
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | ||||
Status LoadTask(NodeItem &node_item); | Status LoadTask(NodeItem &node_item); | ||||
Status LoadTasks(); | Status LoadTasks(); | ||||
Status IdentifyVariableOutputs(NodeItem &node_item); | |||||
Status IdentifySameInputs(NodeItem &node_item); | |||||
Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); | |||||
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | ||||
@@ -22,6 +22,7 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/load/model_manager/model_utils.h" | #include "graph/load/model_manager/model_utils.h" | ||||
#include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
#include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
@@ -184,13 +185,15 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node | |||||
GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); | GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str()); | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
const GeModelPtr ge_model = model.GetGeModel(node); | |||||
GE_CHECK_NOTNULL(ge_model); | |||||
AscendString graph_name; | |||||
GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get graph name"); | |||||
auto weight_buffer = model.GetModelWeight(graph_name.GetString()); | |||||
GeModelPtr ge_model; | |||||
ComputeGraphPtr compute_graph; | |||||
GE_CHK_STATUS_RET(GetModelAndGraph(model, node, ge_model, compute_graph), | |||||
"[%s] Failed to get model and graph", | |||||
node->GetName().c_str()); | |||||
auto node_item = const_cast<NodeItem *>(model.GetNodeItem(node)); | |||||
GE_CHECK_NOTNULL(node_item); | |||||
GE_CHK_STATUS_RET_NOLOG(ParseAttrForAllocatingOutputs(*node_item, *compute_graph)); | |||||
auto weight_buffer = model.GetModelWeight(compute_graph->GetName()); | |||||
std::shared_ptr<DavinciModel> davinci_model = MakeShared<DavinciModel>(0, nullptr); | std::shared_ptr<DavinciModel> davinci_model = MakeShared<DavinciModel>(0, nullptr); | ||||
GE_CHECK_NOTNULL(davinci_model); | GE_CHECK_NOTNULL(davinci_model); | ||||
@@ -223,5 +226,102 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, | |||||
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); | RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status KnownNodeExecutor::ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph) { | |||||
GELOGD("[%s] Start to parse attributes for outputs", node_item.NodeName().c_str()); | |||||
auto net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT); | |||||
if (net_output_node == nullptr) { | |||||
GELOGD("[%s] Subgraph do not got net output", graph.GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
auto net_output_desc = net_output_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(net_output_desc); | |||||
std::map<std::string, int> connected_inputs; | |||||
std::map<NodePtr, int> data_indices; | |||||
GE_CHK_STATUS_RET(GetDataNodes(graph, data_indices), | |||||
"[%s] Failed to get data node indices", | |||||
node_item.NodeName().c_str()); | |||||
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { | |||||
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
if (out_data_anchor == nullptr) { | |||||
continue; | |||||
} | |||||
auto src_node = out_data_anchor->GetOwnerNode(); | |||||
GE_CHECK_NOTNULL(src_node); | |||||
auto op_desc = src_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
auto src_op_type = src_node->GetType(); | |||||
auto output_index = in_data_anchor->GetIdx(); | |||||
GELOGD("Node %s, output %d, src node = %s, src node type = %s", | |||||
node_item.NodeName().c_str(), | |||||
output_index, | |||||
src_node->GetName().c_str(), | |||||
src_op_type.c_str()); | |||||
// parse reuse outputs | |||||
std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx()); | |||||
auto it = connected_inputs.find(input_key); | |||||
if (it == connected_inputs.end()) { | |||||
connected_inputs.emplace(input_key, output_index); | |||||
} else { | |||||
GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(), | |||||
output_index, | |||||
it->second, | |||||
src_node->GetName().c_str(), | |||||
out_data_anchor->GetIdx()); | |||||
node_item.reuse_outputs.emplace(output_index, it->second); | |||||
} | |||||
if (src_op_type == DATA) { | |||||
int data_index = data_indices[src_node]; | |||||
node_item.reuse_inputs.emplace(output_index, data_index); | |||||
GELOGD("[%s] output[%u] reuses input[%d]", node_item.NodeName().c_str(), output_index, data_index); | |||||
} else if (src_op_type == CONSTANTOP || src_op_type == CONSTANT || src_op_type == VARIABLE) { | |||||
node_item.ref_outputs.emplace(output_index, src_node); | |||||
GELOGD("[%s] output[%d] ref to node [%s]", | |||||
node_item.NodeName().c_str(), | |||||
output_index, | |||||
src_node->GetName().c_str()); | |||||
} | |||||
} | |||||
GELOGD("[%s] Done parsing attributes for outputs successfully", node_item.NodeName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status KnownNodeExecutor::GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices) { | |||||
std::map<int, NodePtr> ordered_data_nodes; | |||||
for (const auto &node : graph.GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node); | |||||
if (node->GetType() == DATA) { | |||||
int index = -1; | |||||
(void) AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, index); | |||||
ordered_data_nodes.emplace(index, node); | |||||
} | |||||
} | |||||
// reindex | |||||
int data_index = 0; | |||||
for (const auto &it : ordered_data_nodes) { | |||||
data_indices.emplace(it.second, data_index++); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status KnownNodeExecutor::GetModelAndGraph(const HybridModel &model, | |||||
const NodePtr &node, | |||||
GeModelPtr &ge_model, | |||||
ComputeGraphPtr &graph) { | |||||
ge_model = model.GetGeModel(node); | |||||
GE_CHECK_NOTNULL(ge_model); | |||||
const auto &root_graph = model.GetRootGraph(); | |||||
GE_CHECK_NOTNULL(root_graph); | |||||
AscendString graph_name; | |||||
GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get subgraph name"); | |||||
graph = root_graph->GetSubgraph(graph_name.GetString()); | |||||
GE_CHECK_NOTNULL(graph); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge |
@@ -51,6 +51,14 @@ class KnownNodeExecutor : public NodeExecutor { | |||||
Status PrepareTask(NodeTask &task, TaskContext &context) const; | Status PrepareTask(NodeTask &task, TaskContext &context) const; | ||||
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const; | ||||
~KnownNodeExecutor() {} | ~KnownNodeExecutor() {} | ||||
private: | |||||
static Status ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph); | |||||
static Status GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices); | |||||
static Status GetModelAndGraph(const HybridModel &model, | |||||
const NodePtr &node, | |||||
GeModelPtr &ge_model, | |||||
ComputeGraphPtr &graph); | |||||
}; | }; | ||||
} // namespace hybrid | } // namespace hybrid | ||||
} // namespace ge | } // namespace ge | ||||
@@ -202,7 +202,7 @@ TEST_F(UtestGeHybrid, data_direct_connect) { | |||||
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph); | GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph); | ||||
HybridModel hybrid_model(ge_root_model); | HybridModel hybrid_model(ge_root_model); | ||||
HybridModelBuilder hybrid_model_builder(hybrid_model); | HybridModelBuilder hybrid_model_builder(hybrid_model); | ||||
auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get()); | |||||
auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get(), sub_graph); | |||||
ASSERT_EQ(ret, SUCCESS); | ASSERT_EQ(ret, SUCCESS); | ||||
} | } | ||||
@@ -26,6 +26,7 @@ | |||||
#undef private | #undef private | ||||
#undef protected | #undef protected | ||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
#include "../graph/passes/graph_builder_utils.h" | |||||
using namespace std; | using namespace std; | ||||
using namespace testing; | using namespace testing; | ||||
@@ -69,3 +70,22 @@ TEST_F(UnknownNodeExecutorTest, test_init_davinci_model) { | |||||
model.weight_buffer_map_.emplace("subgraph", TensorBuffer::Create(buffer, sizeof(buffer))); | model.weight_buffer_map_.emplace("subgraph", TensorBuffer::Create(buffer, sizeof(buffer))); | ||||
ASSERT_EQ(mock.InitDavinciModel(model, model.GetModelWeight("subgraph")), SUCCESS); | ASSERT_EQ(mock.InitDavinciModel(model, model.GetModelWeight("subgraph")), SUCCESS); | ||||
} | } | ||||
TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) { | |||||
ut::GraphBuilder builder("test-graph"); | |||||
auto data_node = builder.AddNode("Data0", DATA, 1, 1); | |||||
auto netoutput_node = builder.AddNode("NodeOutput", NETOUTPUT, 2, 2); | |||||
builder.AddDataEdge(data_node, 0, netoutput_node, 0); | |||||
auto const_node = builder.AddNode("Const0", CONSTANT, 0, 1); | |||||
builder.AddDataEdge(const_node, 0, netoutput_node, 1); | |||||
auto graph = builder.GetGraph(); | |||||
ut::GraphBuilder builder2("root-graph"); | |||||
auto partitioned_call = builder2.AddNode("Node0", PARTITIONEDCALL, 1, 2); | |||||
NodeItem node_item(partitioned_call); | |||||
ASSERT_EQ(KnownNodeExecutor::ParseAttrForAllocatingOutputs(node_item, *graph), SUCCESS); | |||||
ASSERT_EQ(node_item.ref_outputs.size(), 1); | |||||
ASSERT_EQ(node_item.ref_outputs[1], const_node); | |||||
ASSERT_EQ(node_item.reuse_inputs.size(), 1); | |||||
ASSERT_EQ(node_item.reuse_inputs[0], 0); | |||||
} |