From 8d810ad0eef164fa0a2eec5649268a3dc1752f30 Mon Sep 17 00:00:00 2001 From: yanmingda Date: Sat, 25 Jun 2022 11:54:22 +0000 Subject: [PATCH] !577 tf reshape and squeeze add automapping Merge pull request !577 from yanmingda/ge_dev --- .../tensorflow/tensorflow_reshape_parser.cc | 15 +++++--- parser/tensorflow/tensorflow_reshape_parser.h | 2 +- .../tensorflow/tensorflow_squeeze_parser.cc | 34 ++++++++++--------- parser/tensorflow/tensorflow_squeeze_parser.h | 2 +- 4 files changed, 31 insertions(+), 22 deletions(-) diff --git a/parser/tensorflow/tensorflow_reshape_parser.cc b/parser/tensorflow/tensorflow_reshape_parser.cc index e06f14a..eed8c07 100644 --- a/parser/tensorflow/tensorflow_reshape_parser.cc +++ b/parser/tensorflow/tensorflow_reshape_parser.cc @@ -21,7 +21,9 @@ #include "parser/common/util.h" #include "parser/tensorflow/tensorflow_util.h" #include "parser/common/acl_graph_parser_util.h" +#include "parser/common/parser_utils.h" #include "omg/parser/parser_inner_ctx.h" +#include "register/register_utils.h" using domi::TENSORFLOW; using namespace ge::parser; @@ -57,9 +59,14 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att return SUCCESS; } -Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { +Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { GE_CHECK_NOTNULL(op_src); - GE_CHECK_NOTNULL(op); + GE_CHECK_NOTNULL(op_dest); + + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op), + "call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str()); + op.BreakConnect(); const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST(op_src); GE_CHECK_NOTNULL(node_src); @@ -82,10 +89,10 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr "parse output desc failed"); } - GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED, + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED, "set input desc failed"); - GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED, + GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED, "set output desc failed");); return SUCCESS; diff --git a/parser/tensorflow/tensorflow_reshape_parser.h b/parser/tensorflow/tensorflow_reshape_parser.h index 2d2a9bd..d0d2c3f 100644 --- a/parser/tensorflow/tensorflow_reshape_parser.h +++ b/parser/tensorflow/tensorflow_reshape_parser.h @@ -34,7 +34,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser * @return FAILED parse failed * @author */ - Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; }; } // namespace ge diff --git a/parser/tensorflow/tensorflow_squeeze_parser.cc b/parser/tensorflow/tensorflow_squeeze_parser.cc index 76e7f8d..fa55c93 100644 --- a/parser/tensorflow/tensorflow_squeeze_parser.cc +++ b/parser/tensorflow/tensorflow_squeeze_parser.cc @@ -23,6 +23,8 @@ #include "graph/utils/type_utils.h" #include "parser/common/op_parser_factory.h" #include "parser/common/acl_graph_parser_util.h" +#include "parser/common/parser_utils.h" +#include "register/register_utils.h" using domi::tensorflow::AttrValue; using std::vector; @@ -62,24 +64,24 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att return SUCCESS; } -Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { +Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { GE_CHECK_NOTNULL(op_src); - GE_CHECK_NOTNULL(op); + GE_CHECK_NOTNULL(op_dest); + + ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest); + GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op), + "call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str()); + op.BreakConnect(); const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST(op_src); GE_CHECK_NOTNULL(node); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); - bool has_axis = true; - bool has_dims = true; domi::tensorflow::AttrValue axis; domi::tensorflow::AttrValue dims; - if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis)) { - has_axis = false; - } - if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims)) { - has_dims = false; - } + + bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis); + bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims); if (!has_axis && !has_dims) { return SUCCESS; } @@ -103,9 +105,9 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr int32_t result = values.i(i); v_result.push_back(result); } - if (!ge::AttrUtils::SetListInt(op, SQUEEZE_ATTR_AXIS, v_result)) { + if (!ge::AttrUtils::SetListInt(op_dest, SQUEEZE_ATTR_AXIS, v_result)) { REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", SQUEEZE_ATTR_AXIS.c_str(), - op->GetName().c_str(), op->GetType().c_str()); + op_dest->GetName().c_str(), op_dest->GetType().c_str()); GELOGE(FAILED, "Set squeeze axis attr failed"); return FAILED; } @@ -125,14 +127,14 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr "parse output desc failed"); } - if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { + if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) { REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_INPUT_DESC.c_str(), - op->GetName().c_str(), op->GetType().c_str()); + op_dest->GetName().c_str(), op_dest->GetType().c_str()); GELOGE(FAILED, "Set input desc failed"); return FAILED; - } if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { + } if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) { REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_OUTPUT_DESC.c_str(), - op->GetName().c_str(), op->GetType().c_str()); + op_dest->GetName().c_str(), op_dest->GetType().c_str()); GELOGE(FAILED, "Set output desc failed"); return FAILED; }) diff --git a/parser/tensorflow/tensorflow_squeeze_parser.h b/parser/tensorflow/tensorflow_squeeze_parser.h index c2bba6f..b7675b3 100644 --- a/parser/tensorflow/tensorflow_squeeze_parser.h +++ b/parser/tensorflow/tensorflow_squeeze_parser.h @@ -22,7 +22,7 @@ namespace ge { class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser { public: - Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override; + Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; private: static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc);