Browse Source

fix reuse input

tags/v1.3.0
chuxing 3 years ago
parent
commit
8dac43028b
8 changed files with 156 additions and 86 deletions
  1. +4
    -0
      ge/hybrid/model/hybrid_model.cc
  2. +2
    -0
      ge/hybrid/model/hybrid_model.h
  3. +13
    -76
      ge/hybrid/model/hybrid_model_builder.cc
  4. +1
    -2
      ge/hybrid/model/hybrid_model_builder.h
  5. +107
    -7
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc
  6. +8
    -0
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h
  7. +1
    -1
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc
  8. +20
    -0
      tests/ut/ge/hybrid/known_node_executor_unittest.cc

+ 4
- 0
ge/hybrid/model/hybrid_model.cc View File

@@ -120,6 +120,10 @@ const GraphItem *HybridModel::GetRootGraphItem() const {
return root_graph_item_.get();
}

const ComputeGraphPtr &HybridModel::GetRootGraph() const {
return root_graph_;
}

const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const {
GELOGD("To find subgraph item by name = %s", graph_name.c_str());
auto it = subgraph_items_.find(graph_name);


+ 2
- 0
ge/hybrid/model/hybrid_model.h View File

@@ -101,6 +101,8 @@ class HybridModel {

const GraphItem *GetRootGraphItem() const;

const ComputeGraphPtr &GetRootGraph() const;

const GraphItem *GetSubgraphItem(const std::string &graph_name) const;

const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const;


+ 13
- 76
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -831,13 +831,6 @@ Status HybridModelBuilder::LoadGraph() {
"[Invoke][LoadDynamicSubgraph]Failed to load subgraph: [%s]",
sub_graph->GetName().c_str());
} else {
GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item),
"[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.",
parent_node_item->NodeName().c_str());
GE_CHK_STATUS_RET(IdentifySameInputs(*parent_node_item),
"[Invoke][IdentifySameInputs][%s] Failed to identify same outputs.",
parent_node_item->NodeName().c_str());

// if parent is function control op. need add a virtual partitioned call
if (parent_node_item->IsControlOp()) {
GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item),
@@ -846,7 +839,16 @@ Status HybridModelBuilder::LoadGraph() {
}
}
}

for (auto &it : hybrid_model_.known_shape_sub_models_) {
auto node_item = MutableNodeItem(it.first);
AscendString graph_name;
GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name");
auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString());
GE_CHECK_NOTNULL(subgraph);
GE_CHK_STATUS_RET(IdentifyVariableOutputs(*node_item, subgraph),
"[Invoke][IdentifyVariableOutputs][%s] Failed to identify ref outputs.",
node_item->NodeName().c_str());
}
GE_CHK_STATUS_RET(ParseDependentByParallelGroup(),
"[Invoke][ParseDependentByParallelGroup]Failed to establish dependencies for hccl ops,"
"model_name_:%s.", GetGraphName());
@@ -1478,50 +1480,8 @@ Status HybridModelBuilder::InitRuntimeParams() {
return SUCCESS;
}

Status HybridModelBuilder::IdentifySameInputs(NodeItem &node_item) {
GELOGD("Start to parse same inputs on net output: %s", node_item.NodeName().c_str());
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex);
GE_CHECK_NOTNULL(subgraph);
auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
if (net_output_node == nullptr) {
GELOGD("Subgraph [%s] does not have net output", subgraph->GetName().c_str());
return SUCCESS;
}

auto net_output_desc = net_output_node->GetOpDesc();
GE_CHECK_NOTNULL(net_output_desc);

std::map<std::string, int> connected_inputs;
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) {
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (out_data_anchor == nullptr) {
continue;
}
auto src_node = out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto op_desc = src_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);

std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx());
auto it = connected_inputs.find(input_key);
if (it == connected_inputs.end()) {
connected_inputs.emplace(input_key, in_data_anchor->GetIdx());
} else {
GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(),
in_data_anchor->GetIdx(),
it->second,
src_node->GetName().c_str(),
out_data_anchor->GetIdx());
node_item.reuse_outputs.emplace(in_data_anchor->GetIdx(), it->second);
}
}
return SUCCESS;
}

Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) {
Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph) {
GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str());
auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex);
GE_CHECK_NOTNULL(subgraph);
auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
if (net_output_node == nullptr) {
GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str());
@@ -1530,36 +1490,13 @@ Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) {
auto net_output_desc = net_output_node->GetOpDesc();
GE_CHECK_NOTNULL(net_output_desc);

// constant/variable connected to net output
// constants connected to net output
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) {
auto src_node = GetPeerNode(in_data_anchor);
GE_CHECK_NOTNULL(src_node);
auto src_op_type = src_node->GetType();
GELOGD("Node %s, output %d, src node = %s, src node type = %s",
node_item.NodeName().c_str(),
in_data_anchor->GetIdx(),
src_node->GetName().c_str(),
src_op_type.c_str());
uint32_t parent_index = 0;
if (GetParentNodeOutputIndex(*net_output_desc, in_data_anchor->GetIdx(), parent_index) != SUCCESS) {
continue;
}
GELOGD("Got parent output index = %u", parent_index);
if (src_op_type == DATA) {
int ref_i = 0;
(void)AttrUtils::GetInt(src_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i);
node_item.reuse_inputs.emplace(static_cast<int>(parent_index), ref_i);
GELOGD("[%s] output[%u] resues input[%d]", node_item.NodeName().c_str(), parent_index, ref_i);
}

if (src_op_type != CONSTANTOP && src_op_type != CONSTANT && src_op_type != VARIABLE) {
continue;
}

GE_CHECK_LE(parent_index, INT32_MAX);
node_item.ref_outputs.emplace(static_cast<int>(parent_index), src_node);
if (src_op_type == CONSTANTOP || src_op_type == CONSTANT) {
known_subgraph_constant_output_refs_[&node_item].emplace(parent_index, src_node);
known_subgraph_constant_output_refs_[&node_item].emplace(in_data_anchor->GetIdx(), src_node);
}
}



