/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "parser/tensorflow/tensorflow_squeeze_parser.h" #include #include #include "framework/common/debug/ge_log.h" #include "common/util.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "graph/utils/type_utils.h" #include "parser/common/op_parser_factory.h" #include "parser/common/acl_graph_parser_util.h" #include "parser/common/parser_utils.h" #include "register/register_utils.h" using domi::tensorflow::AttrValue; using std::vector; using std::shared_ptr; using domi::TENSORFLOW; using namespace ge::parser; namespace ge { Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc) { int32_t tf_datatype = 0; auto a_list = attr_value.list(); GE_CHK_BOOL_RET_STATUS(TensorFlowUtil::ParseFromAttrValueList(ge_desc, a_list, 0, tf_datatype), domi::PARAM_INVALID, "parse ge_desc failed."); uint32_t size_type; auto data_type = ge_desc.GetDataType(); bool type_ret = ge::TypeUtils::GetDataTypeLength(data_type, size_type); GE_IF_BOOL_EXEC(!type_ret, REPORT_CALL_ERROR("E19999", "Data type %s is not supported", ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); GELOGE(domi::PARAM_INVALID, "Can't GetDataTypeLength of data_type: %s", ge::TypeUtils::DataTypeToSerialString(data_type).c_str()); return domi::PARAM_INVALID); // calculate size int64_t real_size = 1; for (uint32_t j = 0; j < ge_desc.GetShape().GetDimNum(); ++j) { int64_t tmp_dim = ge_desc.GetShape().GetDim(j); GE_IF_BOOL_EXEC(tmp_dim < 0, real_size = tmp_dim * (-1) * real_size; continue;); PARSER_INT64_MULCHECK(real_size, tmp_dim); real_size *= tmp_dim; } PARSER_INT64_MULCHECK(real_size, size_type); ge::TensorUtils::SetSize(ge_desc, real_size * size_type); ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); GELOGD("after translate tf_desc, datatype: %s, format: %s, real size: %ld, size_type: %u", ge::TypeUtils::DataTypeToSerialString(ge_desc.GetDataType()).c_str(), ge::TypeUtils::FormatToSerialString(ge_desc.GetFormat()).c_str(), real_size * size_type, size_type); return SUCCESS; } Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_dest); ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op), "call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str()); op.BreakConnect(); const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST(op_src); GE_CHECK_NOTNULL(node); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); domi::tensorflow::AttrValue axis; domi::tensorflow::AttrValue dims; bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis); bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims); if (!has_axis && !has_dims) { return SUCCESS; } if (has_axis && has_dims) { REPORT_CALL_ERROR("E19999", "In NodeDef %s, Attr %s or %s not exist, check invalid", node->name().c_str(), SQUEEZE_ATTR_AXIS.c_str(), SQUEEZE_ATTR_DIMS.c_str()); GELOGE(FAILED, "In NodeDef %s dim and axis is error.", node->name().c_str()); return domi::PARAM_INVALID; } domi::tensorflow::AttrValue_ListValue values; if (has_axis) { values = axis.list(); } else { values = dims.list(); } int i = 0; int size = values.i_size(); vector v_result; for (i = 0; i < size; i++) { int32_t result = values.i(i); v_result.push_back(result); } if (!ge::AttrUtils::SetListInt(op_dest, SQUEEZE_ATTR_AXIS, v_result)) { REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", SQUEEZE_ATTR_AXIS.c_str(), op_dest->GetName().c_str(), op_dest->GetType().c_str()); GELOGE(FAILED, "Set squeeze axis attr failed"); return FAILED; } domi::tensorflow::AttrValue input_attr_value; domi::tensorflow::AttrValue output_attr_value; GE_IF_BOOL_EXEC( GetParserContext().train_flag, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed"); } if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { GE_CHK_BOOL_RET_STATUS(ParseDesc(output_attr_value, output_desc) == SUCCESS, FAILED, "parse output desc failed"); } if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_INPUT_DESC.c_str(), op_dest->GetName().c_str(), op_dest->GetType().c_str()); GELOGE(FAILED, "Set input desc failed"); return FAILED; } if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_OUTPUT_DESC.c_str(), op_dest->GetName().c_str(), op_dest->GetType().c_str()); GELOGE(FAILED, "Set output desc failed"); return FAILED; }) return SUCCESS; } REGISTER_OP_PARSER_CREATOR(TENSORFLOW, SQUEEZE, TensorFlowSqueezeParser); } // namespace ge