From c5cce79af15157be0be80ceb7ac34afb26f3d2d4 Mon Sep 17 00:00:00 2001 From: wxl Date: Fri, 23 Oct 2020 15:19:09 +0800 Subject: [PATCH] Feature: support optional input for single op --- ge/generator/ge_generator.cc | 16 ++++++++++++++++ ge/offline/single_op_parser.cc | 15 +++++---------- metadef | 2 +- parser | 2 +- 4 files changed, 23 insertions(+), 12 deletions(-) diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index ad3084dc..93933bc8 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -528,6 +528,16 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr return SUCCESS; } +namespace { + bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tensor_desc) { + bool is_need = true; + // format and dtype is all reserved, stand for Optional input. When singleop scene + if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) { + is_need = false; + } + return is_need; + } +} Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline) { @@ -575,12 +585,18 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in if (inputs.empty()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); + if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { + continue; + } GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); arg_index++; } } else { for (const auto &in_desc : inputs) { GeTensorDesc input_desc = in_desc.GetTensorDesc(); + if (!IsNeedConnectInputOpForSingleOp(input_desc)) { + continue; + } GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); arg_index++; } diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index df75e21d..8a86f5c5 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -226,16 +226,11 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { } int index = 0; - for (auto &tensor_desc : op_desc.input_desc) { - if (tensor_desc.type == DT_UNDEFINED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input's dataType is invalid when the index is %d", index); - return false; - } - - if (tensor_desc.format == FORMAT_RESERVED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); + for (auto &tensor_desc : op_desc.output_desc) { + if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || + (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index); return false; } ++index; diff --git a/metadef b/metadef index 302c31ca..ae80e9e2 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 302c31caf83f331e528b4a6005555af0cfbaca81 +Subproject commit ae80e9e2369458d468c16fd95bb970922ff3a084 diff --git a/parser b/parser index bdeeb7ff..c1530c60 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit bdeeb7ff55f2408cd01ee7a33bf2692ae53c9cbb +Subproject commit c1530c6083ea18c6c6c4c14b08253830e3982344