|
|
@@ -32,18 +32,6 @@ class UtestDavinciModel : public testing::Test { |
|
|
|
void SetUp() {} |
|
|
|
|
|
|
|
void TearDown() {} |
|
|
|
public: |
|
|
|
NodePtr MakeNode(const ComputeGraphPtr &graph, uint32_t in_num, uint32_t out_num, string name, string type) { |
|
|
|
GeTensorDesc test_desc(GeShape(), FORMAT_NCHW, DT_FLOAT); |
|
|
|
auto op_desc = std::make_shared<OpDesc>(name, type); |
|
|
|
for (auto i = 0; i < in_num; ++i) { |
|
|
|
op_desc->AddInputDesc(test_desc); |
|
|
|
} |
|
|
|
for (auto i = 0; i < out_num; ++i) { |
|
|
|
op_desc->AddOutputDesc(test_desc); |
|
|
|
} |
|
|
|
return graph->AddNode(op_desc); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
/*TEST_F(UtestDavinciModel, init_success) { |
|
|
@@ -755,4 +743,96 @@ TEST_F(UtestDavinciModel, init_data_aipp_input_dims_normal) { |
|
|
|
EXPECT_EQ(model.output_addrs_list_.size(), 0); |
|
|
|
EXPECT_EQ(model.op_list_.size(), 1); |
|
|
|
} |
|
|
|
|
|
|
|
// test label_set_task Init |
|
|
|
TEST_F(UtestDavinciModel, label_task_success) { |
|
|
|
DavinciModel model(0, nullptr); |
|
|
|
ComputeGraphPtr graph = make_shared<ComputeGraph>("default"); |
|
|
|
|
|
|
|
GeModelPtr ge_model = make_shared<GeModel>(); |
|
|
|
ge_model->SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); |
|
|
|
AttrUtils::SetInt(ge_model, ATTR_MODEL_MEMORY_SIZE, 5120000); |
|
|
|
AttrUtils::SetInt(ge_model, ATTR_MODEL_STREAM_NUM, 1); |
|
|
|
|
|
|
|
shared_ptr<domi::ModelTaskDef> model_task_def = make_shared<domi::ModelTaskDef>(); |
|
|
|
ge_model->SetModelTaskDef(model_task_def); |
|
|
|
|
|
|
|
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT32); |
|
|
|
TensorUtils::SetSize(tensor, 64); |
|
|
|
|
|
|
|
uint32_t op_index = 0; |
|
|
|
{ |
|
|
|
OpDescPtr op_desc = CreateOpDesc("label_switch", LABELSWITCHBYINDEX); |
|
|
|
op_desc->AddInputDesc(tensor); |
|
|
|
op_desc->SetInputOffset({1024}); |
|
|
|
NodePtr node = graph->AddNode(op_desc); // op_index = 0 |
|
|
|
EXPECT_TRUE(AttrUtils::SetListInt(op_desc, ATTR_NAME_LABEL_SWITCH_LIST, {0, 1})); |
|
|
|
|
|
|
|
domi::TaskDef *task_def1 = model_task_def->add_task(); |
|
|
|
task_def1->set_stream_id(0); |
|
|
|
task_def1->set_type(RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX); |
|
|
|
domi::LabelSwitchByIndexDef *label_task_def = task_def1->mutable_label_switch_by_index(); |
|
|
|
label_task_def->set_op_index(op_index++); |
|
|
|
label_task_def->set_label_max(2); |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
OpDescPtr op_desc = CreateOpDesc("label_then", LABELSET); |
|
|
|
NodePtr node = graph->AddNode(op_desc); // op_index = 1 |
|
|
|
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 1)); |
|
|
|
|
|
|
|
domi::TaskDef *task_def1 = model_task_def->add_task(); |
|
|
|
task_def1->set_stream_id(0); |
|
|
|
task_def1->set_type(RT_MODEL_TASK_LABEL_SET); |
|
|
|
domi::LabelSetDef *label_task_def = task_def1->mutable_label_set(); |
|
|
|
label_task_def->set_op_index(op_index++); |
|
|
|
} |
|
|
|
|
|
|
|
{ |
|
|
|
OpDescPtr op_desc = CreateOpDesc("label_goto", LABELGOTOEX); |
|
|
|
NodePtr node = graph->AddNode(op_desc); // op_index = 2 |
|
|
|
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 2)); |
|
|
|
|
|
|
|
domi::TaskDef *task_def2 = model_task_def->add_task(); |
|
|
|
task_def2->set_stream_id(0); |
|
|
|
task_def2->set_type(RT_MODEL_TASK_STREAM_LABEL_GOTO); |
|
|
|
domi::LabelGotoExDef *label_task_def = task_def2->mutable_label_goto_ex(); |
|
|
|
label_task_def->set_op_index(op_index++); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
{ |
|
|
|
OpDescPtr op_desc = CreateOpDesc("label_else", LABELSET); |
|
|
|
NodePtr node = graph->AddNode(op_desc); // op_index = 3 |
|
|
|
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 0)); |
|
|
|
|
|
|
|
domi::TaskDef *task_def1 = model_task_def->add_task(); |
|
|
|
task_def1->set_stream_id(0); |
|
|
|
task_def1->set_type(RT_MODEL_TASK_LABEL_SET); |
|
|
|
domi::LabelSetDef *label_task_def = task_def1->mutable_label_set(); |
|
|
|
label_task_def->set_op_index(op_index++); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
{ |
|
|
|
OpDescPtr op_desc = CreateOpDesc("label_leave", LABELSET); |
|
|
|
NodePtr node = graph->AddNode(op_desc); // op_index = 4 |
|
|
|
EXPECT_TRUE(AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, 2)); |
|
|
|
|
|
|
|
domi::TaskDef *task_def1 = model_task_def->add_task(); |
|
|
|
task_def1->set_stream_id(0); |
|
|
|
task_def1->set_type(RT_MODEL_TASK_LABEL_SET); |
|
|
|
domi::LabelSetDef *label_task_def = task_def1->mutable_label_set(); |
|
|
|
label_task_def->set_op_index(op_index++); |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
EXPECT_TRUE(AttrUtils::SetInt(ge_model, ATTR_MODEL_LABEL_NUM, 3)); |
|
|
|
EXPECT_EQ(model.Assign(ge_model), SUCCESS); |
|
|
|
EXPECT_EQ(model.Init(), SUCCESS); |
|
|
|
|
|
|
|
EXPECT_EQ(model.input_addrs_list_.size(), 0); |
|
|
|
EXPECT_EQ(model.output_addrs_list_.size(), 0); |
|
|
|
EXPECT_EQ(model.task_list_.size(), 5); |
|
|
|
} |
|
|
|
} // namespace ge |