| @@ -28,6 +28,10 @@ | |||||
| using domi::ParseParamByOpFunc; | using domi::ParseParamByOpFunc; | ||||
| using domi::ParseParamFunc; | using domi::ParseParamFunc; | ||||
| using domi::CAFFE; | |||||
| using domi::caffe::LayerParameter; | |||||
| using domi::caffe::InnerProductParameter; | |||||
| using domi::caffe::ConvolutionParameter; | |||||
| using std::vector; | using std::vector; | ||||
| namespace ge { | namespace ge { | ||||
| @@ -42,7 +42,7 @@ class PARSER_FUNC_VISIBILITY CaffeCustomParserAdapter : public CaffeOpParser { | |||||
| * @return FAILED parse failed | * @return FAILED parse failed | ||||
| * @author | * @author | ||||
| */ | */ | ||||
| Status ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest); | |||||
| static Status ParseParams(const Operator &op_src, ge::OpDescPtr &op_dest); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -25,10 +25,11 @@ | |||||
| #include "parser/common/op_parser_factory.h" | #include "parser/common/op_parser_factory.h" | ||||
| using namespace ge::parser; | using namespace ge::parser; | ||||
| using domi::CAFFE; | |||||
| namespace ge { | namespace ge { | ||||
| Status CaffeDataParser::GetOutputDesc(const string &name, int dim_size, const std::vector<int64_t> &input_dims, | |||||
| ge::OpDescPtr &op) { | |||||
| Status CaffeDataParser::GetOutputDesc(const string &name, const std::vector<int64_t> &input_dims, | |||||
| const ge::OpDescPtr &op) { | |||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| GELOGI("The input dim size is %zu in layer %s.", input_dims.size(), name.c_str()); | GELOGI("The input dim size is %zu in layer %s.", input_dims.size(), name.c_str()); | ||||
| @@ -52,7 +53,7 @@ Status CaffeDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) { | |||||
| if (layer->type() == ge::parser::INPUT_TYPE) { | if (layer->type() == ge::parser::INPUT_TYPE) { | ||||
| GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | GE_CHK_STATUS_RET(ParseParamsForInput(layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | ||||
| "layer type= %s", layer->name().c_str(), layer->type().c_str()); | "layer type= %s", layer->name().c_str(), layer->type().c_str()); | ||||
| } else if(layer->type() == ge::parser::DUMMY_DATA) { | |||||
| } else if (layer->type() == ge::parser::DUMMY_DATA) { | |||||
| GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | GE_CHK_STATUS_RET(ParseParamsForDummyData(layer, op), "[Parse][Params] failed, Caffe layer name = %s, " | ||||
| "layer type= %s", layer->name().c_str(), layer->type().c_str()); | "layer type= %s", layer->name().c_str(), layer->type().c_str()); | ||||
| } else { | } else { | ||||
| @@ -85,7 +86,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
| } | } | ||||
| string name = layer->name(); | string name = layer->name(); | ||||
| GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | ||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims, op), | |||||
| "[Get][OutputDesc] failed in layer %s", name.c_str()); | "[Get][OutputDesc] failed in layer %s", name.c_str()); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -102,7 +103,7 @@ Status CaffeDataParser::ParseParamsForInput(const domi::caffe::LayerParameter *l | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::vector<int64_t> dims = search->second; | std::vector<int64_t> dims = search->second; | ||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, dims, op), | |||||
| "[Get][OutputDesc] failed in layer %s.", name.c_str()); | "[Get][OutputDesc] failed in layer %s.", name.c_str()); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -130,7 +131,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
| string name = layer->name(); | string name = layer->name(); | ||||
| GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | GE_IF_BOOL_EXEC(shape_map.count(name) != 0, model_dims = shape_map.at(name)); | ||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims.size(), model_dims, op), | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, model_dims, op), | |||||
| "[Get][OutputDesc] failed in layer %s", name.c_str()); | "[Get][OutputDesc] failed in layer %s", name.c_str()); | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -147,7 +148,7 @@ Status CaffeDataParser::ParseParamsForDummyData(const domi::caffe::LayerParamete | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::vector<int64_t> dims = search->second; | std::vector<int64_t> dims = search->second; | ||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, dims.size(), dims, op), | |||||
| GE_CHK_STATUS_RET(GetOutputDesc(name, dims, op), | |||||
| "[Get][OutputDesc] failed in layer %s.", name.c_str()); | "[Get][OutputDesc] failed in layer %s.", name.c_str()); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -45,8 +45,7 @@ class PARSER_FUNC_VISIBILITY CaffeDataParser : public CaffeOpParser, public Data | |||||
| * @return SUCCESS parse successfully | * @return SUCCESS parse successfully | ||||
| * @return FAILED parse failed | * @return FAILED parse failed | ||||
| */ | */ | ||||
| Status GetOutputDesc(const std::string &name, int dim_size, | |||||
| const std::vector<int64_t> &input_dims, ge::OpDescPtr &op); | |||||
| Status GetOutputDesc(const std::string &name, const std::vector<int64_t> &input_dims, const ge::OpDescPtr &op); | |||||
| // caffe data layer type could be type of `Input` or `DummyData` | // caffe data layer type could be type of `Input` or `DummyData` | ||||
| Status ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); | Status ParseParamsForInput(const domi::caffe::LayerParameter *layer, ge::OpDescPtr &op); | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
| using namespace ge::parser; | using namespace ge::parser; | ||||
| using domi::caffe::BlobProto; | |||||
| using domi::CAFFE; | using domi::CAFFE; | ||||
| namespace ge { | namespace ge { | ||||
| @@ -45,24 +45,6 @@ | |||||
| #include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
| #include "proto/caffe/caffe.pb.h" | #include "proto/caffe/caffe.pb.h" | ||||
| using domi::caffe::ArgMaxParameter; | |||||
| using domi::caffe::BatchNormParameter; | |||||
| using domi::caffe::BlobProto; | |||||
| using domi::caffe::BlobShape; | |||||
| using domi::caffe::ConcatParameter; | |||||
| using domi::caffe::ConvolutionParameter; | |||||
| using domi::caffe::DetectionOutputParameter; | |||||
| using domi::caffe::EltwiseParameter; | |||||
| using domi::caffe::FillerParameter; | |||||
| using domi::caffe::InnerProductParameter; | |||||
| using domi::caffe::LayerParameter; | |||||
| using domi::caffe::PoolingParameter; | |||||
| using domi::caffe::PReLUParameter; | |||||
| using domi::caffe::ReshapeParameter; | |||||
| using domi::caffe::ROIAlignParameter; | |||||
| using domi::caffe::TanHParameter; | |||||
| using domi::caffe::UpsampleParameter; | |||||
| namespace ge { | namespace ge { | ||||
| /** | /** | ||||
| * @ingroup ge_omg | * @ingroup ge_omg | ||||
| @@ -107,7 +89,7 @@ class PARSER_FUNC_VISIBILITY CaffeOpParser : public OpParser { | |||||
| * @return SUCCESS parse successfully | * @return SUCCESS parse successfully | ||||
| * @return FAILED parse failed | * @return FAILED parse failed | ||||
| */ | */ | ||||
| static Status ConvertWeight(const BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight); | |||||
| static Status ConvertWeight(const domi::caffe::BlobProto &proto, const string &lay_name, ge::GeTensorPtr &weight); | |||||
| /** | /** | ||||
| * @ingroup ge_omg | * @ingroup ge_omg | ||||
| @@ -115,7 +97,7 @@ class PARSER_FUNC_VISIBILITY CaffeOpParser : public OpParser { | |||||
| * @param [in] proto Shape information before conversion | * @param [in] proto Shape information before conversion | ||||
| * @param [out] shape Save converted shape information | * @param [out] shape Save converted shape information | ||||
| */ | */ | ||||
| static void ConvertShape(const BlobProto &proto, std::vector<int64_t> &shape); | |||||
| static void ConvertShape(const domi::caffe::BlobProto &proto, std::vector<int64_t> &shape); | |||||
| private: | private: | ||||
| /** | /** | ||||
| @@ -126,7 +108,7 @@ class PARSER_FUNC_VISIBILITY CaffeOpParser : public OpParser { | |||||
| * @return SUCCESS parse weight type successfully | * @return SUCCESS parse weight type successfully | ||||
| * @return FAILED parse failed | * @return FAILED parse failed | ||||
| */ | */ | ||||
| static Status ParseWeightType(const BlobProto &proto, const ge::GeShape &shape, | |||||
| static Status ParseWeightType(const domi::caffe::BlobProto &proto, const ge::GeShape &shape, | |||||
| int size, const string &lay_name, ge::GeTensorPtr &weight); | int size, const string &lay_name, ge::GeTensorPtr &weight); | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -21,6 +21,13 @@ | |||||
| #include <sstream> | #include <sstream> | ||||
| #include <memory> | #include <memory> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <google/protobuf/compiler/importer.h> | |||||
| #include <google/protobuf/descriptor.h> | |||||
| #include <google/protobuf/dynamic_message.h> | |||||
| #include <google/protobuf/io/coded_stream.h> | |||||
| #include <google/protobuf/io/zero_copy_stream_impl.h> | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include "common/convert/message2operator.h" | #include "common/convert/message2operator.h" | ||||
| #include "parser/common/convert/pb2json.h" | #include "parser/common/convert/pb2json.h" | ||||
| #include "parser/common/acl_graph_parser_util.h" | #include "parser/common/acl_graph_parser_util.h" | ||||
| @@ -32,12 +39,6 @@ | |||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include <google/protobuf/compiler/importer.h> | |||||
| #include <google/protobuf/descriptor.h> | |||||
| #include <google/protobuf/dynamic_message.h> | |||||
| #include <google/protobuf/io/coded_stream.h> | |||||
| #include <google/protobuf/io/zero_copy_stream_impl.h> | |||||
| #include <google/protobuf/text_format.h> | |||||
| #include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
| #include "omg/parser/parser_factory.h" | #include "omg/parser/parser_factory.h" | ||||
| #include "omg/parser/parser_inner_ctx.h" | #include "omg/parser/parser_inner_ctx.h" | ||||
| @@ -54,6 +55,8 @@ | |||||
| #include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
| #include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
| using domi::caffe::ConvolutionParameter; | |||||
| using domi::caffe::InnerProductParameter; | |||||
| using domi::caffe::LayerParameter; | using domi::caffe::LayerParameter; | ||||
| using domi::caffe::NetParameter; | using domi::caffe::NetParameter; | ||||
| using domi::ParseParamByOpFunc; | using domi::ParseParamByOpFunc; | ||||
| @@ -68,7 +71,7 @@ using std::ifstream; | |||||
| #define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \ | #define CAFFE_CHECK_NULL_AND_REPROT_ERRORMSG(val, errormsg) \ | ||||
| do { \ | do { \ | ||||
| if (val == nullptr) { \ | |||||
| if ((val) == nullptr) { \ | |||||
| GELOGE(ge::PARAM_INVALID, errormsg); \ | GELOGE(ge::PARAM_INVALID, errormsg); \ | ||||
| REPORT_INNER_ERROR("E19999", errormsg); \ | REPORT_INNER_ERROR("E19999", errormsg); \ | ||||
| return ge::PARAM_INVALID; \ | return ge::PARAM_INVALID; \ | ||||
| @@ -1384,7 +1387,7 @@ Status CaffeModelParser::SaveDataLayerTops(const domi::caffe::LayerParameter &la | |||||
| if (node->GetType() == ge::parser::DATA) { | if (node->GetType() == ge::parser::DATA) { | ||||
| if (layer.top_size() != 1) { | if (layer.top_size() != 1) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E11035", {"opname", "size"}, | ErrorManager::GetInstance().ATCReportErrMessage("E11035", {"opname", "size"}, | ||||
| {name, std::to_string(layer.top_size())}); | |||||
| {name, std::to_string(layer.top_size())}); | |||||
| GELOGE(FAILED, "[Check][Type]Data layer[%s] top size must be 1, real size: %d", name.c_str(), layer.top_size()); | GELOGE(FAILED, "[Check][Type]Data layer[%s] top size must be 1, real size: %d", name.c_str(), layer.top_size()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1895,7 +1898,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r | |||||
| #define CASE_FIELD_NAME(kName, method) \ | #define CASE_FIELD_NAME(kName, method) \ | ||||
| if (filed_name == kField##kName) { \ | if (filed_name == kField##kName) { \ | ||||
| string value = reflection->GetString(*message, field); \ | string value = reflection->GetString(*message, field); \ | ||||
| GELOGD("Parse result(%s : %s)", filed_name.c_str(), value.c_str());\ | |||||
| GELOGD("Parse res: (%s : %s)", filed_name.c_str(), value.c_str()); \ | |||||
| layer_proto->set_##method(value); \ | layer_proto->set_##method(value); \ | ||||
| return SUCCESS; \ | return SUCCESS; \ | ||||
| } | } | ||||
| @@ -1906,7 +1909,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r | |||||
| if (filed_name == kField##kName) { \ | if (filed_name == kField##kName) { \ | ||||
| int field_size = reflection->FieldSize(*message, field); \ | int field_size = reflection->FieldSize(*message, field); \ | ||||
| for (int i = 0; i < field_size; ++i) { \ | for (int i = 0; i < field_size; ++i) { \ | ||||
| string value = reflection->GetRepeatedString(*message, field, i);\ | |||||
| auto value = reflection->GetRepeatedString(*message, field, i); \ | |||||
| layer_proto->add_##method(value); \ | layer_proto->add_##method(value); \ | ||||
| } \ | } \ | ||||
| return SUCCESS; \ | return SUCCESS; \ | ||||
| @@ -1917,7 +1920,7 @@ Status CaffeWeightsParser::ParseLayerField(const google::protobuf::Reflection *r | |||||
| if (filed_name == kFieldBlobs) { | if (filed_name == kFieldBlobs) { | ||||
| int field_size = reflection->FieldSize(*message, field); | int field_size = reflection->FieldSize(*message, field); | ||||
| for (int i = 0; i < field_size; ++i) { | for (int i = 0; i < field_size; ++i) { | ||||
| BlobProto *item_message = layer_proto->add_blobs(); | |||||
| domi::caffe::BlobProto *item_message = layer_proto->add_blobs(); | |||||
| const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); | const google::protobuf::Message &sub_message = reflection->GetRepeatedMessage(*message, field, i); | ||||
| if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { | if (ConvertBlobsProto(&sub_message, item_message) != SUCCESS) { | ||||
| GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); | GELOGE(FAILED, "[Invoke][ConvertBlobsProto] ParseLayerField of field: %s failed.", field->name().c_str()); | ||||
| @@ -37,7 +37,6 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/operator.h" | |||||
| #include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
| #include "omg/parser/model_parser.h" | #include "omg/parser/model_parser.h" | ||||
| #include "omg/parser/weights_parser.h" | #include "omg/parser/weights_parser.h" | ||||
| @@ -26,6 +26,9 @@ | |||||
| using namespace ge::parser; | using namespace ge::parser; | ||||
| using domi::CAFFE; | using domi::CAFFE; | ||||
| using domi::caffe::BlobShape; | |||||
| using domi::caffe::LayerParameter; | |||||
| using domi::caffe::ReshapeParameter; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| @@ -81,7 +84,7 @@ Status CaffeReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr &op) | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CaffeReshapeParser::ParseWeights(const Message *op_src, ge::OpDescPtr &op) { | |||||
| Status CaffeReshapeParser::ParseWeights(const Message *op_src, const ge::OpDescPtr &op) const { | |||||
| (void)op_src; | (void)op_src; | ||||
| (void)op; | (void)op; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -41,7 +41,7 @@ class PARSER_FUNC_VISIBILITY CaffeReshapeParser : public CaffeOpParser { | |||||
| * @return FAILED parse failed | * @return FAILED parse failed | ||||
| * @author | * @author | ||||
| */ | */ | ||||
| Status ParseWeights(const Message *op_src, ge::OpDescPtr &op); | |||||
| Status ParseWeights(const Message *op_src, const ge::OpDescPtr &op) const; | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -57,8 +57,8 @@ const uint32_t kSetOutputWithNodeAndIndex = 0x1; | |||||
| const uint32_t kSetOutputWithTensorName = 0x2; | const uint32_t kSetOutputWithTensorName = 0x2; | ||||
| const uint32_t kSetOutputModeMixed = 0x3; | const uint32_t kSetOutputModeMixed = 0x3; | ||||
| const std::set<domi::FrameworkType> kSupportTensorAsOutput = { | const std::set<domi::FrameworkType> kSupportTensorAsOutput = { | ||||
| domi::CAFFE, | |||||
| domi::ONNX | |||||
| domi::CAFFE, | |||||
| domi::ONNX | |||||
| }; | }; | ||||
| static string GetSoPath() { | static string GetSoPath() { | ||||
| @@ -318,7 +318,7 @@ domi::Status AclGrphParseUtil::ParseAclOutputNodes(const string &out_nodes) { | |||||
| index_v.emplace_back(index); | index_v.emplace_back(index); | ||||
| ge::GetParserContext().out_nodes_map.emplace(key_value_v[0], index_v); | ge::GetParserContext().out_nodes_map.emplace(key_value_v[0], index_v); | ||||
| } | } | ||||
| ge::GetParserContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); | |||||
| ge::GetParserContext().user_out_nodes.emplace_back(key_value_v[0], index); | |||||
| } | } | ||||
| if (set_output_mode == kSetOutputModeMixed) { | if (set_output_mode == kSetOutputModeMixed) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | ||||
| @@ -38,8 +38,8 @@ class AclGrphParseUtil { | |||||
| public: | public: | ||||
| AclGrphParseUtil() {} | AclGrphParseUtil() {} | ||||
| virtual ~AclGrphParseUtil() {} | virtual ~AclGrphParseUtil() {} | ||||
| domi::Status LoadOpsProtoLib(); | |||||
| void SaveCustomCaffeProtoPath(); | |||||
| static domi::Status LoadOpsProtoLib(); | |||||
| static void SaveCustomCaffeProtoPath(); | |||||
| domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | domi::Status AclParserInitialize(const std::map<std::string, std::string> &options); | ||||
| domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | domi::Status SetOutputNodeInfo(ge::Graph &graph, const std::map<AscendString, AscendString> &parser_params); | ||||
| domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | domi::Status ParseParamsBeforeGraph(const std::map<AscendString, AscendString> &parser_params, | ||||
| @@ -52,7 +52,7 @@ class AclGrphParseUtil { | |||||
| domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | domi::Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | ||||
| void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | void CreateOutputNodesInfo(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | ||||
| std::vector<std::string> &output_nodes_name); | std::vector<std::string> &output_nodes_name); | ||||
| void SetDefaultFormat(); | |||||
| static void SetDefaultFormat(); | |||||
| domi::Status ParseAclOutputNodes(const std::string &out_nodes); | domi::Status ParseAclOutputNodes(const std::string &out_nodes); | ||||
| domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); | domi::Status ParseAclOutputFp16NodesFormat(const std::string &is_output_fp16); | ||||
| domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes); | domi::Status ParseAclEnableScope(const std::string &enable_scope_fusion_passes); | ||||
| @@ -158,7 +158,7 @@ std::string CurrentTimeInStr(); | |||||
| template <typename T, typename... Args> | template <typename T, typename... Args> | ||||
| static inline std::shared_ptr<T> MakeShared(Args &&... args) { | static inline std::shared_ptr<T> MakeShared(Args &&... args) { | ||||
| typedef typename std::remove_const<T>::type T_nc; | |||||
| using T_nc = typename std::remove_const<T>::type; | |||||
| std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...)); | std::shared_ptr<T> ret(new (std::nothrow) T_nc(std::forward<Args>(args)...)); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ const uint32_t kScalarLength = 1; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> &shape, ge::OpDescPtr op) { | |||||
| FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const std::vector<int64_t> &shape, ge::OpDescPtr op) { | |||||
| GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "[Check][Param] ParseShape failed for data_op, op is null"); | GE_RETURN_WITH_LOG_IF_FALSE(op != nullptr, "[Check][Param] ParseShape failed for data_op, op is null"); | ||||
| const string &data_op_name = op->GetName(); | const string &data_op_name = op->GetName(); | ||||
| @@ -45,7 +45,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> & | |||||
| } | } | ||||
| // convert input | // convert input | ||||
| vector<int64_t> def_format_shape(shape); | |||||
| std::vector<int64_t> def_format_shape(shape); | |||||
| ge::GeTensorDesc i_tensor_desc; | ge::GeTensorDesc i_tensor_desc; | ||||
| ge::GeTensorDesc o_tensor_desc; | ge::GeTensorDesc o_tensor_desc; | ||||
| @@ -98,7 +98,7 @@ FMK_FUNC_HOST_VISIBILITY Status DataOpParser::ParseShape(const vector<int64_t> & | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status DataOpParser::Init5DInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &tensor_desc) { | |||||
| Status DataOpParser::Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensor_desc) { | |||||
| tensor_desc.SetDataType(ge::DT_FLOAT16); | tensor_desc.SetDataType(ge::DT_FLOAT16); | ||||
| tensor_desc.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0)); | tensor_desc.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0)); | ||||
| ge::TensorUtils::SetReuseInput(tensor_desc, false); | ge::TensorUtils::SetReuseInput(tensor_desc, false); | ||||
| @@ -117,7 +117,8 @@ Status DataOpParser::Init5DInputTensor(const vector<int64_t> &shape, ge::GeTenso | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status DataOpParser::InitNDTensor(const vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &tensor_desc) { | |||||
| Status DataOpParser::InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, | |||||
| ge::GeTensorDesc &tensor_desc) { | |||||
| // Fixed input ND | // Fixed input ND | ||||
| tensor_desc.SetFormat(static_cast<ge::Format>(DOMI_TENSOR_ND)); | tensor_desc.SetFormat(static_cast<ge::Format>(DOMI_TENSOR_ND)); | ||||
| tensor_desc.SetDataType(data_type); | tensor_desc.SetDataType(data_type); | ||||
| @@ -143,7 +144,7 @@ Status DataOpParser::InitNDTensor(const vector<int64_t> &shape, ge::DataType dat | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status DataOpParser::Init5DOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output) { | |||||
| Status DataOpParser::Init5DOutputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &output) { | |||||
| output.SetDataType(ge::DT_FLOAT16); | output.SetDataType(ge::DT_FLOAT16); | ||||
| output.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0)); | output.SetFormat(static_cast<ge::Format>(domi::DOMI_TENSOR_NC1HWC0)); | ||||
| ge::TensorUtils::SetReuseInput(output, false); | ge::TensorUtils::SetReuseInput(output, false); | ||||
| @@ -162,7 +163,7 @@ Status DataOpParser::Init5DOutputTensor(const vector<int64_t> &shape, ge::GeTens | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status DataOpParser::InitInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &input) { | |||||
| Status DataOpParser::InitInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &input) { | |||||
| input.SetFormat(static_cast<ge::Format>(domiTensorFormat_t(DOMI_TENSOR_ND))); | input.SetFormat(static_cast<ge::Format>(domiTensorFormat_t(DOMI_TENSOR_ND))); | ||||
| input.SetDataType(ge::DT_FLOAT); | input.SetDataType(ge::DT_FLOAT); | ||||
| input.SetOriginDataType(ge::DT_FLOAT); | input.SetOriginDataType(ge::DT_FLOAT); | ||||
| @@ -181,7 +182,7 @@ Status DataOpParser::InitInputTensor(const vector<int64_t> &shape, ge::GeTensorD | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status DataOpParser::InitOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output) { | |||||
| Status DataOpParser::InitOutputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &output) { | |||||
| int64_t output_size = 0; | int64_t output_size = 0; | ||||
| ge::GeShape output_shape = ge::GeShape(shape); | ge::GeShape output_shape = ge::GeShape(shape); | ||||
| ge::Format format = ge::FORMAT_ND; | ge::Format format = ge::FORMAT_ND; | ||||
| @@ -32,9 +32,6 @@ | |||||
| #include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| using google::protobuf::Message; | |||||
| using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -54,7 +51,7 @@ class DataOpParser { | |||||
| * @return SUCCESS Parsing success | * @return SUCCESS Parsing success | ||||
| * @return FAILED Parsing failed | * @return FAILED Parsing failed | ||||
| */ | */ | ||||
| static Status ParseShape(const vector<int64_t> &shape, ge::OpDescPtr op); | |||||
| static Status ParseShape(const std::vector<int64_t> &shape, ge::OpDescPtr op); | |||||
| private: | private: | ||||
| /** | /** | ||||
| @@ -63,7 +60,7 @@ class DataOpParser { | |||||
| * @param [in] 4D shape information (dimensions) | * @param [in] 4D shape information (dimensions) | ||||
| * @param [out] Save converted shap information | * @param [out] Save converted shap information | ||||
| */ | */ | ||||
| static Status Init5DInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &tensorDesc); | |||||
| static Status Init5DInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &tensorDesc); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -73,7 +70,7 @@ class DataOpParser { | |||||
| * @return SUCCESS Convert success | * @return SUCCESS Convert success | ||||
| * @return FAILED Convert failed | * @return FAILED Convert failed | ||||
| */ | */ | ||||
| static Status Init5DOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output); | |||||
| static Status Init5DOutputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &output); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -81,7 +78,7 @@ class DataOpParser { | |||||
| * @param [in] 4D shape information (dimensions) | * @param [in] 4D shape information (dimensions) | ||||
| * @param [out] input Save converted shap information | * @param [out] input Save converted shap information | ||||
| */ | */ | ||||
| static Status InitInputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &input); | |||||
| static Status InitInputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &input); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -91,7 +88,7 @@ class DataOpParser { | |||||
| * @return SUCCESS Convert success | * @return SUCCESS Convert success | ||||
| * @return FAILED Convert failed | * @return FAILED Convert failed | ||||
| */ | */ | ||||
| static Status InitOutputTensor(const vector<int64_t> &shape, ge::GeTensorDesc &output); | |||||
| static Status InitOutputTensor(const std::vector<int64_t> &shape, ge::GeTensorDesc &output); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -101,7 +98,7 @@ class DataOpParser { | |||||
| * @return SUCCESS Convert success | * @return SUCCESS Convert success | ||||
| * @return FAILED Convert failed | * @return FAILED Convert failed | ||||
| */ | */ | ||||
| static Status InitNDTensor(const vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &desc); | |||||
| static Status InitNDTensor(const std::vector<int64_t> &shape, ge::DataType data_type, ge::GeTensorDesc &desc); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -17,9 +17,6 @@ | |||||
| #ifndef PARSER_COMMON_GRAPH_PASS_H_ | #ifndef PARSER_COMMON_GRAPH_PASS_H_ | ||||
| #define PARSER_COMMON_GRAPH_PASS_H_ | #define PARSER_COMMON_GRAPH_PASS_H_ | ||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "common/pass.h" | #include "common/pass.h" | ||||
| @@ -14,10 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "parser/common/model_saver.h" | |||||
| #include <sys/stat.h> | #include <sys/stat.h> | ||||
| #include <fcntl.h> | #include <fcntl.h> | ||||
| #include "parser/common/model_saver.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| @@ -124,7 +125,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int ModelSaver::CreateDirectory | |||||
| auto dir_path_len = directory_path.length(); | auto dir_path_len = directory_path.length(); | ||||
| if (dir_path_len >= PATH_MAX) { | if (dir_path_len >= PATH_MAX) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | ErrorManager::GetInstance().ATCReportErrMessage( | ||||
| "E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)}); | |||||
| "E19002", {"filepath", "size"}, {directory_path, std::to_string(PATH_MAX)}); | |||||
| GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX); | GELOGW("Path[%s] len is too long, it must be less than %d", directory_path.c_str(), PATH_MAX); | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| @@ -93,7 +93,7 @@ static void UpdateTensorForOpDesc(const ParserOperator &op, ge::OpDescPtr op_def | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(const ParserOperator &op, | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status ConvertToOpDesc(const ParserOperator &op, | ||||
| ge::OpDescPtr op_def) { | |||||
| const ge::OpDescPtr &op_def) { | |||||
| if (op_def == nullptr) { | if (op_def == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "param op_def is nullptr, check invalid."); | REPORT_INNER_ERROR("E19999", "param op_def is nullptr, check invalid."); | ||||
| GELOGE(ge::FAILED, "[Check][Param] param op_def is nullptr, check invalid."); | GELOGE(ge::FAILED, "[Check][Param] param op_def is nullptr, check invalid."); | ||||
| @@ -19,15 +19,11 @@ | |||||
| #include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
| #include "parser/common/op_def/operator.h" | #include "parser/common/op_def/operator.h" | ||||
| #include "graph/ge_attr_value.h" | |||||
| #include "graph/ge_tensor.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
| namespace ge { | namespace ge { | ||||
| domi::Status ConvertToOpDesc(const ParserOperator &op, ge::OpDescPtr op_def); | |||||
| domi::Status ConvertToOpDesc(const ParserOperator &op, const ge::OpDescPtr &op_def); | |||||
| domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def, ParserOperator &op); | domi::Status ConvertFromOpDesc(const ge::OpDescPtr op_def, ParserOperator &op); | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -49,7 +49,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator: | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::InputTensorDesc( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::InputTensorDesc( | ||||
| const ge::GeTensorDesc &input_tensordesc) { | |||||
| const ge::GeTensorDesc &input_tensordesc) { | |||||
| input_descs_.push_back(input_tensordesc); | input_descs_.push_back(input_tensordesc); | ||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -76,8 +76,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator: | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| FMK_FUNC_DEV_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector( | FMK_FUNC_DEV_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator::AttrVector( | ||||
| std::string key, | |||||
| std::vector<int64_t> &value) { | |||||
| std::string key, std::vector<int64_t> &value) { | |||||
| domi::AttrDef out; | domi::AttrDef out; | ||||
| auto it = op_attrs_.find(key); | auto it = op_attrs_.find(key); | ||||
| if (it != op_attrs_.end()) { | if (it != op_attrs_.end()) { | ||||
| @@ -91,12 +90,12 @@ FMK_FUNC_DEV_VISIBILITY FMK_FUNC_DEV_VISIBILITY ParserOperator &ParserOperator:: | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| ParserOperator &ParserOperator::Attr(const OpAttribute &attr) { | |||||
| auto it = op_attrs_.find(attr.name_); | |||||
| ParserOperator &ParserOperator::Attr(const OpAttribute &op_attr) { | |||||
| auto it = op_attrs_.find(op_attr.name_); | |||||
| if (it != op_attrs_.end()) { | if (it != op_attrs_.end()) { | ||||
| (void)op_attrs_.erase(it); | (void)op_attrs_.erase(it); | ||||
| } | } | ||||
| (void)op_attrs_.insert(std::make_pair(attr.name_, attr)); | |||||
| (void)op_attrs_.insert(std::make_pair(op_attr.name_, op_attr)); | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -29,8 +29,6 @@ | |||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| #include "external/register/register.h" | #include "external/register/register.h" | ||||
| using domi::CAFFE; | |||||
| namespace ge { | namespace ge { | ||||
| class OpParser; | class OpParser; | ||||
| @@ -32,8 +32,7 @@ class GE_FUNC_VISIBILITY OpTypeContainer { | |||||
| void Register(const std::string &op_type) { op_type_list_.insert(op_type); } | void Register(const std::string &op_type) { op_type_list_.insert(op_type); } | ||||
| bool IsExisting(const std::string &op_type) { | bool IsExisting(const std::string &op_type) { | ||||
| auto iter_find = op_type_list_.find(op_type); | |||||
| return iter_find != op_type_list_.end(); | |||||
| return op_type_list_.count(op_type) > 0UL; | |||||
| } | } | ||||
| protected: | protected: | ||||
| @@ -24,7 +24,7 @@ FMK_FUNC_HOST_VISIBILITY WeightsParserFactory *WeightsParserFactory::Instance() | |||||
| } | } | ||||
| std::shared_ptr<WeightsParser> WeightsParserFactory::CreateWeightsParser(const domi::FrameworkType type) { | std::shared_ptr<WeightsParser> WeightsParserFactory::CreateWeightsParser(const domi::FrameworkType type) { | ||||
| std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::const_iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | if (iter != creator_map_.end()) { | ||||
| return iter->second(); | return iter->second(); | ||||
| } | } | ||||
| @@ -35,7 +35,7 @@ std::shared_ptr<WeightsParser> WeightsParserFactory::CreateWeightsParser(const d | |||||
| FMK_FUNC_HOST_VISIBILITY void WeightsParserFactory::RegisterCreator(const domi::FrameworkType type, | FMK_FUNC_HOST_VISIBILITY void WeightsParserFactory::RegisterCreator(const domi::FrameworkType type, | ||||
| WEIGHTS_PARSER_CREATOR_FUN fun) { | WEIGHTS_PARSER_CREATOR_FUN fun) { | ||||
| std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| std::map<domi::FrameworkType, WEIGHTS_PARSER_CREATOR_FUN>::const_iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | if (iter != creator_map_.end()) { | ||||
| GELOGW("WeightsParserFactory::RegisterCreator: %d creator already exist", type); | GELOGW("WeightsParserFactory::RegisterCreator: %d creator already exist", type); | ||||
| return; | return; | ||||
| @@ -54,7 +54,7 @@ FMK_FUNC_HOST_VISIBILITY ModelParserFactory *ModelParserFactory::Instance() { | |||||
| } | } | ||||
| std::shared_ptr<ModelParser> ModelParserFactory::CreateModelParser(const domi::FrameworkType type) { | std::shared_ptr<ModelParser> ModelParserFactory::CreateModelParser(const domi::FrameworkType type) { | ||||
| std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::const_iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | if (iter != creator_map_.end()) { | ||||
| return iter->second(); | return iter->second(); | ||||
| } | } | ||||
| @@ -65,7 +65,7 @@ std::shared_ptr<ModelParser> ModelParserFactory::CreateModelParser(const domi::F | |||||
| FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::FrameworkType type, | FMK_FUNC_HOST_VISIBILITY void ModelParserFactory::RegisterCreator(const domi::FrameworkType type, | ||||
| MODEL_PARSER_CREATOR_FUN fun) { | MODEL_PARSER_CREATOR_FUN fun) { | ||||
| std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| std::map<domi::FrameworkType, MODEL_PARSER_CREATOR_FUN>::const_iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | if (iter != creator_map_.end()) { | ||||
| GELOGW("ModelParserFactory::RegisterCreator: %d creator already exist", type); | GELOGW("ModelParserFactory::RegisterCreator: %d creator already exist", type); | ||||
| return; | return; | ||||
| @@ -25,7 +25,7 @@ namespace ge { | |||||
| namespace parser { | namespace parser { | ||||
| /// @ingroup fp16_t global filed | /// @ingroup fp16_t global filed | ||||
| /// @brief round mode of last valid digital | /// @brief round mode of last valid digital | ||||
| enum TagFp16RoundMode g_round_mode = kRoundToNearest; | |||||
| enum TagFp16RoundMode g_round_mode = TagFp16RoundMode::kRoundToNearest; | |||||
| void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { | void ExtractFp16(const uint16_t &val, uint16_t &s, int16_t &e, uint16_t &m) { | ||||
| // 1.Extract | // 1.Extract | ||||
| @@ -55,7 +55,7 @@ static bool IsRoundOne(uint64_t man, uint16_t trunc_len) { | |||||
| bool last_bit = ((man & mask0) > 0); | bool last_bit = ((man & mask0) > 0); | ||||
| bool trunc_high = false; | bool trunc_high = false; | ||||
| bool trunc_left = false; | bool trunc_left = false; | ||||
| if (g_round_mode == kRoundToNearest) { | |||||
| if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { | |||||
| trunc_high = ((man & mask1) > 0); | trunc_high = ((man & mask1) > 0); | ||||
| trunc_left = ((man & mask2) > 0); | trunc_left = ((man & mask2) > 0); | ||||
| } | } | ||||
| @@ -480,7 +480,7 @@ static uint32_t Fp16ToUInt32(const uint16_t &fp_val) { | |||||
| return m_ret; | return m_ret; | ||||
| } | } | ||||
| static uint16_t Fp16AddCalVal(uint16_t &s_ret, int16_t e_ret, uint16_t m_ret, uint32_t m_trunc, uint16_t shift_out) { | |||||
| static uint16_t Fp16AddCalVal(uint16_t s_ret, int16_t e_ret, uint16_t m_ret, uint32_t m_trunc, uint16_t shift_out) { | |||||
| uint16_t m_min = kFp16ManHideBit << shift_out; | uint16_t m_min = kFp16ManHideBit << shift_out; | ||||
| uint16_t m_max = m_min << 1; | uint16_t m_max = m_min << 1; | ||||
| // Denormal | // Denormal | ||||
| @@ -500,8 +500,8 @@ static uint16_t Fp16AddCalVal(uint16_t &s_ret, int16_t e_ret, uint16_t m_ret, ui | |||||
| bool b_last_bit = ((m_ret & 1) > 0); | bool b_last_bit = ((m_ret & 1) > 0); | ||||
| bool b_trunc_high = 0; | bool b_trunc_high = 0; | ||||
| bool b_trunc_left = 0; | bool b_trunc_left = 0; | ||||
| b_trunc_high = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); | |||||
| b_trunc_left = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); | |||||
| b_trunc_high = (TagFp16RoundMode::kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); | |||||
| b_trunc_left = (TagFp16RoundMode::kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); | |||||
| m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); | m_ret = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, m_ret, shift_out); | ||||
| while (m_ret >= m_max) { | while (m_ret >= m_max) { | ||||
| m_ret = m_ret >> 1; | m_ret = m_ret >> 1; | ||||
| @@ -623,8 +623,8 @@ static uint16_t Fp16Mul(uint16_t v_1, uint16_t v_2) { | |||||
| bool b_last_bit = ((mul_m & 1) > 0); | bool b_last_bit = ((mul_m & 1) > 0); | ||||
| bool b_trunc_high = 0; | bool b_trunc_high = 0; | ||||
| bool b_trunc_left = 0; | bool b_trunc_left = 0; | ||||
| b_trunc_high = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); | |||||
| b_trunc_left = (kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); | |||||
| b_trunc_high = (TagFp16RoundMode::kRoundToNearest == g_round_mode) && ((m_trunc & kFp32SignMask) > 0); | |||||
| b_trunc_left = (TagFp16RoundMode::kRoundToNearest == g_round_mode) && ((m_trunc & kFp32AbsMax) > 0); | |||||
| mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); | mul_m = ManRoundToNearest(b_last_bit, b_trunc_high, b_trunc_left, mul_m); | ||||
| while (mul_m >= m_max || e_ret < 0) { | while (mul_m >= m_max || e_ret < 0) { | ||||
| @@ -701,25 +701,25 @@ static uint16_t Fp16Div(uint16_t v_1, uint16_t v_2) { | |||||
| } | } | ||||
| // operate | // operate | ||||
| fp16_t fp16_t::operator+(const fp16_t fp) { | |||||
| fp16_t fp16_t::operator+(const fp16_t fp) const { | |||||
| uint16_t ret_val = Fp16Add(val, fp.val); | uint16_t ret_val = Fp16Add(val, fp.val); | ||||
| fp16_t ret(ret_val); | fp16_t ret(ret_val); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| fp16_t fp16_t::operator-(const fp16_t fp) { | |||||
| fp16_t fp16_t::operator-(const fp16_t fp) const { | |||||
| uint16_t ret_val = Fp16Sub(val, fp.val); | uint16_t ret_val = Fp16Sub(val, fp.val); | ||||
| fp16_t ret(ret_val); | fp16_t ret(ret_val); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| fp16_t fp16_t::operator*(const fp16_t fp) { | |||||
| fp16_t fp16_t::operator*(const fp16_t fp) const { | |||||
| uint16_t ret_val = Fp16Mul(val, fp.val); | uint16_t ret_val = Fp16Mul(val, fp.val); | ||||
| fp16_t ret(ret_val); | fp16_t ret(ret_val); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| fp16_t fp16_t::operator/(const fp16_t fp) { | |||||
| fp16_t fp16_t::operator/(const fp16_t fp) const { | |||||
| uint16_t ret_val = Fp16Div(val, fp.val); | uint16_t ret_val = Fp16Div(val, fp.val); | ||||
| fp16_t ret(ret_val); | fp16_t ret(ret_val); | ||||
| return ret; | return ret; | ||||
| @@ -968,7 +968,7 @@ static void SetValByUint16Val(const uint16_t &input_val, const uint16_t &sign, u | |||||
| bool b_last_bit = ((m_tmp & 1) > 0); | bool b_last_bit = ((m_tmp & 1) > 0); | ||||
| bool b_trunc_high = 0; | bool b_trunc_high = 0; | ||||
| bool b_trunc_left = 0; | bool b_trunc_left = 0; | ||||
| if (kRoundToNearest == g_round_mode) { // trunc | |||||
| if (TagFp16RoundMode::kRoundToNearest == g_round_mode) { // trunc | |||||
| b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | ||||
| b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | ||||
| } | } | ||||
| @@ -1027,7 +1027,7 @@ fp16_t &fp16_t::operator=(const uint16_t &ui_val) { | |||||
| bool b_last_bit = ((m_ret & 1) > 0); | bool b_last_bit = ((m_ret & 1) > 0); | ||||
| bool b_trunc_high = 0; | bool b_trunc_high = 0; | ||||
| bool b_trunc_left = 0; | bool b_trunc_left = 0; | ||||
| if (kRoundToNearest == g_round_mode) { // trunc | |||||
| if (TagFp16RoundMode::kRoundToNearest == g_round_mode) { // trunc | |||||
| b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | ||||
| b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | ||||
| } | } | ||||
| @@ -1071,7 +1071,7 @@ static void SetValByUint32Val(const uint32_t &input_val, const uint16_t &sign, u | |||||
| bool b_last_bit = ((m_tmp & 1) > 0); | bool b_last_bit = ((m_tmp & 1) > 0); | ||||
| bool b_trunc_high = 0; | bool b_trunc_high = 0; | ||||
| bool b_trunc_left = 0; | bool b_trunc_left = 0; | ||||
| if (kRoundToNearest == g_round_mode) { // trunc | |||||
| if (TagFp16RoundMode::kRoundToNearest == g_round_mode) { // trunc | |||||
| b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | ||||
| b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | ||||
| } | } | ||||
| @@ -1133,7 +1133,7 @@ fp16_t &fp16_t::operator=(const uint32_t &ui_val) { | |||||
| bool b_last_bit = ((m_tmp & 1) > 0); | bool b_last_bit = ((m_tmp & 1) > 0); | ||||
| bool b_trunc_high = false; | bool b_trunc_high = false; | ||||
| bool b_trunc_left = false; | bool b_trunc_left = false; | ||||
| if (g_round_mode == kRoundToNearest) { // trunc | |||||
| if (g_round_mode == TagFp16RoundMode::kRoundToNearest) { // trunc | |||||
| b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | b_trunc_high = ((m_trunc & kFp32SignMask) > 0); | ||||
| b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | b_trunc_left = ((m_trunc & kFp32AbsMax) > 0); | ||||
| } | } | ||||
| @@ -1239,7 +1239,7 @@ fp16_t::operator int64_t() const { return 0; } | |||||
| // Cannot be used, just in order to solve the compile error | // Cannot be used, just in order to solve the compile error | ||||
| fp16_t::operator uint64_t() const { return 0; } | fp16_t::operator uint64_t() const { return 0; } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int fp16_t::IsInf() const { | |||||
| if ((val & kFp16AbsMax) == kFp16ExpMask) { | if ((val & kFp16AbsMax) == kFp16ExpMask) { | ||||
| if (val & kFp16SignMask) { | if (val & kFp16SignMask) { | ||||
| return -1; | return -1; | ||||
| @@ -146,22 +146,22 @@ constexpr uint16_t kFp16MinNormal = 1.0f / (2 << 14); | |||||
| #define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400)) | #define FP16_EXTRAC_MAN(x) ((((x) >> 0) & 0x3FF) | (((((x) >> 10) & 0x1F) > 0 ? 1 : 0) * 0x400)) | ||||
| /// @ingroup fp16 basic operator | /// @ingroup fp16 basic operator | ||||
| /// @brief constructor of fp16 from sign exponent and mantissa | /// @brief constructor of fp16 from sign exponent and mantissa | ||||
| #define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m)&kFp16MaxMan)) | |||||
| #define FP16_CONSTRUCTOR(s, e, m) (((s) << kFp16SignIndex) | ((e) << kFp16ManLen) | ((m) & kFp16MaxMan)) | |||||
| /// @ingroup fp16 special value judgment | /// @ingroup fp16 special value judgment | ||||
| /// @brief whether a fp16 is zero | /// @brief whether a fp16 is zero | ||||
| #define FP16_IS_ZERO(x) (((x)&kFp16AbsMax) == 0) | |||||
| #define FP16_IS_ZERO(x) (((x) & kFp16AbsMax) == 0) | |||||
| /// @ingroup fp16 special value judgment | /// @ingroup fp16 special value judgment | ||||
| /// @brief whether a fp16 is a denormalized value | /// @brief whether a fp16 is a denormalized value | ||||
| #define FP16_IS_DENORM(x) ((((x)&kFp16ExpMask) == 0)) | |||||
| #define FP16_IS_DENORM(x) ((((x) & kFp16ExpMask) == 0)) | |||||
| /// @ingroup fp16 special value judgment | /// @ingroup fp16 special value judgment | ||||
| /// @brief whether a fp16 is infinite | /// @brief whether a fp16 is infinite | ||||
| #define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask) | #define FP16_IS_INF(x) (((x)&kFp16AbsMax) == kFp16ExpMask) | ||||
| /// @ingroup fp16 special value judgment | /// @ingroup fp16 special value judgment | ||||
| /// @brief whether a fp16 is NaN | /// @brief whether a fp16 is NaN | ||||
| #define FP16_IS_NAN(x) (((x & kFp16ExpMask) == kFp16ExpMask) && (x & kFp16ManMask)) | |||||
| #define FP16_IS_NAN(x) ((((x) & kFp16ExpMask) == kFp16ExpMask) && ((x) & kFp16ManMask)) | |||||
| /// @ingroup fp16 special value judgment | /// @ingroup fp16 special value judgment | ||||
| /// @brief whether a fp16 is invalid | /// @brief whether a fp16 is invalid | ||||
| #define FP16_IS_INVALID(x) ((x & kFp16ExpMask) == kFp16ExpMask) | |||||
| #define FP16_IS_INVALID(x) (((x) & kFp16ExpMask) == kFp16ExpMask) | |||||
| /// @ingroup fp32 basic parameter | /// @ingroup fp32 basic parameter | ||||
| /// @brief fp32 exponent bias | /// @brief fp32 exponent bias | ||||
| constexpr uint16_t kFp32ExpBias = 127; | constexpr uint16_t kFp32ExpBias = 127; | ||||
| @@ -197,10 +197,10 @@ constexpr uint32_t kFp32MaxExp = 0xFF; | |||||
| constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | ||||
| /// @ingroup fp32 special value judgment | /// @ingroup fp32 special value judgment | ||||
| /// @brief whether a fp32 is NaN | /// @brief whether a fp32 is NaN | ||||
| #define FP32_IS_NAN(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (x & kFp32ManMask)) | |||||
| #define FP32_IS_NAN(x) ((((x) & kFp32ExpMask) == kFp32ExpMask) && ((x) & kFp32ManMask)) | |||||
| /// @ingroup fp32 special value judgment | /// @ingroup fp32 special value judgment | ||||
| /// @brief whether a fp32 is infinite | /// @brief whether a fp32 is infinite | ||||
| #define FP32_IS_INF(x) (((x & kFp32ExpMask) == kFp32ExpMask) && (!(x & kFp32ManMask))) | |||||
| #define FP32_IS_INF(x) ((((x) & kFp32ExpMask) == kFp32ExpMask) && (!((x) & kFp32ManMask))) | |||||
| /// @ingroup fp32 special value judgment | /// @ingroup fp32 special value judgment | ||||
| /// @brief whether a fp32 is a denormalized value | /// @brief whether a fp32 is a denormalized value | ||||
| #define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0)) | #define FP32_IS_DENORM(x) ((((x)&kFp32ExpMask) == 0)) | ||||
| @@ -215,7 +215,7 @@ constexpr uint32_t kFp32MaxMan = 0x7FFFFF; | |||||
| #define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit)) | #define FP32_EXTRAC_MAN(x) (((x)&kFp32ManMask) | (((((x) >> kFp32ManLen) & kFp32MaxExp) > 0 ? 1 : 0) * kFp32ManHideBit)) | ||||
| /// @ingroup fp32 basic operator | /// @ingroup fp32 basic operator | ||||
| /// @brief constructor of fp32 from sign exponent and mantissa | /// @brief constructor of fp32 from sign exponent and mantissa | ||||
| #define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m)&kFp32MaxMan)) | |||||
| #define FP32_CONSTRUCTOR(s, e, m) (((s) << kFp32SignIndex) | ((e) << kFp32ManLen) | ((m) & kFp32MaxMan)) | |||||
| /// @ingroup fp64 basic parameter | /// @ingroup fp64 basic parameter | ||||
| /// @brief fp64 exponent bias | /// @brief fp64 exponent bias | ||||
| constexpr uint16_t kFp64ExpBias = 1023; | constexpr uint16_t kFp64ExpBias = 1023; | ||||
| @@ -251,10 +251,10 @@ constexpr uint64_t kFp64MaxExp = 0x07FF; | |||||
| constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu; | constexpr uint64_t kFp64MaxMan = 0xFFFFFFFFFFFLLu; | ||||
| /// @ingroup fp64 special value judgment | /// @ingroup fp64 special value judgment | ||||
| /// @brief whether a fp64 is NaN | /// @brief whether a fp64 is NaN | ||||
| #define FP64_IS_NAN(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (x & kFp64ManMask)) | |||||
| #define FP64_IS_NAN(x) ((((x) & kFp64ExpMask) == kFp64ExpMask) && ((x) & kFp64ManMask)) | |||||
| /// @ingroup fp64 special value judgment | /// @ingroup fp64 special value judgment | ||||
| /// @brief whether a fp64 is infinite | /// @brief whether a fp64 is infinite | ||||
| #define FP64_IS_INF(x) (((x & kFp64ExpMask) == kFp64ExpMask) && (!(x & kFp64ManMask))) | |||||
| #define FP64_IS_INF(x) ((((x) & kFp64ExpMask) == kFp64ExpMask) && (!((x) & kFp64ManMask))) | |||||
| /// @ingroup integer special value judgment | /// @ingroup integer special value judgment | ||||
| /// @brief maximum positive value of int8_t (0111 1111) | /// @brief maximum positive value of int8_t (0111 1111) | ||||
| constexpr int8_t kInt8Max = 0x7F; | constexpr int8_t kInt8Max = 0x7F; | ||||
| @@ -284,7 +284,7 @@ constexpr uint64_t kBitLen64Max = 0xFFFFFFFFFFFFFFFFu; | |||||
| /// @ingroup fp16_t enum | /// @ingroup fp16_t enum | ||||
| /// @brief round mode of last valid digital | /// @brief round mode of last valid digital | ||||
| enum TagFp16RoundMode { | |||||
| enum class TagFp16RoundMode { | |||||
| kRoundToNearest = 0, // < round to nearest even | kRoundToNearest = 0, // < round to nearest even | ||||
| kRoundByTruncated, // < round by truncated | kRoundByTruncated, // < round by truncated | ||||
| kRoundModeReserved, | kRoundModeReserved, | ||||
| @@ -301,7 +301,7 @@ using fp16_t = struct TagFp16 { | |||||
| public: | public: | ||||
| /// @ingroup fp16_t constructor | /// @ingroup fp16_t constructor | ||||
| /// @brief Constructor without any param(default constructor) | /// @brief Constructor without any param(default constructor) | ||||
| TagFp16(void) { val = 0x0u; } | |||||
| TagFp16() : val(0x0u) {} | |||||
| /// @ingroup fp16_t constructor | /// @ingroup fp16_t constructor | ||||
| /// @brief Constructor with an uint16_t value | /// @brief Constructor with an uint16_t value | ||||
| @@ -315,25 +315,25 @@ public: | |||||
| /// @param [in] fp fp16_t object to be added | /// @param [in] fp fp16_t object to be added | ||||
| /// @brief Override addition operator to performing fp16_t addition | /// @brief Override addition operator to performing fp16_t addition | ||||
| /// @return Return fp16_t result of adding this and fp | /// @return Return fp16_t result of adding this and fp | ||||
| TagFp16 operator+(const TagFp16 fp); | |||||
| TagFp16 operator+(const TagFp16 fp) const; | |||||
| /// @ingroup fp16_t math operator | /// @ingroup fp16_t math operator | ||||
| /// @param [in] fp fp16_t object to be subtracted | /// @param [in] fp fp16_t object to be subtracted | ||||
| /// @brief Override addition operator to performing fp16_t subtraction | /// @brief Override addition operator to performing fp16_t subtraction | ||||
| /// @return Return fp16_t result of subtraction fp from this | /// @return Return fp16_t result of subtraction fp from this | ||||
| TagFp16 operator-(const TagFp16 fp); | |||||
| TagFp16 operator-(const TagFp16 fp) const; | |||||
| /// @ingroup fp16_t math operator | /// @ingroup fp16_t math operator | ||||
| /// @param [in] fp fp16_t object to be multiplied | /// @param [in] fp fp16_t object to be multiplied | ||||
| /// @brief Override multiplication operator to performing fp16_t multiplication | /// @brief Override multiplication operator to performing fp16_t multiplication | ||||
| /// @return Return fp16_t result of multiplying this and fp | /// @return Return fp16_t result of multiplying this and fp | ||||
| TagFp16 operator*(const TagFp16 fp); | |||||
| TagFp16 operator*(const TagFp16 fp) const; | |||||
| /// @ingroup fp16_t math operator divided | /// @ingroup fp16_t math operator divided | ||||
| /// @param [in] fp fp16_t object to be divided | /// @param [in] fp fp16_t object to be divided | ||||
| /// @brief Override division operator to performing fp16_t division | /// @brief Override division operator to performing fp16_t division | ||||
| /// @return Return fp16_t result of division this by fp | /// @return Return fp16_t result of division this by fp | ||||
| TagFp16 operator/(const TagFp16 fp); | |||||
| TagFp16 operator/(const TagFp16 fp) const; | |||||
| /// @ingroup fp16_t math operator | /// @ingroup fp16_t math operator | ||||
| /// @param [in] fp fp16_t object to be added | /// @param [in] fp fp16_t object to be added | ||||
| @@ -503,7 +503,7 @@ public: | |||||
| /// @param [in] fp fp16_t object to be judgement | /// @param [in] fp fp16_t object to be judgement | ||||
| /// @brief whether a fp16_t is inifinite | /// @brief whether a fp16_t is inifinite | ||||
| /// @return Returns 1:+INF -1:-INF 0:not INF | /// @return Returns 1:+INF -1:-INF 0:not INF | ||||
| int IsInf(); | |||||
| int IsInf() const; | |||||
| /// @ingroup fp16_t math conversion | /// @ingroup fp16_t math conversion | ||||
| /// @brief Convert fp16_t to float/fp32 | /// @brief Convert fp16_t to float/fp32 | ||||
| @@ -71,7 +71,7 @@ Status HandleNewOp(const NodePtr &node, | |||||
| } | } | ||||
| } | } | ||||
| Status ParserUtils::ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping) { | |||||
| Status ParserUtils::ExpandOneToManyGraph(const Graph &graph, OutputMapping &output_mapping) { | |||||
| GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); | GELOGD("Begin run ParserUtils::ExpandOneToManyGraph."); | ||||
| for (const auto &gn : graph.GetDirectNode()) { | for (const auto &gn : graph.GetDirectNode()) { | ||||
| NodePtr n = NodeAdapter::GNode2Node(gn); | NodePtr n = NodeAdapter::GNode2Node(gn); | ||||
| @@ -105,7 +105,7 @@ Status ParserUtils::ExpandOneToManyGraph(Graph &graph, OutputMapping &output_map | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph, | |||||
| Status ParserUtils::ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, const Graph &graph, | |||||
| OutputMapping &output_mapping) { | OutputMapping &output_mapping) { | ||||
| ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph); | ComputeGraphPtr sub_compute_graph = GraphUtils::GetComputeGraph(subgraph); | ||||
| GE_CHECK_NOTNULL(sub_compute_graph); | GE_CHECK_NOTNULL(sub_compute_graph); | ||||
| @@ -27,13 +27,13 @@ class ParserUtils { | |||||
| public: | public: | ||||
| using OutputNodeInfo = std::pair<std::string, int32_t>; | using OutputNodeInfo = std::pair<std::string, int32_t>; | ||||
| using OutputMapping = std::unordered_map<std::string, OutputNodeInfo>; | using OutputMapping = std::unordered_map<std::string, OutputNodeInfo>; | ||||
| static Status ExpandOneToManyGraph(Graph &graph, OutputMapping &output_mapping); | |||||
| static Status ExpandOneToManyGraph(const Graph &graph, OutputMapping &output_mapping); | |||||
| static string GenOutputKey(const OutputNodeInfo &node_info); | static string GenOutputKey(const OutputNodeInfo &node_info); | ||||
| static void UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info); | static void UpdateOutputNodeInfo(const OutputMapping &final_output_nodes, OutputNodeInfo &output_node_info); | ||||
| static void UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes); | static void UpdateOutputCtx(const OutputMapping &final_output_nodes, OutputMapping &tensor_to_nodes); | ||||
| private: | private: | ||||
| static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, Graph &graph, | |||||
| static Status ExpandNodeToSubgraph(const Graph &subgraph, const NodePtr &node, const Graph &graph, | |||||
| OutputMapping &output_mapping); | OutputMapping &output_mapping); | ||||
| static Status HandleInputContext(const NodePtr &node, | static Status HandleInputContext(const NodePtr &node, | ||||
| const std::vector<NodePtr> &input_nodes, | const std::vector<NodePtr> &input_nodes, | ||||
| @@ -23,7 +23,9 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace parser { | namespace parser { | ||||
| const vector<std::pair<std::string, GraphPass *>> &PassManager::GraphPasses() const { return names_to_graph_passes_; } | |||||
| const std::vector<std::pair<std::string, GraphPass *>> &PassManager::GraphPasses() const { | |||||
| return names_to_graph_passes_; | |||||
| } | |||||
| Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { | Status PassManager::AddPass(const string &pass_name, GraphPass *pass) { | ||||
| GE_CHECK_NOTNULL(pass); | GE_CHECK_NOTNULL(pass); | ||||
| @@ -36,7 +38,8 @@ Status PassManager::Run(const ComputeGraphPtr &graph) { | |||||
| return Run(graph, names_to_graph_passes_); | return Run(graph, names_to_graph_passes_); | ||||
| } | } | ||||
| Status PassManager::Run(const ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &names_to_passes) { | |||||
| Status PassManager::Run(const ComputeGraphPtr &graph, | |||||
| std::vector<std::pair<std::string, GraphPass *>> &names_to_passes) { | |||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| bool not_changed = true; | bool not_changed = true; | ||||
| @@ -21,8 +21,6 @@ | |||||
| #include "common/graph_pass.h" | #include "common/graph_pass.h" | ||||
| using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| namespace parser { | namespace parser { | ||||
| /// | /// | ||||
| @@ -36,7 +34,7 @@ public: | |||||
| /// get graph passes | /// get graph passes | ||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| const vector<std::pair<std::string, GraphPass *>> &GraphPasses() const; | |||||
| const std::vector<std::pair<std::string, GraphPass *>> &GraphPasses() const; | |||||
| /// | /// | ||||
| /// Add graph pass | /// Add graph pass | ||||
| @@ -64,12 +62,12 @@ public: | |||||
| /// @return others optimized failed | /// @return others optimized failed | ||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| static Status Run(const ge::ComputeGraphPtr &graph, vector<std::pair<std::string, GraphPass *>> &passes); | |||||
| static Status Run(const ge::ComputeGraphPtr &graph, std::vector<std::pair<std::string, GraphPass *>> &passes); | |||||
| ~PassManager(); | ~PassManager(); | ||||
| private: | private: | ||||
| vector<std::pair<std::string, GraphPass *>> names_to_graph_passes_; | |||||
| std::vector<std::pair<std::string, GraphPass *>> names_to_graph_passes_; | |||||
| }; | }; | ||||
| } // namespace parser | } // namespace parser | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -57,12 +57,13 @@ void PreChecker::Init() { | |||||
| // Currently only Caffe and tensorflow are supported | // Currently only Caffe and tensorflow are supported | ||||
| domi::FrameworkType fmk_type = GetParserContext().type; | domi::FrameworkType fmk_type = GetParserContext().type; | ||||
| if (fmk_type == domi::CAFFE) | |||||
| if (fmk_type == domi::CAFFE) { | |||||
| fmk_op_types_ = &caffe_op_map; | fmk_op_types_ = &caffe_op_map; | ||||
| else if (fmk_type == domi::TENSORFLOW) | |||||
| } else if (fmk_type == domi::TENSORFLOW) { | |||||
| fmk_op_types_ = &tensorflow_op_map; | fmk_op_types_ = &tensorflow_op_map; | ||||
| else | |||||
| } else { | |||||
| return; | return; | ||||
| } | |||||
| } | } | ||||
| PreChecker::~PreChecker() {} | PreChecker::~PreChecker() {} | ||||
| @@ -238,7 +238,7 @@ Status ProtoFileParser::ParseProtoFile(const string &proto_file, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ProtoFileParser::AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp) { | |||||
| Status ProtoFileParser::AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp) const { | |||||
| ifstream read_custom; | ifstream read_custom; | ||||
| read_custom.open(custom_proto_file, std::ios::in); | read_custom.open(custom_proto_file, std::ios::in); | ||||
| if (read_custom.fail()) { | if (read_custom.fail()) { | ||||
| @@ -309,9 +309,8 @@ Status ProtoFileParser::AddCustomAndConflictMessage(const char *custom_proto_fil | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ProtoFileParser::WriteCaffeProtoFile(const char *custom_proto_file, | |||||
| std::ifstream &read_caffe, | |||||
| std::ofstream &write_tmp) { | |||||
| Status ProtoFileParser::WriteCaffeProtoFile(const char *custom_proto_file, std::ifstream &read_caffe, | |||||
| std::ofstream &write_tmp) const { | |||||
| std::string line_caffe; | std::string line_caffe; | ||||
| bool caffe_in_layer = false; | bool caffe_in_layer = false; | ||||
| bool caffe_in_unrepeated_message = true; | bool caffe_in_unrepeated_message = true; | ||||
| @@ -321,11 +320,11 @@ Status ProtoFileParser::WriteCaffeProtoFile(const char *custom_proto_file, | |||||
| tmp_message_name.assign(GetMessageName(line_caffe)); | tmp_message_name.assign(GetMessageName(line_caffe)); | ||||
| if (custom_repeat_message_map_.count(tmp_message_name) > 0) { | if (custom_repeat_message_map_.count(tmp_message_name) > 0) { | ||||
| caffe_in_unrepeated_message = false; | caffe_in_unrepeated_message = false; | ||||
| } else { | |||||
| caffe_in_unrepeated_message = true; | |||||
| if (tmp_message_name == kLayerParameter) { | |||||
| caffe_in_layer = true; | |||||
| } | |||||
| continue; | |||||
| } | |||||
| caffe_in_unrepeated_message = true; | |||||
| if (tmp_message_name == kLayerParameter) { | |||||
| caffe_in_layer = true; | |||||
| } | } | ||||
| } | } | ||||
| if (!caffe_in_unrepeated_message) { | if (!caffe_in_unrepeated_message) { | ||||
| @@ -25,9 +25,7 @@ namespace ge { | |||||
| class ProtoFileParser { | class ProtoFileParser { | ||||
| public: | public: | ||||
| ProtoFileParser(){}; | ProtoFileParser(){}; | ||||
| ProtoFileParser(const char *dest_path){ | |||||
| fusion_proto_path = dest_path; | |||||
| } | |||||
| explicit ProtoFileParser(const char *dest_path) : fusion_proto_path(dest_path) {} | |||||
| ~ProtoFileParser(); | ~ProtoFileParser(); | ||||
| Status CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file, | Status CombineProtoFile(const char *caffe_proto_file, const char *custom_proto_file, | ||||
| std::string &dest_proto_file); | std::string &dest_proto_file); | ||||
| @@ -37,14 +35,13 @@ private: | |||||
| Status ParseProtoFile(const std::string &proto_file, | Status ParseProtoFile(const std::string &proto_file, | ||||
| std::map<int, std::pair<std::string, std::string> > &identifier_op_map, | std::map<int, std::pair<std::string, std::string> > &identifier_op_map, | ||||
| std::map<std::string, std::pair<int, std::string> > &op_identifier_map); | std::map<std::string, std::pair<int, std::string> > &op_identifier_map); | ||||
| Status WriteCaffeProtoFile(const char *custom_proto_file, | |||||
| std::ifstream &read_caffe, | |||||
| std::ofstream &write_tmp); | |||||
| Status WriteCaffeProtoFile(const char *custom_proto_file, std::ifstream &read_caffe, | |||||
| std::ofstream &write_tmp) const; | |||||
| Status WriteProtoFile(const char *caffe_proto_file, const char *custom_proto_file); | Status WriteProtoFile(const char *caffe_proto_file, const char *custom_proto_file); | ||||
| Status FindConflictLine(const char *proto_file, int identifier, | |||||
| static Status FindConflictLine(const char *proto_file, int identifier, | |||||
| std::string &dest_line); | std::string &dest_line); | ||||
| Status AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp); | |||||
| Status AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp); | |||||
| Status AddCustomAndConflictLayer(const char *custom_proto_file, std::ofstream &write_tmp) const; | |||||
| static Status AddCustomAndConflictMessage(const char *custom_proto_file, std::ofstream &write_tmp); | |||||
| void CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, | void CheckConflictOp(const char *caffe_proto_file, const char *custom_proto_file, | ||||
| std::map<std::string, std::pair<int, std::string>> &caffe_op_identifier_map, | std::map<std::string, std::pair<int, std::string>> &caffe_op_identifier_map, | ||||
| std::map<std::string, std::pair<int, std::string>> &custom_op_identifier_map); | std::map<std::string, std::pair<int, std::string>> &custom_op_identifier_map); | ||||
| @@ -36,7 +36,7 @@ FMK_FUNC_HOST_VISIBILITY OpRegistrationTbe *OpRegistrationTbe::Instance() { | |||||
| } | } | ||||
| bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_train) { | bool OpRegistrationTbe::Finalize(const OpRegistrationData ®_data, bool is_train) { | ||||
| static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{CAFFE, &caffe_op_map}}; | |||||
| static std::map<domi::FrameworkType, std::map<std::string, std::string> *> op_map = {{domi::CAFFE, &caffe_op_map}}; | |||||
| if (is_train) { | if (is_train) { | ||||
| op_map[domi::TENSORFLOW] = &tensorflow_train_op_map; | op_map[domi::TENSORFLOW] = &tensorflow_train_op_map; | ||||
| } else { | } else { | ||||
| @@ -17,11 +17,14 @@ | |||||
| #include "tbe_plugin_loader.h" | #include "tbe_plugin_loader.h" | ||||
| #include <dirent.h> | #include <dirent.h> | ||||
| #include <dlfcn.h> | |||||
| #include <sys/stat.h> | #include <sys/stat.h> | ||||
| #include <unistd.h> | #include <unistd.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <cstring> | #include <cstring> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <type_traits> | |||||
| #include <typeinfo> | |||||
| #include <iostream> | #include <iostream> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -171,7 +174,7 @@ void TBEPluginLoader::FindParserSo(const string &path, vector<string> &file_list | |||||
| GELOGW("%s is not a dir.", real_path.c_str()); | GELOGW("%s is not a dir.", real_path.c_str()); | ||||
| return; | return; | ||||
| } | } | ||||
| struct dirent *dent(0); | |||||
| struct dirent *dent(nullptr); | |||||
| DIR *dir = opendir(real_path.c_str()); | DIR *dir = opendir(real_path.c_str()); | ||||
| // Plugin path does not exist | // Plugin path does not exist | ||||
| if (dir == nullptr) { | if (dir == nullptr) { | ||||
| @@ -180,16 +183,15 @@ void TBEPluginLoader::FindParserSo(const string &path, vector<string> &file_list | |||||
| } | } | ||||
| while ((dent = readdir(dir)) != nullptr) { | while ((dent = readdir(dir)) != nullptr) { | ||||
| if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) continue; | |||||
| if (strcmp(dent->d_name, ".") == 0 || strcmp(dent->d_name, "..") == 0) { | |||||
| continue; | |||||
| } | |||||
| string name = dent->d_name; | string name = dent->d_name; | ||||
| string full_name = real_path + "/" + name; | string full_name = real_path + "/" + name; | ||||
| const string so_suff = ".so"; | const string so_suff = ".so"; | ||||
| const string caffe_parser_so_suff = "lib_caffe_parser.so"; | const string caffe_parser_so_suff = "lib_caffe_parser.so"; | ||||
| const string aicpu_so_suff = "_aicpu.so"; | |||||
| const string aicpu_host_so_suff = "_online.so"; | |||||
| if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { | if (name.size() >= so_suff.size() && name.compare(name.size() - so_suff.size(), so_suff.size(), so_suff) == 0) { | ||||
| ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff, aicpu_so_suff, | |||||
| aicpu_host_so_suff); | |||||
| ProcessSoFullName(file_list, caffe_parser_path, full_name, caffe_parser_so_suff); | |||||
| } else { | } else { | ||||
| FindParserSo(full_name, file_list, caffe_parser_path); | FindParserSo(full_name, file_list, caffe_parser_path); | ||||
| } | } | ||||
| @@ -198,8 +200,7 @@ void TBEPluginLoader::FindParserSo(const string &path, vector<string> &file_list | |||||
| } | } | ||||
| void TBEPluginLoader::ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | void TBEPluginLoader::ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | ||||
| const string &caffe_parser_so_suff, const string &aicpu_so_suff, | |||||
| const string &aicpu_host_so_suff) { | |||||
| const string &caffe_parser_so_suff) { | |||||
| if (full_name.size() >= caffe_parser_so_suff.size() && | if (full_name.size() >= caffe_parser_so_suff.size() && | ||||
| full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), | full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), | ||||
| caffe_parser_so_suff) == 0) { | caffe_parser_so_suff) == 0) { | ||||
| @@ -17,14 +17,11 @@ | |||||
| #ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | #ifndef PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | ||||
| #define PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | #define PARSER_COMMON_TBE_PLUGIN_LOADER_H_ | ||||
| #include <dlfcn.h> | |||||
| #include <functional> | #include <functional> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <type_traits> | |||||
| #include <typeinfo> | |||||
| #include <vector> | #include <vector> | ||||
| #include "external/ge/ge_api_error_codes.h" | #include "external/ge/ge_api_error_codes.h" | ||||
| @@ -48,8 +45,7 @@ private: | |||||
| ~TBEPluginLoader() = default; | ~TBEPluginLoader() = default; | ||||
| Status ClearHandles_(); | Status ClearHandles_(); | ||||
| static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | static void ProcessSoFullName(vector<string> &file_list, string &caffe_parser_path, string &full_name, | ||||
| const string &caffe_parser_so_suff, const string &aicpu_so_suff, | |||||
| const string &aicpu_host_so_suff); | |||||
| const string &caffe_parser_so_suff); | |||||
| static void GetCustomOpPath(std::string &customop_path); | static void GetCustomOpPath(std::string &customop_path); | ||||
| static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | ||||
| static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path); | static void FindParserSo(const string &path, vector<string> &file_list, string &caffe_parser_path); | ||||
| @@ -18,10 +18,7 @@ | |||||
| #include <atomic> | #include <atomic> | ||||
| #include <functional> | #include <functional> | ||||
| #include <queue> | |||||
| #include <stdexcept> | #include <stdexcept> | ||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "register/register_types.h" | #include "register/register_types.h" | ||||
| @@ -54,7 +51,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ThreadPool::~ThreadPool() { | |||||
| } | } | ||||
| } | } | ||||
| void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { | |||||
| void ThreadPool::ThreadFunc(ThreadPool *const thread_pool) { | |||||
| if (thread_pool == nullptr) { | if (thread_pool == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -63,7 +60,7 @@ void ThreadPool::ThreadFunc(ThreadPool *thread_pool) { | |||||
| { | { | ||||
| std::unique_lock<std::mutex> lock{thread_pool->m_lock_}; | std::unique_lock<std::mutex> lock{thread_pool->m_lock_}; | ||||
| thread_pool->cond_var_.wait( | thread_pool->cond_var_.wait( | ||||
| lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); | |||||
| lock, [thread_pool] { return thread_pool->is_stoped_.load() || !thread_pool->tasks_.empty(); }); | |||||
| if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { | if (thread_pool->is_stoped_ && thread_pool->tasks_.empty()) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -68,7 +68,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ThreadPool { | |||||
| return future; | return future; | ||||
| } | } | ||||
| static void ThreadFunc(ThreadPool *thread_pool); | |||||
| static void ThreadFunc(ThreadPool *const thread_pool); | |||||
| private: | private: | ||||
| std::vector<std::thread> pool_; | std::vector<std::thread> pool_; | ||||
| @@ -20,8 +20,6 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <string> | #include <string> | ||||
| #include <type_traits> | |||||
| #include <utility> | |||||
| #include <vector> | #include <vector> | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| @@ -43,12 +41,12 @@ class Tuple { | |||||
| /// @brief constructor from initializer list | /// @brief constructor from initializer list | ||||
| /// @param init the initializer_list | /// @param init the initializer_list | ||||
| /// | /// | ||||
| inline Tuple(const std::initializer_list<ValueType> &init) { this->assign(init.begin(), init.end()); } | |||||
| explicit Tuple(const std::initializer_list<ValueType> &init) { this->assign(init.begin(), init.end()); } | |||||
| /// | /// | ||||
| /// @brief constructor from vector | /// @brief constructor from vector | ||||
| /// @param init the vector | /// @param init the vector | ||||
| /// | /// | ||||
| inline Tuple(const std::vector<ValueType> &init) { // NOLINT(runtime/explicit) | |||||
| explicit Tuple(const std::vector<ValueType> &init) { // NOLINT(runtime/explicit) | |||||
| this->assign(init.begin(), init.end()); | this->assign(init.begin(), init.end()); | ||||
| } | } | ||||
| /// | /// | ||||
| @@ -125,7 +123,9 @@ class Tuple { | |||||
| /// @param s the tuple to compare against | /// @param s the tuple to compare against | ||||
| /// | /// | ||||
| inline bool operator==(const Tuple<ValueType> &s) const { | inline bool operator==(const Tuple<ValueType> &s) const { | ||||
| if (ndim_ != s.ndim_) return false; | |||||
| if (ndim_ != s.ndim_) { | |||||
| return false; | |||||
| } | |||||
| return std::equal(begin(), end(), s.begin()); | return std::equal(begin(), end(), s.begin()); | ||||
| } | } | ||||
| /// | /// | ||||
| @@ -178,7 +178,9 @@ class Tuple { | |||||
| const ValueType *begin = t.begin(); | const ValueType *begin = t.begin(); | ||||
| const ValueType *end = t.end(); | const ValueType *end = t.end(); | ||||
| for (const ValueType *it = begin; it != end; ++it) { | for (const ValueType *it = begin; it != end; ++it) { | ||||
| if (it != begin) os << ','; | |||||
| if (it != begin) { | |||||
| os << ','; | |||||
| } | |||||
| os << *it; | os << *it; | ||||
| } | } | ||||
| os << ']'; | os << ']'; | ||||
| @@ -229,7 +231,9 @@ class Tuple { | |||||
| } | } | ||||
| break; | break; | ||||
| } | } | ||||
| if (IsRightBracket(ch)) break; | |||||
| if (IsRightBracket(ch)) { | |||||
| break; | |||||
| } | |||||
| } else if (IsRightBracket(ch)) { | } else if (IsRightBracket(ch)) { | ||||
| break; | break; | ||||
| } else { | } else { | ||||
| @@ -19,6 +19,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "external/graph/tensor.h" | |||||
| #include "parser/common/data_op_parser.h" | #include "parser/common/data_op_parser.h" | ||||
| #include "parser/onnx/onnx_op_parser.h" | #include "parser/onnx/onnx_op_parser.h" | ||||
| @@ -30,12 +31,12 @@ class PARSER_FUNC_VISIBILITY OnnxConstantParser : public OnnxOpParser { | |||||
| Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | Status ParseParams(const Message *op_src, ge::Operator &op_def) override; | ||||
| private: | private: | ||||
| Status ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def); | |||||
| Status ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | |||||
| Status ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count); | |||||
| void ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count, | |||||
| static Status ParseConstFromInput(const ge::onnx::NodeProto *op_src, ge::Operator &op_def); | |||||
| static Status ParseConvertTensor(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | |||||
| static Status ParseConvertData(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count); | |||||
| static void ParseConvertDataElements(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor, int count, | |||||
| int64_t data_type); | int64_t data_type); | ||||
| Status ParseConvertDataType(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | |||||
| static Status ParseConvertDataType(const ge::onnx::TensorProto &tensor_proto, ge::Tensor &tensor); | |||||
| template <typename T> | template <typename T> | ||||
| static Status SetTensorData(int32_t val_size, const google::protobuf::RepeatedField<T> &val_vector, int count, | static Status SetTensorData(int32_t val_size, const google::protobuf::RepeatedField<T> &val_vector, int count, | ||||
| @@ -32,7 +32,7 @@ class PARSER_FUNC_VISIBILITY OnnxDataParser : public OnnxOpParser { | |||||
| Status ParseInputFromUser(const ge::Operator &op_def); | Status ParseInputFromUser(const ge::Operator &op_def); | ||||
| bool IsSubgraphDataOp() { | |||||
| bool IsSubgraphDataOp() const { | |||||
| return is_subgraph_data_op_; | return is_subgraph_data_op_; | ||||
| } | } | ||||
| @@ -680,7 +680,7 @@ Status OnnxModelParser::GetGraphOutputs(std::vector<std::pair<Operator, std::vec | |||||
| auto itr = outputs_map_.find(output_name); | auto itr = outputs_map_.find(output_name); | ||||
| if (itr == outputs_map_.end()) { | if (itr == outputs_map_.end()) { | ||||
| GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | GELOGE(PARAM_INVALID, "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | ||||
| REPORT_INNER_ERROR( "E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "[Get][Outputs] Can not find output:%s in graph.", output_name.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -755,7 +755,7 @@ Status OnnxModelParser::AdaptAndFindAllOnnxGraph(ge::onnx::GraphProto &root_onnx | |||||
| SubgraphAdapterFactory *factory = SubgraphAdapterFactory::Instance(); | SubgraphAdapterFactory *factory = SubgraphAdapterFactory::Instance(); | ||||
| GE_CHECK_NOTNULL(factory); | GE_CHECK_NOTNULL(factory); | ||||
| std::shared_ptr<SubgraphAdapter> subgraph_adapter = factory->CreateSubgraphAdapter(node_proto->op_type()); | std::shared_ptr<SubgraphAdapter> subgraph_adapter = factory->CreateSubgraphAdapter(node_proto->op_type()); | ||||
| if(subgraph_adapter == nullptr) { | |||||
| if (subgraph_adapter == nullptr) { | |||||
| GELOGD("Do not need adapt subgraph, op type:%s", node_proto->op_type().c_str()); | GELOGD("Do not need adapt subgraph, op type:%s", node_proto->op_type().c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -34,7 +34,6 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/register/register_error_codes.h" | |||||
| #include "omg/parser/model_parser.h" | #include "omg/parser/model_parser.h" | ||||
| #include "omg/parser/op_parser.h" | #include "omg/parser/op_parser.h" | ||||
| #include "omg/parser/weights_parser.h" | #include "omg/parser/weights_parser.h" | ||||
| @@ -19,13 +19,14 @@ | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| namespace ge{ | |||||
| namespace ge { | |||||
| using parser::IF; | |||||
| namespace { | namespace { | ||||
| const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; | const std::map<std::string, int> kAttrNameToIndex = {{"then_branch", 0}, {"else_branch", 1}}; | ||||
| const int kIfNodeAttrSize = 2; | const int kIfNodeAttrSize = 2; | ||||
| } | |||||
| Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_node, | |||||
| std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| } // namespace | |||||
| domi::Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs( | |||||
| ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | ||||
| GE_CHECK_NOTNULL(parent_node); | GE_CHECK_NOTNULL(parent_node); | ||||
| GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | GELOGI("Onnx parent node name=%s, op type=%s, adapt subgraph.", parent_node->name().c_str(), | ||||
| @@ -41,9 +42,9 @@ Status IfSubgraphAdapter::AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_n | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status IfSubgraphAdapter::ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, | |||||
| std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
| domi::Status IfSubgraphAdapter::ParseIfNodeSubgraphs( | |||||
| ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | |||||
| if (parent_node->attribute_size() != kIfNodeAttrSize) { | if (parent_node->attribute_size() != kIfNodeAttrSize) { | ||||
| GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | GELOGE(FAILED, "[Parse][Node] Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | ||||
| REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | REPORT_INNER_ERROR("E19999", "Invalid graph, if node attribute size:%d must be 2.", parent_node->attribute_size()); | ||||
| @@ -88,8 +89,8 @@ Status IfSubgraphAdapter::ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | |||||
| std::set<std::string> &all_inputs) { | |||||
| domi::Status IfSubgraphAdapter::GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, | |||||
| std::set<std::string> &all_inputs) { | |||||
| std::set<std::string> graph_inputs; | std::set<std::string> graph_inputs; | ||||
| std::set<std::string> graph_outputs; | std::set<std::string> graph_outputs; | ||||
| for (int i = 0; i < onnx_graph.node_size(); i++) { | for (int i = 0; i < onnx_graph.node_size(); i++) { | ||||
| @@ -21,17 +21,17 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "subgraph_adapter.h" | #include "subgraph_adapter.h" | ||||
| using ge::onnx::NodeProto; | |||||
| namespace ge { | namespace ge { | ||||
| class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | class PARSER_FUNC_VISIBILITY IfSubgraphAdapter : public SubgraphAdapter { | ||||
| public: | public: | ||||
| Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override; | |||||
| private: | |||||
| Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||||
| Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs); | |||||
| domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | |||||
| std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) override; | |||||
| private: | |||||
| domi::Status ParseIfNodeSubgraphs(ge::onnx::NodeProto *parent_node, std::vector<ge::onnx::GraphProto *> &onnx_graphs, | |||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph); | |||||
| domi::Status GetSubgraphsAllInputs(ge::onnx::GraphProto &onnx_graph, std::set<std::string> &all_inputs); | |||||
| void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph); | void AddInputNodeForGraph(const std::set<std::string> &all_inputs, ge::onnx::GraphProto &onnx_graph); | ||||
| void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node); | void AddInputForParentNode(const std::set<std::string> &all_inputs, ge::onnx::NodeProto &parent_node); | ||||
| }; | }; | ||||
| @@ -38,9 +38,6 @@ | |||||
| #include "framework/omg/parser/parser_types.h" | #include "framework/omg/parser/parser_types.h" | ||||
| #include "parser/onnx/onnx_util.h" | #include "parser/onnx/onnx_util.h" | ||||
| using Status = domi::Status; | |||||
| using namespace ge::parser; | |||||
| namespace ge { | namespace ge { | ||||
| class PARSER_FUNC_VISIBILITY SubgraphAdapter { | class PARSER_FUNC_VISIBILITY SubgraphAdapter { | ||||
| public: | public: | ||||
| @@ -50,7 +47,7 @@ class PARSER_FUNC_VISIBILITY SubgraphAdapter { | |||||
| /// @param [in/out] name_to_onnx_graph map name to onnx graph | /// @param [in/out] name_to_onnx_graph map name to onnx graph | ||||
| /// @return SUCCESS parse success | /// @return SUCCESS parse success | ||||
| /// @return FAILED Parse failed | /// @return FAILED Parse failed | ||||
| virtual Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | |||||
| virtual domi::Status AdaptAndFindAllSubgraphs(ge::onnx::NodeProto *parent_op, | |||||
| std::vector<ge::onnx::GraphProto *> &onnx_graphs, | std::vector<ge::onnx::GraphProto *> &onnx_graphs, | ||||
| std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | std::map<std::string, ge::onnx::GraphProto *> &name_to_onnx_graph) { | ||||
| return domi::SUCCESS; | return domi::SUCCESS; | ||||
| @@ -83,7 +83,7 @@ string NameMapHelper::Renormalize(const string &name) const { | |||||
| } | } | ||||
| domi::Status ComputeArgRange(const domi::tensorflow::NodeDef &node_def, const domi::tensorflow::OpDef::ArgDef &arg_def, | domi::Status ComputeArgRange(const domi::tensorflow::NodeDef &node_def, const domi::tensorflow::OpDef::ArgDef &arg_def, | ||||
| const domi::tensorflow::OpDef &op_def, int *num) { | |||||
| int *num) { | |||||
| GE_CHECK_NOTNULL(num); | GE_CHECK_NOTNULL(num); | ||||
| if (!arg_def.number_attr().empty()) { | if (!arg_def.number_attr().empty()) { | ||||
| // Same type repeated "num" times. | // Same type repeated "num" times. | ||||
| @@ -120,12 +120,12 @@ using NameRangeMap = std::map<string, std::pair<int, int>>; | |||||
| domi::Status NameRangesHelper(const domi::tensorflow::NodeDef &node_def, | domi::Status NameRangesHelper(const domi::tensorflow::NodeDef &node_def, | ||||
| const google::protobuf::RepeatedPtrField<domi::tensorflow::OpDef_ArgDef> &args, | const google::protobuf::RepeatedPtrField<domi::tensorflow::OpDef_ArgDef> &args, | ||||
| const domi::tensorflow::OpDef &op_def, NameRangeMap *result) { | |||||
| NameRangeMap *result) { | |||||
| GE_CHECK_NOTNULL(result); | GE_CHECK_NOTNULL(result); | ||||
| int start = 0; | int start = 0; | ||||
| int num = 0; | int num = 0; | ||||
| for (const auto &arg : args) { | for (const auto &arg : args) { | ||||
| GE_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, op_def, &num)); | |||||
| GE_RETURN_IF_ERROR(ComputeArgRange(node_def, arg, &num)); | |||||
| (*result)[arg.name()] = std::make_pair(start, start + num); | (*result)[arg.name()] = std::make_pair(start, start + num); | ||||
| start += num; | start += num; | ||||
| } | } | ||||
| @@ -136,7 +136,7 @@ domi::Status NameRangesForNode(const domi::tensorflow::NodeDef &node_def, const | |||||
| NameRangeMap *outputs) { | NameRangeMap *outputs) { | ||||
| GE_IF_BOOL_EXEC(outputs == nullptr, return FAILED); | GE_IF_BOOL_EXEC(outputs == nullptr, return FAILED); | ||||
| return NameRangesHelper(node_def, op_def.output_arg(), op_def, outputs); | |||||
| return NameRangesHelper(node_def, op_def.output_arg(), outputs); | |||||
| } | } | ||||
| domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelper &node_names, | domi::Status RemapFunctionDef(FunctionDef *fdef, const string &name, NameMapHelper &node_names, | ||||
| @@ -58,7 +58,7 @@ class GraphToFunctionDef { | |||||
| vector<ge::InDataAnchorPtr> &in_anchor, | vector<ge::InDataAnchorPtr> &in_anchor, | ||||
| vector<ge::OutDataAnchorPtr> &out_anchor); | vector<ge::OutDataAnchorPtr> &out_anchor); | ||||
| static bool FindAttrValue(const domi::tensorflow::NodeDef *nodeDef, | |||||
| static bool FindAttrValue(const domi::tensorflow::NodeDef *node_def, | |||||
| const string attr_name, | const string attr_name, | ||||
| domi::tensorflow::AttrValue &attr_value); | domi::tensorflow::AttrValue &attr_value); | ||||
| @@ -43,7 +43,7 @@ Status ParserGraphOptimizer::FusionFmkop() { | |||||
| GE_CHECK_NOTNULL(graph_); | GE_CHECK_NOTNULL(graph_); | ||||
| std::unordered_map<string, std::vector<NodePtr>> node_cluser_Map; | std::unordered_map<string, std::vector<NodePtr>> node_cluser_Map; | ||||
| GE_CHK_STATUS_RET(MarkForFusion(node_cluser_Map), "find framework node to be fused fail."); | GE_CHK_STATUS_RET(MarkForFusion(node_cluser_Map), "find framework node to be fused fail."); | ||||
| GE_IF_BOOL_EXEC(node_cluser_Map.size() == 0, return SUCCESS); | |||||
| GE_IF_BOOL_EXEC(node_cluser_Map.empty(), return SUCCESS); | |||||
| for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) { | for (auto it = node_cluser_Map.begin(); it != node_cluser_Map.end(); ++it) { | ||||
| GE_CHK_STATUS_RET(UpdateGraph(it->second), "fusion framework nodes failed. node:%s", (it->first).c_str()); | GE_CHK_STATUS_RET(UpdateGraph(it->second), "fusion framework nodes failed. node:%s", (it->first).c_str()); | ||||
| @@ -120,7 +120,7 @@ Status ParserGraphOptimizer::MarkForFusion(unordered_map<string, vector<NodePtr> | |||||
| } | } | ||||
| // find frameworkOP | // find frameworkOP | ||||
| Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_map<string, vector<NodePtr>> &node_cluser_Map) { | |||||
| Status ParserGraphOptimizer::FindFmkNodeCluser(unordered_map<string, vector<NodePtr>> &node_cluser_Map) const { | |||||
| vector<NodePtr> temp_node_cluser; | vector<NodePtr> temp_node_cluser; | ||||
| for (auto node : graph_->GetDirectNode()) { | for (auto node : graph_->GetDirectNode()) { | ||||
| @@ -303,7 +303,7 @@ Status ParserGraphOptimizer::InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map<string, ge::NodePtr> &node_map) { | |||||
| Status ParserGraphOptimizer::LinkInnerAnchor(unordered_map<string, ge::NodePtr> &node_map) const { | |||||
| for (auto node : graph_->GetDirectNode()) { | for (auto node : graph_->GetDirectNode()) { | ||||
| GE_IF_BOOL_EXEC(node_map.count(node->GetName()) == 0, continue); | GE_IF_BOOL_EXEC(node_map.count(node->GetName()) == 0, continue); | ||||
| NodePtr dst = node_map[node->GetName()]; | NodePtr dst = node_map[node->GetName()]; | ||||
| @@ -20,17 +20,11 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/omg/parser/parser_types.h" | |||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| using std::map; | |||||
| using std::string; | |||||
| using std::unordered_map; | |||||
| using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| class ParserGraphOptimizer { | class ParserGraphOptimizer { | ||||
| public: | public: | ||||
| @@ -45,31 +39,34 @@ class ParserGraphOptimizer { | |||||
| ge::ComputeGraphPtr graph_; | ge::ComputeGraphPtr graph_; | ||||
| domi::FrameworkType fmktype_; | domi::FrameworkType fmktype_; | ||||
| domi::Status FindFmkNodeCluser(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | |||||
| domi::Status MarkForFusion(unordered_map<string, vector<ge::NodePtr>> &node_cluser_Map); | |||||
| domi::Status FindFmkNodeCluser(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map) const; | |||||
| domi::Status UpdateGraph(vector<ge::NodePtr> &nodes); | |||||
| domi::Status MarkForFusion(std::unordered_map<std::string, std::vector<ge::NodePtr>> &node_cluser_Map); | |||||
| domi::Status InsertNode(ge::ComputeGraphPtr sub_graph, vector<ge::NodePtr> &nodes, | |||||
| vector<ge::InDataAnchorPtr> &input_anchors, vector<ge::OutDataAnchorPtr> &output_anchors, | |||||
| map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map, | |||||
| vector<ge::InControlAnchorPtr> &input_control_anchors, | |||||
| vector<ge::OutControlAnchorPtr> &output_control_anchors, | |||||
| unordered_map<string, ge::NodePtr> &node_map); | |||||
| domi::Status UpdateGraph(std::vector<ge::NodePtr> &nodes); | |||||
| domi::Status LinkInnerAnchor(unordered_map<string, ge::NodePtr> &node_map); | |||||
| static domi::Status InsertNode(ge::ComputeGraphPtr sub_graph, std::vector<ge::NodePtr> &nodes, | |||||
| std::vector<ge::InDataAnchorPtr> &input_anchors, | |||||
| std::vector<ge::OutDataAnchorPtr> &output_anchors, | |||||
| std::map<ge::OutDataAnchorPtr, std::vector<ge::InDataAnchorPtr>> &output_in_map, | |||||
| std::vector<ge::InControlAnchorPtr> &input_control_anchors, | |||||
| std::vector<ge::OutControlAnchorPtr> &output_control_anchors, | |||||
| std::unordered_map<std::string, ge::NodePtr> &node_map); | |||||
| domi::Status RebuildOutputAnchors(vector<ge::OutDataAnchorPtr> &output_anchors, ge::OpDescPtr fusion_op_desc); | |||||
| domi::Status LinkInnerAnchor(std::unordered_map<std::string, ge::NodePtr> &node_map) const; | |||||
| domi::Status RebuildInputAnchors(vector<ge::InDataAnchorPtr> &input_anchors, ge::OpDescPtr fusion_op_desc); | |||||
| static domi::Status RebuildOutputAnchors(std::vector<ge::OutDataAnchorPtr> &output_anchors, | |||||
| ge::OpDescPtr fusion_op_desc); | |||||
| domi::Status RebuildFusionNode(vector<ge::InDataAnchorPtr> &input_anchors, | |||||
| vector<ge::OutDataAnchorPtr> &output_anchors, | |||||
| map<ge::OutDataAnchorPtr, vector<ge::InDataAnchorPtr>> &output_in_map, | |||||
| vector<ge::InControlAnchorPtr> &input_control_anchors, | |||||
| vector<ge::OutControlAnchorPtr> &output_control_anchors, ge::NodePtr fusion_node); | |||||
| static domi::Status RebuildInputAnchors(std::vector<ge::InDataAnchorPtr> &input_anchors, | |||||
| ge::OpDescPtr fusion_op_desc); | |||||
| static domi::Status RebuildFusionNode(std::vector<ge::InDataAnchorPtr> &input_anchors, | |||||
| std::vector<ge::OutDataAnchorPtr> &output_anchors, | |||||
| std::map<ge::OutDataAnchorPtr, std::vector<ge::InDataAnchorPtr>> &output_in_map, | |||||
| std::vector<ge::InControlAnchorPtr> &input_control_anchors, | |||||
| std::vector<ge::OutControlAnchorPtr> &output_control_anchors, | |||||
| ge::NodePtr fusion_node); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZER_H_ | ||||
| @@ -23,10 +23,9 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class IteratorFusionPass : public GraphPass { | class IteratorFusionPass : public GraphPass { | ||||
| public: | public: | ||||
| IteratorFusionPass(domi::FrameworkType type) | |||||
| : fmk_type_(type) {} | |||||
| explicit IteratorFusionPass(domi::FrameworkType type) : fmk_type_(type) {} | |||||
| virtual ~IteratorFusionPass() {} | |||||
| ~IteratorFusionPass() override {}; | |||||
| Status Run(ge::ComputeGraphPtr graph) final; | Status Run(ge::ComputeGraphPtr graph) final; | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "register/scope/scope_pass_impl.h" | #include "register/scope/scope_pass_impl.h" | ||||
| namespace ge { | namespace ge { | ||||
| shared_ptr<ScopeGraph> ScopePassManager::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { | |||||
| std::shared_ptr<ScopeGraph> ScopePassManager::BuildScopeGraph(domi::tensorflow::GraphDef *graph_def) { | |||||
| GE_CHK_BOOL_EXEC(graph_def != nullptr, return nullptr, "graph_def is nullptr"); | GE_CHK_BOOL_EXEC(graph_def != nullptr, return nullptr, "graph_def is nullptr"); | ||||
| scope_graph_ = ge::parser::MakeShared<ScopeGraph>(); | scope_graph_ = ge::parser::MakeShared<ScopeGraph>(); | ||||
| if (scope_graph_ == nullptr) { | if (scope_graph_ == nullptr) { | ||||
| @@ -43,7 +43,7 @@ shared_ptr<ScopeGraph> ScopePassManager::BuildScopeGraph(domi::tensorflow::Graph | |||||
| return scope_graph_; | return scope_graph_; | ||||
| } | } | ||||
| Status ScopePassManager::AddPass(unique_ptr<ScopeBasePass> &pass) { | |||||
| Status ScopePassManager::AddPass(std::unique_ptr<ScopeBasePass> &pass) { | |||||
| GE_CHECK_NOTNULL(pass); | GE_CHECK_NOTNULL(pass); | ||||
| graph_passes_.push_back(std::move(pass)); | graph_passes_.push_back(std::move(pass)); | ||||
| @@ -51,7 +51,7 @@ Status ScopePassManager::AddPass(unique_ptr<ScopeBasePass> &pass) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ScopePassManager::Run(shared_ptr<ScopeGraph> &graph) { | |||||
| Status ScopePassManager::Run(std::shared_ptr<ScopeGraph> &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| bool not_changed = true; | bool not_changed = true; | ||||
| @@ -21,9 +21,6 @@ | |||||
| #include "external/register/scope/scope_fusion_pass_register.h" | #include "external/register/scope/scope_fusion_pass_register.h" | ||||
| #include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
| using std::shared_ptr; | |||||
| using std::unique_ptr; | |||||
| namespace ge { | namespace ge { | ||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -36,15 +33,15 @@ class ScopePassManager { | |||||
| ScopePassManager &operator=(const ScopePassManager &scope_pass_manager) = delete; | ScopePassManager &operator=(const ScopePassManager &scope_pass_manager) = delete; | ||||
| ~ScopePassManager() {} | ~ScopePassManager() {} | ||||
| shared_ptr<ScopeGraph> BuildScopeGraph(domi::tensorflow::GraphDef *graph_def); | |||||
| std::shared_ptr<ScopeGraph> BuildScopeGraph(domi::tensorflow::GraphDef *graph_def); | |||||
| domi::Status AddPass(unique_ptr<ScopeBasePass> &pass); | |||||
| domi::Status Run(shared_ptr<ScopeGraph> &graph); | |||||
| domi::Status AddPass(std::unique_ptr<ScopeBasePass> &pass); | |||||
| domi::Status Run(std::shared_ptr<ScopeGraph> &graph); | |||||
| std::shared_ptr<ScopeGraph> scope_graph_; | std::shared_ptr<ScopeGraph> scope_graph_; | ||||
| private: | private: | ||||
| std::vector<unique_ptr<ScopeBasePass>> graph_passes_; | |||||
| std::vector<std::unique_ptr<ScopeBasePass>> graph_passes_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -29,7 +29,7 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const char *const kSerializeFormat = "serialize_format"; | const char *const kSerializeFormat = "serialize_format"; | ||||
| } // namespace | } // namespace | ||||
| Status ParseParams(const Message *op_src, ArgOpOperator *op) { | |||||
| Status ParseParams(const Message *op_src, ArgOpOperator *const op) { | |||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | ||||
| @@ -30,8 +30,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowConstantParser : public TensorFlowOpParse | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | ||||
| private: | private: | ||||
| Status ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op); | |||||
| Status ParseValue(const domi::tensorflow::NodeDef *node, const ge::OpDescPtr &opDesc); | |||||
| static Status ParseDType(const domi::tensorflow::NodeDef *node, ConstantOperator *op); | |||||
| static Status ParseValue(const domi::tensorflow::NodeDef *node, const ge::OpDescPtr &opDesc); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -31,7 +31,8 @@ Status TensorFlowCustomParserAdapter::ParseParams(const Message *op_src, ge::OpD | |||||
| GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | GELOGD("TF op node name = %s, op type= %s, parse params", node_src->name().c_str(), node_src->op().c_str()); | ||||
| GE_CHECK_NOTNULL(op_dest); | GE_CHECK_NOTNULL(op_dest); | ||||
| ParseParamFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc(op_dest->GetType(), node_src->op()); | |||||
| ParseParamFunc custom_op_parser = domi::OpRegistry::Instance()->GetParseParamFunc( | |||||
| op_dest->GetType(), node_src->op()); | |||||
| if (custom_op_parser == nullptr) { | if (custom_op_parser == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "No ParseParamFunc of node:%s exist in OpRegistry", node_src->name().c_str()); | REPORT_CALL_ERROR("E19999", "No ParseParamFunc of node:%s exist in OpRegistry", node_src->name().c_str()); | ||||
| GELOGE(FAILED, "No ParseParamFunc of node:%s exist in OpRegistry", node_src->name().c_str()); | GELOGE(FAILED, "No ParseParamFunc of node:%s exist in OpRegistry", node_src->name().c_str()); | ||||
| @@ -53,7 +53,7 @@ Status TensorFlowDataParser::ParseParams(const Message *op_src, ge::OpDescPtr &o | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, ge::OpDescPtr &op_def) { | |||||
| Status TensorFlowDataParser::ParseInputFromModel(const Message *op_src, const ge::OpDescPtr &op_def) { | |||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| GE_CHECK_NOTNULL(op_def); | GE_CHECK_NOTNULL(op_def); | ||||
| @@ -46,7 +46,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowDataParser : public TensorFlowOpParser, p | |||||
| * @return FAILED parse failed | * @return FAILED parse failed | ||||
| * @author | * @author | ||||
| */ | */ | ||||
| Status ParseInputFromModel(const Message *op_src, ge::OpDescPtr &op_def); | |||||
| Status ParseInputFromModel(const Message *op_src, const ge::OpDescPtr &op_def); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -31,12 +31,12 @@ namespace ge { | |||||
| do { \ | do { \ | ||||
| google::protobuf::RepeatedField<FIELD> val_vec; \ | google::protobuf::RepeatedField<FIELD> val_vec; \ | ||||
| int32_t val_size = 0; \ | int32_t val_size = 0; \ | ||||
| val_vec = tensor.FIELD##_val(); \ | |||||
| val_vec = (tensor).FIELD##_val(); \ | |||||
| val_size = val_vec.size(); \ | val_size = val_vec.size(); \ | ||||
| if (index < val_size) { \ | |||||
| param = val_vec.Get(index); \ | |||||
| } else if (tensor.has_tensor_shape()) { \ | |||||
| const std::string tensor_content = tensor.tensor_content(); \ | |||||
| if ((index) < val_size) { \ | |||||
| (param) = val_vec.Get(index); \ | |||||
| } else if ((tensor).has_tensor_shape()) { \ | |||||
| const std::string tensor_content = (tensor).tensor_content(); \ | |||||
| char *buf = const_cast<char *>(tensor_content.data()); \ | char *buf = const_cast<char *>(tensor_content.data()); \ | ||||
| FIELD *buf_v = reinterpret_cast<FIELD *>(buf); \ | FIELD *buf_v = reinterpret_cast<FIELD *>(buf); \ | ||||
| if (static_cast<uint32_t>(index) >= tensor_content.length() / sizeof(FIELD)) { \ | if (static_cast<uint32_t>(index) >= tensor_content.length() / sizeof(FIELD)) { \ | ||||
| @@ -45,7 +45,7 @@ namespace ge { | |||||
| GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \ | GELOGE(domi::PARAM_INVALID, "Const data size is smaller than index :%d,not supported!", index); \ | ||||
| return domi::PARAM_INVALID; \ | return domi::PARAM_INVALID; \ | ||||
| } \ | } \ | ||||
| param = buf_v[index]; \ | |||||
| (param) = buf_v[index]; \ | |||||
| } else { \ | } else { \ | ||||
| REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \ | REPORT_INNER_ERROR("E19999", "Const data size of node:%s is smaller than index:%d, not supported!", \ | ||||
| node_def->name().c_str(), index); \ | node_def->name().c_str(), index); \ | ||||
| @@ -75,7 +75,7 @@ Status TensorFlowFusionOpParser::GetTensorFromNode(const NodeDef *node_def, Tens | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowFusionOpParser::ParseParams(const vector<const NodeDef *> &v_input_const, NodePtr &op_dest) { | |||||
| Status TensorFlowFusionOpParser::ParseParams(const std::vector<const NodeDef *> &v_input_const, NodePtr &op_dest) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -147,8 +147,8 @@ Status TensorFlowFusionOpParser::ParseWeightFromConst(const NodeDef *node_def, g | |||||
| } | } | ||||
| domi::tensorflow::DataType data_type = tensor.dtype(); | domi::tensorflow::DataType data_type = tensor.dtype(); | ||||
| GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
| domi::TensorAssign::SetGeTensorDataType(domi::TensorAssign::ConvertTensorflowDataType(data_type), weight), | |||||
| "set ge tensor data type fail"); | |||||
| domi::TensorAssign::SetGeTensorDataType(domi::TensorAssign::ConvertTensorflowDataType(data_type), weight), | |||||
| "set ge tensor data type fail"); | |||||
| GE_CHK_STATUS_RET(domi::TensorAssign::SetGeTensor(tensor, weight), "set ge tensor fail"); | GE_CHK_STATUS_RET(domi::TensorAssign::SetGeTensor(tensor, weight), "set ge tensor fail"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
| #include "proto/tensorflow/node_def.pb.h" | #include "proto/tensorflow/node_def.pb.h" | ||||
| using std::vector; | |||||
| using google::protobuf::Message; | using google::protobuf::Message; | ||||
| using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
| using domi::tensorflow::TensorProto; | using domi::tensorflow::TensorProto; | ||||
| @@ -45,7 +44,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowFusionOpParser : public TensorFlowOpParse | |||||
| * @return SUCCESS Parsing success | * @return SUCCESS Parsing success | ||||
| * @return FAILED Parsing failed | * @return FAILED Parsing failed | ||||
| */ | */ | ||||
| virtual Status ParseParams(const vector<const NodeDef *> &v_input_const, ge::NodePtr &node); | |||||
| virtual Status ParseParams(const std::vector<const NodeDef *> &v_input_const, ge::NodePtr &node); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -68,19 +67,19 @@ class PARSER_FUNC_VISIBILITY TensorFlowFusionOpParser : public TensorFlowOpParse | |||||
| * | * | ||||
| */ | */ | ||||
| // template <class T> | // template <class T> | ||||
| Status ParseParamFromConst(const NodeDef *input_const, int32_t ¶m); | |||||
| static Status ParseParamFromConst(const NodeDef *input_const, int32_t ¶m); | |||||
| Status ParseParamFromConst(const NodeDef *nodeDef, int32_t ¶m, int index); | |||||
| static Status ParseParamFromConst(const NodeDef *node_def, int32_t ¶m, int index); | |||||
| Status ParseParamFromConst(const NodeDef *input_const, float ¶m); | |||||
| static Status ParseParamFromConst(const NodeDef *input_const, float ¶m); | |||||
| Status ParseParamFromConst(const NodeDef *nodeDef, float ¶m, int index); | |||||
| static Status ParseParamFromConst(const NodeDef *node_def, float ¶m, int index); | |||||
| Status GetTensorFromNode(const NodeDef *nodeDef, TensorProto &tensor); | |||||
| static Status GetTensorFromNode(const NodeDef *node_def, TensorProto &tensor); | |||||
| Status ParseHalfFromConst(const NodeDef *node_def, float ¶m, int index = 0); | |||||
| static Status ParseHalfFromConst(const NodeDef *node_def, float ¶m, int index = 0); | |||||
| Status ParseWeightFromConst(const NodeDef *node_def, ge::GeTensorPtr &weight); | |||||
| static Status ParseWeightFromConst(const NodeDef *node_def, ge::GeTensorPtr &weight); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -221,7 +221,7 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque | |||||
| for (auto &node : parent_graph->GetDirectNode()) { | for (auto &node : parent_graph->GetDirectNode()) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| for (const auto subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { | |||||
| for (const auto &subgraph_name_to_index : op_desc->GetSubgraphNameIndexes()) { | |||||
| auto i = subgraph_name_to_index.second; | auto i = subgraph_name_to_index.second; | ||||
| auto subgraph_iname = op_desc->GetSubgraphInstanceName(i); | auto subgraph_iname = op_desc->GetSubgraphInstanceName(i); | ||||
| if (subgraph_iname.empty()) { | if (subgraph_iname.empty()) { | ||||
| @@ -239,8 +239,8 @@ Status GenSubgraphParseTasks(const ge::ComputeGraphPtr &parent_graph, std::deque | |||||
| } | } | ||||
| auto ret = ge::NodeUtils::SetSubgraph(*node, i, subgraph); | auto ret = ge::NodeUtils::SetSubgraph(*node, i, subgraph); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Set subgraph:%s to node:%s(%s) failed, index:%u", | |||||
| subgraph_iname.c_str(), node->GetName().c_str(), node->GetType().c_str(), i); | |||||
| REPORT_CALL_ERROR("E19999", "Set subgraph:%s to node:%s(%s) failed, index:%u", subgraph_iname.c_str(), | |||||
| node->GetName().c_str(), node->GetType().c_str(), i); | |||||
| GELOGE(ret, "Failed to set subgraph %s to node %s index %u", subgraph_iname.c_str(), node->GetName().c_str(), | GELOGE(ret, "Failed to set subgraph %s to node %s index %u", subgraph_iname.c_str(), node->GetName().c_str(), | ||||
| i); | i); | ||||
| return ret; | return ret; | ||||
| @@ -292,8 +292,8 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||||
| } | } | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Call ParseSubgraphPostFunc:%s failed, subgraph:%s, node:%s(%s), ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call ParseSubgraphPostFunc:%s failed, subgraph:%s, node:%s(%s), ret:0x%X", | ||||
| arg.function_name.c_str(), arg.subgraph_name.c_str(), | |||||
| arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), ret); | |||||
| arg.function_name.c_str(), arg.subgraph_name.c_str(), arg.parent_node->GetName().c_str(), | |||||
| arg.parent_node->GetType().c_str(), ret); | |||||
| GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s subgraph name %s", arg.function_name.c_str(), | GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s subgraph name %s", arg.function_name.c_str(), | ||||
| arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str()); | arg.parent_node->GetName().c_str(), arg.parent_node->GetType().c_str(), arg.subgraph_name.c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -301,20 +301,21 @@ Status PostOpProcessForSubgraph(const ParseArg &arg) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, ComputeGraphPtr &root_graph){ | |||||
| Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, const ComputeGraphPtr &root_graph) { | |||||
| // Inner function, input params have been checked by caller | // Inner function, input params have been checked by caller | ||||
| Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo(graph, | |||||
| [](int in, int &out)->Status { | |||||
| out = in; | |||||
| return SUCCESS; | |||||
| }, | |||||
| [](int in, int &out)->Status { | |||||
| out = in; | |||||
| return SUCCESS; | |||||
| }); | |||||
| Status status = AutoMappingSubgraphIndexByDataNodeAndOutputNodesInfo( | |||||
| graph, | |||||
| [](int in, int &out) -> Status { | |||||
| out = in; | |||||
| return SUCCESS; | |||||
| }, | |||||
| [](int in, int &out) -> Status { | |||||
| out = in; | |||||
| return SUCCESS; | |||||
| }); | |||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.", | |||||
| node->GetName().c_str(), graph.GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Mapping][Subgraph]node:%s, sub graph name:%s.", node->GetName().c_str(), | |||||
| graph.GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.", | REPORT_CALL_ERROR("E19999", "Failed to map sub graph input and output, node:%s, sub graph name:%s.", | ||||
| node->GetName().c_str(), graph.GetName().c_str()); | node->GetName().c_str(), graph.GetName().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -326,10 +327,10 @@ Status MappingAndAddSubGraph(const NodePtr &node, const Graph &graph, ComputeGra | |||||
| (void)node->GetOpDesc()->AddSubgraphName("f"); | (void)node->GetOpDesc()->AddSubgraphName("f"); | ||||
| auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph); | auto ret = NodeUtils::SetSubgraph(*node, 0, compute_graph); | ||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.", | |||||
| node->GetName().c_str(), compute_graph->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.", | |||||
| node->GetName().c_str(), compute_graph->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Set][Subgraph]Node:%s, sub graph name:%s.", node->GetName().c_str(), | |||||
| compute_graph->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Failed to set sub graph, node: %s, sub graph name: %s.", node->GetName().c_str(), | |||||
| compute_graph->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) { | for (const auto &sub_graph : compute_graph->GetAllSubgraphs()) { | ||||
| @@ -365,7 +366,8 @@ Status TensorFlowModelParser::DefunToPartitionedCall(const domi::tensorflow::Nod | |||||
| "may has no ir definition, if it is not a common decorate function operator"}); | "may has no ir definition, if it is not a common decorate function operator"}); | ||||
| GELOGE(FAILED, | GELOGE(FAILED, | ||||
| "Op %s has no ir definition, or has no attr [_disable_call_shape_inference] " | "Op %s has no ir definition, or has no attr [_disable_call_shape_inference] " | ||||
| "if it is a common decorate function operator.", op_name.c_str()); | |||||
| "if it is a common decorate function operator.", | |||||
| op_name.c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -379,8 +381,7 @@ Status TensorFlowModelParser::DefunToPartitionedCall(const domi::tensorflow::Nod | |||||
| for (size_t i = 0; i < input_tensor_num; ++i) { | for (size_t i = 0; i < input_tensor_num; ++i) { | ||||
| ge::GeTensorDesc input_tensor; | ge::GeTensorDesc input_tensor; | ||||
| if (op->AddInputDesc(input_tensor) != ge::GRAPH_SUCCESS) { | if (op->AddInputDesc(input_tensor) != ge::GRAPH_SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", op->GetName().c_str(), op->GetType().c_str()); | |||||
| GELOGE(FAILED, "op [%s] type[%s] add input(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i); | GELOGE(FAILED, "op [%s] type[%s] add input(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -389,8 +390,7 @@ Status TensorFlowModelParser::DefunToPartitionedCall(const domi::tensorflow::Nod | |||||
| for (size_t i = 0; i < output_tensor_num; ++i) { | for (size_t i = 0; i < output_tensor_num; ++i) { | ||||
| ge::GeTensorDesc output_tensor; | ge::GeTensorDesc output_tensor; | ||||
| if (op->AddOutputDesc(output_tensor) != ge::GRAPH_SUCCESS) { | if (op->AddOutputDesc(output_tensor) != ge::GRAPH_SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", | |||||
| op->GetName().c_str(), op->GetType().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed", op->GetName().c_str(), op->GetType().c_str()); | |||||
| GELOGE(FAILED, "op [%s] type[%s] add output(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i); | GELOGE(FAILED, "op [%s] type[%s] add output(%zu) tensor failed.", op_name.c_str(), op->GetType().c_str(), i); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -437,7 +437,7 @@ Status TensorFlowModelParser::TransNodeToOpDesc(const domi::tensorflow::NodeDef | |||||
| } | } | ||||
| Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, | Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, | ||||
| shared_ptr<OpParser> &op_parser) { | |||||
| const shared_ptr<OpParser> &op_parser) { | |||||
| GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| GE_CHECK_NOTNULL(op_parser); | GE_CHECK_NOTNULL(op_parser); | ||||
| @@ -459,8 +459,8 @@ Status TensorFlowModelParser::ParseOpParams(const domi::tensorflow::NodeDef *nod | |||||
| ge::Operator op_src(node_def->name(), node_def->op()); | ge::Operator op_src(node_def->name(), node_def->op()); | ||||
| status = domi::AutoMappingFn(node_def, op_src); | status = domi::AutoMappingFn(node_def, op_src); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", | |||||
| node_def->name().c_str(), node_def->op().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", node_def->name().c_str(), | |||||
| node_def->op().c_str()); | |||||
| GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str()); | GELOGE(status, "Node[%s] auto mapping failed.", node_name.c_str()); | ||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -585,8 +585,8 @@ Status TensorFlowModelParser::AddNode(const domi::tensorflow::NodeDef *node_def, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void TensorFlowModelParser::GetInputOutputTensorNum(ge::OpDescPtr &op_desc, size_t &input_tensor_num, | |||||
| size_t &output_tensor_num) const { | |||||
| void TensorFlowModelParser::GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num, | |||||
| size_t &output_tensor_num) { | |||||
| // The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
| auto iter = op_node_context_map_.find(op_desc->GetName()); | auto iter = op_node_context_map_.find(op_desc->GetName()); | ||||
| if (iter == op_node_context_map_.end()) { | if (iter == op_node_context_map_.end()) { | ||||
| @@ -615,8 +615,6 @@ void TensorFlowModelParser::GetInputOutputTensorNum(ge::OpDescPtr &op_desc, size | |||||
| } | } | ||||
| } | } | ||||
| output_tensor_num = max_anchor_index + 1; | output_tensor_num = max_anchor_index + 1; | ||||
| return; | |||||
| } | } | ||||
| Status TensorFlowModelParser::CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) { | Status TensorFlowModelParser::CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node) { | ||||
| @@ -777,8 +775,8 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(dest); | GE_CHECK_NOTNULL(dest); | ||||
| if (src_output_iter.second.size() != input_iter->second.size()) { | if (src_output_iter.second.size() != input_iter->second.size()) { | ||||
| REPORT_INNER_ERROR("E19999", "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.", | REPORT_INNER_ERROR("E19999", "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.", | ||||
| src_op_name.c_str(), input_iter->second.size(), | |||||
| dest_op_name.c_str(), src_output_iter.second.size()); | |||||
| src_op_name.c_str(), input_iter->second.size(), dest_op_name.c_str(), | |||||
| src_output_iter.second.size()); | |||||
| GELOGE(INTERNAL_ERROR, "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.", | GELOGE(INTERNAL_ERROR, "Input size of op[%s]:%zu is not equal to Output size of op[%s]:%zu.", | ||||
| src_op_name.c_str(), input_iter->second.size(), dest_op_name.c_str(), src_output_iter.second.size()); | src_op_name.c_str(), input_iter->second.size(), dest_op_name.c_str(), src_output_iter.second.size()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -805,13 +803,11 @@ Status TensorFlowModelParser::AddEdges(ge::ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(in_archor_ptr); | GE_CHECK_NOTNULL(in_archor_ptr); | ||||
| ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | ge::OutControlAnchorPtr out_archor_ptr = src->GetOutControlAnchor(); | ||||
| GE_CHECK_NOTNULL(out_archor_ptr); | GE_CHECK_NOTNULL(out_archor_ptr); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
| REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||||
| src->GetName().c_str(), dest->GetName().c_str()); | |||||
| return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", src->GetName().c_str(), | |||||
| dest->GetName().c_str() | |||||
| ); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::GraphUtils::AddEdge(out_archor_ptr, in_archor_ptr) != ge::GRAPH_SUCCESS, | |||||
| REPORT_INNER_ERROR("E19999", "Add link from op:%s to op:%s failed", | |||||
| src->GetName().c_str(), dest->GetName().c_str()); | |||||
| return INTERNAL_ERROR, "Add link failed from op[%s] to op[%s].", | |||||
| src->GetName().c_str(), dest->GetName().c_str()); | |||||
| } | } | ||||
| } | } | ||||
| dest_input_map.erase(input_iter); | dest_input_map.erase(input_iter); | ||||
| @@ -845,8 +841,8 @@ Status TensorFlowModelParser::CheckOpShapeDim(const domi::tensorflow::NodeDef *n | |||||
| GE_IF_BOOL_EXEC(!is_attr_exist, return SUCCESS); | GE_IF_BOOL_EXEC(!is_attr_exist, return SUCCESS); | ||||
| GE_CHK_BOOL_EXEC(input_attr_value.has_list(), | GE_CHK_BOOL_EXEC(input_attr_value.has_list(), | ||||
| REPORT_INNER_ERROR("E19999", "Attr:%s of node_def:%s(%s) is empty, check invalid", | REPORT_INNER_ERROR("E19999", "Attr:%s of node_def:%s(%s) is empty, check invalid", | ||||
| ge::parser::ATTR_NAME_INPUT_TENSOR_DESC.c_str(), | |||||
| node_def->name().c_str(), node_def->op().c_str()); | |||||
| ge::parser::ATTR_NAME_INPUT_TENSOR_DESC.c_str(), node_def->name().c_str(), | |||||
| node_def->op().c_str()); | |||||
| return PARAM_INVALID, "output attr value vector is empty"); | return PARAM_INVALID, "output attr value vector is empty"); | ||||
| // list contain many TensorDescriptors | // list contain many TensorDescriptors | ||||
| @@ -924,11 +920,10 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co | |||||
| } | } | ||||
| auto iterator = parser->adaptedOpTypeMap_.find(node_name); | auto iterator = parser->adaptedOpTypeMap_.find(node_name); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iterator == parser->adaptedOpTypeMap_.end(), | |||||
| REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", | |||||
| node_name.c_str()); | |||||
| return FAILED, | |||||
| "get adapted op type failed, node name = %s", node_name.c_str()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| iterator == parser->adaptedOpTypeMap_.end(), | |||||
| REPORT_INNER_ERROR("E19999", "get adapted op type failed, node name = %s", node_name.c_str()); | |||||
| return FAILED, "get adapted op type failed, node name = %s", node_name.c_str()); | |||||
| string op_type = iterator->second; | string op_type = iterator->second; | ||||
| // Log printing for determining operator type | // Log printing for determining operator type | ||||
| @@ -1021,11 +1016,10 @@ Status TensorFlowModelParser::ParseNodeDef(TensorFlowModelParser *parser, ge::Co | |||||
| node = graph->AddNode(op); | node = graph->AddNode(op); | ||||
| } | } | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node == nullptr), | |||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||||
| op->GetName().c_str(), op->GetType().c_str(), | |||||
| graph->GetName().c_str()); | |||||
| return INTERNAL_ERROR, "add node failed."); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| (node == nullptr), REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op->GetName().c_str(), | |||||
| op->GetType().c_str(), graph->GetName().c_str()); | |||||
| return INTERNAL_ERROR, "add node failed."); | |||||
| if (needFusion) { | if (needFusion) { | ||||
| shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type); | shared_ptr<OpParser> fusion_op_parser = factory->CreateFusionOpParser(op_type); | ||||
| @@ -1121,10 +1115,10 @@ Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &g | |||||
| for (size_t j = 0; j < op_node_list_size; j++) { | for (size_t j = 0; j < op_node_list_size; j++) { | ||||
| const string op_node_name = op_node_name_list[j]; | const string op_node_name = op_node_name_list[j]; | ||||
| auto iterator = node_map_.find(op_node_name); | auto iterator = node_map_.find(op_node_name); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((iterator == node_map_.end()), | |||||
| REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", | |||||
| op_node_name.c_str()); | |||||
| return INTERNAL_ERROR, "add node failed."); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| (iterator == node_map_.end()), | |||||
| REPORT_INNER_ERROR("E19999", "node:%s can't find in node_map_, check invalid", op_node_name.c_str()); | |||||
| return INTERNAL_ERROR, "add node failed."); | |||||
| GE_CHECK_NOTNULL(iterator->second); | GE_CHECK_NOTNULL(iterator->second); | ||||
| GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed"); | GE_CHK_STATUS_RET(iterator->second->SetOwnerComputeGraph(graph), "set owner compute graph failed"); | ||||
| graph->AddNode(iterator->second); | graph->AddNode(iterator->second); | ||||
| @@ -1133,7 +1127,7 @@ Status TensorFlowModelParser::AddNodeToGraphAndMarkFormat(ge::ComputeGraphPtr &g | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *graph_def, | |||||
| Status TensorFlowModelParser::ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *const graph_def, | |||||
| shared_ptr<ge::ScopeGraph> &scope_graph) { | shared_ptr<ge::ScopeGraph> &scope_graph) { | ||||
| // Identifying scope fusion operators based on scope rules | // Identifying scope fusion operators based on scope rules | ||||
| GE_CHECK_NOTNULL(graph_def); | GE_CHECK_NOTNULL(graph_def); | ||||
| @@ -1183,8 +1177,7 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||||
| domi::tensorflow::GraphDef OriDef; | domi::tensorflow::GraphDef OriDef; | ||||
| bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef); | bool read = ge::parser::ReadProtoFromArray(data, static_cast<int>(size), &OriDef); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, | |||||
| REPORT_INNER_ERROR("E19999", "read graph proto from binary failed"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!read, REPORT_INNER_ERROR("E19999", "read graph proto from binary failed"); | |||||
| return INTERNAL_ERROR, "read_proto_from_binary failed."); | return INTERNAL_ERROR, "read_proto_from_binary failed."); | ||||
| domi::tensorflow::GraphDef graph_def; | domi::tensorflow::GraphDef graph_def; | ||||
| @@ -1254,7 +1247,7 @@ Status TensorFlowModelParser::ParseFromMemory(const char *data, uint32_t size, g | |||||
| GELOGD("[TF ParseFromMemory] infer input formats success"); | GELOGD("[TF ParseFromMemory] infer input formats success"); | ||||
| // Building input-output relationship between fusionop and common op | // Building input-output relationship between fusionop and common op | ||||
| GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, graph_def, op_node_name_list)); | |||||
| GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list)); | |||||
| ret = AddFusionNodeDef(scope_graph, op_node_name_list); | ret = AddFusionNodeDef(scope_graph, op_node_name_list); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -1321,13 +1314,13 @@ Status TensorFlowModelParser::GetFunctionProto(const string &file, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::Parse(const char *model_path, ge::Graph &graph) { | |||||
| Status TensorFlowModelParser::Parse(const char *file, ge::Graph &graph) { | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | ||||
| GE_CHECK_NOTNULL(model_path); | |||||
| GE_CHECK_NOTNULL(file); | |||||
| ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | ge::ComputeGraphPtr root_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
| GE_CHECK_NOTNULL(root_graph); | GE_CHECK_NOTNULL(root_graph); | ||||
| Status ret = Parse(model_path, root_graph); | |||||
| Status ret = Parse(file, root_graph); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | GELOGE(ret, "Parser graph %s failed.", graph.GetName().c_str()); | ||||
| return ret; | return ret; | ||||
| @@ -1489,7 +1482,7 @@ Status TensorFlowModelParser::ParseAllGraph(const google::protobuf::Message *pro | |||||
| GELOGD("[TF Parse] infer input formats success"); | GELOGD("[TF Parse] infer input formats success"); | ||||
| // Building input-output relationship between fusionop and common op | // Building input-output relationship between fusionop and common op | ||||
| GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, graph_def, op_node_name_list)); | |||||
| GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list)); | |||||
| GELOGD("[TF Parse] update all node op context success"); | GELOGD("[TF Parse] update all node op context success"); | ||||
| // set user-designate-inputs-order | // set user-designate-inputs-order | ||||
| @@ -1783,8 +1776,7 @@ bool TensorFlowModelParser::MaybeFusionOp(shared_ptr<ge::ScopeGraph> &scope_grap | |||||
| } | } | ||||
| bool TensorFlowModelParser::IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info) { | bool TensorFlowModelParser::IsFusionOpChild(const string &node_name, ge::ScopeFusionOpInfo *info) { | ||||
| GE_CHK_BOOL_EXEC(info != nullptr, | |||||
| REPORT_CALL_ERROR("E19999", "Param info is nullptr, check invalid"); | |||||
| GE_CHK_BOOL_EXEC(info != nullptr, REPORT_CALL_ERROR("E19999", "Param info is nullptr, check invalid"); | |||||
| return false, "fusion info is null."); | return false, "fusion info is null."); | ||||
| // 1.View in full match fusion strategy first | // 1.View in full match fusion strategy first | ||||
| // 2.View in scope fusion policy then | // 2.View in scope fusion policy then | ||||
| @@ -1802,7 +1794,7 @@ bool TensorFlowModelParser::IsFusionOpChild(const string &node_name, ge::ScopeFu | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool TensorFlowModelParser::FusionOpChildIgnore(shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| bool TensorFlowModelParser::FusionOpChildIgnore(const shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| const ge::ScopeFusionOpInfo &info) { | const ge::ScopeFusionOpInfo &info) { | ||||
| GE_CHECK_NOTNULL(scope_graph); | GE_CHECK_NOTNULL(scope_graph); | ||||
| bool ignore = false; | bool ignore = false; | ||||
| @@ -1814,7 +1806,7 @@ bool TensorFlowModelParser::FusionOpChildIgnore(shared_ptr<ge::ScopeGraph> &scop | |||||
| return ignore; | return ignore; | ||||
| } | } | ||||
| bool TensorFlowModelParser::IsFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| bool TensorFlowModelParser::IsFusionOp(const shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| const domi::tensorflow::NodeDef *node_def) { | const domi::tensorflow::NodeDef *node_def) { | ||||
| // The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| @@ -1823,12 +1815,11 @@ bool TensorFlowModelParser::IsFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| Status TensorFlowModelParser::GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | Status TensorFlowModelParser::GetInPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | ||||
| const int32_t old_index, int32_t &new_index) { | const int32_t old_index, int32_t &new_index) { | ||||
| GE_CHECK_NOTNULL(scope_graph); | GE_CHECK_NOTNULL(scope_graph); | ||||
| Status ret; | |||||
| if (info.scope_pass) { | if (info.scope_pass) { | ||||
| auto &impl = scope_graph->impl_; | auto &impl = scope_graph->impl_; | ||||
| ret = impl->GetInputOrOutputIndex(info, old_index, true, new_index); | |||||
| return impl->GetInputOrOutputIndex(info, old_index, true, new_index); | |||||
| } | } | ||||
| return ret; | |||||
| return SUCCESS; | |||||
| } | } | ||||
| Status TensorFlowModelParser::GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | Status TensorFlowModelParser::GetOutPutIndex(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | ||||
| const int32_t old_index, int32_t &new_index) { | const int32_t old_index, int32_t &new_index) { | ||||
| @@ -1862,7 +1853,6 @@ bool TensorFlowModelParser::ConstOpNeedUpdate(const string &op_name) { | |||||
| } | } | ||||
| Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, | Status TensorFlowModelParser::UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, | ||||
| const domi::tensorflow::GraphDef &graph_def, | |||||
| vector<string> &op_node_name_list) { | vector<string> &op_node_name_list) { | ||||
| GE_CHECK_NOTNULL(scope_graph); | GE_CHECK_NOTNULL(scope_graph); | ||||
| vector<string> tmp_op_node_name_list; | vector<string> tmp_op_node_name_list; | ||||
| @@ -2105,9 +2095,9 @@ Status TensorFlowModelParser::NormalizeAllNodeOpContext() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::NormalizeInputOrOutputMap(const string &node_name, | |||||
| std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map) { | |||||
| if (context_map.size() == 0) { | |||||
| Status TensorFlowModelParser::NormalizeInputOrOutputMap( | |||||
| const string &node_name, std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> &context_map) { | |||||
| if (context_map.empty()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -2138,7 +2128,7 @@ Status TensorFlowModelParser::NormalizeInputOrOutputMap(const string &node_name, | |||||
| compare_set.insert(name); | compare_set.insert(name); | ||||
| } | } | ||||
| if (temp_pairs.size() == 0) { | |||||
| if (temp_pairs.empty()) { | |||||
| // If there is no pair, the context can be deleted | // If there is no pair, the context can be deleted | ||||
| iter = context_map.erase(iter); | iter = context_map.erase(iter); | ||||
| continue; | continue; | ||||
| @@ -2175,7 +2165,7 @@ void TensorFlowModelParser::UpdateEdgesControlInfo(const ge::ScopeFusionOpInfo & | |||||
| } | } | ||||
| } | } | ||||
| bool TensorFlowModelParser::GetEdgesControlInfo(const string &node_name, const int32_t index) { | |||||
| bool TensorFlowModelParser::GetEdgesControlInfo(const string &node_name, const int32_t index) const { | |||||
| // If the node name is included, then confirm whether the index is the same | // If the node name is included, then confirm whether the index is the same | ||||
| auto iter = edges_control_map.find(node_name); | auto iter = edges_control_map.find(node_name); | ||||
| if (iter != edges_control_map.end()) { | if (iter != edges_control_map.end()) { | ||||
| @@ -2220,7 +2210,9 @@ Status TensorFlowWeightsParser::ParseFromMemory(const char *data, uint32_t size, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) { return SUCCESS; } | |||||
| Status TensorFlowWeightsParser::Parse(const char *file, ge::Graph &graph) { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { | Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) { | ||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | ||||
| @@ -2296,7 +2288,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||||
| GELOGD("[TF Parser] Get op nodes context from graph success"); | GELOGD("[TF Parser] Get op nodes context from graph success"); | ||||
| // Building input-output relationship between fusionop and common op | // Building input-output relationship between fusionop and common op | ||||
| GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, *graph_def, op_node_name_list)); | |||||
| GE_RETURN_IF_ERROR(UpdateAllNodeOpContext(scope_graph, op_node_name_list)); | |||||
| GELOGI("[TF Parser] TF op node size = %zu.", op_node_name_list.size()); | GELOGI("[TF Parser] TF op node size = %zu.", op_node_name_list.size()); | ||||
| PARSER_TIMESTAMP_START(AddFmkNode); | PARSER_TIMESTAMP_START(AddFmkNode); | ||||
| @@ -2329,8 +2321,7 @@ Status TensorFlowModelParser::ParseProto(const google::protobuf::Message *proto, | |||||
| ge::parser::PassManager iterator_fusion_pass; | ge::parser::PassManager iterator_fusion_pass; | ||||
| try { | try { | ||||
| (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", | |||||
| new ge::IteratorFusionPass(domi::TENSORFLOW)); | |||||
| (void)iterator_fusion_pass.AddPass("ParseProto::IteratorFusionPass", new ge::IteratorFusionPass(domi::TENSORFLOW)); | |||||
| } catch (std::bad_alloc &e) { | } catch (std::bad_alloc &e) { | ||||
| GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -2422,8 +2413,7 @@ Status TensorFlowModelParser::ParseProto(const std::string &serialized_proto, ge | |||||
| return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph); | return ParseProto(reinterpret_cast<const google::protobuf::Message *>(&graph_def), graph); | ||||
| } | } | ||||
| Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, | |||||
| domi::GetGraphCallbackV2 callback, | |||||
| Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | |||||
| ge::ComputeGraphPtr &root_graph) { | ge::ComputeGraphPtr &root_graph) { | ||||
| ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | ErrorManager::GetInstance().SetStage(error_message::kModelCompile, error_message::kParser); | ||||
| ErrorManager::GetInstance().GenWorkStreamIdDefault(); | ErrorManager::GetInstance().GenWorkStreamIdDefault(); | ||||
| @@ -2481,20 +2471,17 @@ Status TensorFlowModelParser::ParseProtoWithSubgraph(const std::string &root_pro | |||||
| Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | Status TensorFlowModelParser::OptimizeIdentityByOutput(map<string, NodeDef *> &nodedef_map, | ||||
| const string &curr_node_name, bool &clear_input_flag) { | const string &curr_node_name, bool &clear_input_flag) { | ||||
| auto context_iter = op_node_context_map_.find(curr_node_name); | auto context_iter = op_node_context_map_.find(curr_node_name); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((context_iter == op_node_context_map_.end()), | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Node:%s can't find in op_node_context_map_, check invalid", | |||||
| curr_node_name.c_str()); | |||||
| return INTERNAL_ERROR, | |||||
| "Can't find op node context."); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| (context_iter == op_node_context_map_.end()), | |||||
| REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||||
| return INTERNAL_ERROR, "Can't find op node context."); | |||||
| OpNodeContext op_node_context = context_iter->second; | OpNodeContext op_node_context = context_iter->second; | ||||
| auto node_def_iter = nodedef_map.find(curr_node_name); | auto node_def_iter = nodedef_map.find(curr_node_name); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((node_def_iter == nodedef_map.end()), | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Node:%s can't find in nodedef_map, check invalid", | |||||
| curr_node_name.c_str()); | |||||
| return INTERNAL_ERROR, "Can't find nodedef"); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| (node_def_iter == nodedef_map.end()), | |||||
| REPORT_INNER_ERROR("E19999", "Node:%s can't find in nodedef_map, check invalid", curr_node_name.c_str()); | |||||
| return INTERNAL_ERROR, "Can't find nodedef"); | |||||
| domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second; | domi::tensorflow::NodeDef *curr_node_def = node_def_iter->second; | ||||
| GE_CHECK_NOTNULL(curr_node_def); | GE_CHECK_NOTNULL(curr_node_def); | ||||
| bool has_out_retval = false; | bool has_out_retval = false; | ||||
| @@ -2565,12 +2552,10 @@ Status TensorFlowModelParser::OptimizeSnapShot(domi::tensorflow::NodeDef *curr_m | |||||
| } | } | ||||
| string curr_node_name = curr_mode_def->name(); | string curr_node_name = curr_mode_def->name(); | ||||
| auto context_iter = op_node_context_map_.find(curr_node_name); | auto context_iter = op_node_context_map_.find(curr_node_name); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((context_iter == op_node_context_map_.end()), | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Node:%s can't find in op_node_context_map_, check invalid", | |||||
| curr_node_name.c_str()); | |||||
| return INTERNAL_ERROR, | |||||
| "Can't find op node context."); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| (context_iter == op_node_context_map_.end()), | |||||
| REPORT_INNER_ERROR("E19999", "Node:%s can't find in op_node_context_map_, check invalid", curr_node_name.c_str()); | |||||
| return INTERNAL_ERROR, "Can't find op node context."); | |||||
| OpNodeContext op_node_context = context_iter->second; | OpNodeContext op_node_context = context_iter->second; | ||||
| std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | std::map<std::string, std::vector<std::pair<int32_t, int32_t>>> output_map = op_node_context.output_map; | ||||
| @@ -2670,9 +2655,39 @@ Status TensorFlowModelParser::GraphDefOptimizeSnapShot(domi::tensorflow::GraphDe | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void TensorFlowModelParser::OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | |||||
| domi::tensorflow::NodeDef *nodeCurrent, | |||||
| bool &clearInputFlag) { | |||||
| Status TensorFlowModelParser::SetDestNodeName(domi::tensorflow::NodeDef *const node_current, | |||||
| domi::tensorflow::NodeDef *const node_dest, | |||||
| const int32_t input_idx, const bool is_control, | |||||
| bool &clear_input_flag) { | |||||
| GELOGI("current node name is %s ", node_current->name().c_str()); | |||||
| clear_input_flag = true; | |||||
| if (is_control) { | |||||
| string node_current_name = node_current->input(0); | |||||
| string current_name; | |||||
| if (CheckInputNodeName(node_current_name, ¤t_name, nullptr, nullptr) != SUCCESS) { | |||||
| GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", node_current_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| current_name = "^" + current_name; | |||||
| GELOGI("set nodeCurrentNameTmp: %s", current_name.c_str()); | |||||
| node_dest->set_input(input_idx, current_name); | |||||
| } else { | |||||
| node_dest->set_input(input_idx, node_current->input(0).c_str()); | |||||
| GELOGD("%s op set input:%s.", node_dest->name().c_str(), node_current->input(0).c_str()); | |||||
| } | |||||
| // DestroyTemporaryVariable node have only one input and one output. | |||||
| // If the number of inputs is greater than 1, all subsequent inputs are | |||||
| // control edge inputs. Therefore, after deleting DestroyTemporaryVariable, | |||||
| // these control edge inputs can be directly connected to nodeDst. | |||||
| for (int i = 1; i < node_current->input_size(); ++i) { | |||||
| node_dest->add_input(node_current->input(i)); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void TensorFlowModelParser::OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def, | |||||
| domi::tensorflow::NodeDef *const nodeCurrent, | |||||
| bool &clearInputFlag) const { | |||||
| // Internal call to ensure that the parameter is not empty. | // Internal call to ensure that the parameter is not empty. | ||||
| GELOGI("DestroyTemporaryVariable optimizing."); | GELOGI("DestroyTemporaryVariable optimizing."); | ||||
| for (int w = 0; w < graph_def->node_size(); w++) { | for (int w = 0; w < graph_def->node_size(); w++) { | ||||
| @@ -2686,40 +2701,20 @@ void TensorFlowModelParser::OptimizeDestroyTemporaryVariable(domi::tensorflow::G | |||||
| GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeDstInputName.c_str()); | GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeDstInputName.c_str()); | ||||
| return; | return; | ||||
| } | } | ||||
| if (nodeDstInputNameTmp == nodeCurrent->name()) { | |||||
| GELOGI("current node name is %s ", nodeCurrent->name().c_str()); | |||||
| clearInputFlag = true; | |||||
| if (isControl) { | |||||
| string nodeCurrentName = nodeCurrent->input(0); | |||||
| string nodeCurrentNameTmp; | |||||
| if (CheckInputNodeName(nodeCurrentName, &nodeCurrentNameTmp, nullptr, nullptr) != SUCCESS) { | |||||
| GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeCurrentName.c_str()); | |||||
| return; | |||||
| } | |||||
| nodeCurrentNameTmp = "^" + nodeCurrentNameTmp; | |||||
| GELOGI("set nodeCurrentNameTmp: %s", nodeCurrentNameTmp.c_str()); | |||||
| nodeDst->set_input(k, nodeCurrentNameTmp); | |||||
| } else { | |||||
| nodeDst->set_input(k, nodeCurrent->input(0).c_str()); | |||||
| GELOGD("%s op set input:%s.", nodeDst->name().c_str(), nodeCurrent->input(0).c_str()); | |||||
| } | |||||
| // DestroyTemporaryVariable node have only one input and one output. | |||||
| // If the number of inputs is greater than 1, all subsequent inputs are | |||||
| // control edge inputs. Therefore, after deleting DestroyTemporaryVariable, | |||||
| // these control edge inputs can be directly connected to nodeDst. | |||||
| if (nodeCurrent->input_size() > 1) { | |||||
| for (int i = 1; i < nodeCurrent->input_size(); ++i) { | |||||
| nodeDst->add_input(nodeCurrent->input(i)); | |||||
| } | |||||
| } | |||||
| GELOGI("Optimize DestroyTemporaryVariable successful."); | |||||
| if (nodeDstInputNameTmp != nodeCurrent->name()) { | |||||
| continue; | |||||
| } | } | ||||
| if (SetDestNodeName(nodeCurrent, nodeDst, k, isControl, clearInputFlag) !=SUCCESS) { | |||||
| GELOGE(FAILED, "CheckInputNodeName failed, node is: %s", nodeCurrent->name().c_str()); | |||||
| return; | |||||
| } | |||||
| GELOGI("Optimize DestroyTemporaryVariable successful."); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| Status TensorFlowModelParser::GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | |||||
| domi::tensorflow::NodeDef *nodeCurrent) { | |||||
| Status TensorFlowModelParser::GraphDefOptimizeDestroyTemporaryVariable( | |||||
| domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *const nodeCurrent) const { | |||||
| if (graph_def == nullptr || nodeCurrent == nullptr) { | if (graph_def == nullptr || nodeCurrent == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param graph_def or nodeCurrent is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param graph_def or nodeCurrent is nullptr, check invalid"); | ||||
| GELOGE(FAILED, "input param is nullptr."); | GELOGE(FAILED, "input param is nullptr."); | ||||
| @@ -2844,7 +2839,7 @@ void TensorFlowModelParser::OptimizeTranspose(std::map<std::string, DelTranspose | |||||
| } | } | ||||
| } | } | ||||
| void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *graph_def) { | |||||
| void TensorFlowModelParser::SoftmaxAddAttr(GraphDef *const graph_def) { | |||||
| // The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
| for (int i = 0; i < graph_def->node_size(); ++i) { | for (int i = 0; i < graph_def->node_size(); ++i) { | ||||
| auto node_def = graph_def->mutable_node(i); | auto node_def = graph_def->mutable_node(i); | ||||
| @@ -2896,7 +2891,7 @@ Status TensorFlowModelParser::GraphDefOptimize(domi::tensorflow::GraphDef *graph | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::RemoveIsolateNode(ge::ComputeGraphPtr &graph) { | |||||
| Status TensorFlowModelParser::RemoveIsolateNode(const ge::ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| auto nodes = graph->GetDirectNode(); | auto nodes = graph->GetDirectNode(); | ||||
| @@ -3022,7 +3017,7 @@ Status TensorFlowModelParser::GetNodeFormat(const NodeDef *node, TfTranspose pre | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc) { | |||||
| Status TensorFlowModelParser::GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc) const { | |||||
| GE_CHECK_NOTNULL(transpose_node); | GE_CHECK_NOTNULL(transpose_node); | ||||
| transpose_direc = NO_TRANSPOSE; | transpose_direc = NO_TRANSPOSE; | ||||
| @@ -3088,7 +3083,7 @@ Status TensorFlowModelParser::TrimGraph(const domi::tensorflow::GraphDef &input_ | |||||
| } | } | ||||
| } | } | ||||
| Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def, | Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def, | ||||
| domi::tensorflow::GraphDef *output_graph_def) { | |||||
| domi::tensorflow::GraphDef *const output_graph_def) { | |||||
| // The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
| std::set<string> delete_nodes; | std::set<string> delete_nodes; | ||||
| std::set<string> input_nodes; | std::set<string> input_nodes; | ||||
| @@ -3108,8 +3103,8 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||||
| for (const string ¤t_input : current_inputs) { | for (const string ¤t_input : current_inputs) { | ||||
| delete_nodes.insert(current_input); | delete_nodes.insert(current_input); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E10016", {"parameter", "opname"}, {"input_shape", current_input}); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
| {"input_shape", current_input}); | |||||
| return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | ||||
| const NodeDef *current_node = node_lookup[current_input]; | const NodeDef *current_node = node_lookup[current_input]; | ||||
| GE_CHECK_NOTNULL(current_node); | GE_CHECK_NOTNULL(current_node); | ||||
| @@ -3160,7 +3155,7 @@ Status TensorFlowModelParser::TrimGraphByInput(const domi::tensorflow::GraphDef | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def, | Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def, | ||||
| domi::tensorflow::GraphDef *output_graph_def) { | |||||
| domi::tensorflow::GraphDef *const output_graph_def) { | |||||
| // The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
| std::set<string> required_nodes; | std::set<string> required_nodes; | ||||
| std::set<string> input_nodes; | std::set<string> input_nodes; | ||||
| @@ -3185,8 +3180,8 @@ Status TensorFlowModelParser::TrimGraphByOutput(const domi::tensorflow::GraphDef | |||||
| required_nodes.insert(current_input); | required_nodes.insert(current_input); | ||||
| GE_IF_BOOL_EXEC(input_nodes.count(current_input), continue); | GE_IF_BOOL_EXEC(input_nodes.count(current_input), continue); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!node_lookup.count(current_input), | ||||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||||
| "E10016", {"parameter", "opname"}, {"out_nodes", current_input}); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, | |||||
| {"out_nodes", current_input}); | |||||
| return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | return FAILED, "Input op[%s] not found in graph.", current_input.c_str()); | ||||
| const NodeDef *current_node = node_lookup[current_input]; | const NodeDef *current_node = node_lookup[current_input]; | ||||
| GE_CHECK_NOTNULL(current_node); | GE_CHECK_NOTNULL(current_node); | ||||
| @@ -3246,7 +3241,8 @@ string TensorFlowModelParser::NodeNameFromInput(const string &input_name) { | |||||
| } | } | ||||
| Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_parser, | Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_parser, | ||||
| const domi::tensorflow::NodeDef *node_def, ge::NodePtr &node) { | |||||
| const domi::tensorflow::NodeDef *node_def, | |||||
| ge::NodePtr &node) const { | |||||
| GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_CHECK_NOTNULL(op_parser); | GE_CHECK_NOTNULL(op_parser); | ||||
| @@ -3263,12 +3259,11 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par | |||||
| // Find all children of the fusion operator | // Find all children of the fusion operator | ||||
| auto iter = fusion_op_nodedef_map_.find(node_def->name()); | auto iter = fusion_op_nodedef_map_.find(node_def->name()); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter == fusion_op_nodedef_map_.end(), | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Node:%s can't find in fusion_op_nodedef_map_, check invalid", | |||||
| node_def->name().c_str()); | |||||
| return INTERNAL_ERROR, | |||||
| "FusionOp node %s has no children node!", node_def->name().c_str()); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
| iter == fusion_op_nodedef_map_.end(), | |||||
| REPORT_INNER_ERROR("E19999", "Node:%s can't find in fusion_op_nodedef_map_, check invalid", | |||||
| node_def->name().c_str()); | |||||
| return INTERNAL_ERROR, "FusionOp node %s has no children node!", node_def->name().c_str()); | |||||
| (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op()); | (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ge::ATTR_NAME_FUSIONOP_ORIGINAL_TYPE, node_def->op()); | ||||
| vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second; | vector<const domi::tensorflow::NodeDef *> node_def_v = iter->second; | ||||
| @@ -3284,8 +3279,8 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par | |||||
| ge::Operator op_src(node_def_src->name(), node_def_src->op()); | ge::Operator op_src(node_def_src->name(), node_def_src->op()); | ||||
| status = domi::AutoMappingFn(node_def_src, op_src); | status = domi::AutoMappingFn(node_def_src, op_src); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", | |||||
| node_def_src->name().c_str(), node_def_src->op().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Auto mapping node_def:%s(%s) to operator failed", node_def_src->name().c_str(), | |||||
| node_def_src->op().c_str()); | |||||
| GELOGE(status, "Node[%s] auto mapping failed", node_def_src->name().c_str()); | GELOGE(status, "Node[%s] auto mapping failed", node_def_src->name().c_str()); | ||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -3295,8 +3290,8 @@ Status TensorFlowModelParser::FusionNodeParseParams(shared_ptr<OpParser> &op_par | |||||
| ge::GeTensorDesc tensor_desc; | ge::GeTensorDesc tensor_desc; | ||||
| tensor_desc.SetName(node_def_src->input(i)); | tensor_desc.SetName(node_def_src->input(i)); | ||||
| if (op_desc->AddInputDesc(tensor_desc) != GRAPH_SUCCESS) { | if (op_desc->AddInputDesc(tensor_desc) != GRAPH_SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed", op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str()); | |||||
| GELOGE(FAILED, "Op [%s] type[%s] add input(%d) tensor failed.", op_desc->GetName().c_str(), | GELOGE(FAILED, "Op [%s] type[%s] add input(%d) tensor failed.", op_desc->GetName().c_str(), | ||||
| op_desc->GetType().c_str(), i); | op_desc->GetType().c_str(), i); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -3392,8 +3387,8 @@ Status TensorFlowModelParser::OptimizeConstNodes4CustomOp(domi::tensorflow::Grap | |||||
| auto inputs = current_node->input(); | auto inputs = current_node->input(); | ||||
| if (static_cast<size_t>(inputs.size()) != it.input_order.size()) { | if (static_cast<size_t>(inputs.size()) != it.input_order.size()) { | ||||
| REPORT_INNER_ERROR("E19999", "Input size of node:%s(%s) is mismatched, new order size:%zu, input size:%d", | REPORT_INNER_ERROR("E19999", "Input size of node:%s(%s) is mismatched, new order size:%zu, input size:%d", | ||||
| current_node->name().c_str(), current_node->op().c_str(), | |||||
| it.input_order.size(), inputs.size()); | |||||
| current_node->name().c_str(), current_node->op().c_str(), it.input_order.size(), | |||||
| inputs.size()); | |||||
| GELOGE(INTERNAL_ERROR, "Size of input is mismatched, new order size is %zu, input size is %d.", | GELOGE(INTERNAL_ERROR, "Size of input is mismatched, new order size is %zu, input size is %d.", | ||||
| it.input_order.size(), inputs.size()); | it.input_order.size(), inputs.size()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -3462,8 +3457,7 @@ Status TensorFlowModelParser::AddControlEdgeAfterRemoveInputs(domi::tensorflow:: | |||||
| * @return false remove failed | * @return false remove failed | ||||
| * | * | ||||
| */ | */ | ||||
| Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, | |||||
| domi::tensorflow::NodeDef *node_def, | |||||
| Status TensorFlowModelParser::RemoveInputs(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *node_def, | |||||
| const set<uint32_t> &remove_index_set, | const set<uint32_t> &remove_index_set, | ||||
| const map<string, NodeDef *> &all_node_map) { | const map<string, NodeDef *> &all_node_map) { | ||||
| GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
| @@ -3605,7 +3599,7 @@ Status TensorFlowModelParser::RemoveIsolateNode(domi::tensorflow::GraphDef *grap | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::RecordFusionResult(std::shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| Status TensorFlowModelParser::RecordFusionResult(const std::shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| const domi::tensorflow::NodeDef *node, ge::OpDescPtr &op_desc) { | const domi::tensorflow::NodeDef *node, ge::OpDescPtr &op_desc) { | ||||
| // The caller guarantees that the pointer is not null | // The caller guarantees that the pointer is not null | ||||
| GELOGI("RecordFusionResult for %s start.", op_desc->GetName().c_str()); | GELOGI("RecordFusionResult for %s start.", op_desc->GetName().c_str()); | ||||
| @@ -3922,8 +3916,10 @@ Status TensorFlowModelParser::AddFusionNodeDef(shared_ptr<ge::ScopeGraph> &scope | |||||
| } else { | } else { | ||||
| Status ret = AddFusionInnerNodeDef(scope_graph, op_node_name, node_name_list_new); | Status ret = AddFusionInnerNodeDef(scope_graph, op_node_name, node_name_list_new); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_INNER_ERROR("E19999", "Failed to add fusion inner nodes for fusion op:%s, " | |||||
| "please check FusionScopesResult set in scope fusion pass", op_node_name.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", | |||||
| "Failed to add fusion inner nodes for fusion op:%s, " | |||||
| "please check FusionScopesResult set in scope fusion pass", | |||||
| op_node_name.c_str()); | |||||
| GELOGE(ret, "Failed to add fusion inner node, fusion_op_name:%s.", op_node_name.c_str()); | GELOGE(ret, "Failed to add fusion inner node, fusion_op_name:%s.", op_node_name.c_str()); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -3960,8 +3956,8 @@ Status TensorFlowModelParser::AddScopeInnerNode(TensorFlowModelParser *parser, g | |||||
| node = graph->AddNode(op_desc); | node = graph->AddNode(op_desc); | ||||
| } | } | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Failed to Add scope inner node:%s, type:%s.", op_desc->GetName().c_str(), | GELOGE(INTERNAL_ERROR, "Failed to Add scope inner node:%s, type:%s.", op_desc->GetName().c_str(), | ||||
| op_desc->GetType().c_str()); | op_desc->GetType().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -3990,7 +3986,7 @@ void TensorFlowModelParser::DumpNodeContext(const string &node_name, const OpNod | |||||
| GELOGD("phase:%s === End to dump context for node:%s ===", phase.c_str(), node_name.c_str()); | GELOGD("phase:%s === End to dump context for node:%s ===", phase.c_str(), node_name.c_str()); | ||||
| } | } | ||||
| void TensorFlowModelParser::DumpAllNodeContext(const string &phase) { | |||||
| void TensorFlowModelParser::DumpAllNodeContext(const string &phase) const { | |||||
| if (!IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { | if (!IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -4048,7 +4044,7 @@ Status TensorFlowModelParser::UpdateOutputsInfo(const ParserUtils::OutputMapping | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowModelParser::AddExternalGraph(ComputeGraphPtr &root_graph) { | |||||
| Status TensorFlowModelParser::AddExternalGraph(const ComputeGraphPtr &root_graph) { | |||||
| GE_CHECK_NOTNULL(root_graph); | GE_CHECK_NOTNULL(root_graph); | ||||
| for (const NodePtr &node : root_graph->GetAllNodes()) { | for (const NodePtr &node : root_graph->GetAllNodes()) { | ||||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | if (node == nullptr || node->GetOpDesc() == nullptr) { | ||||
| @@ -77,7 +77,7 @@ struct DelTransposeInfo; | |||||
| class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | ||||
| public: | public: | ||||
| TensorFlowModelParser() {} | TensorFlowModelParser() {} | ||||
| virtual ~TensorFlowModelParser() {} | |||||
| ~TensorFlowModelParser() {} | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -137,7 +137,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| */ | */ | ||||
| ge::DataType ConvertToGeDataType(const uint32_t type) override; | ge::DataType ConvertToGeDataType(const uint32_t type) override; | ||||
| Status ParseAllGraph(const google::protobuf::Message *root_proto, ge::ComputeGraphPtr &root_graph) override ; | |||||
| Status ParseAllGraph(const google::protobuf::Message *proto, ge::ComputeGraphPtr &graph) override ; | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -158,10 +158,10 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| * @return SUCCESS | * @return SUCCESS | ||||
| * @return Others failed | * @return Others failed | ||||
| */ | */ | ||||
| Status ParseProtoWithSubgraph(const std::string &serialized_proto, domi::GetGraphCallbackV2 callback, | |||||
| ge::ComputeGraphPtr &graph) override; | |||||
| Status ParseProtoWithSubgraph(const std::string &root_proto, domi::GetGraphCallbackV2 callback, | |||||
| ge::ComputeGraphPtr &root_graph) override; | |||||
| private: | private: | ||||
| Status Parse(const char *file, ge::ComputeGraphPtr &graph); | |||||
| Status Parse(const char *model_path, ge::ComputeGraphPtr &root_graph); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -254,7 +254,8 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| bool ConstOpNeedUpdate(const string &op_name); | bool ConstOpNeedUpdate(const string &op_name); | ||||
| Status ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *graph_def, shared_ptr<ge::ScopeGraph> &scope_graph); | |||||
| static Status ExcuteScopeFusionPasses(domi::tensorflow::GraphDef *const graph_def, | |||||
| shared_ptr<ge::ScopeGraph> &scope_graph); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| * @brief Run the scope fusion optimizer in list scope_passes_list | * @brief Run the scope fusion optimizer in list scope_passes_list | ||||
| @@ -264,7 +265,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| * @return SUCCESS Run successfully | * @return SUCCESS Run successfully | ||||
| * @return others Run failed | * @return others Run failed | ||||
| */ | */ | ||||
| Status RunScopeFusionPass(const vector<string> &scope_passes_list, | |||||
| static Status RunScopeFusionPass(const vector<string> &scope_passes_list, | |||||
| ScopePassManager &pass_manager, | ScopePassManager &pass_manager, | ||||
| shared_ptr<ge::ScopeGraph> &scope_graph); | shared_ptr<ge::ScopeGraph> &scope_graph); | ||||
| @@ -284,10 +285,10 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| /** | /** | ||||
| * @brief Inner child operators of fusion operators | * @brief Inner child operators of fusion operators | ||||
| */ | */ | ||||
| bool FusionOpChildIgnore(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info); | |||||
| static bool FusionOpChildIgnore(const shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info); | |||||
| // Is it a fusion operator | // Is it a fusion operator | ||||
| bool IsFusionOp(shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def); | |||||
| static bool IsFusionOp(const shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::NodeDef *node_def); | |||||
| /** | /** | ||||
| * @brief get inPut index of the fusion operator | * @brief get inPut index of the fusion operator | ||||
| @@ -321,8 +322,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| * @return FAILED | * @return FAILED | ||||
| */ | */ | ||||
| Status UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, const domi::tensorflow::GraphDef &graph_def, | |||||
| vector<string> &op_node_name_list); | |||||
| Status UpdateAllNodeOpContext(shared_ptr<ge::ScopeGraph> &scope_graph, vector<string> &op_node_name_list); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -392,7 +392,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| * @brief get contral information | * @brief get contral information | ||||
| */ | */ | ||||
| bool GetEdgesControlInfo(const string &node_name, const int32_t index); | |||||
| bool GetEdgesControlInfo(const string &node_name, const int32_t index) const; | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -404,7 +404,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| * @return FAILED | * @return FAILED | ||||
| */ | */ | ||||
| Status CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index, bool *control); | |||||
| static Status CheckInputNodeName(const string &input_node_name, string *node_name, int32_t *index, bool *control); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -416,7 +416,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| * @return FAILED | * @return FAILED | ||||
| */ | */ | ||||
| Status GeStoi(const string &input_node_name, const string &index_str, int32_t *index); | |||||
| static Status GeStoi(const string &input_node_name, const string &index_str, int32_t *index); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -455,19 +455,24 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | Status GraphDefOptimizeSnapShot(domi::tensorflow::GraphDef *graph_def, map<string, NodeDef *> &nodedef_map, | ||||
| const vector<NodeDef *> &nodedef_to_optimize); | const vector<NodeDef *> &nodedef_to_optimize); | ||||
| Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | Status GraphDefOptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, | ||||
| domi::tensorflow::NodeDef *nodeCurrent); | |||||
| domi::tensorflow::NodeDef *const nodeCurrent) const; | |||||
| Status OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, map<string, NodeDef *> &nodedef_map, | Status OptimizeSnapShot(domi::tensorflow::NodeDef *curr_mode_def, map<string, NodeDef *> &nodedef_map, | ||||
| const std::pair<string, int> &input_data, const std::vector<string> &control_list); | const std::pair<string, int> &input_data, const std::vector<string> &control_list); | ||||
| void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *graph_def, domi::tensorflow::NodeDef *nodeCurrent, | |||||
| bool &clearInputFlag); | |||||
| void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo); | |||||
| void SoftmaxAddAttr(GraphDef *graph_def); | |||||
| static Status SetDestNodeName(domi::tensorflow::NodeDef *const node_current, | |||||
| domi::tensorflow::NodeDef *const node_dest, const int32_t input_idx, | |||||
| const bool is_control, bool &clear_input_flag); | |||||
| void OptimizeDestroyTemporaryVariable(domi::tensorflow::GraphDef *const graph_def, | |||||
| domi::tensorflow::NodeDef *const nodeCurrent, bool &clearInputFlag) const; | |||||
| static void OptimizeTranspose(std::map<std::string, DelTransposeInfo> &transposeInfo); | |||||
| static void SoftmaxAddAttr(GraphDef *const graph_def); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| * @brief Delete isolated nodes in graph | * @brief Delete isolated nodes in graph | ||||
| */ | */ | ||||
| Status RemoveIsolateNode(ge::ComputeGraphPtr &graph); | |||||
| static Status RemoveIsolateNode(const ge::ComputeGraphPtr &graph); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -489,19 +494,20 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| * @brief Get format transpose. | * @brief Get format transpose. | ||||
| */ | */ | ||||
| Status GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc); | |||||
| Status TrimGraph(const domi::tensorflow::GraphDef &input_graph_def, domi::tensorflow::GraphDef *output_graph_def); | |||||
| Status TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def, | |||||
| Status GetFormatTranspose(const NodeDef *transpose_node, TfTranspose &transpose_direc) const; | |||||
| static Status TrimGraph(const domi::tensorflow::GraphDef &input_graph_def, | |||||
| domi::tensorflow::GraphDef *output_graph_def); | domi::tensorflow::GraphDef *output_graph_def); | ||||
| Status TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def, | |||||
| domi::tensorflow::GraphDef *output_graph_def); | |||||
| string NodeNameFromInput(const string &input_name); | |||||
| static Status TrimGraphByInput(const domi::tensorflow::GraphDef &input_graph_def, | |||||
| domi::tensorflow::GraphDef *const output_graph_def); | |||||
| static Status TrimGraphByOutput(const domi::tensorflow::GraphDef &input_graph_def, | |||||
| domi::tensorflow::GraphDef *const output_graph_def); | |||||
| static string NodeNameFromInput(const string &input_name); | |||||
| Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node); | Status AddTensorDescToOpDesc(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node); | ||||
| Status CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node); | Status CheckoutInputNum(ge::OpDescPtr &op_desc, const domi::tensorflow::NodeDef *node); | ||||
| void UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &input_desc, | |||||
| static void UpdateInputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &input_desc, | |||||
| const size_t input_tensor_num); | const size_t input_tensor_num); | ||||
| void UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &output_desc, | |||||
| static void UpdateOutputTensor(ge::OpDescPtr &op_desc, const std::vector<ge::GeTensorDesc> &output_desc, | |||||
| size_t output_tensor_num); | size_t output_tensor_num); | ||||
| Status TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, const string &op_type); | Status TransNodeToOpDesc(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, const string &op_type); | ||||
| @@ -509,8 +515,9 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); | OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); | ||||
| Status UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | Status UppdateOutputMap(shared_ptr<ge::ScopeGraph> &scope_graph, const ge::ScopeFusionOpInfo &info, | ||||
| OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); | OpNodeContext &fusion_op_node_context, OpNodeContext &normal_op_node_context); | ||||
| void GetInputOutputTensorNum (ge::OpDescPtr &op_desc, size_t &input_tensor_num, size_t &output_tensor_num) const; | |||||
| Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims, bool &valid); | |||||
| void GetInputOutputTensorNum(const ge::OpDescPtr &op_desc, size_t &input_tensor_num, | |||||
| size_t &output_tensor_num); | |||||
| static Status CheckOpShapeDim(const domi::tensorflow::NodeDef *node_def, const std::set<int> &dims, bool &valid); | |||||
| Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type); | Status CheckOpType(const domi::tensorflow::NodeDef *node_def, string &op_type); | ||||
| /** | /** | ||||
| @@ -531,7 +538,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| */ | */ | ||||
| Status FusionNodeParseParams(shared_ptr<OpParser> &op_parser, | Status FusionNodeParseParams(shared_ptr<OpParser> &op_parser, | ||||
| const domi::tensorflow::NodeDef *node_def, ge::NodePtr &node); | |||||
| const domi::tensorflow::NodeDef *node_def, ge::NodePtr &node) const; | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -595,22 +602,22 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| Status GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def); | Status GetTensorflowGraphInOutMap(domi::tensorflow::GraphDef *graph_def); | ||||
| Status RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def); | Status RemoveIsolateNode(domi::tensorflow::GraphDef *graph_def); | ||||
| static Status RecordFusionResult(std::shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| static Status RecordFusionResult(const std::shared_ptr<ge::ScopeGraph> &scope_graph, | |||||
| const domi::tensorflow::NodeDef *node, | const domi::tensorflow::NodeDef *node, | ||||
| ge::OpDescPtr &op_def); | |||||
| ge::OpDescPtr &op_desc); | |||||
| Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library); | |||||
| static Status GetFunctionProto(const string &file, domi::tensorflow::GraphDefLibrary &graph_def_library); | |||||
| Status SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, | Status SetOriginNodeContext(NodeDef *node_def, OpNodeContext &op_node_context, | ||||
| const std::vector<std::pair<std::string, int32_t>> &inputs, | const std::vector<std::pair<std::string, int32_t>> &inputs, | ||||
| const std::vector<std::pair<std::string, int32_t>> &outputs); | const std::vector<std::pair<std::string, int32_t>> &outputs); | ||||
| void GetFusionInputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, | |||||
| static void GetFusionInputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, | |||||
| std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> &remap_data_input, | std::map<string, std::pair<std::string, std::pair<int32_t, int32_t>>> &remap_data_input, | ||||
| std::map<string, std::vector<string>> &remap_ctrl_input, | std::map<string, std::vector<string>> &remap_ctrl_input, | ||||
| std::set<string> &fusion_input_nodes); | std::set<string> &fusion_input_nodes); | ||||
| void GetFusionOutputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, | |||||
| static void GetFusionOutputInfo(const string &fusion_op_name, OpNodeContext &fusion_context, | |||||
| std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> &remap_data_output, | std::map<string, std::vector<std::pair<std::string, std::pair<int32_t, int32_t>>>> &remap_data_output, | ||||
| std::map<string, std::vector<string>> &remap_ctrl_output, | std::map<string, std::vector<string>> &remap_ctrl_output, | ||||
| std::set<string> &fusion_output_nodes); | std::set<string> &fusion_output_nodes); | ||||
| @@ -634,13 +641,14 @@ class PARSER_FUNC_VISIBILITY TensorFlowModelParser : public domi::ModelParser { | |||||
| static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | static Status AddScopeInnerNode(TensorFlowModelParser *parser, ge::ComputeGraphPtr &graph, | ||||
| std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def); | std::mutex *graph_mutex, const domi::tensorflow::NodeDef *node_def); | ||||
| void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase); | |||||
| void DumpAllNodeContext(const string &phase); | |||||
| static void DumpNodeContext(const string &node_name, const OpNodeContext &ctx, const string &phase); | |||||
| void DumpAllNodeContext(const string &phase) const; | |||||
| Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, shared_ptr<OpParser> &op_parser); | |||||
| Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); | |||||
| static Status ParseOpParams(const domi::tensorflow::NodeDef *node_def, ge::OpDescPtr &op, | |||||
| const shared_ptr<OpParser> &op_parser); | |||||
| static Status CheckAndUpdateInputDesc(ge::ComputeGraphPtr &compute_graph); | |||||
| static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); | static Status UpdateOutputsInfo(const ParserUtils::OutputMapping &final_output_nodes); | ||||
| static Status AddExternalGraph(ComputeGraphPtr &root_graph); | |||||
| static Status AddExternalGraph(const ComputeGraphPtr &root_graph); | |||||
| /** | /** | ||||
| * save <node_name, node_def> | * save <node_name, node_def> | ||||
| @@ -51,7 +51,7 @@ class TensorflowParserBuilder; | |||||
| class PARSER_FUNC_VISIBILITY TensorflowWeightParserBuilder : public TensorflowFinalizeable { | class PARSER_FUNC_VISIBILITY TensorflowWeightParserBuilder : public TensorflowFinalizeable { | ||||
| public: | public: | ||||
| virtual ~TensorflowWeightParserBuilder() {} | |||||
| ~TensorflowWeightParserBuilder() override {} | |||||
| }; | }; | ||||
| template <typename Param> | template <typename Param> | ||||
| @@ -64,7 +64,7 @@ class PARSER_FUNC_VISIBILITY TensorflowParserBuilder : public TensorflowWeightPa | |||||
| explicit TensorflowParserBuilder(const std::string &davinci_optype) : davinci_optype_(davinci_optype) {} | explicit TensorflowParserBuilder(const std::string &davinci_optype) : davinci_optype_(davinci_optype) {} | ||||
| ~TensorflowParserBuilder() {} | |||||
| ~TensorflowParserBuilder() override {} | |||||
| TensorflowParserBuilder &SetParseParamsFn(ParseParamsFn parse_params_fn) { | TensorflowParserBuilder &SetParseParamsFn(ParseParamsFn parse_params_fn) { | ||||
| parse_params_fn_ = parse_params_fn; | parse_params_fn_ = parse_params_fn; | ||||
| @@ -95,9 +95,10 @@ class PARSER_FUNC_VISIBILITY TensorflowOpParserAdapter : public TensorFlowOpPars | |||||
| using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>; | using ParseParamsFn = std::function<domi::Status(const domi::tensorflow::NodeDef *, Param *)>; | ||||
| public: | public: | ||||
| TensorflowOpParserAdapter(TensorflowParserBuilder<Param> builder) { parse_params_fn_ = builder.parse_params_fn_; } | |||||
| explicit TensorflowOpParserAdapter(TensorflowParserBuilder<Param> builder) { | |||||
| parse_params_fn_ = builder.parse_params_fn_; } | |||||
| ~TensorflowOpParserAdapter() {} | |||||
| ~TensorflowOpParserAdapter() override {} | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override { | ||||
| const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src); | const domi::tensorflow::NodeDef *node = static_cast<const domi::tensorflow::NodeDef *>(op_src); | ||||
| @@ -55,9 +55,9 @@ Status TensorFlowRefSwitchParser::ParseT(const domi::tensorflow::NodeDef *node, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TensorFlowRefSwitchParser::ParseParams(const Message *opSrc, ge::OpDescPtr &opDest) { | |||||
| GE_CHECK_NOTNULL(opSrc); | |||||
| const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(opSrc); | |||||
| Status TensorFlowRefSwitchParser::ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) { | |||||
| GE_CHECK_NOTNULL(op_src); | |||||
| const NodeDef *node = DOMI_DYNAMIC_CAST<const NodeDef *>(op_src); | |||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| RefSwitchOperator op; | RefSwitchOperator op; | ||||
| @@ -70,7 +70,7 @@ Status TensorFlowRefSwitchParser::ParseParams(const Message *opSrc, ge::OpDescPt | |||||
| GE_RETURN_IF_ERROR(PostParseParams(node, &op)); | GE_RETURN_IF_ERROR(PostParseParams(node, &op)); | ||||
| Status status = ConvertToOpDesc(op, opDest); | |||||
| Status status = ConvertToOpDesc(op, op_dest); | |||||
| return status; | return status; | ||||
| } | } | ||||
| @@ -65,7 +65,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowRefSwitchParser : public TensorFlowOpPars | |||||
| * @return SUCCESS 解析成功 | * @return SUCCESS 解析成功 | ||||
| * @return FAILED 解析失败 | * @return FAILED 解析失败 | ||||
| */ | */ | ||||
| Status ParseT(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op); | |||||
| static Status ParseT(const domi::tensorflow::NodeDef *node, RefSwitchOperator *op); | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | // AUTO GEN PLEASE DO NOT MODIFY IT | ||||
| }; | }; | ||||
| @@ -70,7 +70,7 @@ Status TensorFlowReshapeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
| domi::tensorflow::AttrValue output_attr_value; | domi::tensorflow::AttrValue output_attr_value; | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| GetParserContext().train_flag == true, | |||||
| GetParserContext().train_flag, | |||||
| ge::GeTensorDesc input_desc; | ge::GeTensorDesc input_desc; | ||||
| ge::GeTensorDesc output_desc; | ge::GeTensorDesc output_desc; | ||||
| @@ -22,7 +22,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser { | class PARSER_FUNC_VISIBILITY TensorFlowReshapeParser : public TensorFlowOpParser { | ||||
| private: | private: | ||||
| Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | |||||
| static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | |||||
| public: | public: | ||||
| /** | /** | ||||
| @@ -122,8 +122,8 @@ Status TensorFlowShapeNParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
| domi::tensorflow::AttrValue output_attr_value; | domi::tensorflow::AttrValue output_attr_value; | ||||
| if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { | if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_OUTPUT_TENSOR_DESC, output_attr_value)) { | ||||
| GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
| TensorFlowUtil::TransTensorDescriptor(output_attr_value, &op, TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG), | |||||
| "trans output_attr_value failed, op: %s", node->name().c_str()); | |||||
| TensorFlowUtil::TransTensorDescriptor(output_attr_value, &op, TENSORFLOW_NORMAL_OUTPUT_TENSOR_FLAG), | |||||
| "trans output_attr_value failed, op: %s", node->name().c_str()); | |||||
| ret = ConvertToOpDesc(op, op_dest); | ret = ConvertToOpDesc(op, op_dest); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return ret; | return ret; | ||||
| @@ -32,9 +32,9 @@ class PARSER_FUNC_VISIBILITY TensorFlowShapeNParser : public TensorFlowOpParser | |||||
| Status PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | Status PreParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | ||||
| Status PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | Status PostParseParams(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | ||||
| Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
| Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
| Status ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
| static Status ParseN(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
| static Status ParseInType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
| static Status ParseOutType(const domi::tensorflow::NodeDef *node, ShapeNOperator *op); | |||||
| // AUTO GEN PLEASE DO NOT MODIFY IT | // AUTO GEN PLEASE DO NOT MODIFY IT | ||||
| }; | }; | ||||
| @@ -116,7 +116,7 @@ Status TensorFlowSqueezeParser::ParseParams(const Message *op_src, ge::OpDescPtr | |||||
| domi::tensorflow::AttrValue output_attr_value; | domi::tensorflow::AttrValue output_attr_value; | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| GetParserContext().train_flag == true, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; | |||||
| GetParserContext().train_flag, ge::GeTensorDesc input_desc; ge::GeTensorDesc output_desc; | |||||
| if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { | if (TensorFlowUtil::FindAttrValue(node, ge::ATTR_NAME_INPUT_TENSOR_DESC, input_attr_value)) { | ||||
| GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed"); | GE_CHK_BOOL_RET_STATUS(ParseDesc(input_attr_value, input_desc) == SUCCESS, FAILED, "parse input desc failed"); | ||||
| @@ -25,7 +25,7 @@ class PARSER_FUNC_VISIBILITY TensorFlowSqueezeParser : public TensorFlowOpParser | |||||
| Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | Status ParseParams(const Message *op_src, ge::OpDescPtr &op_dest) override; | ||||
| private: | private: | ||||
| Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | |||||
| static Status ParseDesc(const domi::tensorflow::AttrValue &attr_value, ge::GeTensorDesc &ge_desc); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -31,9 +31,9 @@ | |||||
| using domi::tensorflow::DT_INVALID; | using domi::tensorflow::DT_INVALID; | ||||
| namespace ge { | namespace ge { | ||||
| using AttrValueMap = ::google::protobuf::Map<string, domi::tensorflow::AttrValue>; | |||||
| using AttrValueMap = ::google::protobuf::Map<std::string, domi::tensorflow::AttrValue>; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrValue( | ||||
| const domi::tensorflow::NodeDef *node_def, const string &attr_name, domi::tensorflow::AttrValue &attr_value) { | |||||
| const domi::tensorflow::NodeDef *node_def, const std::string &attr_name, domi::tensorflow::AttrValue &attr_value) { | |||||
| GE_CHECK_NOTNULL(node_def); | GE_CHECK_NOTNULL(node_def); | ||||
| const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr(); | const google::protobuf::Map<std::string, domi::tensorflow::AttrValue> &attr = node_def->attr(); | ||||
| const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | const google::protobuf::Map<std::string, domi::tensorflow::AttrValue>::const_iterator it = attr.find(attr_name); | ||||
| @@ -46,7 +46,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::FindAttrVa | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::CheckAttrHasType( | ||||
| const domi::tensorflow::AttrValue &attr_value, const string &type) { | |||||
| const domi::tensorflow::AttrValue &attr_value, const std::string &type) { | |||||
| uint32_t num_set = 0; | uint32_t num_set = 0; | ||||
| #define VALIDATE_FIELD(name, type_string, oneof_case) \ | #define VALIDATE_FIELD(name, type_string, oneof_case) \ | ||||
| do { \ | do { \ | ||||
| @@ -59,7 +59,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch | |||||
| ++num_set; \ | ++num_set; \ | ||||
| } \ | } \ | ||||
| } else if (attr_value.value_case() == domi::tensorflow::AttrValue::oneof_case) { \ | } else if (attr_value.value_case() == domi::tensorflow::AttrValue::oneof_case) { \ | ||||
| if (type != type_string) { \ | |||||
| if (type != (type_string)) { \ | |||||
| GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \ | GELOGE(FAILED, "GeAttrValue had value with type '" type_string "' when '%s' expected", type.c_str()); \ | ||||
| return FAILED; \ | return FAILED; \ | ||||
| } \ | } \ | ||||
| @@ -118,10 +118,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Ch | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::ParseDataType( | ||||
| const NodeDef *node_src, const string &attr_src, domi::tensorflow::DataType &data_type) { | |||||
| const NodeDef *node_src, const std::string &attr_src, domi::tensorflow::DataType &data_type) { | |||||
| GE_CHECK_NOTNULL(node_src); | GE_CHECK_NOTNULL(node_src); | ||||
| string node_name = node_src->name(); | |||||
| std::string node_name = node_src->name(); | |||||
| // Find the value of attr_src from node_src | // Find the value of attr_src from node_src | ||||
| domi::tensorflow::AttrValue attr_value; | domi::tensorflow::AttrValue attr_value; | ||||
| @@ -152,7 +152,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA | |||||
| "In FrameworkOp translate datatype:%d failed, domi cann't support.", tf_datatype); | "In FrameworkOp translate datatype:%d failed, domi cann't support.", tf_datatype); | ||||
| ge_desc.SetDataType(type); | ge_desc.SetDataType(type); | ||||
| int shape_dim_dim = a_list.func(i).attr().at(SERIALIZE_SHAPE).list().i_size(); | int shape_dim_dim = a_list.func(i).attr().at(SERIALIZE_SHAPE).list().i_size(); | ||||
| vector<int64_t> data_dim; | |||||
| std::vector<int64_t> data_dim; | |||||
| for (int j = 0; j < shape_dim_dim; j++) { | for (int j = 0; j < shape_dim_dim; j++) { | ||||
| data_dim.push_back(a_list.func(i).attr().at(SERIALIZE_SHAPE).list().i(j)); | data_dim.push_back(a_list.func(i).attr().at(SERIALIZE_SHAPE).list().i(j)); | ||||
| } | } | ||||
| @@ -162,14 +162,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool TensorFlowUtil::ParseFromA | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::TransTensorDescriptor( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::TransTensorDescriptor( | ||||
| const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const string &type) { | |||||
| const domi::tensorflow::AttrValue &attr_value, ParserOperator *op, const uint32_t io, const std::string &type) { | |||||
| GE_CHECK_NOTNULL(op); | GE_CHECK_NOTNULL(op); | ||||
| if (!attr_value.has_list()) { | if (!attr_value.has_list()) { | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| vector<int32_t> tf_in_type; | |||||
| vector<int32_t> tf_out_type; | |||||
| std::vector<int32_t> tf_in_type; | |||||
| std::vector<int32_t> tf_out_type; | |||||
| // list contain many TensorDescriptors | // list contain many TensorDescriptors | ||||
| domi::tensorflow::AttrValue_ListValue a_list = attr_value.list(); | domi::tensorflow::AttrValue_ListValue a_list = attr_value.list(); | ||||
| for (int32_t i = 0; i < a_list.func_size(); i++) { | for (int32_t i = 0; i < a_list.func_size(); i++) { | ||||
| @@ -193,7 +193,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
| // Adjust shape to fit resnet50 network only. | // Adjust shape to fit resnet50 network only. | ||||
| GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); | GE_IF_BOOL_EXEC((type == ge::parser::FUSEDBATCHNORMGRAD) && (tmp_dim == 0), ge_desc.SetShape(ge::GeShape()); | ||||
| break;); | break;); | ||||
| GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), vector<int64_t> data_dim = {tmp_dim}; | |||||
| GE_IF_BOOL_EXEC((type == ge::parser::MEAN) && (tmp_dim == 0), std::vector<int64_t> data_dim = {tmp_dim}; | |||||
| ge_desc.SetShape(ge::GeShape(data_dim)); break;); | ge_desc.SetShape(ge::GeShape(data_dim)); break;); | ||||
| } | } | ||||
| ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ge::TensorUtils::SetRealDimCnt(ge_desc, ge_desc.GetShape().GetDimNum()); | ||||
| @@ -215,7 +215,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY domi::Status TensorFlowUtil::Tr | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TensorFlowUtil::AddNodeAttr( | ||||
| const string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) { | |||||
| const std::string &attr_name, const domi::tensorflow::AttrValue &value, domi::tensorflow::NodeDef *node_def) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | GE_CHK_BOOL_TRUE_EXEC_INFO(node_def == nullptr, return, "input parameter is null."); | ||||
| node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | node_def->mutable_attr()->insert(AttrValueMap::value_type(attr_name, value)); | ||||
| } | } | ||||
| @@ -36,8 +36,7 @@ | |||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "proto/tensorflow/graph.pb.h" | #include "proto/tensorflow/graph.pb.h" | ||||
| using std::string; | |||||
| using std::vector; | |||||
| using domi::tensorflow::NodeDef; | using domi::tensorflow::NodeDef; | ||||
| using domi::tensorflow::FunctionDef; | using domi::tensorflow::FunctionDef; | ||||
| using domi::tensorflow::AttrValue_ListValue; | using domi::tensorflow::AttrValue_ListValue; | ||||
| @@ -45,79 +44,79 @@ using domi::tensorflow::FunctionDefLibrary; | |||||
| namespace ge { | namespace ge { | ||||
| /***************************TensorFlow attribute type, constant definition*******************************************/ | /***************************TensorFlow attribute type, constant definition*******************************************/ | ||||
| static const string TENSORFLOW_ATTR_TYPE_STRING = "string"; | |||||
| static const string TENSORFLOW_ATTR_TYPE_INT = "int"; | |||||
| static const string TENSORFLOW_ATTR_TYPE_FLOAT = "float"; | |||||
| static const string TENSORFLOW_ATTR_TYPE_BOOL = "bool"; | |||||
| static const string TENSORFLOW_ATTR_TYPE_TYPE = "type"; | |||||
| static const string TENSORFLOW_ATTR_TYPE_SHAPE = "shape"; | |||||
| static const string TENSORFLOW_ATTR_TYPE_TENSOR = "tensor"; | |||||
| static const string TENSORFLOW_ATTR_TYPE_FUNC = "func"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_STRING = "list(string)"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_INT = "list(int)"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_FLOAT = "list(float)"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_BOOL = "list(bool)"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_TYPE = "list(type)"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_SHAPE = "list(shape)"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_TENSOR = "list(tensor)"; | |||||
| static const string TENSORFLOW_ATTR_LIST_TYPE_FUNC = "list(func)"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_STRING = "string"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_INT = "int"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_FLOAT = "float"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_BOOL = "bool"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_TYPE = "type"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_SHAPE = "shape"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_TENSOR = "tensor"; | |||||
| static const std::string TENSORFLOW_ATTR_TYPE_FUNC = "func"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_STRING = "list(string)"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_INT = "list(int)"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_FLOAT = "list(float)"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_BOOL = "list(bool)"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_TYPE = "list(type)"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_SHAPE = "list(shape)"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_TENSOR = "list(tensor)"; | |||||
| static const std::string TENSORFLOW_ATTR_LIST_TYPE_FUNC = "list(func)"; | |||||
| /***************************constant definition*******************************************/ | /***************************constant definition*******************************************/ | ||||
| static const string TENSORFLOW_ATTR_OUTPUT_OP = "output_op"; | |||||
| static const string TENSORFLOW_ATTR_T = "T"; | |||||
| static const string TENSORFLOW_ATTR_N = "N"; | |||||
| static const string TENSORFLOW_ATTR_DATA_FORMAT = "data_format"; | |||||
| static const string TENSORFLOW_ATTR_PADDING = "padding"; | |||||
| static const string TENSORFLOW_ATTR_KSIZE = "ksize"; | |||||
| static const string TENSORFLOW_ATTR_STRIDES = "strides"; | |||||
| static const string TENSORFLOW_ATTR_DILATIONS = "dilations"; | |||||
| static const string TENSORFLOW_ATTR_DTYPE = "dtype"; | |||||
| static const string TENSORFLOW_ATTR_VALUE = "value"; | |||||
| static const string TENSORFLOW_ATTR_TRANSINPUT = "transpose_a"; | |||||
| static const string TENSORFLOW_ATTR_TRANSWEIGHT = "transpose_b"; | |||||
| static const string TENSORFLOW_ATTR_SHAPE = "shape"; | |||||
| static const string TENSORFLOW_ATTR_TIDX = "Tidx"; | |||||
| static const string TENSORFLOW_ATTR_TPADDINGS = "Tpaddings"; | |||||
| static const string TENSORFLOW_ATTR_TMULTIPLES = "Tmultiples"; | |||||
| static const string TENSORFLOW_ATTR_TINDICES = "Tindices"; | |||||
| static const string TENSORFLOW_ATTR_TPARAMS = "Tparams"; | |||||
| static const string TENSORFLOW_ATTR_TAXIS = "Taxis"; | |||||
| static const string TENSORFLOW_ATTR_DSTT = "DstT"; | |||||
| static const string TENSORFLOW_ATTR_SRCT = "SrcT"; | |||||
| static const string TENSORFLOW_ATTR_PERM = "perm"; | |||||
| static const string TENSORFLOW_ATTR_INDEX = "Index"; | |||||
| static const string TENSORFLOW_ATTR_TSHAPE = "Tshape"; | |||||
| static const string TENSORFLOW_ATTR_AXIS = "Axis"; | |||||
| static const string TENSORFLOW_ATTR_BIAS = "bias"; | |||||
| static const string TENSORFLOW_ATTR_DEPTH_RADIUS = "depth_radius"; | |||||
| static const string TENSORFLOW_ATTR_ALPHA = "alpha"; | |||||
| static const string TENSORFLOW_ATTR_BETA = "beta"; | |||||
| static const string TENSORFLOW_ATTR_MODE = "mode"; | |||||
| static const std::string TENSORFLOW_ATTR_OUTPUT_OP = "output_op"; | |||||
| static const std::string TENSORFLOW_ATTR_T = "T"; | |||||
| static const std::string TENSORFLOW_ATTR_N = "N"; | |||||
| static const std::string TENSORFLOW_ATTR_DATA_FORMAT = "data_format"; | |||||
| static const std::string TENSORFLOW_ATTR_PADDING = "padding"; | |||||
| static const std::string TENSORFLOW_ATTR_KSIZE = "ksize"; | |||||
| static const std::string TENSORFLOW_ATTR_STRIDES = "strides"; | |||||
| static const std::string TENSORFLOW_ATTR_DILATIONS = "dilations"; | |||||
| static const std::string TENSORFLOW_ATTR_DTYPE = "dtype"; | |||||
| static const std::string TENSORFLOW_ATTR_VALUE = "value"; | |||||
| static const std::string TENSORFLOW_ATTR_TRANSINPUT = "transpose_a"; | |||||
| static const std::string TENSORFLOW_ATTR_TRANSWEIGHT = "transpose_b"; | |||||
| static const std::string TENSORFLOW_ATTR_SHAPE = "shape"; | |||||
| static const std::string TENSORFLOW_ATTR_TIDX = "Tidx"; | |||||
| static const std::string TENSORFLOW_ATTR_TPADDINGS = "Tpaddings"; | |||||
| static const std::string TENSORFLOW_ATTR_TMULTIPLES = "Tmultiples"; | |||||
| static const std::string TENSORFLOW_ATTR_TINDICES = "Tindices"; | |||||
| static const std::string TENSORFLOW_ATTR_TPARAMS = "Tparams"; | |||||
| static const std::string TENSORFLOW_ATTR_TAXIS = "Taxis"; | |||||
| static const std::string TENSORFLOW_ATTR_DSTT = "DstT"; | |||||
| static const std::string TENSORFLOW_ATTR_SRCT = "SrcT"; | |||||
| static const std::string TENSORFLOW_ATTR_PERM = "perm"; | |||||
| static const std::string TENSORFLOW_ATTR_INDEX = "Index"; | |||||
| static const std::string TENSORFLOW_ATTR_TSHAPE = "Tshape"; | |||||
| static const std::string TENSORFLOW_ATTR_AXIS = "Axis"; | |||||
| static const std::string TENSORFLOW_ATTR_BIAS = "bias"; | |||||
| static const std::string TENSORFLOW_ATTR_DEPTH_RADIUS = "depth_radius"; | |||||
| static const std::string TENSORFLOW_ATTR_ALPHA = "alpha"; | |||||
| static const std::string TENSORFLOW_ATTR_BETA = "beta"; | |||||
| static const std::string TENSORFLOW_ATTR_MODE = "mode"; | |||||
| // op:Const | // op:Const | ||||
| static const string TENSORFLOWF_NODE_OP_CONST = "Const"; | |||||
| static const string TENSORFLOWF_NODE_OP_IDENTITY = "Identity"; | |||||
| static const string TENSORFLOWF_NODE_OP_SWITCH = "Switch"; | |||||
| static const string TENSORFLOWF_NODE_OP_PLACEHOLDER = "Placeholder"; | |||||
| static const string TENSORFLOWF_NODE_OP_ADDN = "AddN"; | |||||
| static const string TENSORFLOWF_NODE_OP_MATMUL = "MatMul"; | |||||
| static const string TENSORFLOWF_NODE_OP_RELU = "Relu"; | |||||
| static const string TENSORFLOWF_NODE_OP_SHAPE = "Shape"; | |||||
| static const string TENSORFLOWF_NODE_OP_TRANSPOSE = "Transpose"; | |||||
| static const string TENSORFLOWF_NODE_OP_MERGE = "Merge"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_CONST = "Const"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_IDENTITY = "Identity"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_SWITCH = "Switch"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_PLACEHOLDER = "Placeholder"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_ADDN = "AddN"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_MATMUL = "MatMul"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_RELU = "Relu"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_SHAPE = "Shape"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_TRANSPOSE = "Transpose"; | |||||
| static const std::string TENSORFLOWF_NODE_OP_MERGE = "Merge"; | |||||
| // data_format | // data_format | ||||
| static const string TENSORFLOWF_TENSOR_NCHW = "NCHW"; | |||||
| static const string TENSORFLOWF_TENSOR_NHWC = "NHWC"; | |||||
| static const std::string TENSORFLOWF_TENSOR_NCHW = "NCHW"; | |||||
| static const std::string TENSORFLOWF_TENSOR_NHWC = "NHWC"; | |||||
| static const int TENSORFLOW_CONV_STRIDE_NUM = 4; | static const int TENSORFLOW_CONV_STRIDE_NUM = 4; | ||||
| static const int TENSORFLOW_CONV_DILATION_NUM = 4; | static const int TENSORFLOW_CONV_DILATION_NUM = 4; | ||||
| // padding | // padding | ||||
| static const string TENSORFLOWF_OP_PADDING_VALID = "VALID"; | |||||
| static const string TENSORFLOWF_OP_PADDING_SAME = "SAME"; | |||||
| static const std::string TENSORFLOWF_OP_PADDING_VALID = "VALID"; | |||||
| static const std::string TENSORFLOWF_OP_PADDING_SAME = "SAME"; | |||||
| // normal input size | // normal input size | ||||
| static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL = 2; | static const uint32_t TENSORFLOW_NORMAL_INPUT_SIZE_MATMUL = 2; | ||||
| @@ -144,7 +143,7 @@ class TensorFlowUtil { | |||||
| * @return false attribute does not exist | * @return false attribute does not exist | ||||
| * | * | ||||
| */ | */ | ||||
| static bool FindAttrValue(const domi::tensorflow::NodeDef *nodeDef, const string &attr_name, | |||||
| static bool FindAttrValue(const domi::tensorflow::NodeDef *node_def, const std::string &attr_name, | |||||
| domi::tensorflow::AttrValue &attr_value); | domi::tensorflow::AttrValue &attr_value); | ||||
| /** | /** | ||||
| @@ -156,7 +155,7 @@ class TensorFlowUtil { | |||||
| * @return FAILED failed | * @return FAILED failed | ||||
| * | * | ||||
| */ | */ | ||||
| static domi::Status CheckAttrHasType(const domi::tensorflow::AttrValue &attr_value, const string &type); | |||||
| static domi::Status CheckAttrHasType(const domi::tensorflow::AttrValue &attr_value, const std::string &type); | |||||
| /** | /** | ||||
| * @ingroup domi_omg | * @ingroup domi_omg | ||||
| @@ -169,7 +168,7 @@ class TensorFlowUtil { | |||||
| * | * | ||||
| */ | */ | ||||
| static domi::Status ParseDataType(const NodeDef *node_src, | static domi::Status ParseDataType(const NodeDef *node_src, | ||||
| const string &attr_src, | |||||
| const std::string &attr_src, | |||||
| domi::tensorflow::DataType &data_type); | domi::tensorflow::DataType &data_type); | ||||
| /** | /** | ||||
| @@ -184,7 +183,7 @@ class TensorFlowUtil { | |||||
| static domi::Status TransTensorDescriptor(const domi::tensorflow::AttrValue &attr_value, | static domi::Status TransTensorDescriptor(const domi::tensorflow::AttrValue &attr_value, | ||||
| ParserOperator *op, | ParserOperator *op, | ||||
| const uint32_t io, | const uint32_t io, | ||||
| const string &type = ""); | |||||
| const std::string &type = ""); | |||||
| /* | /* | ||||
| * @brief 添加NodeDef属性 | * @brief 添加NodeDef属性 | ||||
| * @param [in] attr_name attribute name | * @param [in] attr_name attribute name | ||||
| @@ -193,7 +192,7 @@ class TensorFlowUtil { | |||||
| * @return void | * @return void | ||||
| * | * | ||||
| */ | */ | ||||
| static void AddNodeAttr(const string &attr_name, | |||||
| static void AddNodeAttr(const std::string &attr_name, | |||||
| const domi::tensorflow::AttrValue &value, | const domi::tensorflow::AttrValue &value, | ||||
| domi::tensorflow::NodeDef *node_def); | domi::tensorflow::NodeDef *node_def); | ||||
| @@ -23,7 +23,7 @@ | |||||
| using namespace ge::parser; | using namespace ge::parser; | ||||
| namespace ge { | namespace ge { | ||||
| Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *op) { | |||||
| Status ParseParams(const Message *op_src, VarIsInitializedOpOperator *const op) { | |||||
| GE_CHECK_NOTNULL(op_src); | GE_CHECK_NOTNULL(op_src); | ||||
| const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | const NodeDef *node = reinterpret_cast<const NodeDef *>(op_src); | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| @@ -169,7 +169,7 @@ static Status InitOutTensor(const vector<int64_t> &shape, int64_t data_type, ge: | |||||
| ge::Format format) { | ge::Format format) { | ||||
| out_tensor_desc.SetFormat(format); | out_tensor_desc.SetFormat(format); | ||||
| out_tensor_desc.SetDataType((ge::DataType)data_type); | |||||
| out_tensor_desc.SetDataType(static_cast<ge::DataType>(data_type)); | |||||
| ge::TensorUtils::SetReuseInput(out_tensor_desc, false); | ge::TensorUtils::SetReuseInput(out_tensor_desc, false); | ||||
| ge::TensorUtils::SetRealDimCnt(out_tensor_desc, shape.size()); | ge::TensorUtils::SetRealDimCnt(out_tensor_desc, shape.size()); | ||||
| @@ -180,7 +180,7 @@ static Status InitOutTensor(const vector<int64_t> &shape, int64_t data_type, ge: | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| static Status ParseVarShape(const domi::tensorflow::NodeDef *node, VariableOperator *op) { | |||||
| static Status ParseVarShape(const domi::tensorflow::NodeDef *node, VariableOperator *const op) { | |||||
| // The upper caller guarantees input params is not empty. | // The upper caller guarantees input params is not empty. | ||||
| string node_src_name = node->name(); | string node_src_name = node->name(); | ||||
| domi::tensorflow::AttrValue attr_value; | domi::tensorflow::AttrValue attr_value; | ||||