+ 1
- 2
ge/hybrid/model/hybrid_model_builder.h View File

@@ -59,8 +59,7 @@ class HybridModelBuilder {
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
Status LoadTask(NodeItem &node_item);
Status LoadTasks();
Status IdentifyVariableOutputs(NodeItem &node_item);
Status IdentifySameInputs(NodeItem &node_item);
Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);


+ 107
- 7
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc View File

@@ -22,6 +22,7 @@
#include "common/ge/ge_util.h"
#include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/graph_utils.h"
#include "graph/load/model_manager/model_utils.h"
#include "graph/load/model_manager/model_manager.h"
#include "hybrid/executor/hybrid_execution_context.h"
@@ -184,13 +185,15 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node
GELOGI("[%s] KnownNodeExecutor::LoadTask in.", node->GetName().c_str());
GE_CHECK_NOTNULL(node);

const GeModelPtr ge_model = model.GetGeModel(node);
GE_CHECK_NOTNULL(ge_model);

AscendString graph_name;
GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get graph name");
auto weight_buffer = model.GetModelWeight(graph_name.GetString());

GeModelPtr ge_model;
ComputeGraphPtr compute_graph;
GE_CHK_STATUS_RET(GetModelAndGraph(model, node, ge_model, compute_graph),
"[%s] Failed to get model and graph",
node->GetName().c_str());
auto node_item = const_cast<NodeItem *>(model.GetNodeItem(node));
GE_CHECK_NOTNULL(node_item);
GE_CHK_STATUS_RET_NOLOG(ParseAttrForAllocatingOutputs(*node_item, *compute_graph));
auto weight_buffer = model.GetModelWeight(compute_graph->GetName());
std::shared_ptr<DavinciModel> davinci_model = MakeShared<DavinciModel>(0, nullptr);
GE_CHECK_NOTNULL(davinci_model);

@@ -223,5 +226,102 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context,
RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[KnownNodeExecutorExecuteTask] End");
return SUCCESS;
}

Status KnownNodeExecutor::ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph) {
GELOGD("[%s] Start to parse attributes for outputs", node_item.NodeName().c_str());
auto net_output_node = graph.FindFirstNodeMatchType(NETOUTPUT);
if (net_output_node == nullptr) {
GELOGD("[%s] Subgraph do not got net output", graph.GetName().c_str());
return SUCCESS;
}

auto net_output_desc = net_output_node->GetOpDesc();
GE_CHECK_NOTNULL(net_output_desc);
std::map<std::string, int> connected_inputs;
std::map<NodePtr, int> data_indices;
GE_CHK_STATUS_RET(GetDataNodes(graph, data_indices),
"[%s] Failed to get data node indices",
node_item.NodeName().c_str());
for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) {
auto out_data_anchor = in_data_anchor->GetPeerOutAnchor();
if (out_data_anchor == nullptr) {
continue;
}
auto src_node = out_data_anchor->GetOwnerNode();
GE_CHECK_NOTNULL(src_node);
auto op_desc = src_node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto src_op_type = src_node->GetType();
auto output_index = in_data_anchor->GetIdx();
GELOGD("Node %s, output %d, src node = %s, src node type = %s",
node_item.NodeName().c_str(),
output_index,
src_node->GetName().c_str(),
src_op_type.c_str());
// parse reuse outputs
std::string input_key = std::to_string(op_desc->GetId()) + "_" + std::to_string(out_data_anchor->GetIdx());
auto it = connected_inputs.find(input_key);
if (it == connected_inputs.end()) {
connected_inputs.emplace(input_key, output_index);
} else {
GELOGD("[%s] output [%d] reuse output [%d] input node = %s, idx = %d.", node_item.NodeName().c_str(),
output_index,
it->second,
src_node->GetName().c_str(),
out_data_anchor->GetIdx());
node_item.reuse_outputs.emplace(output_index, it->second);
}

if (src_op_type == DATA) {
int data_index = data_indices[src_node];
node_item.reuse_inputs.emplace(output_index, data_index);
GELOGD("[%s] output[%u] reuses input[%d]", node_item.NodeName().c_str(), output_index, data_index);
} else if (src_op_type == CONSTANTOP || src_op_type == CONSTANT || src_op_type == VARIABLE) {
node_item.ref_outputs.emplace(output_index, src_node);
GELOGD("[%s] output[%d] ref to node [%s]",
node_item.NodeName().c_str(),
output_index,
src_node->GetName().c_str());
}
}

