From a2da82c86fd78c0e199d87fdbf4137b370b53dcd Mon Sep 17 00:00:00 2001 From: l00444296 Date: Tue, 8 Dec 2020 15:33:42 +0800 Subject: [PATCH] Feature: Get default from ge ir graph while no user input shape --- ge/ir_build/ge_ir_build.cc | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/ge/ir_build/ge_ir_build.cc b/ge/ir_build/ge_ir_build.cc index 2035dbe0..518ab49b 100644 --- a/ge/ir_build/ge_ir_build.cc +++ b/ge/ir_build/ge_ir_build.cc @@ -36,6 +36,7 @@ #include "model/ge_model.h" #include "graph/shape_refiner.h" #include "graph/opsproto_manager.h" +#include "graph/utils/type_utils.h" using std::string; using namespace std; @@ -225,7 +226,7 @@ class Impl { ~Impl() { (void)generator_.Finalize(); }; graphStatus CheckOptions(const std::map &options); graphStatus CreateInputsForIRBuild(const ge::Graph &graph, vector &inputs); - graphStatus GetDefaultInputShape(const Graph &graph, string &default_shape); + graphStatus GetDefaultInputShapeAndFormat(const Graph &graph, string &default_shape, string &input_format); graphStatus Init(const Graph &graph, const std::map &options); graphStatus BuildModel(const Graph &graph, const std::map &options, ModelBufferData &ge_models); @@ -279,7 +280,7 @@ graphStatus Impl::CheckOptions(const std::map &options return GRAPH_SUCCESS; } -graphStatus Impl::GetDefaultInputShape(const Graph &graph, string &default_shape) { +graphStatus Impl::GetDefaultInputShapeAndFormat(const Graph &graph, string &default_shape, string &input_format) { auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { @@ -307,7 +308,11 @@ graphStatus Impl::GetDefaultInputShape(const Graph &graph, string &default_shape tmp_shape_str = tmp_shape_str.substr(0, tmp_shape_str.size() - 1); tmp_shape_str += ";"; default_shape += tmp_shape_str; - GELOGD("Data op name: %s, data shape: %s", data_op_name.c_str(), tmp_shape_str.c_str()); + + ge::Format data_format = tensor.GetFormat(); + input_format.assign(ge::TypeUtils::FormatToSerialString(data_format)); + GELOGD("Data op name: %s, data shape: %s, data format: %s.", data_op_name.c_str(), tmp_shape_str.c_str(), + input_format.c_str()); } } default_shape = (default_shape.empty() ? default_shape : default_shape.substr(0, default_shape.size() - 1)); @@ -334,13 +339,14 @@ graphStatus Impl::Init(const Graph &graph, const std::map