diff --git a/ge/generator/ge_generator.cc b/ge/generator/ge_generator.cc index ad3084dc..2f56c367 100644 --- a/ge/generator/ge_generator.cc +++ b/ge/generator/ge_generator.cc @@ -528,6 +528,16 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr return SUCCESS; } +namespace { + bool IsNeedConnectInputOpForSingleOp(GeTensorDesc &tesor_desc) { + bool is_need = true; + // format and dtype is all reserved, stand for Optional input. When singleop scene + if (tensor_desc.GetFormat() == FORMAT_RESERVED && tensor_desc.GetDataType() == DT_UNDEFINED) { + is_need = false; + } + return is_need; + } +} Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &inputs, const vector &outputs, const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, bool is_offline) { @@ -575,12 +585,18 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in if (inputs.empty()) { for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { GE_CHECK_NOTNULL_EXEC(input_desc, return INTERNAL_ERROR); + if (!IsNeedConnectInputOpForSingleOp(*input_desc)) { + continue; + } GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, *input_desc, arg_index, false)); arg_index++; } } else { for (const auto &in_desc : inputs) { GeTensorDesc input_desc = in_desc.GetTensorDesc(); + if (!IsNeedConnectInputOpForSingleOp(input_desc)) { + continue; + } GE_CHK_STATUS_RET_NOLOG(AddInputs(compute_graph, op_node, input_desc, arg_index, true)); arg_index++; } diff --git a/ge/offline/single_op_parser.cc b/ge/offline/single_op_parser.cc index df75e21d..8a86f5c5 100644 --- a/ge/offline/single_op_parser.cc +++ b/ge/offline/single_op_parser.cc @@ -226,16 +226,11 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { } int index = 0; - for (auto &tensor_desc : op_desc.input_desc) { - if (tensor_desc.type == DT_UNDEFINED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input's dataType is invalid when the index is %d", index); - return false; - } - - if (tensor_desc.format == FORMAT_RESERVED) { - ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); + for (auto &tensor_desc : op_desc.output_desc) { + if ((tensor_desc.type == DT_UNDEFINED && tensor_desc.format != FORMAT_RESERVED) || + (tensor_desc.type != DT_UNDEFINED && tensor_desc.format == FORMAT_RESERVED)){ + ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); + GELOGE(PARAM_INVALID, "Input's dataType or format is invalid when the index is %d", index); return false; } ++index;