From: @zhengyuanhua Reviewed-by: @xchu42 Signed-off-by:tags/v1.3.0
@@ -619,19 +619,25 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
input->SetDataType(dt_set); | input->SetDataType(dt_set); | ||||
int64_t input_shape_size = 0; | |||||
int64_t output_shape_size = 0; | |||||
ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); | |||||
ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); | |||||
if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "GetTensorSize failed!"); | |||||
return FAILED; | |||||
} | |||||
ge::TensorUtils::SetSize(*input, input_shape_size); | |||||
const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); | const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); | ||||
GE_CHECK_NOTNULL(output); | GE_CHECK_NOTNULL(output); | ||||
output->SetDataType(dt_set); | output->SetDataType(dt_set); | ||||
ge::TensorUtils::SetSize(*output, output_shape_size); | |||||
GeShape shape = input->GetShape(); | |||||
if (!shape.IsUnknownShape()) { | |||||
int64_t input_shape_size = 0; | |||||
int64_t output_shape_size = 0; | |||||
ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); | |||||
ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); | |||||
if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "[Process][InputOp] Get tensor size of op [%s] failed!", node_ptr->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
ge::TensorUtils::SetSize(*input, input_shape_size); | |||||
ge::TensorUtils::SetSize(*output, output_shape_size); | |||||
GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str()); | |||||
} | |||||
if (is_dynamic_batch) { | if (is_dynamic_batch) { | ||||
GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); | GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); | ||||
auto switchn_op_desc = switchn_node->GetOpDesc(); | auto switchn_op_desc = switchn_node->GetOpDesc(); | ||||
@@ -1255,6 +1261,12 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { | |||||
return GE_GRAPH_GRAPH_NODE_NULL; | return GE_GRAPH_GRAPH_NODE_NULL; | ||||
} | } | ||||
GeTensorDesc output = op_desc_ptr->GetOutputDesc(0); | GeTensorDesc output = op_desc_ptr->GetOutputDesc(0); | ||||
GeShape output_shape = output.GetShape(); | |||||
if (output_shape.IsUnknownShape()) { | |||||
GELOGD("[Adjust][DataOpOutput] Shape of op [%s] output is unknown.", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
int64_t tensor_size = 0; | int64_t tensor_size = 0; | ||||
graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size); | graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size); | ||||
if (graph_status != GRAPH_SUCCESS) { | if (graph_status != GRAPH_SUCCESS) { | ||||
@@ -244,9 +244,11 @@ class GFlagUtils { | |||||
" --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" | " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" | ||||
" --input_format Format of input data. E.g.: \"NCHW\"\n" | " --input_format Format of input data. E.g.: \"NCHW\"\n" | ||||
" --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " | " --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " | ||||
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." | |||||
"Use double quotation marks (\") to enclose each argument.\n" | "Use double quotation marks (\") to enclose each argument.\n" | ||||
" E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" | " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" | ||||
" --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." | |||||
"Use double quotation marks (\") to enclose each argument.\n" | |||||
" E.g.: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2,c2~c3,h2,w2]\"\n" | |||||
" --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" | " --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" | ||||
" --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). " | " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). " | ||||
"Use double quotation marks (\") to enclose each argument.\n" | "Use double quotation marks (\") to enclose each argument.\n" | ||||
@@ -50,6 +50,28 @@ ComputeGraphPtr BuildGraph1(){ | |||||
return builder.GetGraph(); | return builder.GetGraph(); | ||||
} | } | ||||
ComputeGraphPtr BuildGraph2() { | |||||
auto builder = ut::GraphBuilder("g2"); | |||||
auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector<int64_t>({22, -1})); | |||||
ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8"); | |||||
auto data_opdesc = data1->GetOpDesc(); | |||||
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0); | |||||
data1->UpdateOpDesc(data_opdesc); | |||||
return builder.GetGraph(); | |||||
} | |||||
ComputeGraphPtr BuildGraph3() { | |||||
auto builder = ut::GraphBuilder("g3"); | |||||
auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT); | |||||
ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8"); | |||||
auto data_opdesc = data1->GetOpDesc(); | |||||
AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0); | |||||
data1->UpdateOpDesc(data_opdesc); | |||||
return builder.GetGraph(); | |||||
} | |||||
TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { | TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { | ||||
ge::GraphPrepare graph_prepare; | ge::GraphPrepare graph_prepare; | ||||
graph_prepare.compute_graph_ = BuildGraph1(); | graph_prepare.compute_graph_ = BuildGraph1(); | ||||
@@ -88,4 +110,12 @@ TEST_F(UtestGraphPreproces, test_check_user_input) { | |||||
Status ret = graph_prepare.CheckUserInput(user_input); | Status ret = graph_prepare.CheckUserInput(user_input); | ||||
EXPECT_EQ(ret, GE_GRAPH_INIT_FAILED); | EXPECT_EQ(ret, GE_GRAPH_INIT_FAILED); | ||||
} | } | ||||
TEST_F(UtestGraphPreproces, test_update_input_output1) { | |||||
ge::GraphPrepare graph_prepare; | |||||
graph_prepare.compute_graph_ = BuildGraph3(); | |||||
Status ret = graph_prepare.UpdateInputOutputByOptions(); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
} | |||||
} | } |