Browse Source

!1043 Enable rtLabelCreateExV2

From: @zhangxiaokun9
Reviewed-by: @wangxiaotian22,@ji_chen
Signed-off-by: @ji_chen
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
67eddb6ef5
6 changed files with 134 additions and 22 deletions
  1. +1
    -0
      .gitignore
  2. +1
    -1
      ge/graph/build/run_context.cc
  3. +1
    -1
      ge/graph/load/model_manager/davinci_model.cc
  4. +37
    -8
      tests/depends/runtime/src/runtime_stub.cc
  5. +2
    -0
      tests/ut/ge/CMakeLists.txt
  6. +92
    -12
      tests/ut/ge/graph/load/davinci_model_unittest.cc

+ 1
- 0
.gitignore View File

@@ -2,6 +2,7 @@
/build
/output
/prebuilts
/cov
*.ir
*.out



+ 1
- 1
ge/graph/build/run_context.cc View File

@@ -90,7 +90,7 @@ Status RunContextUtil::CreateRtModelResources(uint32_t stream_num, uint32_t even
// Create rt label
for (uint32_t i = 0; i < label_num; ++i) {
rtLabel_t label = nullptr;
rt_ret = rtLabelCreate(&label);
rt_ret = rtLabelCreateV2(&label, rt_model_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "rtLabelCreate failed. rt_ret = %d, index = %u", static_cast<int>(rt_ret), i);
return RT_FAILED;


+ 1
- 1
ge/graph/load/model_manager/davinci_model.cc View File

@@ -1402,7 +1402,7 @@ Status DavinciModel::InitLabelSet(const OpDescPtr &op_desc) {
}

rtLabel_t rt_label = nullptr;
rtError_t rt_error = rtLabelCreateEx(&rt_label, stream);
rtError_t rt_error = rtLabelCreateExV2(&rt_label, rt_model_handle_, stream);
if (rt_error != RT_ERROR_NONE || rt_label == nullptr) {
GELOGE(INTERNAL_ERROR, "InitLabelSet: %s create label failed, error=0x%x.", op_desc->GetName().c_str(), rt_error);
return INTERNAL_ERROR;


+ 37
- 8
tests/depends/runtime/src/runtime_stub.cc View File

@@ -245,9 +245,35 @@ rtError_t rtProfilerInit(const char *prof_dir, const char *address, const char *

rtError_t rtProfilerStart(void) { return RT_ERROR_NONE; }

rtError_t rtLabelCreate(rtLabel_t *label) { return RT_ERROR_NONE; }
rtError_t rtLabelCreate(rtLabel_t *label) {
*label = new uint64_t;
return RT_ERROR_NONE;
}

rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream) {
*label = new uint64_t;
return RT_ERROR_NONE;
}

rtError_t rtLabelCreateV2(rtLabel_t *label, rtModel_t model) {
*label = new uint64_t;
return RT_ERROR_NONE;
}

rtError_t rtLabelDestroy(rtLabel_t label) { return RT_ERROR_NONE; }
rtError_t rtLabelCreateExV2(rtLabel_t *label, rtModel_t model, rtStream_t stream) {
*label = new uint64_t;
return RT_ERROR_NONE;
}

rtError_t rtLabelListCpy(rtLabel_t *label, uint32_t labelNumber, void *dst, uint32_t dstMax) {
return RT_ERROR_NONE;
}

rtError_t rtLabelDestroy(rtLabel_t label) {
uint64_t *stub = static_cast<uint64_t *>(label);
delete stub;
return RT_ERROR_NONE;
}

rtError_t rtLabelSet(rtLabel_t label, rtStream_t stream) { return RT_ERROR_NONE; }

@@ -255,8 +281,17 @@ rtError_t rtLabelSwitch(void *ptr, rtCondition_t condition, uint32_t value, rtLa
return RT_ERROR_NONE;
}

rtError_t rtLabelSwitchByIndex(void *ptr, uint32_t max, void *labelInfoPtr, rtStream_t stream) {
return RT_ERROR_NONE;
}

rtError_t rtLabelGoto(rtLabel_t label, rtStream_t stream) { return RT_ERROR_NONE; }

rtError_t rtLabelGotoEx(rtLabel_t label, rtStream_t stream) {
return RT_ERROR_NONE;
}


rtError_t rtInvalidCache(uint64_t base, uint32_t len) { return RT_ERROR_NONE; }

rtError_t rtModelLoadComplete(rtModel_t model) { return RT_ERROR_NONE; }
@@ -364,12 +399,6 @@ rtError_t rtSetCtxINFMode(bool mode)
return RT_ERROR_NONE;
}

rtError_t rtLabelCreateEx(rtLabel_t *label, rtStream_t stream)
{
*label = new uint32_t;
return RT_ERROR_NONE;
}

rtError_t rtGetRtCapability(rtFeatureType_t featureType, int32_t featureInfo, int64_t *value)
{
return RT_ERROR_NONE;


+ 2
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -404,6 +404,8 @@ set(DISTINCT_GRAPH_LOAD_SRC_FILES
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/kernel_ex_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/kernel_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/label_set_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/label_goto_ex_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/label_switch_by_index_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/memcpy_addr_async_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/memcpy_async_task_info.cc"
"${GE_CODE_DIR}/ge/graph/load/model_manager/task_info/profiler_trace_task_info.cc"


+ 92
- 12
tests/ut/ge/graph/load/davinci_model_unittest.cc View File

@@ -32,18 +32,6 @@ class UtestDavinciModel : public testing::Test {
void SetUp() {}

void TearDown() {}
public:
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) {
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT);
auto op_desc = std::make_shared<OpDesc>(name, type);
for (auto i = 0; i < in_num; ++i) {
op_desc->AddInputDesc(test_desc);
}
for (auto i = 0; i < out_num; ++i) {
op_desc->AddOutputDesc(test_desc);
}
return graph->AddNode(op_desc);
}
};

/*TEST_F(UtestDavinciModel, init_success) {
@@ -755,4 +743,96 @@ TEST_F(UtestDavinciModel, init_data_aipp_input_dims_normal) {
EXPECT_EQ(model.output_addrs_list_.size(), 0);
EXPECT_EQ(model.op_list_.size(), 1);
}

// test label_set_task Init
TEST_F(UtestDavinciModel, label_task_success) {
DavinciModel model(0, nullptr);
ComputeGraphPtr graph = make_shared<ComputeGraph>("default");

GeModelPtr ge_model = make_shared<GeModel>();
ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph));
AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, 5120000);
AttrUtils::SetInt(ge_model, ATTR_MODEL_STREAM_NUM, 1);

shared_ptr<domi::ModelTaskDef> model_task_def = make_shared<domi::ModelTaskDef>();
ge_model->SetModelTaskDef(model_task_def);

GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT32);
TensorUtils::SetSize(tensor, 64);

uint32_t op_index = 0;
{
OpDescPtr op_desc = CreateOpDesc("label_switch", LABELSWITCHBYINDEX);
op_desc->AddInputDesc(tensor);
op_desc->SetInputOffset({1024});
NodePtr node = graph->AddNode(op_desc); // op_index = 0
EXPECT_TRUE(AttrUtils::SetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, {0, 1}));

domi::TaskDef *task_def1 = model_task_def->add_task();
task_def1->set_stream_id(0);
task_def1->set_type(RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX);
domi::LabelSwitchByIndexDef *label_task_def = task_def1->mutable_label_switch_by_index();
label_task_def->set_op_index(op_index++);
label_task_def->set_label_max(2);
}

{
OpDescPtr op_desc = CreateOpDesc("label_then", LABELSET);
NodePtr node = graph->AddNode(op_desc); // op_index = 1
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 1));

domi::TaskDef *task_def1 = model_task_def->add_task();
task_def1->set_stream_id(0);
task_def1->set_type(RT_MODEL_TASK_LABEL_SET);
domi::LabelSetDef *label_task_def = task_def1->mutable_label_set();
label_task_def->set_op_index(op_index++);
}

{
OpDescPtr op_desc = CreateOpDesc("label_goto", LABELGOTOEX);
NodePtr node = graph->AddNode(op_desc); // op_index = 2
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 2));

domi::TaskDef *task_def2 = model_task_def->add_task();
task_def2->set_stream_id(0);
task_def2->set_type(RT_MODEL_TASK_STREAM_LABEL_GOTO);
domi::LabelGotoExDef *label_task_def = task_def2->mutable_label_goto_ex();
label_task_def->set_op_index(op_index++);
}


{
OpDescPtr op_desc = CreateOpDesc("label_else", LABELSET);
NodePtr node = graph->AddNode(op_desc); // op_index = 3
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 0));

domi::TaskDef *task_def1 = model_task_def->add_task();
task_def1->set_stream_id(0);
task_def1->set_type(RT_MODEL_TASK_LABEL_SET);
domi::LabelSetDef *label_task_def = task_def1->mutable_label_set();
label_task_def->set_op_index(op_index++);
}


{
OpDescPtr op_desc = CreateOpDesc("label_leave", LABELSET);
NodePtr node = graph->AddNode(op_desc); // op_index = 4
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 2));

domi::TaskDef *task_def1 = model_task_def->add_task();
task_def1->set_stream_id(0);
task_def1->set_type(RT_MODEL_TASK_LABEL_SET);
domi::LabelSetDef *label_task_def = task_def1->mutable_label_set();
label_task_def->set_op_index(op_index++);
}


EXPECT_TRUE(AttrUtils::SetInt(ge_model, ATTR_MODEL_LABEL_NUM, 3));
EXPECT_EQ(model.Assign(ge_model), SUCCESS);
EXPECT_EQ(model.Init(), SUCCESS);

EXPECT_EQ(model.input_addrs_list_.size(), 0);
EXPECT_EQ(model.output_addrs_list_.size(), 0);
EXPECT_EQ(model.task_list_.size(), 5);
}
} // namespace ge

Loading…
Cancel
Save