|
|
|
@@ -23,6 +23,8 @@ |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "parser/common/op_parser_factory.h" |
|
|
|
#include "parser/common/acl_graph_parser_util.h" |
|
|
|
#include "parser/common/parser_utils.h" |
|
|
|
#include "register/register_utils.h" |
|
|
|
|
|
|
|
using domi::tensorflow::AttrValue; |
|
|
|
using std::vector; |
|
|
|
@@ -62,24 +64,24 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { |
|
|
|
Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { |
|
|
|
GE_CHECK_NOTNULL(op_src); |
|
|
|
GE_CHECK_NOTNULL(op); |
|
|
|
GE_CHECK_NOTNULL(op_dest); |
|
|
|
|
|
|
|
ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); |
|
|
|
GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op), |
|
|
|
"call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str()); |
|
|
|
op.BreakConnect(); |
|
|
|
|
|
|
|
const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); |
|
|
|
GE_CHECK_NOTNULL(node); |
|
|
|
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); |
|
|
|
bool has_axis = true; |
|
|
|
bool has_dims = true; |
|
|
|
|
|
|
|
domi::tensorflow::AttrValue axis; |
|
|
|
domi::tensorflow::AttrValue dims; |
|
|
|
if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis)) { |
|
|
|
has_axis = false; |
|
|
|
} |
|
|
|
if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims)) { |
|
|
|
has_dims = false; |
|
|
|
} |
|
|
|
|
|
|
|
bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis); |
|
|
|
bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims); |
|
|
|
if (!has_axis && !has_dims) { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
@@ -103,9 +105,9 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr |
|
|
|
int32_t result = values.i(i); |
|
|
|
v_result.push_back(result); |
|
|
|
} |
|
|
|
if (!ge::AttrUtils::SetListInt(op, SQUEEZE_ATTR_AXIS, v_result)) { |
|
|
|
if (!ge::AttrUtils::SetListInt(op_dest, SQUEEZE_ATTR_AXIS, v_result)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", SQUEEZE_ATTR_AXIS.c_str(), |
|
|
|
op->GetName().c_str(), op->GetType().c_str()); |
|
|
|
op_dest->GetName().c_str(), op_dest->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Set squeeze axis attr failed"); |
|
|
|
return FAILED; |
|
|
|
} |
|
|
|
@@ -125,14 +127,14 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr |
|
|
|
"parse output desc failed"); |
|
|
|
} |
|
|
|
|
|
|
|
if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { |
|
|
|
if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_INPUT_DESC.c_str(), |
|
|
|
op->GetName().c_str(), op->GetType().c_str()); |
|
|
|
op_dest->GetName().c_str(), op_dest->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Set input desc failed"); |
|
|
|
return FAILED; |
|
|
|
} if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { |
|
|
|
} if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { |
|
|
|
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_OUTPUT_DESC.c_str(), |
|
|
|
op->GetName().c_str(), op->GetType().c_str()); |
|
|
|
op_dest->GetName().c_str(), op_dest->GetType().c_str()); |
|
|
|
GELOGE(FAILED, "Set output desc failed"); |
|
|
|
return FAILED; |
|
|
|
}) |
|
|
|
|