Browse Source

dts: when input node is dynamic, no need to cal memory size

tags/v1.3.0
zhengyuanhua 3 years ago
parent
commit
2210a7177c
3 changed files with 55 additions and 11 deletions
  1. +22
    -10
      ge/graph/preprocess/graph_preprocess.cc
  2. +3
    -1
      ge/offline/main.cc
  3. +30
    -0
      tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc

+ 22
- 10
ge/graph/preprocess/graph_preprocess.cc View File

@@ -619,19 +619,25 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr
return SUCCESS;
}
input->SetDataType(dt_set);
int64_t input_shape_size = 0;
int64_t output_shape_size = 0;
ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size);
ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size);
if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "GetTensorSize failed!");
return FAILED;
}
ge::TensorUtils::SetSize(*input, input_shape_size);
const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0);
GE_CHECK_NOTNULL(output);
output->SetDataType(dt_set);
ge::TensorUtils::SetSize(*output, output_shape_size);

GeShape shape = input->GetShape();
if (!shape.IsUnknownShape()) {
int64_t input_shape_size = 0;
int64_t output_shape_size = 0;
ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size);
ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size);
if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) {
GELOGE(GRAPH_FAILED, "[Process][InputOp] Get tensor size of op [%s] failed!", node_ptr->GetName().c_str());
return FAILED;
}
ge::TensorUtils::SetSize(*input, input_shape_size);
ge::TensorUtils::SetSize(*output, output_shape_size);
GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str());
}

if (is_dynamic_batch) {
GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str());
auto switchn_op_desc = switchn_node->GetOpDesc();
@@ -1255,6 +1261,12 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) {
return GE_GRAPH_GRAPH_NODE_NULL;
}
GeTensorDesc output = op_desc_ptr->GetOutputDesc(0);
GeShape output_shape = output.GetShape();
if (output_shape.IsUnknownShape()) {
GELOGD("[Adjust][DataOpOutput] Shape of op [%s] output is unknown.", node->GetName().c_str());
return SUCCESS;
}

int64_t tensor_size = 0;
graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size);
if (graph_status != GRAPH_SUCCESS) {


+ 3
- 1
ge/offline/main.cc View File

@@ -244,9 +244,11 @@ class GFlagUtils {
" --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n"
" --input_format Format of input data. E.g.: \"NCHW\"\n"
" --input_shape Shape of input data. Separate multiple nodes with semicolons (;). "
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)."
"Use double quotation marks (\") to enclose each argument.\n"
" E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n"
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)."
"Use double quotation marks (\") to enclose each argument.\n"
" E.g.: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2,c2~c3,h2,w2]\"\n"
" --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n"
" --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). "
"Use double quotation marks (\") to enclose each argument.\n"


+ 30
- 0
tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc View File

@@ -50,6 +50,28 @@ ComputeGraphPtr BuildGraph1(){
return builder.GetGraph();
}

ComputeGraphPtr BuildGraph2() {
auto builder = ut::GraphBuilder("g2");
auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({22, -1}));
ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8");
auto data_opdesc = data1->GetOpDesc();
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0);

data1->UpdateOpDesc(data_opdesc);
return builder.GetGraph();
}

ComputeGraphPtr BuildGraph3() {
auto builder = ut::GraphBuilder("g3");
auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT);
ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8");
auto data_opdesc = data1->GetOpDesc();
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0);

data1->UpdateOpDesc(data_opdesc);
return builder.GetGraph();
}

TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) {
ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph1();
@@ -88,4 +110,12 @@ TEST_F(UtestGraphPreproces, test_check_user_input) {
Status ret = graph_prepare.CheckUserInput(user_input);
EXPECT_EQ(ret, GE_GRAPH_INIT_FAILED);
}

TEST_F(UtestGraphPreproces, test_update_input_output1) {
ge::GraphPrepare graph_prepare;
graph_prepare.compute_graph_ = BuildGraph3();

Status ret = graph_prepare.UpdateInputOutputByOptions();
EXPECT_EQ(ret, SUCCESS);
}
}

Loading…
Cancel
Save