GELOGD("[%s] Done parsing attributes for outputs successfully", node_item.NodeName().c_str());
return SUCCESS;
}

Status KnownNodeExecutor::GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices) {
std::map<int, NodePtr> ordered_data_nodes;
for (const auto &node : graph.GetDirectNode()) {
GE_CHECK_NOTNULL(node);
if (node->GetType() == DATA) {
int index = -1;
(void) AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, index);
ordered_data_nodes.emplace(index, node);
}
}

// reindex
int data_index = 0;
for (const auto &it : ordered_data_nodes) {
data_indices.emplace(it.second, data_index++);
}

return SUCCESS;
}

Status KnownNodeExecutor::GetModelAndGraph(const HybridModel &model,
const NodePtr &node,
GeModelPtr &ge_model,
ComputeGraphPtr &graph) {
ge_model = model.GetGeModel(node);
GE_CHECK_NOTNULL(ge_model);
const auto &root_graph = model.GetRootGraph();
GE_CHECK_NOTNULL(root_graph);
AscendString graph_name;
GE_CHK_GRAPH_STATUS_RET(ge_model->GetGraph().GetName(graph_name), "Failed to get subgraph name");
graph = root_graph->GetSubgraph(graph_name.GetString());
GE_CHECK_NOTNULL(graph);
return SUCCESS;
}
} // namespace hybrid
} // namespace ge

+ 8
- 0
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.h View File

@@ -51,6 +51,14 @@ class KnownNodeExecutor : public NodeExecutor {
Status PrepareTask(NodeTask &task, TaskContext &context) const;
Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function<void()> &callback) const;
~KnownNodeExecutor() {}

private:
static Status ParseAttrForAllocatingOutputs(NodeItem &node_item, ComputeGraph &graph);
static Status GetDataNodes(ComputeGraph &graph, std::map<NodePtr, int> &data_indices);
static Status GetModelAndGraph(const HybridModel &model,
const NodePtr &node,
GeModelPtr &ge_model,
ComputeGraphPtr &graph);
};
} // namespace hybrid
} // namespace ge


+ 1
- 1
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -202,7 +202,7 @@ TEST_F(UtestGeHybrid, data_direct_connect) {
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get());
auto ret = hybrid_model_builder.IdentifyVariableOutputs(*new_node.get(), sub_graph);
ASSERT_EQ(ret, SUCCESS);
}



+ 20
- 0
tests/ut/ge/hybrid/known_node_executor_unittest.cc View File

@@ -26,6 +26,7 @@
#undef private
#undef protected
#include "graph/manager/graph_mem_allocator.h"
#include "../graph/passes/graph_builder_utils.h"

using namespace std;
using namespace testing;
@@ -69,3 +70,22 @@ TEST_F(UnknownNodeExecutorTest, test_init_davinci_model) {
model.weight_buffer_map_.emplace("subgraph", TensorBuffer::Create(buffer, sizeof(buffer)));
ASSERT_EQ(mock.InitDavinciModel(model, model.GetModelWeight("subgraph")), SUCCESS);
}

TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) {
ut::GraphBuilder builder("test-graph");
auto data_node = builder.AddNode("Data0", DATA, 1, 1);
auto netoutput_node = builder.AddNode("NodeOutput", NETOUTPUT, 2, 2);
builder.AddDataEdge(data_node, 0, netoutput_node, 0);
auto const_node = builder.AddNode("Const0", CONSTANT, 0, 1);
builder.AddDataEdge(const_node, 0, netoutput_node, 1);
auto graph = builder.GetGraph();

ut::GraphBuilder builder2("root-graph");
auto partitioned_call = builder2.AddNode("Node0", PARTITIONEDCALL, 1, 2);
NodeItem node_item(partitioned_call);
ASSERT_EQ(KnownNodeExecutor::ParseAttrForAllocatingOutputs(node_item, *graph), SUCCESS);
ASSERT_EQ(node_item.ref_outputs.size(), 1);
ASSERT_EQ(node_item.ref_outputs[1], const_node);
ASSERT_EQ(node_item.reuse_inputs.size(), 1);
ASSERT_EQ(node_item.reuse_inputs[0], 0);
}

Loading…
Cancel
Save