diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index b5b0daa4..08dd6f98 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -619,19 +619,25 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr return SUCCESS; } input->SetDataType(dt_set); - int64_t input_shape_size = 0; - int64_t output_shape_size = 0; - ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); - ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); - if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "GetTensorSize failed!"); - return FAILED; - } - ge::TensorUtils::SetSize(*input, input_shape_size); const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); output->SetDataType(dt_set); - ge::TensorUtils::SetSize(*output, output_shape_size); + + GeShape shape = input->GetShape(); + if (!shape.IsUnknownShape()) { + int64_t input_shape_size = 0; + int64_t output_shape_size = 0; + ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); + ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); + if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "[Process][InputOp] Get tensor size of op [%s] failed!", node_ptr->GetName().c_str()); + return FAILED; + } + ge::TensorUtils::SetSize(*input, input_shape_size); + ge::TensorUtils::SetSize(*output, output_shape_size); + GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str()); + } + if (is_dynamic_batch) { GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); auto switchn_op_desc = switchn_node->GetOpDesc(); @@ -1255,6 +1261,12 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { return GE_GRAPH_GRAPH_NODE_NULL; } GeTensorDesc output = op_desc_ptr->GetOutputDesc(0); + GeShape output_shape = output.GetShape(); + if (output_shape.IsUnknownShape()) { + GELOGD("[Adjust][DataOpOutput] Shape of op [%s] output is unknown.", node->GetName().c_str()); + return SUCCESS; + } + int64_t tensor_size = 0; graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size); if (graph_status != GRAPH_SUCCESS) { diff --git a/ge/offline/main.cc b/ge/offline/main.cc index 69ee29de..30285780 100755 --- a/ge/offline/main.cc +++ b/ge/offline/main.cc @@ -244,9 +244,11 @@ class GFlagUtils { " --framework Framework type. 0:Caffe; 1:MindSpore; 3:Tensorflow; 5:Onnx\n" " --input_format Format of input data. E.g.: \"NCHW\"\n" " --input_shape Shape of input data. Separate multiple nodes with semicolons (;). " - " --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." "Use double quotation marks (\") to enclose each argument.\n" " E.g.: \"input_name1:n1,c1,h1,w1;input_name2:n2,c2,h2,w2\"\n" + " --input_shape_range Shape range of input data. Separate multiple nodes with semicolons (;)." + "Use double quotation marks (\") to enclose each argument.\n" + " E.g.: \"input_name1:[n1~n2,c1,h1,w1];input_name2:[n2,c2~c3,h2,w2]\"\n" " --dynamic_batch_size Set dynamic batch size. E.g.: \"batchsize1,batchsize2,batchsize3\"\n" " --dynamic_image_size Set dynamic image size. Separate multiple nodes with semicolons (;). " "Use double quotation marks (\") to enclose each argument.\n" diff --git a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc index 69192631..ff49f34c 100644 --- a/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc +++ b/tests/ut/ge/graph/preprocess/graph_preprocess_unittest.cc @@ -50,6 +50,28 @@ ComputeGraphPtr BuildGraph1(){ return builder.GetGraph(); } +ComputeGraphPtr BuildGraph2() { + auto builder = ut::GraphBuilder("g2"); + auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, std::vector({22, -1})); + ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8"); + auto data_opdesc = data1->GetOpDesc(); + AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0); + + data1->UpdateOpDesc(data_opdesc); + return builder.GetGraph(); +} + +ComputeGraphPtr BuildGraph3() { + auto builder = ut::GraphBuilder("g3"); + auto data1 = builder.AddNode("data1", DATA, 1, 1, FORMAT_NCHW, DT_FLOAT); + ge::AttrUtils::SetStr(data1->GetOpDesc(), ATTR_ATC_USER_DEFINE_DATATYPE, "DT_INT8"); + auto data_opdesc = data1->GetOpDesc(); + AttrUtils::SetInt(data_opdesc, ATTR_NAME_INDEX, 0); + + data1->UpdateOpDesc(data_opdesc); + return builder.GetGraph(); +} + TEST_F(UtestGraphPreproces, test_dynamic_input_shape_parse) { ge::GraphPrepare graph_prepare; graph_prepare.compute_graph_ = BuildGraph1(); @@ -88,4 +110,12 @@ TEST_F(UtestGraphPreproces, test_check_user_input) { Status ret = graph_prepare.CheckUserInput(user_input); EXPECT_EQ(ret, GE_GRAPH_INIT_FAILED); } + +TEST_F(UtestGraphPreproces, test_update_input_output1) { + ge::GraphPrepare graph_prepare; + graph_prepare.compute_graph_ = BuildGraph3(); + + Status ret = graph_prepare.UpdateInputOutputByOptions(); + EXPECT_EQ(ret, SUCCESS); +} } \ No newline at end of file