Browse Source

Pre Merge pull request !1902 from lichun/master

pull/1902/MERGE
lichun Gitee 4 years ago
parent
commit
6ee2912f7a
4 changed files with 34 additions and 2 deletions
  1. +1
    -0
      ge/graph/load/model_manager/davinci_model.h
  2. +2
    -0
      ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc
  3. +2
    -2
      ge/offline/main.cc
  4. +29
    -0
      tests/ut/ge/hybrid/known_node_executor_unittest.cc

+ 1
- 0
ge/graph/load/model_manager/davinci_model.h View File

@@ -300,6 +300,7 @@ class DavinciModel {
return op_list_.at(index); return op_list_.at(index);
} }


void SetGlobalStep(void *global_step) { global_step_addr_ = global_step; }
void *GetGlobalStep() const { return global_step_addr_; } void *GetGlobalStep() const { return global_step_addr_; }


// get task info for profiling // get task info for profiling


+ 2
- 0
ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc View File

@@ -204,6 +204,8 @@ Status KnownNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node
davinci_model->SetId(model.GetModelId()); davinci_model->SetId(model.GetModelId());
davinci_model->SetDumpModelName(model.GetModelName()); davinci_model->SetDumpModelName(model.GetModelName());
davinci_model->SetOmName(model.GetOmName()); davinci_model->SetOmName(model.GetOmName());
TensorValue *global_step_var = model.GetVariable(NODE_NAME_GLOBAL_STEP);
davinci_model->SetGlobalStep(global_step_var->MutableData());
// set model id as root node's node id // set model id as root node's node id
davinci_model->SetSubModelId(node->GetOpDesc()->GetId()); davinci_model->SetSubModelId(node->GetOpDesc()->GetId());
GELOGD("KnownNodeExecutor::LoadTask node id %ld.", node->GetOpDesc()->GetId()); GELOGD("KnownNodeExecutor::LoadTask node id %ld.", node->GetOpDesc()->GetId());


+ 2
- 2
ge/offline/main.cc View File

@@ -1149,9 +1149,9 @@ domi::Status GenerateSingleOp(const std::string& json_file_path) {
if (ret != SUCCESS) { if (ret != SUCCESS) {
DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index); DOMI_LOGE("Compile op failed. ge ret = %u, op index = %d", ret, index);
ret = domi::FAILED; ret = domi::FAILED;
break;
} else {
GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str());
} }
GELOGI("Compile op success. op index = %d, output = %s", index, output_path.c_str());
index += 1; index += 1;
} }




+ 29
- 0
tests/ut/ge/hybrid/known_node_executor_unittest.cc View File

@@ -88,4 +88,33 @@ TEST_F(UnknownNodeExecutorTest, TestParseAttrForAllocatingOutputs) {
ASSERT_EQ(node_item.ref_outputs[1], const_node); ASSERT_EQ(node_item.ref_outputs[1], const_node);
ASSERT_EQ(node_item.reuse_inputs.size(), 1); ASSERT_EQ(node_item.reuse_inputs.size(), 1);
ASSERT_EQ(node_item.reuse_inputs[0], 0); ASSERT_EQ(node_item.reuse_inputs[0], 0);
}

TEST_F(UnknownNodeExecutorTest, TestSetGlobalStepInLoadTask) {
OpDescPtr op_desc = CreateOpDesc("PartitionedCall", "PartitionedCall");
auto node = root_graph->AddNode(op_desc);
node->SetOwnerComputeGraph(root_graph);
auto sub_graph = BuildDataDirectConnectGraph();
sub_graph->SetParentGraph(root_graph);
sub_graph->SetParentNode(node);
node->GetOpDesc()->AddSubgraphName("subgraph");
node->GetOpDesc()->SetSubgraphInstanceName(0, "subgraph");
root_graph->AddSubgraph("subgraph", sub_graph);

ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("root_graph");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(root_graph);
HybridModel hybrid_model(ge_root_model);
hybrid_model.model_id_ = 0;
hybrid_model.model_name_ = "root_model";
hybrid_model.om_name_ = "temp_om";
auto *step_id = new int64_t[1];
step_id[0] = 520;
TensorValue tensor_value((void *)step_id, sizeof(step_id));
hybrid_model.variable_tensors_.insert({"ge_global_step", make_unique<TensorValue>(tensor_value)});

KnownNodeExecutor known_node_executor;
shared_ptr<KnownNodeTask> task = nullptr;
Status ret = known_node_executor.LoadTask(hybrid_model, node, task);
EXPECT_EQ(*(task->davinci_model_->global_step_addr_), 520);
EXPECT_EQ(known_node_executor.LoadTask(hybrid_model, node, task), SUCCESS);
} }

Loading…
Cancel
Save