From: @zhengyuanhua Reviewed-by: @xchu42 Signed-off-by:tags/v1.3.0
@@ -619,19 +619,25 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr | |||
return SUCCESS; | |||
} | |||
input->SetDataType(dt_set); | |||
int64_t input_shape_size = 0; | |||
int64_t output_shape_size = 0; | |||
ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); | |||
ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); | |||
if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "GetTensorSize failed!"); | |||
return FAILED; | |||
} | |||
ge::TensorUtils::SetSize(*input, input_shape_size); | |||
const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); | |||
GE_CHECK_NOTNULL(output); | |||
output->SetDataType(dt_set); | |||
ge::TensorUtils::SetSize(*output, output_shape_size); | |||
GeShape shape = input->GetShape(); | |||
if (!shape.IsUnknownShape()) { | |||
int64_t input_shape_size = 0; | |||
int64_t output_shape_size = 0; | |||
ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); | |||
ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); | |||
if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { | |||
GELOGE(GRAPH_FAILED, "[Process][InputOp] Get tensor size of op [%s] failed!", node_ptr->GetName().c_str()); | |||
return FAILED; | |||
} | |||
ge::TensorUtils::SetSize(*input, input_shape_size); | |||
ge::TensorUtils::SetSize(*output, output_shape_size); | |||
GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str()); | |||
} | |||
if (is_dynamic_batch) { | |||
GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); | |||
auto switchn_op_desc = switchn_node->GetOpDesc(); | |||
@@ -1255,6 +1261,12 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { | |||
return GE_GRAPH_GRAPH_NODE_NULL; | |||
} | |||
GeTensorDesc output = op_desc_ptr->GetOutputDesc(0); | |||
GeShape output_shape = output.GetShape(); | |||
if (output_shape.IsUnknownShape()) { | |||
GELOGD("[Adjust][DataOpOutput] Shape of op [%s] output is unknown.", node->GetName().c_str()); | |||
return SUCCESS; | |||
} | |||
int64_t tensor_size = 0; | |||
graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size); | |||
if (graph_status != GRAPH_SUCCESS) { | |||
@@ -244,9 +244,11 @@ class GFlagUtils { | |||
" --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" | |||
" --input_format Format of input data. E.g.: \"NCHW\"\n" | |||
" --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " | |||
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." | |||
"Use double quotation marks (\") to enclose each argument.\n" | |||
" E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" | |||
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." | |||
"Use double quotation marks (\") to enclose each argument.\n" | |||
" E.g.: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2,c2~c3,h2,w2]\"\n" | |||
" --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" | |||
" --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). " | |||
"Use double quotation marks (\") to enclose each argument.\n" | |||
@@ -50,6 +50,28 @@ ComputeGraphPtr BuildGraph1(){ | |||
return builder.GetGraph(); | |||
} | |||
ComputeGraphPtr BuildGraph2() { | |||
auto builder = ut::GraphBuilder("g2"); | |||
auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({22, -1})); | |||
ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8"); | |||
auto data_opdesc = data1->GetOpDesc(); | |||
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0); | |||
data1->UpdateOpDesc(data_opdesc); | |||
return builder.GetGraph(); | |||
} | |||
ComputeGraphPtr BuildGraph3() { | |||
auto builder = ut::GraphBuilder("g3"); | |||
auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT); | |||
ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8"); | |||
auto data_opdesc = data1->GetOpDesc(); | |||
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0); | |||
data1->UpdateOpDesc(data_opdesc); | |||
return builder.GetGraph(); | |||
} | |||
TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { | |||
ge::GraphPrepare graph_prepare; | |||
graph_prepare.compute_graph_ = BuildGraph1(); | |||
@@ -88,4 +110,12 @@ TEST_F(UtestGraphPreproces, test_check_user_input) { | |||
Status ret = graph_prepare.CheckUserInput(user_input); | |||
EXPECT_EQ(ret, GE_GRAPH_INIT_FAILED); | |||
} | |||
TEST_F(UtestGraphPreproces, test_update_input_output1) { | |||
ge::GraphPrepare graph_prepare; | |||
graph_prepare.compute_graph_ = BuildGraph3(); | |||
Status ret = graph_prepare.UpdateInputOutputByOptions(); | |||
EXPECT_EQ(ret, SUCCESS); | |||
} | |||
} |