Browse Source

!577 tf reshape and squeeze add automapping

Merge pull request !577 from yanmingda/ge_dev
pull/580/head
yanmingda i-robot 3 years ago
parent
commit
8d810ad0ee
4 changed files with 31 additions and 22 deletions
  1. +11
    -4
      parser/tensorflow/tensorflow_reshape_parser.cc
  2. +1
    -1
      parser/tensorflow/tensorflow_reshape_parser.h
  3. +18
    -16
      parser/tensorflow/tensorflow_squeeze_parser.cc
  4. +1
    -1
      parser/tensorflow/tensorflow_squeeze_parser.h

+ 11
- 4
parser/tensorflow/tensorflow_reshape_parser.cc View File

@@ -21,7 +21,9 @@
#include "parser/common/util.h" #include "parser/common/util.h"
#include "parser/tensorflow/tensorflow_util.h" #include "parser/tensorflow/tensorflow_util.h"
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "parser/common/parser_utils.h"
#include "omg/parser/parser_inner_ctx.h" #include "omg/parser/parser_inner_ctx.h"
#include "register/register_utils.h"


using domi::TENSORFLOW; using domi::TENSORFLOW;
using namespace ge::parser; using namespace ge::parser;
@@ -57,9 +59,14 @@ Status TensorFlowReshapeParser::ParseDesc(const domi::tensorflow::AttrValue &att
return SUCCESS; return SUCCESS;
} }


Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
GE_CHECK_NOTNULL(op_dest);

ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op),
"call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str());
op.BreakConnect();


const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); const domi::tensorflow::NodeDef *node_src = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node_src); GE_CHECK_NOTNULL(node_src);
@@ -82,10 +89,10 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr
"parse output desc failed"); "parse output desc failed");
} }


GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED,
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc), FAILED,
"set input desc failed"); "set input desc failed");


GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED,
GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc), FAILED,
"set output desc failed");); "set output desc failed"););


return SUCCESS; return SUCCESS;


+ 1
- 1
parser/tensorflow/tensorflow_reshape_parser.h View File

@@ -34,7 +34,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser
* @return FAILED parse failed * @return FAILED parse failed
* @author * @author
*/ */
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;
}; };
} // namespace ge } // namespace ge




+ 18
- 16
parser/tensorflow/tensorflow_squeeze_parser.cc View File

@@ -23,6 +23,8 @@
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"
#include "parser/common/op_parser_factory.h" #include "parser/common/op_parser_factory.h"
#include "parser/common/acl_graph_parser_util.h" #include "parser/common/acl_graph_parser_util.h"
#include "parser/common/parser_utils.h"
#include "register/register_utils.h"


using domi::tensorflow::AttrValue; using domi::tensorflow::AttrValue;
using std::vector; using std::vector;
@@ -62,24 +64,24 @@ Status TensorFlowSqueezeParser::ParseDesc(const domi::tensorflow::AttrValue &att
return SUCCESS; return SUCCESS;
} }


Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) {
Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) {
GE_CHECK_NOTNULL(op_src); GE_CHECK_NOTNULL(op_src);
GE_CHECK_NOTNULL(op);
GE_CHECK_NOTNULL(op_dest);

ge::Operator op = ge::OpDescUtils::CreateOperatorFromOpDesc(op_dest);
GE_CHK_STATUS_RET(domi::OperatorAutoMapping(op_src, op),
"call auto mapping failed for node:%s", ParserUtils::GetOperatorName(op).c_str());
op.BreakConnect();


const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src); const domi::tensorflow::NodeDef *node = DOMI_DYNAMIC_CAST<const domi::tensorflow::NodeDef *>(op_src);
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str()); GELOGD("TF op node name = %s, op type= %s, parse params", node->name().c_str(), node->op().c_str());
bool has_axis = true;
bool has_dims = true;


domi::tensorflow::AttrValue axis; domi::tensorflow::AttrValue axis;
domi::tensorflow::AttrValue dims; domi::tensorflow::AttrValue dims;
if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis)) {
has_axis = false;
}
if (!TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims)) {
has_dims = false;
}

bool has_axis = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_AXIS, axis);
bool has_dims = TensorFlowUtil::FindAttrValue(node, SQUEEZE_ATTR_DIMS, dims);
if (!has_axis && !has_dims) { if (!has_axis && !has_dims) {
return SUCCESS; return SUCCESS;
} }
@@ -103,9 +105,9 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
int32_t result = values.i(i); int32_t result = values.i(i);
v_result.push_back(result); v_result.push_back(result);
} }
if (!ge::AttrUtils::SetListInt(op, SQUEEZE_ATTR_AXIS, v_result)) {
if (!ge::AttrUtils::SetListInt(op_dest, SQUEEZE_ATTR_AXIS, v_result)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", SQUEEZE_ATTR_AXIS.c_str(), REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", SQUEEZE_ATTR_AXIS.c_str(),
op->GetName().c_str(), op->GetType().c_str());
op_dest->GetName().c_str(), op_dest->GetType().c_str());
GELOGE(FAILED, "Set squeeze axis attr failed"); GELOGE(FAILED, "Set squeeze axis attr failed");
return FAILED; return FAILED;
} }
@@ -125,14 +127,14 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr
"parse output desc failed"); "parse output desc failed");
} }


if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) {
if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_INPUT_DESC, input_desc)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_INPUT_DESC.c_str(), REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_INPUT_DESC.c_str(),
op->GetName().c_str(), op->GetType().c_str());
op_dest->GetName().c_str(), op_dest->GetType().c_str());
GELOGE(FAILED, "Set input desc failed"); GELOGE(FAILED, "Set input desc failed");
return FAILED; return FAILED;
} if (!ge::AttrUtils::SetTensorDesc(op, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) {
} if (!ge::AttrUtils::SetTensorDesc(op_dest, RESHAPE_ATTR_NAME_OUTPUT_DESC, output_desc)) {
REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_OUTPUT_DESC.c_str(), REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", RESHAPE_ATTR_NAME_OUTPUT_DESC.c_str(),
op->GetName().c_str(), op->GetType().c_str());
op_dest->GetName().c_str(), op_dest->GetType().c_str());
GELOGE(FAILED, "Set output desc failed"); GELOGE(FAILED, "Set output desc failed");
return FAILED; return FAILED;
}) })


+ 1
- 1
parser/tensorflow/tensorflow_squeeze_parser.h View File

@@ -22,7 +22,7 @@
namespace ge { namespace ge {
class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser { class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser {
public: public:
Status ParseParams(const Message *op_src, ge::OpDescPtr &op) override;
Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override;


private: private:
static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc);


Loading…
Cancel
Save