Browse Source

add copy graph

tags/v1.5.1
wuweikang 3 years ago
parent
commit
a6b6229967
6 changed files with 63 additions and 15 deletions
  1. +1
    -1
      ge/graph/manager/graph_manager.cc
  2. +1
    -0
      ge/hybrid/model/hybrid_model.h
  3. +36
    -9
      ge/hybrid/model/hybrid_model_builder.cc
  4. +1
    -0
      ge/hybrid/model/hybrid_model_builder.h
  5. +3
    -0
      tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc
  6. +21
    -5
      tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc

+ 1
- 1
ge/graph/manager/graph_manager.cc View File

@@ -3131,10 +3131,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) {
}
// Avoid repeatively prerun for graphs owns same graph_id in online inference concurrency
if (count > 1 && graph_node->GetBuildFlag()) {
graph_node->Lock();
GELOGD("Avoid repeatively prerun, graph_id:%u.", args.graph_id);
// In online inference concurrency senario, graph_node is allowed to be locked for 'count' times
graph_node->SetSemSize(count);
graph_node->Lock();
graph_manager->run_args_q_.Push(RunArgs( { graph_node, args.graph_id, args.session_id, args.error_context,
args.input_tensor, graph_node->GetGeRootModel(), GetThreadLocalContext(), args.callback }));
GELOGI("[PreRunThread] Loop end. Start to run with cached build model.");


+ 1
- 0
ge/hybrid/model/hybrid_model.h View File

@@ -147,6 +147,7 @@ class HybridModel {
GeRootModelPtr ge_root_model_;
std::map<uint32_t, NodeItem *> input_nodes_;
ComputeGraphPtr root_graph_;
ComputeGraphPtr orig_root_graph_;
std::map<std::string, NodePtr> device_variable_nodes_; //lint !e148
std::map<std::string, NodePtr> host_variable_nodes_; //lint !e148
std::map<std::string, std::unique_ptr<TensorValue>> variable_tensors_;


+ 36
- 9
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -147,6 +147,7 @@ Status HybridModelBuilder::Build() {
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName());
hybrid_model_.model_name_ = ge_root_model_->GetModelName();
GELOGI("[%s] Start to build hybrid model.", GetGraphName());
GE_CHK_STATUS_RET(CopyGraph(), "[Invoke][CopyGraph] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET(InitRuntimeParams(), "[Invoke][InitRuntimeParams] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET(RecoverGraphUnknownFlag(),
"[Invoke][RecoverGraphUnknownFlag] failed, model_name_:[%s]", GetGraphName());
@@ -171,11 +172,12 @@ Status HybridModelBuilder::Build() {

Status HybridModelBuilder::BuildForSingleOp() {
GE_CHK_STATUS_RET(ValidateParams(), "[Invoke][ValidateParams] failed, model_name_:[%s]", GetGraphName());
hybrid_model_.root_graph_ = ge_root_model_->GetRootGraph();
hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName();
GELOGI("[%s] Start to build hybrid model.", GetGraphName());
auto ret = ge_root_model_->GetSubgraphInstanceNameToModel();
const GeModelPtr ge_model = ret[ge_root_model_->GetRootGraph()->GetName()];
GE_CHK_STATUS_RET(IndexTaskDefs(ge_root_model_->GetRootGraph(), ge_model),
const GeModelPtr ge_model = ret[hybrid_model_.root_graph_->GetName()];
GE_CHK_STATUS_RET(IndexTaskDefs(hybrid_model_.root_graph_, ge_model),
"[Invoke][IndexTaskDefs] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET(LoadGraph(), "[Invoke][LoadGraph] failed, model_name_:[%s]", GetGraphName());
GE_CHK_STATUS_RET(InitWeights(), "[Invoke][InitWeights] failed, model_name_:[%s]", GetGraphName());
@@ -190,6 +192,27 @@ Status HybridModelBuilder::ValidateParams() {
return SUCCESS;
}

Status HybridModelBuilder::CopyGraph() {
GELOGD("Copy compute graph begin.");
auto root_graph = ge_root_model_->GetRootGraph();

std::string new_graph_name = ge_root_model_->GetRootGraph()->GetName();
ComputeGraphPtr new_root_graph = MakeShared<ComputeGraph>(new_graph_name);
GE_CHECK_NOTNULL(new_root_graph);
int32_t depth = 0;
std::map<ConstNodePtr, NodePtr> node_old_2_new;
std::map<ConstOpDescPtr, OpDescPtr> op_desc_old_2_new;
graphStatus ret = GraphUtils::CopyComputeGraph(root_graph, new_root_graph, node_old_2_new, op_desc_old_2_new, depth);
if (ret != GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "Copy compute graph failed.");
return GRAPH_FAILED;
}
hybrid_model_.root_graph_ = new_root_graph;

GELOGD("Copy compute graph[%s] success.", new_graph_name.c_str());
return SUCCESS;
}

Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) {
auto op_desc = node->GetOpDesc();
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item),
@@ -810,12 +833,13 @@ Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item,
}

Status HybridModelBuilder::LoadGraph() {
auto root_graph = ge_root_model_->GetRootGraph();
auto root_graph = hybrid_model_.root_graph_;
if (!GetContext().GetHostExecFlag()) {
std::shared_ptr<ComputeGraph> merged_graph;
GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu",
root_graph->GetDirectNodesSize(),
root_graph->GetAllNodesSize());
hybrid_model_.orig_root_graph_ = root_graph;
GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(root_graph, merged_graph),
"[Invoke][UnfoldSubgraphs]Failed to unfold subgraphs, model_name_:%s.", GetGraphName());
root_graph = std::move(merged_graph);
@@ -873,6 +897,7 @@ Status HybridModelBuilder::LoadGraph() {
}
for (auto &it : hybrid_model_.known_shape_sub_models_) {
auto node_item = MutableNodeItem(it.first);
GE_CHECK_NOTNULL(node_item);
AscendString graph_name;
GE_CHK_GRAPH_STATUS_RET(it.second->GetGraph().GetName(graph_name), "Failed to get subgraph name");
auto subgraph = hybrid_model_.GetRootGraph()->GetSubgraph(graph_name.GetString());
@@ -1121,7 +1146,9 @@ Status HybridModelBuilder::InitWeights() {
sub_weight_buffer->GetSize());
auto subgraph = GraphUtils::GetComputeGraph(subgraph_model.second->GetGraph());
if (subgraph != ge_root_model_->GetRootGraph()) {
subgraph = ge_root_model_->GetRootGraph()->GetSubgraph(subgraph_model.first);
subgraph = hybrid_model_.root_graph_->GetSubgraph(subgraph_model.first);
} else {
subgraph = hybrid_model_.root_graph_;
}
GE_CHECK_NOTNULL(subgraph);
hybrid_model_.weight_buffer_map_.emplace(subgraph->GetName(), std::move(sub_weight_buffer));
@@ -1300,7 +1327,7 @@ Status HybridModelBuilder::IndexTaskDefs(const ComputeGraphPtr &sub_graph, const
}

Status HybridModelBuilder::IndexTaskDefs() {
const auto root_graph = ge_root_model_->GetRootGraph();
const auto &root_graph = hybrid_model_.root_graph_;
const auto &root_graph_name = root_graph->GetName();
if (SetOutputNameAttr(*root_graph) != SUCCESS) {
GELOGW("Set output name attr failed.");
@@ -1334,7 +1361,7 @@ Status HybridModelBuilder::IndexTaskDefs() {

Status HybridModelBuilder::IndexSpecialNodes() {
GELOGD("Start to index special nodes");
const auto &root_graph = ge_root_model_->GetRootGraph();
const auto &root_graph = hybrid_model_.root_graph_;
for (auto &node : root_graph->GetAllNodes()) {
GE_CHECK_NOTNULL(node);
GE_CHECK_NOTNULL(node->GetOpDesc());
@@ -1489,7 +1516,7 @@ Status HybridModelBuilder::InitRuntimeParams() {
runtime_param_.session_id = ret ? static_cast<uint64_t>(value) : 0;
ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_TASK_GEN_VAR_ADDR, value);
runtime_param_.logic_var_base = ret ? static_cast<uint64_t>(value) : 0;
runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID();
runtime_param_.graph_id = hybrid_model_.root_graph_->GetGraphID();
value = 0;
for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) {
(void) ge::AttrUtils::GetInt(it.second, ATTR_MODEL_VAR_SIZE, value);
@@ -1626,7 +1653,7 @@ Status HybridModelBuilder::TransAllVarData() {
}

Status HybridModelBuilder::CopyVarData() {
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(ge_root_model_->GetRootGraph(),
GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(hybrid_model_.root_graph_,
runtime_param_.session_id,
hybrid_model_.device_id_),
"[Invoke][CopyVarData] failed.");
@@ -1709,7 +1736,7 @@ Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem
}

Status HybridModelBuilder::RecoverGraphUnknownFlag() {
const auto &root_graph = ge_root_model_->GetRootGraph();
const auto &root_graph = hybrid_model_.root_graph_;
for (auto &sub_graph : root_graph->GetAllSubgraphs()) {
GE_CHECK_NOTNULL(sub_graph);
for (const auto &node : sub_graph->GetDirectNode()) {


+ 1
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -56,6 +56,7 @@ class HybridModelBuilder {
Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph);
Status ValidateParams();
Status LoadGraph();
Status CopyGraph();
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model);
static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model);
Status LoadTask(NodeItem &node_item);


+ 3
- 0
tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc View File

@@ -249,6 +249,9 @@ TEST_F(UtestSubgraphExecutor, cond_graph_schedule_tasks) {
graph_context.callback_manager = std::unique_ptr<CallbackManager>(new CallbackManager());
ASSERT_EQ(graph_context.callback_manager->Init(), SUCCESS);

auto root_graph = hybrid_model.root_graph_;
switch_t = root_graph->FindNode("switch_t");
switch_f = root_graph->FindNode("switch_f");
const auto node_it_t = hybrid_model.node_items_.find(switch_t);
const auto node_it_f = hybrid_model.node_items_.find(switch_f);
ASSERT_NE(hybrid_model.node_items_.end(), node_it_t);


+ 21
- 5
tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc View File

@@ -214,11 +214,17 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) {
ASSERT_EQ(it->second->frame_index_, index);
ASSERT_EQ(it->second->parent_frame_, -1);
};
TestFrameGroup(enter1, control_group_index);
TestFrameGroup(active1, control_group_index);
TestFrameGroup(active2, control_group_index);
TestFrameGroup(active3, control_group_index);
TestFrameGroup(output1, -1);
auto root_graph = hybrid_model.root_graph_;
auto enter1_node = root_graph->FindNode("enter");
auto active1_node = root_graph->FindNode("active1");
auto active2_node = root_graph->FindNode("active2");
auto active3_node = root_graph->FindNode("active3");
auto output1_node = root_graph->FindNode("net_output");
TestFrameGroup(enter1_node, control_group_index);
TestFrameGroup(active1_node, control_group_index);
TestFrameGroup(active2_node, control_group_index);
TestFrameGroup(active3_node, control_group_index);
TestFrameGroup(output1_node, -1);

engine_mapping.clear();
task_executor.clear();
@@ -373,4 +379,14 @@ TEST_F(UtestHybridModelBuilder, TestInitHcclExecutorOnDemand) {
NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL);
ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS);
}

TEST_F(UtestHybridModelBuilder, copy_graph_success) {
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);

Status st = hybrid_model_builder.CopyGraph();
EXPECT_EQ(st, SUCCESS);
}
} // namespace ge

Loading…
Cancel
Save