Merge pull request !35 from yanghaoran/mastertags/v0.5.0-beta
| @@ -16,6 +16,7 @@ | |||
| cmake_minimum_required(VERSION 3.14) | |||
| project (GraphEngine[CXX]) | |||
| set(CMAKE_CXX_STANDARD 14) | |||
| add_compile_definitions(_GLIBCXX_USE_CXX11_ABI=0) | |||
| set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${PROJECT_BINARY_DIR}) | |||
| set(GE_SOURCE_DIR ${CMAKE_CURRENT_LIST_DIR}) | |||
| @@ -71,6 +72,7 @@ elseif(DEFINED ENV{D_LINK_PATH}) | |||
| find_library(register libregister.so ${GE_LIB_PATH}) | |||
| find_library(hccl libhccl.so ${GE_LIB_PATH}) | |||
| find_library(resource libresource.so ${GE_LIB_PATH}) | |||
| find_library(error_manager liberror_manager.so ${GE_LIB_PATH}) | |||
| else() | |||
| # Ascend mode | |||
| if(DEFINED ENV{ASCEND_CUSTOM_PATH}) | |||
| @@ -88,6 +90,7 @@ else() | |||
| find_library(runtime libruntime.so ${ASCEND_RUNTIME_DIR}) | |||
| find_library(register libregister.so ${ASCEND_RUNTIME_DIR}) | |||
| find_library(resource libresource.so ${ASCEND_RUNTIME_DIR}) | |||
| find_library(error_manager liberror_manager.so ${ASCEND_RUNTIME_DIR}) | |||
| endif() | |||
| # add compile flags | |||
| @@ -44,6 +44,9 @@ class GraphOptimizer { | |||
| // optimize original graph, using in graph preparation stage | |||
| virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | |||
| // optimize original graph, using for conversion operator insert in graph preparation stage | |||
| virtual Status OptimizeOriginalGraphJudgeInsert(ComputeGraph &graph) { return SUCCESS; } | |||
| // optimize fused graph | |||
| virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | |||
| @@ -0,0 +1,36 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef COMPRESS_H | |||
| #define COMPRESS_H | |||
| #include <uchar.h> | |||
| enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; | |||
| struct CompressConfig { | |||
| size_t inputSize; // length of data to compress | |||
| size_t engineNum; // how many decompress engines | |||
| size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) | |||
| size_t channel; // channels of L2 or DDR. For load balance | |||
| size_t fractalSize; // size of compressing block | |||
| bool isTight; // whether compose compressed data tightly | |||
| }; | |||
| CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||
| size_t& compressedLength); | |||
| #endif // COMPRESS_H | |||
| @@ -0,0 +1,83 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef ERROR_MANAGER_H_ | |||
| #define ERROR_MANAGER_H_ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| class ErrorManager { | |||
| public: | |||
| /// | |||
| /// @brief Obtain ErrorManager instance | |||
| /// @return ErrorManager instance | |||
| /// | |||
| static ErrorManager &GetInstance(); | |||
| /// | |||
| /// @brief init | |||
| /// @param [in] path current so path | |||
| /// @return int 0(success) -1(fail) | |||
| /// | |||
| int Init(std::string path); | |||
| /// | |||
| /// @brief Report error message | |||
| /// @param [in] errCode error code | |||
| /// @param [in] mapArgs parameter map | |||
| /// @return int 0(success) -1(fail) | |||
| /// | |||
| int ReportErrMessage(std::string error_code, const std::map<std::string, std::string> &args_map); | |||
| /// @brief output error message | |||
| /// @param [in] handle print handle | |||
| /// @return int 0(success) -1(fail) | |||
| /// | |||
| int OutputErrMessage(int handle); | |||
| /// @brief Report error message | |||
| /// @param [in] vector parameter key, vector parameter value | |||
| /// | |||
| void ATCReportErrMessage(std::string error_code, const std::vector<std::string> &key = {}, | |||
| const std::vector<std::string> &value = {}); | |||
| private: | |||
| struct ErrorInfo { | |||
| std::string error_id; | |||
| std::string error_message; | |||
| std::vector<std::string> arglist; | |||
| }; | |||
| ErrorManager() {} | |||
| ~ErrorManager() {} | |||
| ErrorManager(const ErrorManager &) = delete; | |||
| ErrorManager(ErrorManager &&) = delete; | |||
| ErrorManager &operator=(const ErrorManager &) = delete; | |||
| ErrorManager &operator=(ErrorManager &&) = delete; | |||
| int ParseJsonFile(std::string path); | |||
| int ReadJsonFile(const std::string &file_path, void *handle); | |||
| bool is_init_ = false; | |||
| std::map<std::string, ErrorInfo> error_map_; | |||
| std::vector<std::string> error_message_evc_; | |||
| }; | |||
| #endif // ERROR_MANAGER_H_ | |||
| @@ -65,6 +65,8 @@ class PlatformInfoManager { | |||
| void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||
| void ParseUnzipOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||
| void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||
| void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||
| @@ -65,6 +65,10 @@ typedef struct tagAiCoreSpec { | |||
| uint64_t ubbankNum; | |||
| uint64_t ubburstInOneBlock; | |||
| uint64_t ubbankGroupNum; | |||
| uint32_t unzipEngines; | |||
| uint32_t unzipMaxRatios; | |||
| uint32_t unzipChannels; | |||
| uint8_t unzipIsTight; | |||
| } AiCoreSpec; | |||
| typedef struct tagAiCoreMemoryRates { | |||
| @@ -82,14 +82,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Session { | |||
| /// @brief run graph in the session with specific session id asynchronously | |||
| /// @param [in] graphId: graph id | |||
| /// @param [in] inputs: input data | |||
| /// @param [out] outputs: output data | |||
| /// @param [out] callback: callback while runing graph has been finished. | |||
| /// The callback function will not be checked. | |||
| /// Please ensure that the implementation of the function is trusted. | |||
| /// @return Status result of function | |||
| /// | |||
| Status RunGraphAsync(uint32_t graphId, const std::vector<ge::TensorInfo> &inputs, | |||
| std::vector<ge::TensorInfo> &outputs, std::function<void(Status)> callback); | |||
| Status RunGraphAsync(uint32_t graphId, const std::vector<ge::InputTensorInfo> &inputs, RunAsyncCallback callback); | |||
| /// | |||
| /// @ingroup ge_graph | |||
| @@ -21,6 +21,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <set> | |||
| #include <functional> | |||
| #include <memory> | |||
| namespace ge { | |||
| // Option key: graph run mode | |||
| @@ -40,6 +42,12 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||
| const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | |||
| const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | |||
| const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||
| const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; | |||
| const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||
| const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||
| // profiling flag | |||
| const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; | |||
| const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | |||
| // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | |||
| const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | |||
| const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | |||
| @@ -173,6 +181,9 @@ const std::string AICORE_NUM = "ge.aicoreNum"; | |||
| // Configure L1FUSION | |||
| const std::string L1_FUSION = "ge.l1Fusion"; | |||
| // Configure l1,l2,and others optimize option | |||
| const std::string BUFFER_OPTIMIZE = "ge.bufferOptimize"; | |||
| // Configure Small Channel flag | |||
| const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||
| @@ -188,6 +199,9 @@ const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||
| // Save original model file name | |||
| const std::string ORIGINAL_MODEL_FILE = "ge.originalModelFile"; | |||
| // FE enable quant optimize | |||
| const std::string QUANT_OPTIMIZE = "ge.quantOptimize"; | |||
| const char *const OPTION_GE_MAX_DUMP_FILE_NUM = "ge.maxDumpFileNum"; | |||
| const char *const OPTION_GE_MAX_DUMP_FILE_SIZE = "ge.maxDumpFileSize"; | |||
| const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | |||
| @@ -196,36 +210,49 @@ const char *const OPTION_GE_MAX_DUMP_OP_NUM = "ge.maxDumpOpNum"; | |||
| // Its value should be "0" or "1", default value is "1" | |||
| const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | |||
| // Configure whether to use single stream. | |||
| // Its value should be "true" or "false", default value is "false" | |||
| const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | |||
| // Graph run mode | |||
| enum GraphRunMode { PREDICTION = 0, TRAIN }; | |||
| // Data description | |||
| struct DataDesc { | |||
| void *data = nullptr; // data address | |||
| uint32_t length = 0; // data size | |||
| bool isDataSupportMemShare = false; | |||
| // Input/Output tensor info | |||
| struct InputTensorInfo { | |||
| uint32_t data_type; // data type | |||
| std::vector<int64_t> dims; // shape description | |||
| void *data; // tensor data | |||
| int64_t length; // tensor length | |||
| }; | |||
| // Input/Output shape description | |||
| struct ShapeDesc { | |||
| int64_t num = 0; | |||
| int64_t channel = 0; | |||
| int64_t height = 0; | |||
| int64_t width = 0; | |||
| std::vector<int64_t> dims; | |||
| struct OutputTensorInfo { | |||
| uint32_t data_type; // data type | |||
| std::vector<int64_t> dims; // shape description | |||
| std::unique_ptr<uint8_t[]> data; // tensor data | |||
| int64_t length; // tensor length | |||
| OutputTensorInfo() : data_type(0), dims({}), data(nullptr), length(0) {} | |||
| OutputTensorInfo(OutputTensorInfo &&out) | |||
| : data_type(out.data_type), dims(out.dims), data(std::move(out.data)), length(out.length) {} | |||
| OutputTensorInfo &operator=(OutputTensorInfo &&out) { | |||
| if (this != &out) { | |||
| data_type = out.data_type; | |||
| dims = out.dims; | |||
| data = std::move(out.data); | |||
| length = out.length; | |||
| } | |||
| return *this; | |||
| } | |||
| OutputTensorInfo(const OutputTensorInfo &) = delete; | |||
| OutputTensorInfo &operator=(const OutputTensorInfo &) = delete; | |||
| }; | |||
| // Input/Output tensor info | |||
| struct TensorInfo { | |||
| uint32_t dataType; // data type | |||
| DataDesc data; // tensor data | |||
| ShapeDesc shapeInfo; // tensor shape | |||
| }; | |||
| using Status = uint32_t; | |||
| using RunAsyncCallback = std::function<void(Status, std::vector<ge::OutputTensorInfo> &)>; | |||
| // for ir build | |||
| namespace ir_option { | |||
| static const char *const INPUT_FORMAT = "input_format"; | |||
| static const char *const INPUT_SHAPE = "input_shape"; | |||
| static const char *const OP_NAME_MAP = "op_name_map"; | |||
| static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | |||
| static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | |||
| static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | |||
| @@ -235,13 +262,15 @@ static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); | |||
| static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | |||
| static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | |||
| static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | |||
| static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; | |||
| // for interface: aclgrphBuildModel | |||
| const std::set<std::string> ir_builder_suppported_options = { | |||
| INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||
| DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
| AUTO_TUNE_MODE}; | |||
| const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, | |||
| DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE}; | |||
| // for interface: aclgrphBuildInitialize | |||
| const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; | |||
| const std::set<std::string> global_options = { | |||
| HEAD_STREAM, CORE_TYPE, SOC_VERSION, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
| AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM}; | |||
| } // namespace ir_option | |||
| } // namespace ge | |||
| @@ -55,12 +55,16 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | |||
| graphStatus FindOpByName(const string &name, ge::Operator &op) const; | |||
| graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const; | |||
| graphStatus GetAllOpName(std::vector<string> &op_name) const; | |||
| graphStatus SaveToFile(const string &file_name) const; | |||
| graphStatus LoadFromFile(const string &file_name); | |||
| const std::string &GetName() const; | |||
| /// | |||
| /// Set is need train iteration. | |||
| /// If set true, it means this graph need to be run iteration some | |||
| @@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||
| static std::unique_ptr<InferenceContext> Create(); | |||
| private: | |||
| InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
| explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
| std::shared_ptr<InferenceContextImpl> inference_context_impl_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -44,11 +44,16 @@ | |||
| namespace ge { | |||
| class OperatorImpl; | |||
| class NamedAttrs; | |||
| class Graph; | |||
| class AttrValue; | |||
| using SubgraphBuilder = std::function<Graph(const std::string &name)>; | |||
| using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | |||
| class Graph; | |||
| using GraphBuilderCallback = std::function<Graph()>; | |||
| class OpIO; | |||
| using OutHandler = std::shared_ptr<OpIO>; | |||
| using InHandler = std::shared_ptr<OpIO>; | |||
| @@ -69,6 +74,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| using OpBool = bool; | |||
| using OpTensor = Tensor; | |||
| using OpType = ge::DataType; | |||
| using OpNamedAttrs = ge::NamedAttrs; | |||
| using OpListInt = std::vector<int64_t>; | |||
| using OpListFloat = std::vector<float>; | |||
| using OpListString = std::vector<string>; | |||
| @@ -77,6 +83,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| using OpBytes = std::vector<uint8_t>; | |||
| using OpListListInt = std::vector<std::vector<int64_t>>; | |||
| using OpListType = std::vector<ge::DataType>; | |||
| using OpListNamedAttrs = std::vector<ge::NamedAttrs>; | |||
| Operator() {} | |||
| @@ -132,6 +139,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| void SetInferenceContext(const InferenceContextPtr &inference_context); | |||
| InferenceContextPtr GetInferenceContext() const; | |||
| void SetGraphBuilder(const GraphBuilderCallback &builder); | |||
| graphStatus GetGraphBuilder(GraphBuilderCallback &builder) const; | |||
| void AddSubgraphName(const string &name); | |||
| string GetSubgraphName(int index) const; | |||
| graphStatus VerifyAllAttr(bool disable_common_verifier = false); | |||
| size_t GetInputsSize() const; | |||
| @@ -190,8 +203,21 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| Operator &SetAttr(const string &name, const ge::DataType &attr_value); | |||
| graphStatus GetAttr(const string &name, ge::DataType &attr_value) const; | |||
| // func type | |||
| Operator &SetAttr(const string &name, const ge::NamedAttrs &attr_value); | |||
| graphStatus GetAttr(const string &name, ge::NamedAttrs &attr_value) const; | |||
| Operator &SetAttr(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||
| graphStatus GetAttr(const string &name, std::vector<ge::NamedAttrs> &attr_value) const; | |||
| void BreakConnect() const; | |||
| size_t GetSubgraphNamesCount() const; | |||
| std::vector<std::string> GetSubgraphNames() const; | |||
| SubgraphBuilder GetSubgraphBuilder(const string &name) const; | |||
| Graph GetSubgraph(const string &name) const; | |||
| SubgraphBuilder GetDynamicSubgraphBuilder(const string &name, uint32_t index) const; | |||
| Graph GetDynamicSubgraph(const string &name, uint32_t index) const; | |||
| protected: | |||
| void AttrRegister(const string &name, float attr_value); | |||
| void AttrRegister(const string &name, const std::vector<float> &attr_value); | |||
| @@ -207,6 +233,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| void AttrRegister(const string &name, const std::vector<std::vector<int64_t>> &attr_value); | |||
| void AttrRegister(const string &name, const std::vector<ge::DataType> &attr_value); | |||
| void AttrRegister(const string &name, const ge::DataType &attr_value); | |||
| void AttrRegister(const string &name, const ge::NamedAttrs &attr_value); | |||
| void AttrRegister(const string &name, const std::vector<ge::NamedAttrs> &attr_value); | |||
| explicit Operator(OperatorImplPtr &&op_impl); | |||
| @@ -224,6 +252,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| void DynamicInputRegister(const string &name, const unsigned int num, bool is_push_back = true); | |||
| void DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index); | |||
| void DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back = true); | |||
| void RequiredAttrRegister(const string &name); | |||
| @@ -235,6 +265,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | |||
| void SubgraphRegister(const std::string &name, bool dynamic); | |||
| void SubgraphCountRegister(const std::string &name, uint32_t count); | |||
| void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder); | |||
| private: | |||
| Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | |||
| @@ -22,10 +22,11 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "./operator.h" | |||
| #include "./operator_factory.h" | |||
| #include "./tensor.h" | |||
| #include "./types.h" | |||
| #include "graph/operator.h" | |||
| #include "graph/operator_factory.h" | |||
| #include "graph/tensor.h" | |||
| #include "graph/types.h" | |||
| #include "graph/graph.h" | |||
| namespace ge { | |||
| using std::function; | |||
| @@ -46,6 +47,10 @@ class OpReg { | |||
| OpReg &OUTPUT() { return *this; } | |||
| OpReg &GRAPH() { return *this; } | |||
| OpReg &DYNAMIC_GRAPH() { return *this; } | |||
| OpReg &INFER_SHAPE_AND_TYPE() { return *this; } | |||
| }; | |||
| @@ -191,6 +196,10 @@ class OpReg { | |||
| Operator::DynamicInputRegister(#x, num, isPushBack); \ | |||
| return *this; \ | |||
| } \ | |||
| _THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \ | |||
| Operator::DynamicInputRegisterByIndex(#x, num, index); \ | |||
| return *this; \ | |||
| } \ | |||
| TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \ | |||
| graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ | |||
| return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ | |||
| @@ -229,6 +238,51 @@ class OpReg { | |||
| void __dy_output_##x() { \ | |||
| (void)OpReg() | |||
| #define GRAPH(x) \ | |||
| N(); \ | |||
| __graph_##x(); \ | |||
| } \ | |||
| \ | |||
| public: \ | |||
| static const string name_graph_##x() { return #x; } \ | |||
| SubgraphBuilder get_subgraph_builder_##x() const { return Operator::GetSubgraphBuilder(#x); } \ | |||
| _THIS_TYPE &set_subgraph_builder_##x(const SubgraphBuilder &v) { \ | |||
| Operator::SetSubgraphBuilder(#x, 0, v); \ | |||
| return *this; \ | |||
| } \ | |||
| Graph get_subgraph_##x() const { return Operator::GetSubgraph(#x); } \ | |||
| \ | |||
| private: \ | |||
| void __graph_##x() { \ | |||
| Operator::SubgraphRegister(#x, false); \ | |||
| Operator::SubgraphCountRegister(#x, 1); \ | |||
| (void)OpReg() | |||
| #define DYNAMIC_GRAPH(x) \ | |||
| N(); \ | |||
| __graph_##x(); \ | |||
| } \ | |||
| \ | |||
| public: \ | |||
| static const string name_graph_##x() { return #x; } \ | |||
| _THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \ | |||
| Operator::SubgraphCountRegister(#x, num); \ | |||
| return *this; \ | |||
| } \ | |||
| SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \ | |||
| return Operator::GetDynamicSubgraphBuilder(#x, index); \ | |||
| } \ | |||
| Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \ | |||
| _THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \ | |||
| Operator::SetSubgraphBuilder(#x, index, v); \ | |||
| return *this; \ | |||
| } \ | |||
| \ | |||
| private: \ | |||
| void __graph_##x() { \ | |||
| Operator::SubgraphRegister(#x, true); \ | |||
| (void)OpReg() | |||
| #define PASTE(g_register, y) g_register##y | |||
| #define __OP_END_IMPL__(x, y) \ | |||
| N(); \ | |||
| @@ -21,6 +21,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <utility> | |||
| #include "./ge_error_codes.h" | |||
| #include "./types.h" | |||
| @@ -62,6 +63,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY TensorDesc { | |||
| void Update(const Shape &shape, Format format = FORMAT_ND, DataType dt = DT_FLOAT); | |||
| Shape GetShape() const; | |||
| void SetShape(const Shape &shape); | |||
| // set shape with -2, it stand for unknown shape | |||
| graphStatus SetUnknownDimNumShape(); | |||
| // for unknown shape | |||
| graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||
| graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||
| Format GetFormat() const; | |||
| void SetFormat(Format format); | |||
| @@ -23,7 +23,9 @@ | |||
| namespace ge { | |||
| static const int64_t UNKNOWN_DIM = -1; | |||
| static const int64_t UNKNOWN_DIM_NUM = -2; | |||
| static const std::vector<int64_t> UNKNOWN_SHAPE = {0}; | |||
| static const std::vector<int64_t> UNKNOWN_RANK = {-2}; | |||
| #ifdef HOST_VISIBILITY | |||
| #define GE_FUNC_HOST_VISIBILITY __attribute__((visibility("default"))) | |||
| @@ -140,10 +142,19 @@ enum Format { | |||
| FORMAT_NC, | |||
| FORMAT_DHWNC, | |||
| FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | |||
| FORMAT_FRACTAL_ZN_LSTM, | |||
| FORMAT_RESERVED, | |||
| FORMAT_ALL | |||
| }; | |||
| // for unknown shape op type | |||
| enum UnknowShapeOpType { | |||
| DEPEND_IN_SHAPE = 1, // op out shape get by input shape | |||
| DEPEND_CONST_VALUE = 2, // op out shape get by const op value | |||
| DEPEND_SHAPE_RANGE = 3, // op out shape get by range | |||
| DEPEND_COMPUTE = 4 // op out shape get by totally computing | |||
| }; | |||
| struct TensorDescInfo { | |||
| Format format_ = FORMAT_RESERVED; // tbe op register support format | |||
| DataType dataType_ = DT_UNDEFINED; // tbe op register support datatype | |||
| @@ -58,12 +58,18 @@ Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | |||
| Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | |||
| std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | |||
| int in_pos = -1, int out_pos = -1); | |||
| Status AutoMappingSubgraphIndex(const ge::Graph &graph, const std::function<int(int data_index)> &input, | |||
| const std::function<int(int netoutput_index)> &output); | |||
| Status AutoMappingSubgraphIndex(const ge::Graph &graph, | |||
| const std::function<Status(int data_index, int &parent_input_index)> &input, | |||
| const std::function<Status(int netoutput_index, int &parent_output_index)> &output); | |||
| using google::protobuf::Message; | |||
| class OpRegistrationDataImpl; | |||
| using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | |||
| using FusionParseParamFunc = | |||
| std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | |||
| using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | |||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
| public: | |||
| @@ -81,6 +87,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
| OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | |||
| OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); | |||
| OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | |||
| OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | |||
| @@ -93,6 +101,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
| domi::FrameworkType GetFrameworkType() const; | |||
| ParseParamFunc GetParseParamFn() const; | |||
| FusionParseParamFunc GetFusionParseParamFn() const; | |||
| ParseSubgraphFunc GetParseSubgraphPostFn() const; | |||
| private: | |||
| std::shared_ptr<OpRegistrationDataImpl> impl_; | |||
| @@ -116,27 +125,5 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||
| namespace ge { | |||
| using OpRegistrationData = domi::OpRegistrationData; | |||
| using OpReceiver = domi::OpReceiver; | |||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||
| public: | |||
| HostCpuOp() = default; | |||
| virtual ~HostCpuOp() = default; | |||
| virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||
| std::map<std::string, Tensor> &outputs) = 0; | |||
| }; | |||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||
| public: | |||
| HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||
| }; | |||
| #define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||
| static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||
| ::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | |||
| @@ -51,24 +51,24 @@ inline pid_t GetTid() { | |||
| return tid; | |||
| } | |||
| #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = domi::GetCurrentTimestap() | |||
| #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||
| #define GE_TIMESTAMP_END(stage, stage_name) \ | |||
| do { \ | |||
| uint64_t endUsec_##stage = domi::GetCurrentTimestap(); \ | |||
| uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ | |||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||
| (endUsec_##stage - startUsec_##stage)); \ | |||
| } while (0); | |||
| #define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||
| uint64_t startUsec_##stage = domi::GetCurrentTimestap(); \ | |||
| uint64_t call_num_of##stage = 0; \ | |||
| #define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||
| uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ | |||
| uint64_t call_num_of##stage = 0; \ | |||
| uint64_t time_of##stage = 0 | |||
| #define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = domi::GetCurrentTimestap()) | |||
| #define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) | |||
| #define GE_TIMESTAMP_ADD(stage) \ | |||
| time_of##stage += domi::GetCurrentTimestap() - startUsec_##stage; \ | |||
| #define GE_TIMESTAMP_ADD(stage) \ | |||
| time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ | |||
| call_num_of##stage++ | |||
| #define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | |||
| @@ -22,7 +22,6 @@ | |||
| #include "cce/cce_def.hpp" | |||
| #include "common/string_util.h" | |||
| #include "common/util.h" | |||
| #include "dlog/log.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "ge/ge_api_error_codes.h" | |||
| @@ -30,7 +29,7 @@ using cce::CC_STATUS_SUCCESS; | |||
| using cce::ccStatus_t; | |||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||
| #define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||
| #define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||
| #else | |||
| #include <android/log.h> | |||
| #if defined(BUILD_VERSION_PERF) | |||
| @@ -103,17 +102,17 @@ using cce::ccStatus_t; | |||
| } while (0); | |||
| // If expr is not true, print the log and return the specified status | |||
| #define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| std::string msg; \ | |||
| (void)msg.append(domi::StringUtils::FormatString(__VA_ARGS__)); \ | |||
| (void)msg.append( \ | |||
| domi::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
| DOMI_LOGE("%s", msg.c_str()); \ | |||
| return _status; \ | |||
| } \ | |||
| #define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| std::string msg; \ | |||
| (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||
| (void)msg.append( \ | |||
| ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
| DOMI_LOGE("%s", msg.c_str()); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not true, print the log and return the specified status | |||
| @@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | |||
| @@ -25,6 +25,7 @@ | |||
| #include "common/fmk_error_codes.h" | |||
| #include "ge/ge_api_error_codes.h" | |||
| #include "external/graph/types.h" | |||
| #include "external/ge/ge_api_types.h" | |||
| namespace ge { | |||
| enum RuntimeType { HOST = 0, DEVICE = 1 }; | |||
| @@ -130,7 +131,8 @@ class ModelListener { | |||
| /// @param [in] data_index Index of the input_data | |||
| /// @param [in] resultCode Execution results | |||
| /// | |||
| virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code) = 0; | |||
| virtual Status OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t result_code, | |||
| std::vector<ge::OutputTensorInfo> &outputs) = 0; | |||
| }; | |||
| // OMM configuration item | |||
| @@ -147,6 +149,8 @@ struct Options { | |||
| std::string rankTableFile; | |||
| int32_t ge_hccl_flag = 0; | |||
| int32_t physical_device_id; | |||
| std::string profiling_mode; | |||
| std::string profiling_options; | |||
| }; | |||
| // Profiling info of task | |||
| @@ -20,7 +20,7 @@ | |||
| #include <gflags/gflags.h> | |||
| #include <string> | |||
| namespace domi { | |||
| namespace ge { | |||
| class GflagsUtils { | |||
| public: | |||
| static bool IsSetCommandTrue(const char *name) { | |||
| @@ -66,6 +66,6 @@ class GflagsUtils { | |||
| } | |||
| } | |||
| }; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_GFLAGS_UTIL_H_ | |||
| @@ -26,7 +26,7 @@ | |||
| #include "graph/model.h" | |||
| #include "model/ge_model.h" | |||
| namespace domi { | |||
| namespace ge { | |||
| class ModelHelper { | |||
| public: | |||
| ModelHelper() = default; | |||
| @@ -38,7 +38,7 @@ class ModelHelper { | |||
| Status LoadModel(const ge::ModelData& model_data); | |||
| Status GetModelBufferData(ge::ModelBufferData& model); | |||
| ModelFileHeader* GetFileHeader() { return file_header_; } | |||
| const ModelFileHeader* GetFileHeader() const { return file_header_; } | |||
| GeModelPtr GetGeModel(); | |||
| void SetSaveMode(bool val) { is_offline_ = val; } | |||
| @@ -65,9 +65,8 @@ class ModelHelper { | |||
| Status LoadTask(OmFileLoadHelper& om_load_helper); | |||
| Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | |||
| Status ReleaseLocalModelData() noexcept; | |||
| Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | |||
| const uint8_t* data, size_t size); | |||
| }; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ | |||
| @@ -26,8 +26,10 @@ | |||
| #include "framework/common/ge_types.h" | |||
| using ProcParam = struct PROC_PARAM; | |||
| using std::string; | |||
| using std::vector; | |||
| namespace domi { | |||
| namespace ge { | |||
| struct ModelPartition { | |||
| ModelPartitionType type; | |||
| uint8_t *data = 0; | |||
| @@ -88,5 +90,5 @@ class OmFileSaveHelper { | |||
| ModelFileHeader model_header_; | |||
| OmFileContext context_; | |||
| }; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_HELPER_OM_FILE_HELPER_H_ | |||
| @@ -30,7 +30,7 @@ | |||
| using std::vector; | |||
| namespace domi { | |||
| namespace ge { | |||
| // Size of RC memory alignment, 2M | |||
| constexpr size_t ALIGN_SIZE = 2097152; | |||
| @@ -118,6 +118,6 @@ class L2CacheOptimize { | |||
| bool Cross(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | |||
| bool Connect(const RCMemoryBlock &l_block, const RCMemoryBlock &r_block); | |||
| }; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | |||
| @@ -1,810 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||
| #define INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||
| #include <string> | |||
| #include "framework/common/fmk_types.h" | |||
| namespace domi { | |||
| // Public Attribute | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAME; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHT_NAME; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_QUANTIZE_FACTOR; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALPHA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BETA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADMODES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BIAS_TERM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PADS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PAD_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WINDOWS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GLOBAL_POOLING; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CEIL_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELU_FLAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ALGO; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FILTER_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_K; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_NORM_REGION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_LOCAL_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_ALPHA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LRN_BETA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROADCAST; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TIDX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TPADDINGS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_IMG_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NET_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TMULTIPLES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTIPLES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_T; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_N; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TSHAPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NAN_OPT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AIPP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string NEW_AIPP_CONV_OP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
| static const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||
| static const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_BATCH_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_NODE_DEF; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_OP_DEF; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INPUT_TENSOR_DESC; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUTPUT_TENSOR_DESC; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INFERRED_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_WEIGHTS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DIM_ALIGN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
| // To be deleted | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_TO_BE_DELETED; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_PROPOSAL; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_CONV_DECODEBBOX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_RESHAPE_FUSION_BOX_TYPE_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_LOC_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_CONF_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_OCR_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||
| // Refinedet | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_LOC_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_CONF_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIORBOX_CONCAT; | |||
| // _Arg | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_INDEX; | |||
| // _RetVal | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETVAL_ATTR_NAME_INDEX; | |||
| // Data | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DATA_ATTR_NAME_DATA_TYPE; | |||
| // Send | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SEND_ATTR_EVENT_ID; | |||
| // Recv | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RECV_ATTR_EVENT_ID; | |||
| // convolution | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_COEF; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STRIDES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATIONS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ALGO; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_GROUP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_PAD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_STRIDE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_DILATION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_NUM_OUTPUT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_KERNEL; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_FILTER; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BIAS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_RELU_FLAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_ADJ; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_TARGET_SHAPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_BEFORE_PAD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_ATTR_NAME_HAS_BIAS; | |||
| // Pooling | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAN_OPT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_GLOBAL_POOLING; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_WINDOW; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_PAD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_STRIDE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_CEIL_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_DATA_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_BEFORE_PAD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOLING_ATTR_NAME_ALGO; | |||
| // Eltwise | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_COEFF; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_WEIGHT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_RELU_FLAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_ALPHA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ELTWISE_ATTR_BETA; | |||
| // BatchNorm | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_EPSILON; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_USE_GLOBAL_STATS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_MOVING_AVERAGE_FRACTION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_MEAN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_BIAS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||
| // Huberloss | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HUBER_LOSS_ATTR_DELTA; | |||
| // SSDRealDivTileMul | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||
| // SSDSumMulRealDivMean | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
| SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||
| /// ConcatFive2Four | |||
| /// ConcatFour2Five | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_CLASS_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TRANS_FOR_LOSS_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOX_TYPE_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_HIGH; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_FEATURE_MAP_WIDTH; | |||
| // Scale | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SCALE_ATTR_BIAS; | |||
| // FullConnection | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_FILTER; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_BIAS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_NUM_OUTPUT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_CONNECTION_ATTR_RELU_FLAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FULL_ATTR_NAME_ALGO; | |||
| // SoftmaxOpParams | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_ALGO; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_MODE; | |||
| // SparseSoftmaxCrossEntropy | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_ATTR_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFTMAX_CROSS_ENTROPY_IS_GRAD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_CROSS_ENTROPY_LABELSMOOTHING; | |||
| // Activation | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ACTIVATION_ATTR_COEF; | |||
| // Concat | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_ATTR_NAME_AXIS; | |||
| // Const | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_DATA_TRANSTYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||
| // Roipooling | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLED_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SPATIAL_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_RIO_POOLING_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_POOLING_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIPOOLING_ATTR_NAME_SAMPLING_RATIO; | |||
| // DetectionOutput | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IMG_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BATCH_SIZE; | |||
| // Ssd DetectionOutput | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ETA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_SHARED_LOCATION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BACKGROUND_LABEL_ID; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CODE_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
| DETECTIONOUTPUT_ATTR_VARIANCE_ENCODED_IN_TARGET; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_KEEP_TOP_K; | |||
| // Refinedet DetectionOutput | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_SCORE; | |||
| // yolo DetectionOutput | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_ClASSES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_BIASES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_RELATIVE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_OBJECTNESS_THRESHOLD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_CLASS_THRESHOLD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_POST_TOP_K; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_IOU_THRESHOLD_DECAY; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_COOR_SCALE_FACTOR; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DETECTIONOUTPUT_ATTR_YOLO_VERSION; | |||
| // DetectionPostprocess | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CLS_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_CONF_THRESH; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_NMS_THRESH; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_POST_NMS_TOPN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POSTPROCESS_ATTR_NAME_BBOX_REG_WEIGHT; | |||
| // Spatialtransfrom | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_OUTPUT_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_BORDER_VALUE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPTIALTF_ATTR_NAME_AFFINE_TRANSFORM; | |||
| // Proposal | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_FEAT_STRIDE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_BASE_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_MIN_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_RATIO; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_PRE_NMS_TOPN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_POST_NMS_TOPN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_NMS_THRESH; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_NAME_TOP_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PROPOSAL_ATTR_IMG_W; | |||
| // Softmax | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SOFTMAX_ATTR_AXIS; | |||
| // Permute | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_ORDER; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PERMUTE_ATTR_PERM; | |||
| // SSD Normalize | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_CHANNEL_SHARED; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSDNORMALIZE_ATTR_EPS; | |||
| // Flatten | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_ATTR_END_AXIS; | |||
| // SsdPRIORBOX | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_FLIP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_CLIP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_IMG_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_STEP_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_OFFSET; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MIN_SIZE_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_MAX_SIZE_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
| // RefinedetPRIORBOX | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||
| // PRelu | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PRELU_ATTR_CHANNEL_SHARED; | |||
| // Psroi pooling | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_SPATIAL_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_OUTPUT_DIM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PSROIPOOLING_ATTR_GROUP_SIZE; | |||
| // Power | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_POWER; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POWER_ATTR_NAME_SHIFT; | |||
| // Log | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_SHIFT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_ATTR_NAME_BASE; | |||
| // Pack | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PACK_ATTR_NAME_NUM; | |||
| // Dynamic stitch | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
| // Unpack | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UNPACK_ATTR_NAME_NUM; | |||
| // Gathernd | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TINDICES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERND_ATTR_NAME_TPARAMS; | |||
| // Argmax | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_TOPK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||
| // Upsample | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||
| // Relu | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||
| // FreeSpaceExtract | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FREESPACEEXTRACT_ATTR_NAME_ORG_HEIGHT; | |||
| // split | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SLICE_POINT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_SIZE_SPLIT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPLIT_ATTR_NAME_NUM_SPLIT; | |||
| // Tvm | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_MAGIC; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_BLOCKDIM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TVM_ATTR_NAME_METADATA; | |||
| // Squeeze | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_ATTR_DIMS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SQUEEZE_OP_NAME; | |||
| // Stride slice | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_BEGIN_MASK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_END_MASK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_ELLIPSIS_MASK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_NEW_AXIS_MASK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK; | |||
| // Slice | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_BEGINS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SLICE_ATTR_NAME_SIZES; | |||
| // Roialign | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SPATIAL_SCALE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ROIALIGN_ATTR_NAME_POOLED_W; | |||
| // Generate_rpn_proposal | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_POST_NMS_TOPK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GENERATE_RPN_PROPOSAL_ATTR_RPN_MINI_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
| GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_NMS_THRESH; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string | |||
| GENERATE_RPN_PROPOSAL_ATTR_RPN_PROPOSAL_FILTER_THRESH; | |||
| // Decode_bbox | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DECODE_BBOX_ATTR_DECODECLIP; | |||
| // Cast | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_DSTT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CAST_ATTR_SRCT; | |||
| // Fastrcnnn predications | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_TOPK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_SCORE_THRESHOLD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NMS_THRESHOLD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FASTRCNN_PREDICTIONS_ATTR_NUM_CLASSES; | |||
| // REORG | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_STRIDE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REORG_ATTR_REVERSE; | |||
| // MERGE | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_DEAD_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MERGE_PRENODE_FLAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TO_BE_OUTPUT; | |||
| static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||
| // Concatv2 | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_TIDX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONCAT_V2_ATTR_N; | |||
| // SUM | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_TIDX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SUM_ATTR_KEEP_DIMS; | |||
| // ResizeBilinear | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALIGN_CORNERS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_HEIGHT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_WIDTH; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ZOOM_FACTOR; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_SHRINK_FACTOR; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_BEGIN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_PAD_END; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESIZE_BILINEAR_ATTR_BETA; | |||
| // RetinaNet | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RETINANET_ANCHOR_FUSION; | |||
| // MatMul | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_X; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_TRANSPOSE_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_HAS_BIAS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MATMUL_ATTR_IS_TRAINING; | |||
| // Flatten | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_START_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FLATTEN_END_AXIS; | |||
| // Reshape | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NUM_AXES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_SHAPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_ALPHA; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_BETA; | |||
| // Frameoworkop | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_IN_DATATYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string T_OUT_DATATYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_N; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_C; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_H; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OUT_W; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_DEPTH_CONV; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_PAD_CONV; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BEFORE_PAD; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ANN_MEAN_KEEPDIMS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_PADDINGDS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_ATTR_CONSTANT_VALUE; | |||
| // ConvGradFilter | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE; | |||
| // ConvGradInput | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE; | |||
| // Rnn | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_MODE_STATIC; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MUTI_RNN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CELL_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string CNN_RNN; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GRU_CELL; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_HT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_XT_HT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RNN_BATCH_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_CELL_CLIP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_PROJ_CLIP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_ACTIVATE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MAP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_OUT_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_STATE_OUT_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_TIME_MAJOR; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||
| // Upsample | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||
| // PadV2 | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PADS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_T; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||
| // MirrorPad | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PADS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||
| // Filler | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FILLER_VALUE; | |||
| // Shufflechannel | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHUFFLE_CHANNEL_GROUP; | |||
| // TopKV2 | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TOPKV2_ATTR_K; | |||
| // Calibaration | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_H_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string STRIDE_W_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_TOP_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_BOTTOM_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_RIGHT_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string PAD_LEFT_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IS_CONST; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_GROUP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_DILATION_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_EPSILON; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_POOLING_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CLASS_NUM; | |||
| // model | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TARGET_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_STREAM_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||
| // Public Attribute | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_IMPLY_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BYTE_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_INFERENCE_ID; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_OPDEF; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FUSION_SCOPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_OPATTR; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_RELUFLAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_SEQLEN_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_X_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_CONT_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_XSTATIC_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_MINI; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_TINY; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string TARGET_TYPE_LITE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_STREAM_LABEL; | |||
| // L2_normalize | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_AXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string L2_NORMALIZE_ATTR_EPS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_WINDOW; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_NAN_OP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||
| // HCOM | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_ROOT_RANK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCE_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_RANK_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_REDUCTION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_GROUP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SR_TAG; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SRC_RANK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DEST_RANK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_SHAPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HCOM_ATTR_DATA_TYPE; | |||
| // Log time stamp | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_LOGID; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LOG_TIME_STAMP_NOTIFY; | |||
| // SpaceToDepth/DepthToSpace | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_BLOCK_SIZE; | |||
| // SparseSoftmaxCrossEntropyWithLogits | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||
| // MaxPoolGradWithArgmax | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||
| // AvgPoolGrad | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||
| // Pad | |||
| extern const std::string ATTR_PAD_FORMAT; | |||
| // Varible | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_NAME; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_4D_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_5D_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DATA_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_NAME; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHAPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string HALF_VAR_NAME_END; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_CONTAINER; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SHARED_NAME; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_DTYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_ADDR_OFFSET; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_SRC_VAR_NAME; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_SAVE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_SRC_VAR_NAME; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
| // Assign | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ASSIGN_VALIDATE_SHAPE; | |||
| // ShapeN | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_N; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_IN_TYPE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SHAPEN_ATTR_OUT_TYPE; | |||
| // Space2bacth batch2space | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_BLOCK; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string BATCH_SPACE_ATTR_PADDING; | |||
| // Depth_to_space space_to_depth | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
| // FakeQuantWithMinMaxVars | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||
| // Mobilenet_ssd_conv_fusion | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||
| // Lsh project | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string LSH_PROJ_TYPE; | |||
| // Control flow | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||
| // GatherV2 attr def | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TAXIS; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TINDICES; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||
| // Reshape attr def | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||
| // Axis attr def | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_AXIS_ORG_OP; | |||
| // The node link with SparseSoftmaxCrossEntropyWithLogits | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_LINK_WITH_SPARE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
| // For constant folding | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
| } // namespace domi | |||
| #endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||
| @@ -21,11 +21,17 @@ | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include "common/op/attr_define.h" | |||
| #include "common/types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "proto/om.pb.h" | |||
| namespace domi { | |||
| using domi::AttrDef; | |||
| using domi::AttrDef_ListValue; | |||
| using domi::ModelDef; | |||
| using domi::NamedAttrs; | |||
| using domi::OpDef; | |||
| namespace ge { | |||
| using AttrDefMap = ::google::protobuf::Map<::std::string, ::domi::AttrDef>; | |||
| using AttrDefPair = ::google::protobuf::MapPair<std::string, domi::AttrDef>; | |||
| @@ -150,6 +156,6 @@ bool GetAttrDefListValue(const std::string &key, int idx, int32_t *value, const | |||
| bool GetAttrDefListValue(const std::string &key, int idx, uint32_t *value, const AttrDefMap &attr); | |||
| bool GetAttrDefListValue(const std::string &key, int idx, float *value, const AttrDefMap &attr); | |||
| bool GetAttrDefListValue(const std::string &key, int idx, double *value, const AttrDefMap &attr); | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | |||
| @@ -62,6 +62,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||
| class OpUtils { | |||
| public: | |||
| /// | |||
| @@ -22,7 +22,7 @@ | |||
| #include <math.h> | |||
| #include <stdint.h> | |||
| namespace domi { | |||
| namespace ge { | |||
| // general | |||
| const float DEFAULT_ALPHA_VALUE = 1.0; | |||
| const float DEFAULT_BETA_VALUE = 0.0; | |||
| @@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||
| // Shufflechannel | |||
| const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ | |||
| @@ -20,7 +20,7 @@ | |||
| #include <set> | |||
| #include <string> | |||
| namespace domi { | |||
| namespace ge { | |||
| class OpTypeContainer { | |||
| public: | |||
| static OpTypeContainer *Instance() { | |||
| @@ -57,6 +57,6 @@ class OpTypeRegistrar { | |||
| const OpTypeRegistrar g_##var_name##_reg(str_name); | |||
| #define IS_OPTYPE_EXISTING(str_name) (OpTypeContainer::Instance()->IsExisting(str_name)) | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_OP_TYPES_H_ | |||
| @@ -25,10 +25,10 @@ | |||
| /// MAKE_GUARD([&] { Release Resource 1 }) | |||
| /// Acquire Resource 2 | |||
| // MAKE_GUARD([&] { Release Resource 2 }) | |||
| #define GE_MAKE_GUARD(var, callback) domi::ScopeGuard make_guard_##var(callback) | |||
| #define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||
| #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | |||
| namespace domi { | |||
| namespace ge { | |||
| class ScopeGuard { | |||
| public: | |||
| // Noncopyable | |||
| @@ -55,6 +55,6 @@ class ScopeGuard { | |||
| std::function<void()> on_exit_scope_; | |||
| bool dismissed_; | |||
| }; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_SCOPE_GUARD_H_ | |||
| @@ -25,7 +25,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| namespace domi { | |||
| namespace ge { | |||
| class StringUtils { | |||
| public: | |||
| static std::string &Ltrim(std::string &s) { | |||
| @@ -151,6 +151,6 @@ class StringUtils { | |||
| return ret > 0 ? buffer : ""; | |||
| } | |||
| }; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_STRING_UTIL_H_ | |||
| @@ -26,6 +26,7 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "framework/common/fmk_error_codes.h" | |||
| #include "framework/common/fmk_types.h" | |||
| #include "framework/common/op_types.h" | |||
| @@ -46,9 +47,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_A | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_STATUS; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | |||
| } // namespace ge | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; | |||
| namespace domi { | |||
| // Supported public properties name | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_DUMP_PATH; // Dump path | |||
| @@ -68,14 +68,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFIL | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::map<std::string, std::string> PROFILE_COMPONENT_MAP; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROFILE_CONFIG; | |||
| /// @brief Data structure definition related to task sinking | |||
| /// Build model | |||
| enum BuildMode { | |||
| GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||
| GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||
| GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||
| }; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASKS; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR; | |||
| @@ -341,8 +333,9 @@ REGISTER_OPTYPE_DECLARE(END, "End"); | |||
| REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); | |||
| REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | |||
| REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | |||
| REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") | |||
| /***************ANN dedicated operator *************************/ | |||
| // ANN dedicated operator | |||
| REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | |||
| REGISTER_OPTYPE_DECLARE(ANN_CONVOLUTION, "AnnConvolution"); | |||
| REGISTER_OPTYPE_DECLARE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | |||
| @@ -359,7 +352,7 @@ REGISTER_OPTYPE_DECLARE(ANN_QUANTIZE, "AnnQuant"); | |||
| REGISTER_OPTYPE_DECLARE(ANN_PAD, "AnnPad"); | |||
| REGISTER_OPTYPE_DECLARE(ANN_RESIZE_BILINEAR, "AnnResizeBilinear"); | |||
| /********************Training operator ***********************/ | |||
| // Training operator | |||
| REGISTER_OPTYPE_DECLARE(GATHERV2, "GatherV2"); | |||
| REGISTER_OPTYPE_DECLARE(CONVGRADFILTER, "Conv2DBackpropFilter"); | |||
| REGISTER_OPTYPE_DECLARE(CONV2D, "Conv2D"); | |||
| @@ -438,11 +431,13 @@ REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | |||
| REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | |||
| REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | |||
| REGISTER_OPTYPE_DECLARE(LogTimeStamp, "LogTimeStamp"); | |||
| REGISTER_OPTYPE_DECLARE(PARALLELCONCATSTART, "_ParallelConcatStart"); | |||
| REGISTER_OPTYPE_DECLARE(CONSTANTOP, "Constant"); | |||
| REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); | |||
| REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | |||
| REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | |||
| REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | |||
| REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||
| REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||
| REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | |||
| REGISTER_OPTYPE_DECLARE(SEND, "Send"); | |||
| @@ -450,6 +445,7 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||
| REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | |||
| REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | |||
| REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); | |||
| REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | |||
| REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | |||
| @@ -828,9 +824,6 @@ static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; | |||
| // number of partitions in the current model | |||
| static constexpr uint32_t PARTITION_SIZE = 4; | |||
| #define SIZE_OF_MODEL_PARTITION_TABLE(table) \ | |||
| (sizeof(domi::ModelPartitionTable) + sizeof(domi::ModelPartitionMemInfo) * (table).num) | |||
| enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; | |||
| struct ModelPartitionMemInfo { | |||
| @@ -844,6 +837,8 @@ struct ModelPartitionTable { | |||
| ModelPartitionMemInfo partition[0]; | |||
| }; | |||
| #define SIZE_OF_MODEL_PARTITION_TABLE(table) (sizeof(ModelPartitionTable) + sizeof(ModelPartitionMemInfo) * (table).num) | |||
| static constexpr int32_t PTHREAD_CREAT_SUCCESS = 0; // pthread_creat success | |||
| // Filter format | |||
| @@ -975,8 +970,8 @@ typedef enum tagDomiNanPropagation { | |||
| // mode of cropandresize | |||
| typedef enum tagDomiCropAndResizeMode { | |||
| DOMI_RESIZE_METHOD_BILINEAR = 0, /**< resize bilinear */ | |||
| DOMI_RESIZE_METHOD_NEAREST, /**< resize nearest */ | |||
| DOMI_RESIZE_METHOD_BILINEAR = 0, // resize bilinear | |||
| DOMI_RESIZE_METHOD_NEAREST, // resize nearest | |||
| DOMI_RESIZE_RESERVED | |||
| } domiCropAndResizeMode_t; | |||
| @@ -1063,6 +1058,15 @@ struct BasicInfo { | |||
| uint32_t total_size; // total memory size | |||
| }; | |||
| #pragma pack() // Cancels single-byte alignment | |||
| } // namespace ge | |||
| namespace domi { | |||
| /// @brief Data structure definition related to task sinking | |||
| enum BuildMode { | |||
| GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||
| GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||
| GEN_TASK_WITH_FUSION = 5 // Carrying task data (with UB/L1/L2 enabled for all convergence functions) | |||
| }; | |||
| } // namespace domi | |||
| #endif // INC_FRAMEWORK_COMMON_TYPES_H_ | |||
| @@ -30,12 +30,12 @@ | |||
| #include "framework/common/ge_inner_error_codes.h" | |||
| #include "mmpa/mmpa_api.h" | |||
| #define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
| do { \ | |||
| if (size <= 0) { \ | |||
| DOMI_LOGE(param[#size] is not a positive number); \ | |||
| return PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
| do { \ | |||
| if (size <= 0) { \ | |||
| DOMI_LOGE("param[%s] is not a positive number", #size); \ | |||
| return PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | |||
| @@ -44,7 +44,7 @@ | |||
| if (!b) { \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // new ge marco | |||
| // Encapsulate common resource releases | |||
| @@ -113,101 +113,101 @@ | |||
| } while (0) | |||
| // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||
| #define GE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the parameter is null. If yes, just return and record the error | |||
| #define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return; \ | |||
| } \ | |||
| #define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
| return; \ | |||
| } \ | |||
| } while (0) | |||
| // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | |||
| #define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| exec_expr; \ | |||
| } \ | |||
| #define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } while (0) | |||
| // Check whether the parameter is null. If yes, return directly and record the error log | |||
| #define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return; \ | |||
| } \ | |||
| #define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
| return; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the parameter is null. If yes, return false and record the error log | |||
| #define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return false; \ | |||
| } \ | |||
| #define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE("param[%s] must not be null.", #val); \ | |||
| return false; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the parameter is out of bounds | |||
| #define GE_CHECK_SIZE(size) \ | |||
| do { \ | |||
| if (size == 0) { \ | |||
| DOMI_LOGE(param[#size] is out of range); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_SIZE(size) \ | |||
| do { \ | |||
| if (size == 0) { \ | |||
| DOMI_LOGE("param[%s] is out of range", #size); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the container is empty | |||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
| do { \ | |||
| if (vector.empty()) { \ | |||
| DOMI_LOGE(param[#vector] is empty !); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
| do { \ | |||
| if (vector.empty()) { \ | |||
| DOMI_LOGE("param[%s] is empty!", #vector); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the value on the left is greater than or equal to the value on the right | |||
| #define GE_CHECK_GE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs < rhs) { \ | |||
| DOMI_LOGE(param[#lhs] is less than[#rhs]); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_GE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs < rhs) { \ | |||
| DOMI_LOGE("param[%s] is less than[%s]", #lhs, #rhs); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the value on the left is less than or equal to the value on the right | |||
| #define GE_CHECK_LE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs > rhs) { \ | |||
| DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_LE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs > rhs) { \ | |||
| DOMI_LOGE("param[%s] is greater than[%s]", #lhs, #rhs); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| #define GE_DELETE_NEW_SINGLE(var) \ | |||
| { \ | |||
| do { \ | |||
| if (var != nullptr) { \ | |||
| delete var; \ | |||
| var = nullptr; \ | |||
| } \ | |||
| }; | |||
| } while (0) | |||
| #define GE_DELETE_NEW_ARRAY(var) \ | |||
| { \ | |||
| do { \ | |||
| if (var != nullptr) { \ | |||
| delete[] var; \ | |||
| var = nullptr; \ | |||
| } \ | |||
| }; | |||
| } while (0) | |||
| /** | |||
| * @ingroup domi_common | |||
| @@ -220,7 +220,7 @@ static constexpr int32_t OM_PROTO_VERSION = 2; | |||
| */ | |||
| #define CEIL(N, n) (((N) + (n)-1) / (n)) | |||
| namespace domi { | |||
| namespace ge { | |||
| using google::protobuf::Message; | |||
| /// | |||
| @@ -373,7 +373,7 @@ std::string RealPath(const char *path); | |||
| /// @param [in] file_path path of input file | |||
| /// @param [out] result | |||
| /// | |||
| bool CheckInputPathValid(const std::string &file_path); | |||
| bool CheckInputPathValid(const std::string &file_path, const std::string &atc_param = ""); | |||
| /// | |||
| /// @ingroup domi_common | |||
| @@ -381,7 +381,7 @@ bool CheckInputPathValid(const std::string &file_path); | |||
| /// @param [in] file_path path of output file | |||
| /// @param [out] result | |||
| /// | |||
| bool CheckOutputPathValid(const std::string &file_path); | |||
| bool CheckOutputPathValid(const std::string &file_path, const std::string &atc_param = ""); | |||
| /// | |||
| /// @ingroup domi_common | |||
| @@ -390,6 +390,6 @@ bool CheckOutputPathValid(const std::string &file_path); | |||
| /// @param [out] result | |||
| /// | |||
| bool ValidateStr(const std::string &filePath, const std::string &mode); | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_UTIL_H_ | |||
| @@ -47,6 +47,8 @@ class GeGenerator { | |||
| Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | |||
| Status GenerateInfershapeGraph(const Graph &graph); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief: Build single OP in Model. | |||
| @@ -33,7 +33,7 @@ class MemoryAssigner { | |||
| MemoryAssigner &operator=(const MemoryAssigner &) = delete; | |||
| Status AssignMemory(bool is_loop_graph, size_t &mem_offset); | |||
| Status AssignMemory(bool is_loop_graph, size_t &mem_offset, size_t &zero_copy_mem_size); | |||
| private: | |||
| ge::ComputeGraphPtr compute_graph_; | |||
| @@ -28,21 +28,27 @@ | |||
| #include "framework/common/types.h" | |||
| #include "register/register_fmk_types.h" | |||
| using domi::DOMI_TENSOR_ND; | |||
| using domi::DOMI_TENSOR_RESERVED; | |||
| using domi::domiTensorFormat_t; | |||
| using domi::FMK_TYPE_RESERVED; | |||
| using domi::FrameworkType; | |||
| using std::map; | |||
| using std::string; | |||
| using std::unordered_map; | |||
| using std::vector; | |||
| namespace domi { | |||
| namespace ge { | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief run model | |||
| */ | |||
| enum RunMode { | |||
| GEN_OM_MODEL = 0, // generate offline model file | |||
| MODEL_TO_JSON = 1, // convert to JSON file | |||
| ONLY_PRE_CHECK = 3, // only for pre-check | |||
| PBTXT_TO_JSON = 5 // pbtxt to json | |||
| GEN_OM_MODEL = 0, // generate offline model file | |||
| MODEL_TO_JSON = 1, // convert to JSON file | |||
| MODEL_TO_JSON_WITH_SHAPE = 2, // convert to json file with shape | |||
| ONLY_PRE_CHECK = 3, // only for pre-check | |||
| PBTXT_TO_JSON = 5 // pbtxt to json | |||
| }; | |||
| /// | |||
| @@ -93,7 +99,7 @@ struct OmgContext { | |||
| std::string ddk_version; | |||
| // preferential format used by the entire network | |||
| domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | |||
| FrameworkType type = FMK_TYPE_RESERVED; | |||
| domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | |||
| RunMode run_mode = ONLY_PRE_CHECK; | |||
| bool train_flag = false; | |||
| // whether to use FP16 high precision | |||
| @@ -102,23 +108,25 @@ struct OmgContext { | |||
| std::string output_type; | |||
| // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | |||
| // network require special processing based on the specific network. | |||
| // e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||
| // fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||
| // convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||
| // network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||
| // is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||
| // FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||
| // operator needs to be deleted for the convolution kernel. | |||
| std::string net_name; | |||
| // Whether to use dynamic batch size or dynamic image size | |||
| bool is_dynamic_input = false; | |||
| std::string dynamic_batch_size; | |||
| std::string dynamic_image_size; | |||
| }; | |||
| } // namespace ge | |||
| namespace domi { | |||
| /** | |||
| * @ingroup domi_omg | |||
| * @brief get OMG context | |||
| * @return OmgContext context | |||
| */ | |||
| OmgContext &GetContext(); | |||
| ge::OmgContext &GetContext(); | |||
| struct TEBinInfo { | |||
| // It is obsolete. It will be automatically obtained from the binfilename field of the JSON file later. | |||
| @@ -26,7 +26,7 @@ | |||
| #include "common/string_util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| namespace domi { | |||
| namespace ge { | |||
| class PlatformVersionManager { | |||
| public: | |||
| PlatformVersionManager() = delete; | |||
| @@ -40,6 +40,6 @@ class PlatformVersionManager { | |||
| return SUCCESS; | |||
| } | |||
| }; // class PlatformManager | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_OMG_VERSION_H_ | |||
| @@ -86,16 +86,16 @@ class _GeSerializable { | |||
| } | |||
| template <class T, class... Args> | |||
| static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { | |||
| static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||
| GeAttrValue itemVal = SaveItemAsAttrValue(item); | |||
| (void)namedAttrs.SetAttr(itemName, itemVal); | |||
| SaveItem(namedAttrs, args...); | |||
| } | |||
| static void SaveItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) {} | |||
| static void SaveItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) {} | |||
| template <class T, class... Args> | |||
| static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs, string itemName, T &item, Args &... args) { | |||
| static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs, string itemName, T &item, Args &... args) { | |||
| auto itemVal = namedAttrs.GetItem(itemName); | |||
| auto status = LoadItemFromAttrValue(item, itemVal); | |||
| if (status != GRAPH_SUCCESS) { | |||
| @@ -104,7 +104,9 @@ class _GeSerializable { | |||
| return LoadItem(namedAttrs, args...); | |||
| } | |||
| static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | |||
| static graphStatus LoadItem(GeAttrValue::NAMED_ATTRS &namedAttrs __attribute__((__unused__))) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| }; | |||
| #define _GE_FI(a) #a, a | |||
| @@ -171,13 +173,13 @@ class _GeSerializable { | |||
| \ | |||
| private: \ | |||
| ge::graphStatus Save(GeAttrValue &ar) const { \ | |||
| GeAttrValue::NamedAttrs named_attrs; \ | |||
| GeAttrValue::NAMED_ATTRS named_attrs; \ | |||
| _GeSerializable::SaveItem(named_attrs, _GE_INVOKE_VAR_MACRO(__VA_ARGS__)); \ | |||
| return ar.SetValue<GeAttrValue::NamedAttrs>(named_attrs); \ | |||
| return ar.SetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||
| } \ | |||
| ge::graphStatus Load(const GeAttrValue &ar) { \ | |||
| GeAttrValue::NamedAttrs named_attrs; \ | |||
| ge::graphStatus status = ar.GetValue<GeAttrValue::NamedAttrs>(named_attrs); \ | |||
| GeAttrValue::NAMED_ATTRS named_attrs; \ | |||
| ge::graphStatus status = ar.GetValue<GeAttrValue::NAMED_ATTRS>(named_attrs); \ | |||
| if (status != GRAPH_SUCCESS) { \ | |||
| return status; \ | |||
| } \ | |||
| @@ -83,6 +83,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| // AddNode with NodePtr | |||
| NodePtr AddNode(NodePtr node); | |||
| NodePtr AddNode(OpDescPtr op); | |||
| NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize. | |||
| NodePtr AddNodeFront(NodePtr node); | |||
| NodePtr AddNodeFront(const OpDescPtr &op); | |||
| NodePtr AddInputNode(NodePtr node); | |||
| @@ -236,8 +237,9 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| std::deque<NodePtr> &stack); | |||
| graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
| std::map<string, NodePtr> &breadth_node_map); | |||
| graphStatus TopologicalSortingSubgraph(); | |||
| graphStatus TopologicalSortingGraph(); | |||
| graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | |||
| Vistor<NodePtr> AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const; | |||
| size_t GetInEdgeSize(const NodePtr &node); | |||
| size_t GetOutEdgeSize(const NodePtr &node); | |||
| graphStatus RemoveExtraOutEdge(const NodePtr &node); | |||
| @@ -32,6 +32,12 @@ namespace ge { | |||
| #define GE_FUNC_DEV_VISIBILITY | |||
| #endif | |||
| // Public attribute | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TYPE; | |||
| @@ -58,6 +64,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | |||
| @@ -74,8 +82,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | |||
| // GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||
| // ATTR_NAME_WEIGHTS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | |||
| @@ -123,6 +130,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | |||
| @@ -140,12 +154,24 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
| // to be deleted | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; | |||
| @@ -158,15 +184,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | |||
| // _Arg | |||
| @@ -255,7 +281,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||
| // Huberloss | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||
| // SSDRealDivTileMul | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||
| // SSDSumMulRealDivMean | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||
| /// ConcatFive2Four | |||
| /// ConcatFour2Five | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||
| // Scale | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | |||
| @@ -292,7 +340,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
| // Roipooling | |||
| @@ -305,6 +352,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||
| // DetectionOutput | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||
| @@ -363,6 +411,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||
| // Permute | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||
| // SSD Normalize | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||
| @@ -403,9 +452,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | |||
| // Log | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||
| // Pack | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | |||
| // Dynamic stitch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
| // Unpack | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | |||
| // Gathernd | |||
| @@ -414,8 +469,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||
| // Argmax | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||
| // Upsample | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||
| // Relu | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||
| @@ -486,6 +549,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | |||
| static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||
| // ENTER | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | |||
| @@ -511,6 +575,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | |||
| // RetinaNet | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||
| // MatMul | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | |||
| @@ -559,10 +626,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||
| // Upsample | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||
| // PadV2 | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||
| // MirrorPad | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||
| // Filler | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | |||
| @@ -583,36 +670,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | |||
| @@ -627,24 +684,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_HUGE_STREAM_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
| // Public attribute | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | |||
| @@ -678,6 +731,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TARGET_T | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_OUTPUT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | |||
| @@ -696,6 +751,161 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
| // L2_normalize | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||
| // HCOM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||
| // Log time stamp | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||
| // SpaceToDepth/DepthToSpace | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||
| // SparseSoftmaxCrossEntropyWithLogits | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||
| // MaxPoolGradWithArgmax | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||
| // AvgPoolGrad | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||
| // Varible | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
| // Assign | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||
| // ShapeN | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||
| // Space2bacth batch2space | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||
| // Depth_to_space space_to_depth | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
| // FakeQuantWithMinMaxVars | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||
| // Mobilenet_ssd_conv_fusion | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||
| // Lsh project | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||
| // Control flow | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||
| // GatherV2 attr def | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||
| // Reshape attr def | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||
| // Axis attr def | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||
| // The node link with SparseSoftmaxCrossEntropyWithLogits | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
| // For constant folding | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
| // Used for mark the active label list to find stream of activated node | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | |||
| @@ -708,7 +918,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| // Control flow | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | |||
| @@ -722,6 +931,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| // Function Op | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_CONST_TYPE; | |||
| // Used for mark the active node is for loop, type:bool | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | |||
| @@ -752,6 +962,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEE | |||
| // For mutil-batch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; | |||
| // For inserted op | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | |||
| @@ -772,6 +983,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| // used for l1 fusion and other fusion in future | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||
| @@ -782,10 +994,44 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; | |||
| // functional ops attr | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | |||
| // used for label switch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | |||
| // Varible | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
| // HCOM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DATATYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_DATATYPE; | |||
| // used for LX tiling | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_L1_SPACE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_TYPE_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST; | |||
| // Dynamic stitch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | |||
| @@ -22,7 +22,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/anchor.h" | |||
| #include "detail/attributes_holder.h" | |||
| #include "graph/detail/attributes_holder.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/node.h" | |||
| @@ -77,6 +77,8 @@ class ModelSerializeImp { | |||
| void SetProtobufOwner(const ProtoMsgOwner &bufferProtobufOnwer) { protobuf_owner_ = bufferProtobufOnwer; } | |||
| private: | |||
| bool RebuildOwnership(ComputeGraphPtr &compute_graph, std::map<std::string, ComputeGraphPtr> &subgraphs); | |||
| std::vector<NodeNameGraphReq> graph_input_node_names_; | |||
| std::vector<NodeNameGraphReq> graph_output_node_names_; | |||
| std::vector<NodeNameNodeReq> node_input_node_names_; | |||
| @@ -43,30 +43,31 @@ using ComputeGraphPtr = std::shared_ptr<ComputeGraph>; | |||
| using ConstComputeGraphPtr = std::shared_ptr<const ComputeGraph>; | |||
| class GeTensorDesc; | |||
| class GeAttrValue; | |||
| class GeAttrValueImp; | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NamedAttrs : public AttrHolder { | |||
| public: | |||
| class NamedAttrs : public AttrHolder { | |||
| public: | |||
| NamedAttrs(); | |||
| virtual ~NamedAttrs() = default; | |||
| void SetName(const std::string &name); | |||
| string GetName() const; | |||
| GeAttrValue GetItem(const string &key) const; | |||
| protected: | |||
| ProtoAttrMapHelper MutableAttrMap() override; | |||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||
| private: | |||
| // Create namedAttrs from protobuf obj | |||
| NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); | |||
| GeIrProtoHelper<proto::NamedAttrs> named_attrs_; | |||
| friend class GeAttrValueImp; | |||
| }; | |||
| NamedAttrs(); | |||
| virtual ~NamedAttrs() = default; | |||
| void SetName(const std::string &name); | |||
| string GetName() const; | |||
| GeAttrValue GetItem(const string &key) const; | |||
| protected: | |||
| ProtoAttrMapHelper MutableAttrMap() override; | |||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||
| private: | |||
| // Create namedAttrs from protobuf obj | |||
| NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *protoMsg); | |||
| GeIrProtoHelper<proto::NamedAttrs> named_attrs_; | |||
| friend class GeAttrValueImp; | |||
| friend class GeAttrValue; | |||
| }; | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
| public: | |||
| using INT = int64_t; | |||
| using FLOAT = float; | |||
| using BOOL = bool; | |||
| @@ -75,7 +76,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
| using TENSOR_DESC = GeTensorDesc; | |||
| using GRAPH = ComputeGraphPtr; | |||
| using BYTES = Buffer; | |||
| using NAMED_ATTRS = NamedAttrs; | |||
| using NAMED_ATTRS = ge::NamedAttrs; | |||
| using DATA_TYPE = ge::DataType; | |||
| using LIST_INT = vector<INT>; | |||
| @@ -90,6 +91,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
| using LIST_LIST_INT = vector<vector<int64_t>>; | |||
| using LIST_DATA_TYPE = vector<ge::DataType>; | |||
| using NamedAttrs = ge::NamedAttrs; // for cce use (ge::GeAttrValue::NamedAttrs). | |||
| enum ValueType { | |||
| VT_NONE = 0, | |||
| VT_STRING, | |||
| @@ -87,6 +87,12 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||
| GeShape &MutableShape(); | |||
| void SetShape(GeShape shape); | |||
| // set shape with -2, it stand for unknown shape | |||
| void SetUnknownDimNumShape(); | |||
| // for unknown shape | |||
| graphStatus SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range); | |||
| graphStatus GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const; | |||
| GeShape GetOriginShape() const; | |||
| void SetOriginShape(const GeShape &originShape); | |||
| @@ -25,11 +25,7 @@ | |||
| #include "graph/ge_attr_value.h" | |||
| #include "graph/graph.h" | |||
| namespace domi { | |||
| class ModelHelper; | |||
| } | |||
| namespace ge { | |||
| using domi::ModelHelper; | |||
| using std::map; | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -50,6 +50,8 @@ class GeAttrValue; | |||
| using ConstOpDesc = const OpDesc; | |||
| enum SubgraphType { kStatic, kDynamic, kSubgraphTypeEnd }; | |||
| class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| public: | |||
| template <class T> | |||
| @@ -83,6 +85,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| graphStatus AddInputDescForward(const string &name, const unsigned int num); | |||
| graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); | |||
| graphStatus AddOutputDescForward(const string &name, const unsigned int num); | |||
| graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); | |||
| @@ -141,6 +145,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
| graphStatus AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index); | |||
| graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
| bool IsOptionalInput(const string &name) const; | |||
| @@ -214,6 +220,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| void SetIsInputConst(const vector<bool> &is_input_const); | |||
| vector<bool> GetIsInputConst() const; | |||
| void SetOpInferDepends(const vector<string> &depend_names); | |||
| vector<string> GetOpInferDepends() const; | |||
| string GetInputNameByIndex(uint32_t index) const; | |||
| int GetInputIndexByName(const string &name) const; | |||
| @@ -236,12 +245,23 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| std::string GetOpEngineName() const; | |||
| void RegisterSubgraphIrName(const std::string &name, SubgraphType type); | |||
| const std::map<std::string, SubgraphType> &GetSubgraphIrNames() const; | |||
| SubgraphType GetSubgraphTypeByIrName(const std::string &name) const; | |||
| graphStatus AddSubgraphName(const std::string &name); | |||
| const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | |||
| std::string GetSubgraphInstanceName(uint32_t index) const; | |||
| const std::vector<std::string> &GetSubgraphInstanceNames() const; | |||
| void AddSubgraphInstanceName(std::string name); | |||
| /// Does not provide functions `AddSubgraphInstance` or `AppendSubgraphInstance`, | |||
| /// because this kind of functions will only append a new subgraph instance name | |||
| /// at the tail of `subgraph_instance_names_` and ignore the synchronous change of `subgraph_names_to_index_`. | |||
| /// If we want to append a new subgraph instance name, the function `AddSubgraphName` should be called first. | |||
| /// \param index | |||
| /// \param name | |||
| /// \return | |||
| graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); | |||
| void RemoveSubgraphInstanceName(const std::string &name); | |||
| protected: | |||
| @@ -256,7 +276,23 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| GeIrProtoHelper<ge::proto::OpDef> op_def_; | |||
| std::vector<std::string> subgraph_instance_names_; | |||
| // subgraph names to index, for a `if` operator: | |||
| // then_branch: 0 | |||
| // else_branch: 1 | |||
| // or for a `case` node: | |||
| // branches0: 0 | |||
| // branches1: 1 | |||
| // branches2: 2 | |||
| std::map<std::string, uint32_t> subgraph_names_to_index_; | |||
| // subgraph ir names to type, for a `if` operator: | |||
| // then_branch: static | |||
| // else_branch: dynamic | |||
| // or for a `case` op: | |||
| // branches: dynamic | |||
| std::map<std::string, SubgraphType> subgraph_ir_names_to_type_; | |||
| vector<GeTensorDescPtr> inputs_desc_{}; | |||
| vector<GeTensorDescPtr> outputs_desc_{}; | |||
| map<string, uint32_t> output_name_idx_{}; | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef COMMON_GRAPH_REF_RELATION_H_ | |||
| #define COMMON_GRAPH_REF_RELATION_H_ | |||
| #include <deque> | |||
| #include <string> | |||
| #include <unordered_map> | |||
| #include <vector> | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/types.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "node.h" | |||
| namespace ge { | |||
| enum InOutFlag { | |||
| NODE_IN = 0, // input flag | |||
| NODE_OUT = 1, // output flag | |||
| }; | |||
| struct RefCell { | |||
| std::string node_name; | |||
| ge::NodePtr node = nullptr; | |||
| InOutFlag in_out = NODE_IN; | |||
| int in_out_idx = 0; | |||
| bool operator==(const RefCell &c) const { | |||
| return node_name == c.node_name && node == c.node && in_out == c.in_out && in_out_idx == c.in_out_idx; | |||
| } | |||
| RefCell() = default; | |||
| RefCell(std::string name, ge::NodePtr node_ptr, InOutFlag in_out_flag, int idx) { | |||
| node_name = name; | |||
| node = node_ptr; | |||
| in_out = in_out_flag; | |||
| in_out_idx = idx; | |||
| }; | |||
| ~RefCell() = default; | |||
| }; | |||
| struct RefCellHash { | |||
| size_t operator()(const RefCell &c) const { | |||
| unsigned long number = reinterpret_cast<unsigned long>(reinterpret_cast<uintptr_t>(c.node.get())); | |||
| string tmp = c.node_name + std::to_string(c.in_out) + std::to_string(c.in_out_idx) + std::to_string(number); | |||
| return std::hash<string>()(tmp); | |||
| } | |||
| }; | |||
| class RefRelations { | |||
| public: | |||
| graphStatus LookUpRefRelations(const RefCell &key, std::unordered_set<RefCell, RefCellHash> &result); | |||
| graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||
| graphStatus Clear(); | |||
| RefRelations(); | |||
| ~RefRelations() = default; | |||
| public: | |||
| class Impl; | |||
| std::shared_ptr<Impl> impl_ = nullptr; | |||
| }; | |||
| } // namespace ge | |||
| #endif // COMMON_GRAPH_REF_RELATION_H_ | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
| #define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
| #ifndef INC_GRAPH_USR_TYPES_H_ | |||
| #define INC_GRAPH_USR_TYPES_H_ | |||
| #include <atomic> | |||
| #include <memory> | |||
| @@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||
| #undef USR_TYPE_BYTES_DEC | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
| #endif // INC_GRAPH_USR_TYPES_H_ | |||
| @@ -62,9 +62,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
| static bool SetListGraph(AttrHolderAdapter &&obj, const string &name, const vector<ComputeGraphPtr> &value); | |||
| static bool SetBytes(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::BYTES &value); | |||
| static bool SetListBytes(AttrHolderAdapter &&obj, const string &name, const vector<GeAttrValue::BYTES> &value); | |||
| static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NamedAttrs &value); | |||
| static bool SetNamedAttrs(AttrHolderAdapter &&obj, const string &name, const GeAttrValue::NAMED_ATTRS &value); | |||
| static bool SetListNamedAttrs(AttrHolderAdapter &&obj, const string &name, | |||
| const vector<GeAttrValue::NamedAttrs> &value); | |||
| const vector<GeAttrValue::NAMED_ATTRS> &value); | |||
| static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<ConstOpDescPtr> &value); | |||
| static bool SetListOpDesc(AttrHolderAdapter &&obj, const string &name, const vector<OpDescPtr> &value); | |||
| @@ -91,9 +91,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
| static bool GetListGraph(ConstAttrHolderAdapter &&obj, const string &name, vector<ComputeGraphPtr> &value); | |||
| static bool GetBytes(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::BYTES &value); | |||
| static bool GetListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<GeAttrValue::BYTES> &value); | |||
| static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NamedAttrs &value); | |||
| static bool GetNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, GeAttrValue::NAMED_ATTRS &value); | |||
| static bool GetListNamedAttrs(ConstAttrHolderAdapter &&obj, const string &name, | |||
| vector<GeAttrValue::NamedAttrs> &value); | |||
| vector<GeAttrValue::NAMED_ATTRS> &value); | |||
| static bool GetListOpDesc(ConstAttrHolderAdapter &&obj, const string &name, vector<OpDescPtr> &value); | |||
| // Value will be moved | |||
| static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | |||
| @@ -95,12 +95,35 @@ | |||
| }; | |||
| namespace ge { | |||
| enum IOType { kIn, kOut }; | |||
| struct NodeIndexIO { | |||
| NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type) | |||
| : node(std::move(node)), index(index), io_type(io_type) {} | |||
| NodeIndexIO(ge::NodePtr node, int index, IOType io_type) | |||
| : node(std::move(node)), index(static_cast<uint32_t>(index)), io_type(io_type) {} | |||
| ~NodeIndexIO() {} | |||
| NodePtr node = nullptr; | |||
| uint32_t index = 0; | |||
| IOType io_type = kOut; | |||
| std::string ToString() const { | |||
| if ((node == nullptr) || (node->GetOwnerComputeGraph() == nullptr)) { | |||
| return ""; | |||
| } | |||
| return node->GetName() + (io_type == kOut ? "_out_" : "_in_") + std::to_string(index); | |||
| } | |||
| }; | |||
| class GraphUtils { | |||
| public: | |||
| static ComputeGraphPtr GetComputeGraph(const Graph &graph); | |||
| static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph); | |||
| static graphStatus RecoverGraphOperators(const Graph &graph); | |||
| static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs); | |||
| static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||
| @@ -262,6 +285,108 @@ class GraphUtils { | |||
| static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | |||
| static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||
| static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||
| /// | |||
| /// Get reference-mapping of all data_anchors in graph | |||
| /// @param [in] graph | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus GetRefMapping(const ComputeGraphPtr &graph, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| private: | |||
| /// | |||
| /// Get reference-mapping for in_data_anchors of node | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus HandleInAnchorMapping(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| /// | |||
| /// Get reference-mapping for out_data_anchors of node | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus HandleOutAnchorMapping(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| /// | |||
| /// Handle input of subgraph | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus HandleSubgraphInput(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| /// | |||
| /// Handle input of Merge op | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus HandleMergeInput(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| /// | |||
| /// Handle output of subgraph | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus HandleSubgraphOutput(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| /// | |||
| /// Union ref-mapping | |||
| /// @param [in] exist_node_info1 | |||
| /// @param [in] exist_node_info2 | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @param [out] symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol); | |||
| /// | |||
| /// Update symbol mapping with a new reference pair | |||
| /// @param [in] cur_node_info | |||
| /// @param [in] exist_node_info | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| /// | |||
| /// Check if out_data_anchor is reference of input | |||
| /// @param [in] out_data_anchor | |||
| /// @param [out] reuse_in_index | |||
| /// @return bool | |||
| /// | |||
| static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index); | |||
| }; | |||
| class ComputeGraphBuilder { | |||
| @@ -441,12 +566,12 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||
| private: | |||
| /// | |||
| /// @brief Build inputs | |||
| /// @brief Add data nodes | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildInputs(graphStatus &error_code, std::string &error_msg); | |||
| void AddDataNodes(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Add data node | |||
| @@ -455,41 +580,15 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||
| NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Build outputs | |||
| /// @brief Add RetVal nodes | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildOutputs(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Add NetOutput node | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return NodePtr | |||
| /// | |||
| NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Add input/output tensor for NetOutput node | |||
| /// @param [in] out_nodes_info | |||
| /// @param [out] net_output_desc | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
| OpDescPtr &net_output_desc); | |||
| /// | |||
| /// @brief Add edge for NetOutput node | |||
| /// @param [in] out_nodes_info | |||
| /// @param [out] net_output_node | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
| const NodePtr &net_output_node); | |||
| void AddRetValNodes(graphStatus &error_code, std::string &error_msg); | |||
| std::string name_; | |||
| NodePtr parent_node_; | |||
| @@ -55,11 +55,44 @@ class NodeUtils { | |||
| static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); | |||
| static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
| static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
| // check node whether unknown shape.If node shape contain -1 or -2,out param "is_unknow" will be true; | |||
| // for func op, it will check subgraph yet, if some node shape of subgraph contain -1 or -2, | |||
| // the out param "is_unknow" will be true too | |||
| static graphStatus GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow); | |||
| static std::string GetNodeType(const Node &node); | |||
| static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | |||
| static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); | |||
| static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); | |||
| /// | |||
| /// Check if node is input of subgraph | |||
| /// @param [in] node | |||
| /// @return bool | |||
| /// | |||
| static bool IsSubgraphInput(const NodePtr &node); | |||
| /// | |||
| /// Check if node is output of subgraph | |||
| /// @param [in] node | |||
| /// @return bool | |||
| /// | |||
| static bool IsSubgraphOutput(const NodePtr &node); | |||
| /// | |||
| /// @brief Get subgraph original input node. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| static NodePtr GetParentInput(const NodePtr &node); | |||
| /// | |||
| /// @brief Get subgraph input is constant. | |||
| /// @param [in] node | |||
| /// @param [out] string | |||
| /// @return bool | |||
| /// | |||
| static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); | |||
| private: | |||
| static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | |||
| @@ -81,6 +81,9 @@ class OpDescUtils { | |||
| static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | |||
| static graphStatus SetSubgraphInstanceName(const std::string& subgraph_name, | |||
| const std::string& subgraph_instance_name, OpDescPtr& op_desc); | |||
| private: | |||
| static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | |||
| static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | |||
| @@ -104,6 +107,14 @@ class OpDescBuilder { | |||
| /// | |||
| OpDescBuilder& AddInput(const std::string& name); | |||
| /// | |||
| /// @brief Add input | |||
| /// @param [in] name | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddInput(const std::string& name, const GeTensorDesc& tensor); | |||
| /// | |||
| /// @brief Add dynamic input | |||
| /// @param [in] name | |||
| @@ -112,6 +123,15 @@ class OpDescBuilder { | |||
| /// | |||
| OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | |||
| /// | |||
| /// @brief Add dynamic input | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||
| /// | |||
| /// @brief Add output | |||
| /// @param [in] name | |||
| @@ -119,6 +139,14 @@ class OpDescBuilder { | |||
| /// | |||
| OpDescBuilder& AddOutput(const std::string& name); | |||
| /// | |||
| /// @brief Add output | |||
| /// @param [in] name | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddOutput(const std::string& name, const GeTensorDesc& tensor); | |||
| /// | |||
| /// @brief Add dynamic output | |||
| /// @param [in] name | |||
| @@ -127,6 +155,15 @@ class OpDescBuilder { | |||
| /// | |||
| OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | |||
| /// | |||
| /// @brief Add dynamic output | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num, const GeTensorDesc& tensor); | |||
| /// | |||
| /// @brief Build op_desc | |||
| /// @return OpDescPtr | |||
| @@ -136,8 +173,8 @@ class OpDescBuilder { | |||
| private: | |||
| std::string name_; | |||
| std::string type_; | |||
| std::vector<std::string> inputs_; | |||
| std::vector<std::string> outputs_; | |||
| std::vector<std::pair<std::string, GeTensorDesc>> inputs_; | |||
| std::vector<std::pair<std::string, GeTensorDesc>> outputs_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -34,13 +34,12 @@ ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
| ge_protobuf_generate(ge PROTO_ONNX_SRCS PROTO_ONNX_HDRS ${ONNX_PROTO_LIST}) | |||
| # need to remove dependencies on pb files later | |||
| file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "*.cc" | |||
| "utils/*.cc" | |||
| "opsproto/*.cc" | |||
| "detail/*.cc" | |||
| "debug/*.cc" | |||
| "op_imp.cc" | |||
| "option/*.cc" | |||
| ) | |||
| @@ -53,7 +53,6 @@ void Anchor::UnlinkAll() noexcept { | |||
| if (Unlink(peer_anchor_ptr) != GRAPH_SUCCESS) { | |||
| GELOGW("unlink peer_anchor_ptr failed."); | |||
| } | |||
| } while (!peer_anchors_.empty()); | |||
| } | |||
| } | |||
| @@ -42,8 +42,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const | |||
| : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | |||
| attrs_.InitDefault(); | |||
| } | |||
| ComputeGraph::~ComputeGraph() {} | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() const { return name_; } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { | |||
| @@ -53,24 +56,50 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS | |||
| } | |||
| return s; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | |||
| vector<NodePtr> all_nodes(nodes_.size()); | |||
| (void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); | |||
| for (const auto &sub_graph : sub_graph_) { | |||
| if (sub_graph == nullptr) { | |||
| GELOGW("sub graph is nullptr"); | |||
| if (sub_graph_.empty()) { | |||
| return Vistor<NodePtr>(shared_from_this(), nodes_); | |||
| } | |||
| std::vector<std::shared_ptr<ComputeGraph>> subgraphs; | |||
| return AllGraphNodes(subgraphs); | |||
| } | |||
| ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::shared_ptr<ComputeGraph>> &subgraphs) const { | |||
| std::vector<NodePtr> all_nodes; | |||
| std::deque<NodePtr> candidates; | |||
| candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); | |||
| while (!candidates.empty()) { | |||
| NodePtr node = candidates.front(); | |||
| all_nodes.emplace_back(node); | |||
| candidates.pop_front(); | |||
| OpDescPtr op_desc = node->GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| continue; | |||
| } | |||
| for (const auto &node : sub_graph->GetAllNodes()) { | |||
| all_nodes.push_back(node); | |||
| const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
| for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { | |||
| auto subgraph = GetSubgraph(*name_iter); | |||
| if (subgraph != nullptr) { | |||
| subgraphs.emplace_back(subgraph); | |||
| candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); | |||
| } | |||
| } | |||
| } | |||
| return Vistor<NodePtr>(shared_from_this(), all_nodes); | |||
| } | |||
| size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const { | |||
| return Vistor<NodePtr>(shared_from_this(), nodes_); | |||
| } | |||
| ComputeGraph::Vistor<NodePtr> ComputeGraph::GetInputNodes() const { | |||
| return Vistor<NodePtr>(shared_from_this(), input_nodes_); | |||
| } | |||
| @@ -82,6 +111,7 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::GetOutputNodes() const { | |||
| } | |||
| return Vistor<NodePtr>(shared_from_this(), result); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(const std::string &name) const { | |||
| for (const auto &node : nodes_) { | |||
| if (node == nullptr) { | |||
| @@ -203,10 +233,6 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) { | |||
| return nullptr; | |||
| } | |||
| node->GetOpDesc()->SetId(nodes_.size()); | |||
| if (nodes_[0] == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "nodes_ size or nodes_[0] is nullptr"); | |||
| return nullptr; | |||
| } | |||
| if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { | |||
| (void)nodes_.insert(nodes_.begin() + 1, node); | |||
| } else { | |||
| @@ -248,6 +274,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpD | |||
| GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | |||
| return AddNode(node_ptr); | |||
| } | |||
| NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. | |||
| if (op == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); | |||
| return nullptr; | |||
| } | |||
| op->SetId(id); | |||
| NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this())); | |||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); | |||
| GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | |||
| nodes_.push_back(node); | |||
| return node; | |||
| } | |||
| NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | |||
| @@ -259,6 +299,7 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||
| } | |||
| return node; | |||
| } | |||
| NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | |||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null."); | |||
| @@ -336,6 +377,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveN | |||
| } | |||
| return GRAPH_FAILED; | |||
| } | |||
| // Used in sub_graph scenes | |||
| graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { | |||
| if (node == nullptr) { | |||
| @@ -372,20 +414,24 @@ graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { | |||
| GE_IF_BOOL_EXEC(find_node == false, return GRAPH_FAILED); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) { | |||
| if (sub_graph == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | |||
| return nullptr; | |||
| } | |||
| sub_graph_.push_back(sub_graph); | |||
| names_to_subgraph_[sub_graph->GetName()] = sub_graph; | |||
| return sub_graph; | |||
| } | |||
| graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) { | |||
| if (sub_graph == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| names_to_subgraph_.erase(sub_graph->GetName()); | |||
| auto iter = find(sub_graph_.begin(), sub_graph_.end(), sub_graph); | |||
| if (iter != sub_graph_.end()) { | |||
| (void)sub_graph_.erase(iter); | |||
| @@ -462,8 +508,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | |||
| const std::string &name) const { | |||
| auto iter = names_to_subgraph_.find(name); | |||
| return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||
| std::shared_ptr<ComputeGraph> parent = parent_graph_.lock(); | |||
| if (parent == nullptr) { | |||
| auto iter = names_to_subgraph_.find(name); | |||
| return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||
| } else { | |||
| return parent->GetSubgraph(name); | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | |||
| @@ -495,7 +546,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode( | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | |||
| for (auto &input : input_nodes_) { | |||
| size_t update_num = 0; | |||
| for (auto &input : nodes_) { | |||
| if (update_num >= input_mapping.size()) { | |||
| break; | |||
| } | |||
| uint32_t cur_index = 0; | |||
| if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||
| continue; | |||
| @@ -508,6 +563,7 @@ ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mappi | |||
| GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| update_num++; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| @@ -520,9 +576,9 @@ ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mappi | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | |||
| NodePtr net_output = FindNode(kNodeNameNetOutput); | |||
| NodePtr net_output = FindNode(NODE_NAME_NET_OUTPUT); | |||
| if (net_output == nullptr) { | |||
| GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); | |||
| GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", NODE_NAME_NET_OUTPUT); | |||
| return GRAPH_FAILED; | |||
| } | |||
| OpDescPtr op_desc = net_output->GetOpDesc(); | |||
| @@ -557,13 +613,13 @@ ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_map | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | |||
| std::vector<NodePtr> node_vec = nodes_; | |||
| for (const auto &node : GetAllNodes()) { | |||
| for (const auto &node : GetDirectNode()) { | |||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
| GELOGW("node or OpDescPtr is nullptr."); | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED); | |||
| if (node->GetOpDesc()->GetType() == kRecvType) { | |||
| if (node->GetOpDesc()->GetType() == RECV) { | |||
| auto iter = find(node_vec.begin(), node_vec.end(), node); | |||
| if (iter == node_vec.end()) { | |||
| GELOGW("no node found."); | |||
| @@ -574,7 +630,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||
| auto dst_iter = find(node_vec.begin(), node_vec.end(), node->GetOutControlNodes().at(0)); | |||
| (void)node_vec.insert(dst_iter, node); | |||
| } | |||
| if (node->GetOpDesc()->GetType() == kSendType) { | |||
| if (node->GetOpDesc()->GetType() == SEND) { | |||
| auto iter = find(node_vec.begin(), node_vec.end(), node); | |||
| if (iter == node_vec.end()) { | |||
| GELOGW("no node found."); | |||
| @@ -602,7 +658,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||
| graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||
| std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
| std::vector<NodePtr> &stack) { | |||
| GELOGI("Runing_Dfs_Sort"); | |||
| GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); | |||
| // Record the number of non data nodes but no input nodes | |||
| GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | |||
| @@ -647,7 +703,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||
| graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||
| std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
| std::deque<NodePtr> &stack) { | |||
| GELOGI("Runing_Bfs_Sort"); | |||
| GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); | |||
| std::vector<NodePtr> stack_input; | |||
| std::map<string, NodePtr> breadth_node_map; | |||
| // Record the number of non data nodes but no input nodes | |||
| @@ -708,23 +764,36 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | |||
| auto ret = TopologicalSortingSubgraph(); | |||
| auto ret = TopologicalSortingGraph(); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Sub graph partition Failed"); | |||
| return ret; | |||
| } | |||
| if (sub_graph_.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| // partition sub graph | |||
| for (const auto &sub_graph : GetAllSubgraphs()) { | |||
| ret = sub_graph->TopologicalSortingSubgraph(); | |||
| for (const auto &sub_graph : sub_graph_) { | |||
| ret = sub_graph->TopologicalSortingGraph(); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Sub graph topological sort Failed"); | |||
| return ret; | |||
| } | |||
| } | |||
| std::vector<std::shared_ptr<ComputeGraph>> subgraphs; | |||
| (void)AllGraphNodes(subgraphs); | |||
| if (sub_graph_.size() != subgraphs.size()) { // Graph Partition use subgraph, Keep original | |||
| GELOGW("Keep original subgraph for graph size %zu not equal %zu.", sub_graph_.size(), subgraphs.size()); | |||
| return SUCCESS; | |||
| } | |||
| sub_graph_.swap(subgraphs); | |||
| return SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { | |||
| graphStatus ComputeGraph::TopologicalSortingGraph() { | |||
| std::vector<NodePtr> node_vec; | |||
| std::map<NodePtr, uint32_t> map_in_edge_num; | |||
| bool use_BFS = false; | |||
| @@ -735,7 +804,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||
| use_BFS = true; | |||
| } | |||
| } else { | |||
| GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); | |||
| GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||
| } | |||
| if (use_BFS) { | |||
| @@ -793,8 +862,8 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
| map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | |||
| if (map_in_edge_num[node] == 0) { | |||
| if ((node->GetOpDesc()->GetType() != kDataType) && (node->GetOpDesc()->GetType() != kAippDataType) && | |||
| (node->GetOpDesc()->GetType() != kInputType) && (node->GetOpDesc()->GetType() != kAnnDataType)) { | |||
| if ((node->GetOpDesc()->GetType() != DATA) && (node->GetOpDesc()->GetType() != AIPPDATA) && | |||
| (node->GetOpDesc()->GetType() != INPUT_TYPE) && (node->GetOpDesc()->GetType() != ANN_DATA)) { | |||
| // At present, can only judge the isolated point without input and output. | |||
| // It is impossible to judge the situation with multiple output nodes. | |||
| if (verify_isolated && GetOutEdgeSize(node) == 0) { | |||
| @@ -832,6 +901,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||
| size_t in_edge_size = 0; | |||
| if (node == nullptr) { | |||
| @@ -884,6 +954,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::IsValid() const { return is_valid_flag_; } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||
| GELOGI("graph name = %s.", GetName().c_str()); | |||
| for (const auto &node : GetAllNodes()) { | |||
| @@ -915,6 +986,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||
| } | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| auto next_nodes = node->GetOutAllNodes(); | |||
| @@ -954,6 +1026,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate | |||
| } | |||
| } | |||
| } | |||
| // If there is an input control side | |||
| auto in_ctrl_anchor = node->GetInControlAnchor(); | |||
| GE_CHECK_NOTNULL(in_ctrl_anchor); | |||
| @@ -991,6 +1064,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Isolate | |||
| return RemoveExtraOutEdge(node); | |||
| } | |||
| graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| // Remove redundant output edges | |||
| @@ -1041,7 +1115,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InferSh | |||
| node_ptr->GetName().c_str()); | |||
| graphStatus status = node_ptr->InferShapeAndType(); | |||
| GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == kDataType || GRAPH_PARAM_INVALID != status, break, | |||
| GE_CHK_BOOL_EXEC_INFO(node_ptr->GetType() == DATA || GRAPH_PARAM_INVALID != status, break, | |||
| "Op %s does not have the IMPLEMT_INFERFUNC definition," | |||
| " and subsequent operators no longer perform shape inference.", | |||
| node_ptr->GetName().c_str()); | |||
| @@ -16,237 +16,41 @@ | |||
| #ifndef COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||
| #define COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||
| #include <limits.h> | |||
| #include <stdint.h> | |||
| #include <algorithm> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| namespace ge { | |||
| #define GE_REGISTER_OPTYPE(var_name, str_name) static const char* var_name __attribute__((unused)) = str_name | |||
| #define GE_REGISTER_OPTYPE(var_name, str_name) static const char *var_name __attribute__((unused)) = str_name | |||
| GE_REGISTER_OPTYPE(DATA, "Data"); | |||
| GE_REGISTER_OPTYPE(AIPPDATA, "AippData"); | |||
| GE_REGISTER_OPTYPE(CONVOLUTION, "Convolution"); | |||
| GE_REGISTER_OPTYPE(CORRELATION, "Correlation"); | |||
| GE_REGISTER_OPTYPE(CORRELATIONV2, "Correlation_V2"); | |||
| GE_REGISTER_OPTYPE(DECONVOLUTION, "Deconvolution"); | |||
| GE_REGISTER_OPTYPE(POOLING, "Pooling"); | |||
| GE_REGISTER_OPTYPE(ELTWISE, "Eltwise"); | |||
| GE_REGISTER_OPTYPE(RELU, "ReLU"); | |||
| GE_REGISTER_OPTYPE(RELU6, "ReLU6"); | |||
| GE_REGISTER_OPTYPE(SIGMOID, "Sigmoid"); | |||
| GE_REGISTER_OPTYPE(ABSVAL, "AbsVal"); | |||
| GE_REGISTER_OPTYPE(TANH, "TanH"); | |||
| GE_REGISTER_OPTYPE(PRELU, "PReLU"); | |||
| GE_REGISTER_OPTYPE(BATCHNORM, "BatchNorm"); | |||
| GE_REGISTER_OPTYPE(FUSIONBATCHNORM, "FusionBatchNorm"); | |||
| GE_REGISTER_OPTYPE(SCALE, "Scale"); | |||
| GE_REGISTER_OPTYPE(FULL_CONNECTION, "FullConnection"); | |||
| GE_REGISTER_OPTYPE(SOFTMAX, "Softmax"); | |||
| GE_REGISTER_OPTYPE(PLUS, "Plus"); | |||
| GE_REGISTER_OPTYPE(ACTIVATION, "Activation"); | |||
| GE_REGISTER_OPTYPE(FLATTEN, "Flatten"); | |||
| GE_REGISTER_OPTYPE(ADD, "Add"); | |||
| GE_REGISTER_OPTYPE(SUB, "Sub"); | |||
| GE_REGISTER_OPTYPE(MUL, "Mul"); | |||
| GE_REGISTER_OPTYPE(MATMUL, "MatMul"); | |||
| GE_REGISTER_OPTYPE(RSQRT, "Rsqrt"); | |||
| GE_REGISTER_OPTYPE(BIASADD, "BiasAdd"); | |||
| GE_REGISTER_OPTYPE(RESHAPE, "Reshape"); | |||
| GE_REGISTER_OPTYPE(DEPCONVOLUTION, "ConvolutionDepthwise"); | |||
| GE_REGISTER_OPTYPE(DROPOUT, "Dropout"); | |||
| GE_REGISTER_OPTYPE(CONCAT, "Concat"); | |||
| GE_REGISTER_OPTYPE(ROIPOOLING, "ROIPooling"); | |||
| GE_REGISTER_OPTYPE(PROPOSAL, "Proposal"); | |||
| GE_REGISTER_OPTYPE(FSRDETECTIONOUTPUT, "FSRDetectionOutput"); | |||
| GE_REGISTER_OPTYPE(DETECTIONPOSTPROCESS, "Detectpostprocess"); | |||
| GE_REGISTER_OPTYPE(LRN, "LRN"); | |||
| GE_REGISTER_OPTYPE(TRANSDATA, "TransData"); | |||
| GE_REGISTER_OPTYPE(PERMUTE, "Permute"); | |||
| GE_REGISTER_OPTYPE(SSDNORMALIZE, "SSDNormalize"); | |||
| GE_REGISTER_OPTYPE(SSDPRIORBOX, "SSDPriorBox"); | |||
| GE_REGISTER_OPTYPE(NETOUTPUT, "NetOutput"); | |||
| GE_REGISTER_OPTYPE(SSDDETECTIONOUTPUT, "SSDDetectionOutput"); | |||
| GE_REGISTER_OPTYPE(CHANNELAXPY, "ChannelAxpy"); | |||
| GE_REGISTER_OPTYPE(PSROIPOOLING, "PSROIPooling"); | |||
| GE_REGISTER_OPTYPE(POWER, "Power"); | |||
| GE_REGISTER_OPTYPE(ROIALIGN, "ROIAlign"); | |||
| GE_REGISTER_OPTYPE(PYTHON, "Python"); | |||
| GE_REGISTER_OPTYPE(FREESPACEEXTRACT, "FreespaceExtract"); | |||
| GE_REGISTER_OPTYPE(SPATIALTF, "SpatialTransform"); | |||
| GE_REGISTER_OPTYPE(SHAPE, "Shape"); | |||
| GE_REGISTER_OPTYPE(ARGMAX, "ArgMax"); | |||
| GE_REGISTER_OPTYPE(GATHERND, "GatherNd"); | |||
| GE_REGISTER_OPTYPE(GATHER, "Gather"); | |||
| GE_REGISTER_OPTYPE(REALDIV, "RealDiv"); | |||
| GE_REGISTER_OPTYPE(PACK, "Pack"); | |||
| GE_REGISTER_OPTYPE(SLICE, "Slice"); | |||
| GE_REGISTER_OPTYPE(FLOORDIV, "FloorDiv"); | |||
| GE_REGISTER_OPTYPE(_WHILE, "_While"); | |||
| GE_REGISTER_OPTYPE(WHILE, "While"); | |||
| GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); | |||
| GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); | |||
| GE_REGISTER_OPTYPE(STRIDEDSLICE, "StridedSlice"); | |||
| GE_REGISTER_OPTYPE(RANGE, "Range"); | |||
| GE_REGISTER_OPTYPE(RPNPROPOSALS, "GenerateRpnProposals"); | |||
| GE_REGISTER_OPTYPE(DECODEBBOX, "DecodeBBox"); | |||
| GE_REGISTER_OPTYPE(PAD, "Pad"); | |||
| GE_REGISTER_OPTYPE(TILE, "Tile"); | |||
| GE_REGISTER_OPTYPE(SIZE, "Size"); | |||
| GE_REGISTER_OPTYPE(CLIPBOXES, "Clipboxes"); | |||
| GE_REGISTER_OPTYPE(FASTRCNNPREDICTIONS, "FastrcnnPredictions"); | |||
| GE_REGISTER_OPTYPE(SPLIT, "Split"); | |||
| GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | |||
| GE_REGISTER_OPTYPE(MEAN, "Mean"); | |||
| GE_REGISTER_OPTYPE(GREATER, "Greater"); | |||
| GE_REGISTER_OPTYPE(SWITCH, "Switch"); | |||
| GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); | |||
| GE_REGISTER_OPTYPE(MERGE, "Merge"); | |||
| GE_REGISTER_OPTYPE(REFMERGE, "RefMerge"); | |||
| GE_REGISTER_OPTYPE(ENTER, "Enter"); | |||
| GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); | |||
| GE_REGISTER_OPTYPE(LOOPCOND, "LoopCond"); | |||
| GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | |||
| GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | |||
| GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | |||
| GE_REGISTER_OPTYPE(EXIT, "Exit"); | |||
| GE_REGISTER_OPTYPE(REFEXIT, "RefExit"); | |||
| GE_REGISTER_OPTYPE(CONTROLTRIGGER, "ControlTrigger"); | |||
| GE_REGISTER_OPTYPE(TRANSPOSE, "Transpose"); | |||
| GE_REGISTER_OPTYPE(CAST, "Cast"); | |||
| GE_REGISTER_OPTYPE(REGION, "Region"); | |||
| GE_REGISTER_OPTYPE(YOLO, "Yolo"); | |||
| GE_REGISTER_OPTYPE(YOLODETECTIONOUTPUT, "YoloDetectionOutput"); | |||
| GE_REGISTER_OPTYPE(FILL, "Fill"); | |||
| GE_REGISTER_OPTYPE(REVERSE, "Reverse"); | |||
| GE_REGISTER_OPTYPE(UNPACK, "Unpack"); | |||
| GE_REGISTER_OPTYPE(YOLO2REORG, "Yolo2Reorg"); | |||
| GE_REGISTER_OPTYPE(REDUCESUM, "ReduceSum"); | |||
| GE_REGISTER_OPTYPE(CONSTANT, "Const"); | |||
| GE_REGISTER_OPTYPE(RESIZEBILINEAR, "ResizeBilinear"); | |||
| GE_REGISTER_OPTYPE(MAXIMUM, "Maximum"); | |||
| GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | |||
| GE_REGISTER_OPTYPE(ARG, "_Arg"); | |||
| GE_REGISTER_OPTYPE(FUSEDBATCHNORMGRAD, "FusedBatchNormGrad"); | |||
| GE_REGISTER_OPTYPE(LSTM, "LSTM"); | |||
| GE_REGISTER_OPTYPE(HIGHWAY, "HighWay"); | |||
| GE_REGISTER_OPTYPE(RNN, "RNN"); | |||
| GE_REGISTER_OPTYPE(ATTENTIONDECODER, "AttentionDecoder"); | |||
| GE_REGISTER_OPTYPE(LOGICAL_NOT, "LogicalNot"); | |||
| GE_REGISTER_OPTYPE(LOGICAL_AND, "LogicalAnd"); | |||
| GE_REGISTER_OPTYPE(EQUAL, "Equal"); | |||
| GE_REGISTER_OPTYPE(INTERP, "Interp"); | |||
| GE_REGISTER_OPTYPE(SHUFFLECHANNEL, "ShuffleChannel"); | |||
| GE_REGISTER_OPTYPE(AIPP, "Aipp"); | |||
| GE_REGISTER_OPTYPE(CROPANDRESIZE, "CropAndResize"); | |||
| GE_REGISTER_OPTYPE(UNUSEDCONST, "UnusedConst"); | |||
| GE_REGISTER_OPTYPE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); | |||
| GE_REGISTER_OPTYPE(BROADCASTARGS, "BroadcastArgs"); | |||
| GE_REGISTER_OPTYPE(STOPGRADIENT, "StopGradient"); | |||
| GE_REGISTER_OPTYPE(PPREVENTGRADIENT, "PreventGradient"); | |||
| GE_REGISTER_OPTYPE(GUARANTEECONST, "GuaranteeConst"); | |||
| GE_REGISTER_OPTYPE(SPARSETODENSE, "SparseToDense"); | |||
| GE_REGISTER_OPTYPE(NONMAXSUPPRESSION, "NonMaxSuppression"); | |||
| GE_REGISTER_OPTYPE(TOPKV2, "TopKV2"); | |||
| GE_REGISTER_OPTYPE(INVERTPERMUTATION, "InvertPermutation"); | |||
| GE_REGISTER_OPTYPE(MULTINOMIAL, "Multinomial"); | |||
| GE_REGISTER_OPTYPE(REVERSESEQUENCE, "ReverseSequence"); | |||
| GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | |||
| GE_REGISTER_OPTYPE(INITDATA, "InitData"); | |||
| // ANN specific operator | |||
| GE_REGISTER_OPTYPE(ANN_MEAN, "AnnMean"); | |||
| GE_REGISTER_OPTYPE(ANN_CONVOLUTION, "AnnConvolution"); | |||
| GE_REGISTER_OPTYPE(ANN_DEPCONVOLUTION, "AnnDepthConv"); | |||
| GE_REGISTER_OPTYPE(DIV, "Div"); | |||
| GE_REGISTER_OPTYPE(ANN_FULLCONNECTION, "AnnFullConnection"); | |||
| GE_REGISTER_OPTYPE(ANN_NETOUTPUT, "AnnNetOutput"); | |||
| GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | |||
| // Training operator | |||
| GE_REGISTER_OPTYPE(CONVGRADFILTER, "Conv2DBackpropFilter"); | |||
| GE_REGISTER_OPTYPE(CONV2D, "Conv2D"); | |||
| GE_REGISTER_OPTYPE(CONV2DBACKPROPINPUT, "Conv2DBackpropInput"); | |||
| GE_REGISTER_OPTYPE(ACTIVATIONGRAD, "ReluGrad"); | |||
| GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | |||
| GE_REGISTER_OPTYPE(AVGPOOLGRAD, "AvgPoolGrad"); | |||
| GE_REGISTER_OPTYPE(SQUARE, "Square"); | |||
| GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | |||
| GE_REGISTER_OPTYPE(END, "End"); | |||
| GE_REGISTER_OPTYPE(VARIABLE, "Variable"); | |||
| GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); | |||
| /// @ingroup domi_omg | |||
| /// @brief INPUT node type | |||
| static const char* const kInputType = "Input"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief AIPP tag, tag for aipp conv operator | |||
| /// | |||
| static const char* const kAippConvFlag = "Aipp_Conv_Flag"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief AIPP tag, tag for aipp data operator | |||
| /// | |||
| static const char* const kAippDataFlag = "Aipp_Data_Flag"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief AIPP tag, tag for aipp data operator | |||
| /// | |||
| static const char* const kAippDataType = "AippData"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief DATA node type | |||
| /// | |||
| static const char* const kDataType = "Data"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Frame operator type | |||
| /// | |||
| static const char* const kFrameworkOpType = "FrameworkOp"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Data node type | |||
| /// | |||
| static const char* const kAnnDataType = "AnnData"; | |||
| static const char* const kAnnNetoutputType = "AnnNetOutput"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Convolution node type | |||
| /// | |||
| static const char* const kNodeNameNetOutput = "Node_Output"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief RECV node type | |||
| /// | |||
| static const char* const kRecvType = "Recv"; | |||
| GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief SEND node type | |||
| /// | |||
| static const char* const kSendType = "Send"; | |||
| GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Convolution node type | |||
| /// | |||
| static const char* const kOpTypeConvolution = "Convolution"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Add convolution node name to hard AIPP | |||
| /// | |||
| static const char* const kAippConvOpNmae = "aipp_conv_op"; | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Operator configuration item separator | |||
| /// | |||
| static const char* const kOpConfDelimiter = ":"; | |||
| GE_REGISTER_OPTYPE(RECV, "Recv"); | |||
| GE_REGISTER_OPTYPE(SEND, "Send"); | |||
| }; // namespace ge | |||
| #endif // COMMON_GRAPH_DEBUG_GE_OP_TYPES_H_ | |||
| @@ -15,11 +15,14 @@ | |||
| */ | |||
| #include "format_refiner.h" | |||
| #include <deque> | |||
| #include <iostream> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include "graph/ref_relation.h" | |||
| #include "./compute_graph.h" | |||
| #include "./ge_error_codes.h" | |||
| #include "./graph/ge_tensor.h" | |||
| @@ -34,14 +37,41 @@ | |||
| #include "utils/tensor_utils.h" | |||
| #include "utils/type_utils.h" | |||
| using namespace ge; | |||
| using namespace std; | |||
| namespace ge { | |||
| namespace { | |||
| static const std::unordered_set<string> kChangeDimNodes = {RESHAPE, PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
| static bool net_format_is_nd = true; | |||
| static Format g_user_set_format = FORMAT_ND; | |||
| static bool is_first_infer = true; | |||
| static RefRelations reflection_builder; | |||
| } // namespace | |||
| graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | |||
| std::deque<ge::NodePtr> &nodes, ge::Format to_be_set_format) { | |||
| for (const auto &cell : reflection) { | |||
| auto node = cell.node; | |||
| auto in_out_idx = cell.in_out_idx; | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| if (cell.in_out == ge::NODE_IN) { | |||
| auto desc = node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(in_out_idx)); | |||
| desc.SetOriginFormat(to_be_set_format); | |||
| desc.SetFormat(to_be_set_format); | |||
| (void)node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||
| } else { | |||
| auto desc = node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(in_out_idx)); | |||
| desc.SetOriginFormat(to_be_set_format); | |||
| desc.SetFormat(to_be_set_format); | |||
| (void)node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(in_out_idx), desc); | |||
| } | |||
| nodes.push_back(cell.node); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { | |||
| @@ -66,7 +96,6 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| anchor_points.clear(); | |||
| // Get all anchor point nodes and switch nodes | |||
| for (const auto &node_ptr : graph->GetAllNodes()) { | |||
| std::vector<bool> is_node_set_format; | |||
| if (node_ptr == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -86,7 +115,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| for (uint32_t i = 0; i < input_size; i++) { | |||
| // Operator pre-set format but not origin format | |||
| auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | |||
| // Pre-save data node and default infer fail | |||
| // Pre-save data node (only main graph data) and default infer fail | |||
| if (node_ptr->GetType() == DATA) { | |||
| data_nodes.push_back(node_ptr); | |||
| } | |||
| @@ -163,6 +192,16 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
| } | |||
| // Check format whether have been set | |||
| int idx = peer_out_data_anchor->GetIdx(); | |||
| // do peer_out_node name and index as key to lookup reflections | |||
| ge::RefCell key(peer_out_data_node->GetName(), peer_out_data_node, ge::NODE_OUT, idx); | |||
| std::unordered_set<RefCell, RefCellHash> reflection; | |||
| auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d out edge", | |||
| (peer_out_data_node->GetName()).c_str(), idx); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | |||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
| @@ -181,18 +220,26 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
| continue; | |||
| } | |||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
| ge_tensor_desc.SetFormat(to_be_set_format); | |||
| (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
| if (reflection.empty()) { | |||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
| ge_tensor_desc.SetFormat(to_be_set_format); | |||
| (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
| // Call operator infer format api (forward) to get out format | |||
| GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||
| graphStatus status = peer_out_data_node->InferOriginFormat(); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); | |||
| return GRAPH_FAILED; | |||
| // Call operator infer format api (forward) to get out format | |||
| GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||
| status = peer_out_data_node->InferOriginFormat(); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_out_data_node->GetName()).c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| nodes.push_back(peer_out_data_node); | |||
| } else { | |||
| auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| nodes.push_back(peer_out_data_node); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| @@ -213,17 +260,23 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
| continue; | |||
| } | |||
| for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
| if (peer_in_data_anchor == nullptr) { | |||
| GELOGW("Node[%s] some peer_in_anchor is null", (node->GetName()).c_str()); | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(peer_in_data_anchor == nullptr, continue); | |||
| auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); | |||
| if (peer_in_data_node == nullptr || peer_in_data_node->GetOpDesc() == nullptr) { | |||
| GELOGW("Node[%s] peer_in_data_node or peer_in_data_node desc is null", node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(peer_in_data_node == nullptr, continue); | |||
| GE_IF_BOOL_EXEC(peer_in_data_node->GetOpDesc() == nullptr, continue); | |||
| // Check format whether have been set | |||
| int idx = peer_in_data_anchor->GetIdx(); | |||
| // do peer_out_node name and index as key to lookup reflections | |||
| ge::RefCell key(peer_in_data_node->GetName(), peer_in_data_node, ge::NODE_IN, idx); | |||
| std::unordered_set<RefCell, RefCellHash> reflection; | |||
| auto status = reflection_builder.LookUpRefRelations(key, reflection); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "LookUpRefRelations failed!Node is [%s],the %d input edge", | |||
| (peer_in_data_node->GetName()).c_str(), idx); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | |||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
| @@ -240,24 +293,33 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
| GELOGD("Node[%s] is change dim node. do not infer origin format", (peer_in_data_node->GetName()).c_str()); | |||
| continue; | |||
| } | |||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
| ge_tensor_desc.SetFormat(to_be_set_format); | |||
| (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(idx, ge_tensor_desc); | |||
| /// Because netoutput node added before infer format ,so netoutput is end condition | |||
| /// must set netoutput format , because saved result depend on format | |||
| if (peer_in_data_node_type == NETOUTPUT) { | |||
| continue; | |||
| } | |||
| if (reflection.empty()) { | |||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
| ge_tensor_desc.SetFormat(to_be_set_format); | |||
| (void)peer_in_data_node->GetOpDesc()->UpdateInputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
| // Call operator infer format api (forward) to get out format | |||
| GELOGD("call infer format func[Forward]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); | |||
| graphStatus status = peer_in_data_node->InferOriginFormat(); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); | |||
| return GRAPH_FAILED; | |||
| /// Because netoutput node added before infer format ,so netoutput is end condition | |||
| /// must set netoutput format , because saved result depend on format | |||
| if (peer_in_data_node_type == NETOUTPUT) { | |||
| continue; | |||
| } | |||
| // Call operator infer format api (forward) to get out format | |||
| GELOGD("call infer format func[Back]!Node is [%s] ", (peer_in_data_node->GetName()).c_str()); | |||
| status = peer_in_data_node->InferOriginFormat(); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Node[%s] infer format failed", (peer_in_data_node->GetName()).c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| nodes.push_back(peer_in_data_node); | |||
| } else { | |||
| auto status = ReflectionProcess(reflection, nodes, to_be_set_format); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "reflection process failed!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| nodes.push_back(peer_in_data_node); | |||
| } | |||
| } | |||
| } | |||
| @@ -355,8 +417,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
| GELOGE(GRAPH_FAILED, "input graph is null"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| // build reflection relations of boundary | |||
| (void)reflection_builder.Clear(); | |||
| auto status = reflection_builder.BuildRefRelations(*graph); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "build reflection relations failed for main and subgraph!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| // User set global net format | |||
| graphStatus status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); | |||
| status = GetAnchorPoints(graph, anchor_points, data_nodes, node_status); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "GetAnchorPoints Process Faild!"); | |||
| return GRAPH_FAILED; | |||
| @@ -18,6 +18,12 @@ | |||
| namespace ge { | |||
| // Public attribute | |||
| const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; | |||
| const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; | |||
| const std::string ATTR_NAME_UNKNOWN_SHAPE_TYPE = "_unknown_shape_type"; | |||
| const std::string ATTR_NAME_NAME = "name"; | |||
| const std::string ATTR_NAME_TYPE = "type"; | |||
| @@ -42,6 +48,8 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||
| const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | |||
| const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||
| const std::string ATTR_NAME_PAD = "pad"; | |||
| const std::string ATTR_NAME_PADS = "pad"; | |||
| @@ -83,6 +91,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||
| const std::string ATTR_NAME_AXIS = "axis"; | |||
| const std::string ATTR_NAME_BROADCAST = "broadcast"; | |||
| const std::string ATTR_NAME_OUTPUT = "output"; | |||
| const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | |||
| const std::string ATTR_NAME_TIDX = "t_idx"; | |||
| @@ -103,6 +112,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||
| const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | |||
| const std::string ATTR_NAME_AIPP = "aipp"; | |||
| const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||
| const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
| const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||
| const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||
| const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||
| const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | |||
| const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | |||
| @@ -111,6 +127,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||
| const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | |||
| const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | |||
| const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | |||
| const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
| const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | |||
| const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||
| @@ -122,9 +139,12 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||
| const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | |||
| const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | |||
| const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | |||
| const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
| const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
| const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | |||
| const std::string ATTR_NAME_CONTINUOUS_STREAM_LABEL = "_continuous_stream_label"; | |||
| const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | |||
| const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||
| const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||
| const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||
| // To be deleted | |||
| const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | |||
| @@ -138,15 +158,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||
| const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
| const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
| const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
| // Refinedet | |||
| const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | |||
| const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
| const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | |||
| const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
| const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
| const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
| // _Arg | |||
| const std::string ATTR_NAME_INDEX = "index"; | |||
| @@ -236,6 +254,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||
| const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | |||
| const std::string BATCHNORM_ATTR_SCALE = "scale"; | |||
| const std::string BATCHNORM_ATTR_BIAS = "bias"; | |||
| const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||
| const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||
| const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||
| // huberloss | |||
| const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||
| // SSDRealDivTileMul | |||
| const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||
| // SSDSumMulRealDivMean | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||
| // ConcatFive2Four | |||
| // ConcatFour2Five | |||
| const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||
| const std::string SSD_CLASS_NUM = "class_num"; | |||
| const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||
| const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||
| const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||
| const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||
| // Scale | |||
| const std::string SCALE_ATTR_SCALE = "scale"; | |||
| @@ -340,6 +382,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||
| // Permute | |||
| const std::string PERMUTE_ATTR_ORDER = "order"; | |||
| const std::string PERMUTE_ATTR_PERM = "perm"; | |||
| // SSD Normalize | |||
| const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | |||
| @@ -367,6 +410,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||
| const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
| const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
| // RefinedetDetectionOutput | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
| // PRelu | |||
| const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | |||
| @@ -380,11 +427,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||
| const std::string POWER_ATTR_NAME_SCALE = "scale"; | |||
| const std::string POWER_ATTR_NAME_SHIFT = "shift"; | |||
| // log | |||
| const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||
| const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||
| const std::string LOG_ATTR_NAME_BASE = "base"; | |||
| // Pack | |||
| const std::string PACK_ATTR_NAME_NUM = "N"; | |||
| // Unpack | |||
| const std::string UNPACK_ATTR_NAME_NUM = "num"; | |||
| const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||
| // Gathernd | |||
| const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | |||
| const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | |||
| @@ -394,6 +446,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||
| const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | |||
| const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | |||
| const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | |||
| const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||
| const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||
| const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||
| // upsample | |||
| const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||
| const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||
| // Relu | |||
| const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | |||
| @@ -531,19 +590,41 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||
| const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | |||
| // Rnn | |||
| const std::string RNN_MODE_ = "rnn_"; | |||
| const std::string CNN_RNN = "cnn_rnn"; | |||
| const std::string RNN_MODE_STATIC = "rnn_static"; | |||
| const std::string MUTI_RNN = "multi_rnn"; | |||
| const std::string CNN_RNN = "cnn_rnn"; | |||
| const std::string RNN_MODE_ = "rnn_"; | |||
| const std::string CELL_MODE = "mode"; | |||
| const std::string LSTM_CELL = "lstm_cell"; | |||
| const std::string GRU_CELL = "gru_cell"; | |||
| const std::string RNN_HT = "ht"; | |||
| const std::string RNN_XT_HT = "xt_ht"; | |||
| const std::string RNN_BATCH_SIZE = "batch_size"; | |||
| const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||
| const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||
| const std::string LSTM_ACTIVATE = "lstm_activate"; | |||
| const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||
| const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||
| const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||
| const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||
| const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||
| // Upsample | |||
| const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | |||
| // PadV2 | |||
| const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||
| const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||
| const std::string PADV2_ATTR_NAME_T = "T"; | |||
| const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
| const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||
| // MirrorPad | |||
| const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||
| const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||
| const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
| const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||
| // Filler | |||
| const std::string FILLER_TYPE = "filler_type"; | |||
| const std::string FILLER_VALUE = "filler_value"; | |||
| @@ -554,9 +635,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||
| // TopKV2 | |||
| const std::string TOPKV2_ATTR_K = "k"; | |||
| const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
| const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
| // Calibaration | |||
| const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | |||
| const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | |||
| @@ -611,10 +689,14 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; | |||
| const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | |||
| const std::string ATTR_MODEL_HUGE_STREAM_LIST = "huge_stream_list"; | |||
| const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | |||
| const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | |||
| const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; | |||
| const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | |||
| const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | |||
| @@ -660,8 +742,125 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||
| const std::string TARGET_TYPE_LITE = "LITE"; | |||
| // l2_normalize | |||
| const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||
| const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
| const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||
| const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||
| const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||
| const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||
| const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||
| const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||
| // HCOM | |||
| const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||
| const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||
| const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||
| const std::string HCOM_ATTR_GROUP = "group"; | |||
| const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||
| const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||
| const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||
| const std::string HCOM_ATTR_FUSION = "fusion"; | |||
| const std::string HCOM_ATTR_SHAPE = "shape"; | |||
| const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||
| // SpaceToDepth/DepthToSpace | |||
| const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||
| // SparseSoftmaxCrossEntropyWithLogits | |||
| const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||
| // MaxPoolGradWithArgmax | |||
| const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||
| // AvgPoolGrad | |||
| const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||
| // Pad | |||
| const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||
| // Varible | |||
| const std::string VAR_ATTR_FORMAT = "_var_format"; | |||
| const std::string VAR_ATTR_NAME = "var_name"; | |||
| const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||
| const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||
| const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||
| const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||
| const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||
| const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||
| const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||
| const std::string VAR_ATTR_SHAPE = "shape"; | |||
| const std::string HALF_VAR_NAME_END = "_fp16"; | |||
| const std::string VAR_ATTR_INITED = "var_is_inited"; | |||
| const std::string VAR_ATTR_CONTAINER = "container"; | |||
| const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||
| const std::string VAR_ATTR_DTYPE = "dtype"; | |||
| const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||
| const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||
| const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||
| const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||
| const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||
| const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||
| // Assign | |||
| const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||
| // space2bacth batch2space | |||
| const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||
| const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||
| // depth_to_space space_to_depth | |||
| const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
| // FakeQuantWithMinMaxVars | |||
| const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||
| const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||
| // mobilenet_ssd_conv_fusion | |||
| const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||
| const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||
| const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||
| // lsh project | |||
| const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||
| // log time stamp | |||
| const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||
| const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||
| // ShapeN | |||
| const std::string SHAPEN_ATTR_N = "N"; | |||
| const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||
| const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||
| // GatherV2 attr def | |||
| const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||
| const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||
| const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||
| // Reshape attr def | |||
| const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||
| const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||
| // axis attr def | |||
| const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||
| const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||
| const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||
| const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||
| // For constant folding | |||
| const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||
| const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | |||
| const std::string ATTR_NAME_CONTINUOUS_INPUT_ALLOC = "continuous_input_alloc"; | |||
| const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | |||
| const std::string ATTR_NAME_REFERENCE = "reference"; | |||
| @@ -694,6 +893,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||
| const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | |||
| const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | |||
| const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | |||
| const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||
| const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||
| const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | |||
| const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | |||
| @@ -705,6 +906,7 @@ const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | |||
| // Function Op | |||
| const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | |||
| const std::string ATTR_NAME_PARENT_CONST_TYPE = "_parent_const_type"; | |||
| // Used for mark the active node is for loop, type:bool | |||
| const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | |||
| @@ -719,6 +921,7 @@ const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||
| // l1 fusion and other fusion in future | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||
| const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||
| const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | |||
| const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | |||
| @@ -730,6 +933,9 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 | |||
| const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||
| const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||
| const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||
| const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; | |||
| const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; | |||
| const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; | |||
| // Atomic addr clean attrs | |||
| const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | |||
| @@ -748,6 +954,8 @@ const std::string ATTR_NEED_COMPILE = "_node_need_compile"; | |||
| const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | |||
| const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; | |||
| // For inserted op | |||
| const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | |||
| @@ -764,7 +972,22 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | |||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | |||
| // functional ops attr | |||
| const std::string ATTR_NAME_WHILE_COND = "cond"; | |||
| const std::string ATTR_NAME_WHILE_BODY = "body"; | |||
| // used for label switch | |||
| const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | |||
| const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | |||
| const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||
| const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||
| // used for LX tiling | |||
| const std::string ATTR_NAME_OP_L1_SPACE = "_l1_space"; | |||
| const std::string ATTR_NAME_FUSION_TYPE_LIST = "_fusion_type_list"; | |||
| const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_list_list"; | |||
| const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; | |||
| const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | |||
| const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | |||
| } // namespace ge | |||
| @@ -31,19 +31,18 @@ using std::string; | |||
| using std::vector; | |||
| namespace ge { | |||
| GeAttrValue::NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | |||
| NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | |||
| GeAttrValue::NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) | |||
| : named_attrs_(owner, proto_msg) {} | |||
| NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {} | |||
| void GeAttrValue::NamedAttrs::SetName(const std::string &name) { | |||
| void NamedAttrs::SetName(const std::string &name) { | |||
| auto proto_msg = named_attrs_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| proto_msg->set_name(name); | |||
| } | |||
| } | |||
| string GeAttrValue::NamedAttrs::GetName() const { | |||
| string NamedAttrs::GetName() const { | |||
| auto proto_msg = named_attrs_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| return proto_msg->name(); | |||
| @@ -51,13 +50,13 @@ string GeAttrValue::NamedAttrs::GetName() const { | |||
| return string(); | |||
| } | |||
| GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | |||
| GeAttrValue NamedAttrs::GetItem(const string &key) const { | |||
| GeAttrValue value; | |||
| GetAttr(key, value); | |||
| (void)GetAttr(key, value); | |||
| return value; | |||
| } | |||
| ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { | |||
| ProtoAttrMapHelper NamedAttrs::MutableAttrMap() { | |||
| auto proto_msg = named_attrs_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), proto_msg->mutable_attr()); | |||
| @@ -65,7 +64,7 @@ ProtoAttrMapHelper GeAttrValue::NamedAttrs::MutableAttrMap() { | |||
| return ProtoAttrMapHelper(named_attrs_.GetProtoOwner(), nullptr); | |||
| } | |||
| ConstProtoAttrMapHelper GeAttrValue::NamedAttrs::GetAttrMap() const { | |||
| ConstProtoAttrMapHelper NamedAttrs::GetAttrMap() const { | |||
| auto proto_msg = named_attrs_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| return ConstProtoAttrMapHelper(named_attrs_.GetProtoOwner(), &proto_msg->attr()); | |||
| @@ -515,7 +514,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAtt | |||
| return true; | |||
| } | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NamedAttrs &value) { | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue::NAMED_ATTRS &value) { | |||
| if (!AttrUtilsHelper::SetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | |||
| return false; | |||
| } | |||
| @@ -528,7 +527,7 @@ bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const GeAttrValue: | |||
| return true; | |||
| } | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NamedAttrs> &value) { | |||
| bool GeAttrValueImp::SetValue(proto::AttrDef &proto_attr_val, const vector<GeAttrValue::NAMED_ATTRS> &value) { | |||
| if (!AttrUtilsHelper::SetValueCheckAndSetListType(proto_attr_val, | |||
| proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS)) { | |||
| return false; | |||
| @@ -739,7 +738,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| } | |||
| bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | |||
| GeAttrValue::NamedAttrs &value) { | |||
| GeAttrValue::NAMED_ATTRS &value) { | |||
| if (!AttrUtilsHelper::GetValueCheckType(proto_attr_val, proto::AttrDef::kFunc)) { | |||
| return false; | |||
| } | |||
| @@ -752,7 +751,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| } | |||
| bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoMsgOwner &, | |||
| vector<GeAttrValue::NamedAttrs> &value) { | |||
| vector<GeAttrValue::NAMED_ATTRS> &value) { | |||
| value.clear(); | |||
| if (!AttrUtilsHelper::GetValueCheckListType( | |||
| proto_attr_val, proto::AttrDef_ListValue_ListValueType_VT_LIST_NAMED_ATTRS, ListValueItemCheck(na))) { | |||
| @@ -760,7 +759,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| } | |||
| auto &list = proto_attr_val.list(); | |||
| for (const auto &item : list.na()) { | |||
| value.emplace_back(GeAttrValue::NamedAttrs()); | |||
| value.emplace_back(GeAttrValue::NAMED_ATTRS()); | |||
| if (value.empty()) { | |||
| return false; | |||
| } | |||
| @@ -967,7 +966,7 @@ ATTR_UTILS_SET_GET_IMP(TensorDesc, GeTensorDesc) | |||
| ATTR_UTILS_SET_IMP(Tensor, GeTensorPtr) | |||
| ATTR_UTILS_SET_IMP(Tensor, ConstGeTensorPtr) | |||
| ATTR_UTILS_SET_IMP(Tensor, GeTensor) | |||
| ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NamedAttrs) | |||
| ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) | |||
| ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | |||
| ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | |||
| ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>) | |||
| @@ -982,7 +981,7 @@ ATTR_UTILS_SET_GET_IMP(ListTensorDesc, vector<GeTensorDesc>) | |||
| ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensorPtr>) | |||
| ATTR_UTILS_SET_IMP(ListTensor, vector<ConstGeTensorPtr>) | |||
| ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>) | |||
| ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NamedAttrs>) | |||
| ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>) | |||
| ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | |||
| ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | |||
| ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) | |||
| @@ -83,6 +83,12 @@ size_t GeShape::GetDimNum() const { | |||
| auto proto_msg = shape_def_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| if (proto_msg->dim_size() >= 0) { | |||
| // check whether contain -2, if true, return -1 | |||
| for (auto i : proto_msg->dim()) { | |||
| if (i == UNKNOWN_DIM_NUM) { | |||
| return 0; | |||
| } | |||
| } | |||
| return proto_msg->dim_size(); | |||
| } else { | |||
| return 0; | |||
| @@ -157,6 +163,10 @@ int64_t GeShape::GetShapeSize() const { | |||
| return 0; | |||
| } | |||
| for (auto i : proto_msg->dim()) { | |||
| // if unknown shape, return -1 | |||
| if (i == UNKNOWN_DIM || i == UNKNOWN_DIM_NUM) { | |||
| return UNKNOWN_DIM; | |||
| } | |||
| res *= i; | |||
| } | |||
| } | |||
| @@ -209,6 +219,7 @@ const string TENSOR_UTILS_RC = "rc"; | |||
| const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; | |||
| const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; | |||
| const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; | |||
| const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; | |||
| GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} | |||
| @@ -396,6 +407,35 @@ GeShape &GeTensorDesc::MutableShape() { return ShapeReference(); } | |||
| void GeTensorDesc::SetShape(GeShape shape) { ShapeReference() = std::move(shape); } | |||
| // set shape with -2, it stand for unknown shape | |||
| void GeTensorDesc::SetUnknownDimNumShape() { SetShape(GeShape({UNKNOWN_DIM_NUM})); } | |||
| // for unknown shape | |||
| graphStatus GeTensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) { | |||
| std::vector<vector<int64_t>> shape_range; | |||
| for (const auto &ele : range) { | |||
| shape_range.emplace_back(std::vector<int64_t>({ele.first, ele.second})); | |||
| } | |||
| auto ret = AttrUtils::SetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); | |||
| return ret ? GRAPH_SUCCESS : GRAPH_FAILED; | |||
| } | |||
| graphStatus GeTensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const { | |||
| std::vector<vector<int64_t>> shape_range; | |||
| (void)AttrUtils::GetListListInt(this, TENSOR_UTILS_SHAPE_RANGE, shape_range); | |||
| for (const auto &ele : shape_range) { | |||
| // here must be only two elemenet because pair | |||
| if (ele.size() != 2) { | |||
| GELOGE(GRAPH_FAILED, "shape_range must contain only 2 value but really is %lu", ele.size()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::pair<int64_t, int64_t> pair({ele[0], ele[1]}); | |||
| range.push_back(pair); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GeShape GeTensorDesc::GetOriginShape() const { | |||
| vector<int64_t> origin_shape; | |||
| if (!AttrUtils::GetListInt(this, TENSOR_UTILS_ORIGIN_SHAPE, origin_shape)) { | |||
| @@ -16,11 +16,12 @@ | |||
| #include "external/graph/graph.h" | |||
| #include "debug/ge_util.h" | |||
| #include "external/graph/operator.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/ge_attr_value.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/debug/ge_op_types.h" | |||
| #include "graph/model.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| using std::map; | |||
| using std::pair; | |||
| @@ -214,6 +215,23 @@ class GraphImpl { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||
| for (auto &op : op_list_) { | |||
| auto op_type = op.second.GetOpType(); | |||
| if (op_type == type) { | |||
| ops.push_back(op.second); | |||
| continue; | |||
| } | |||
| if (op_type == ge::FRAMEWORKOP) { | |||
| op.second.GetAttr(ge::ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, op_type); | |||
| if (op_type == type) { | |||
| ops.push_back(op.second); | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| void SetNeedIteration(bool need_iteration) { | |||
| if (compute_graph_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Set need iteration failed, as compute graph is null."); | |||
| @@ -222,6 +240,8 @@ class GraphImpl { | |||
| compute_graph_->SetNeedIteration(need_iteration); | |||
| } | |||
| const std::string &GetName() const { return name_; } | |||
| private: | |||
| std::string name_; | |||
| std::string output_name_; | |||
| @@ -255,6 +275,11 @@ graphStatus Graph::FindOpByName(const std::string &name, Operator &op) const { | |||
| return impl_->FindOpByName(name, op); | |||
| } | |||
| graphStatus Graph::FindOpByType(const string &type, std::vector<ge::Operator> &ops) const { | |||
| GE_CHECK_NOTNULL(impl_); | |||
| return impl_->FindOpByType(type, ops); | |||
| } | |||
| Graph &Graph::SetInputs(const vector<ge::Operator> &inputs) { | |||
| GE_CHK_BOOL_EXEC(impl_ != nullptr, return *this, "SetInputs failed: graph can not be used, impl is nullptr.") | |||
| GE_CHK_BOOL_EXEC(inputs.size() > 0, return *this, "SetInputs failed: input operator size can not be 0."); | |||
| @@ -331,6 +356,8 @@ graphStatus Graph::LoadFromFile(const string &file_name) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::string &Graph::GetName() const { return impl_->GetName(); } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph | |||
| GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) { | |||
| GE_CHK_BOOL_EXEC_NOLOG(compute_graph != nullptr, return Graph("")); | |||
| @@ -343,4 +370,15 @@ GraphUtils::CreateGraphFromComputeGraph(const ge::ComputeGraphPtr compute_graph) | |||
| return graph; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::RecoverGraphOperators(const Graph &graph) { | |||
| GE_CHECK_NOTNULL(graph.impl_); | |||
| GE_CHECK_NOTNULL(graph.impl_->compute_graph_); | |||
| graph.impl_->op_list_.clear(); | |||
| for (const auto &node : graph.impl_->compute_graph_->GetDirectNode()) { | |||
| graph.impl_->op_list_[node->GetName()] = OpDescUtils::CreateOperatorFromNode(node); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -16,7 +16,10 @@ | |||
| #include "graph/model_serialize.h" | |||
| #include <google/protobuf/text_format.h> | |||
| #include <queue> | |||
| #include <iostream> | |||
| #include "debug/ge_attr_define.h" | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_util.h" | |||
| @@ -26,6 +29,7 @@ | |||
| #include "utils/graph_utils.h" | |||
| #include "debug/ge_op_types.h" | |||
| using std::map; | |||
| using std::string; | |||
| namespace ge { | |||
| @@ -121,6 +125,11 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||
| } | |||
| } | |||
| } | |||
| op_def_proto->set_id(op_desc->GetId()); | |||
| for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||
| op_def_proto->add_subgraph_name(name); | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -196,6 +205,14 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||
| GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | |||
| return false; | |||
| } | |||
| for (auto subgraph : compute_graph->GetAllSubgraphs()) { | |||
| if (!SerializeGraph(subgraph, model_proto->add_graph(), is_dump)) { | |||
| GELOGE(GRAPH_FAILED, "Serialize subgraph failed"); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| @@ -228,6 +245,14 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d | |||
| GE_CHK_BOOL_RET_STATUS(temp_value != nullptr, false, "temp_value is nullptr"); | |||
| op_desc->outputs_desc_.push_back(temp_value); | |||
| } | |||
| op_desc->SetId(op_def_proto.id()); | |||
| uint32_t graph_index = 0; | |||
| for (const std::string &name : op_def_proto.subgraph_name()) { | |||
| op_desc->AddSubgraphName(name); | |||
| op_desc->SetSubgraphInstanceName(graph_index++, name); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -238,7 +263,7 @@ bool ModelSerializeImp::UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &op | |||
| GELOGW("UnserializeOpDesc error."); | |||
| } | |||
| NodePtr node = graph->AddNode(op_desc); | |||
| NodePtr node = graph->AddNode(op_desc, op_desc->GetId()); | |||
| GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr."); | |||
| // Inputs | |||
| @@ -319,6 +344,35 @@ bool ModelSerializeImp::HandleNodeNameRef() { | |||
| return true; | |||
| } | |||
| bool ModelSerializeImp::RebuildOwnership(ComputeGraphPtr &compute_graph, map<string, ComputeGraphPtr> &subgraphs) { | |||
| std::queue<ComputeGraphPtr> all_graphs; | |||
| all_graphs.emplace(compute_graph); | |||
| while (!all_graphs.empty()) { | |||
| ComputeGraphPtr graph = all_graphs.front(); | |||
| all_graphs.pop(); | |||
| for (const NodePtr &node : graph->GetDirectNode()) { | |||
| const OpDescPtr op_desc = node->GetOpDesc(); | |||
| for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||
| auto it = subgraphs.find(name); | |||
| if (it == subgraphs.end()) { | |||
| GELOGE(GRAPH_FAILED, "Node:%s, Subgraph:%s not found, num:%zu.", op_desc->GetName().c_str(), name.c_str(), | |||
| subgraphs.size()); | |||
| return false; | |||
| } | |||
| ComputeGraphPtr &subgraph = it->second; | |||
| subgraph->SetParentGraph(graph); | |||
| subgraph->SetParentNode(node); | |||
| compute_graph->AddSubgraph(subgraph->GetName(), subgraph); | |||
| all_graphs.emplace(subgraph); | |||
| } | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_proto) { | |||
| model.name_ = model_proto.name(); | |||
| model.version_ = model_proto.version(); | |||
| @@ -332,7 +386,31 @@ bool ModelSerializeImp::UnserializeModel(Model &model, proto::ModelDef &model_pr | |||
| if (UnserializeGraphWithoutEdge(compute_graph_ptr, graph_proto)) { | |||
| model.graph_ = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); | |||
| } | |||
| // 0 is main graph, following is subgraph. | |||
| map<string, ComputeGraphPtr> subgraphs; | |||
| for (int idx = 1; idx < graphs_proto.size(); ++idx) { | |||
| ComputeGraphPtr subgraph; | |||
| ModelSerializeImp impl; | |||
| if (!impl.UnserializeGraphWithoutEdge(subgraph, graphs_proto[idx])) { | |||
| GELOGE(GRAPH_FAILED, "UnserializeGraphWithoutEdge failed"); | |||
| return false; | |||
| } | |||
| if (!impl.HandleNodeNameRef()) { | |||
| GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | |||
| return false; | |||
| } | |||
| subgraphs[subgraph->GetName()] = subgraph; | |||
| } | |||
| if (!RebuildOwnership(compute_graph_ptr, subgraphs)) { | |||
| GELOGE(GRAPH_FAILED, "Rebuild graph ownership failed"); | |||
| return false; | |||
| } | |||
| } | |||
| if (!HandleNodeNameRef()) { | |||
| GELOGE(GRAPH_FAILED, "HandleNodeNameRef failed"); | |||
| return false; | |||
| @@ -61,6 +61,8 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; | |||
| const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | |||
| const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; | |||
| const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | |||
| const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | |||
| @@ -227,6 +229,40 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||
| } | |||
| } | |||
| graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| for (unsigned int i = 0; i < num; i++) { | |||
| string input_name = name + std::to_string(i); | |||
| GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||
| "Add input tensor_desc is existed. name[%s]", input_name.c_str()); | |||
| std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||
| if (in_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (index > inputs_desc_.size()) { | |||
| GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); | |||
| // Update index in input_name_idx | |||
| for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||
| if (it->second >= (index + i)) { | |||
| it->second += 1; | |||
| } | |||
| } | |||
| (void)input_name_idx.insert(make_pair(input_name, i + index)); | |||
| } | |||
| SetAllInputName(input_name_idx); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| for (unsigned int i = 0; i < num; i++) { | |||
| @@ -239,7 +275,6 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||
| GELOGE(GRAPH_FAILED, "AddInputDescForward failed, malloc shared_ptr failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | |||
| // Update index in input_name_idx | |||
| @@ -634,6 +669,13 @@ graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int n | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigned int num, size_t index) { | |||
| if (AddInputDescMiddle(name, num, index) != GRAPH_SUCCESS) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { | |||
| if (is_push_back) { | |||
| for (unsigned int i = 0; i < num; i++) { | |||
| @@ -1054,6 +1096,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetDstName | |||
| return dst_name; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpInferDepends(const vector<string> &depend_names) { | |||
| auto ret = AttrUtils::SetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); | |||
| if (ret != true) { | |||
| GELOGE(GRAPH_FAILED, "set op_infer_depends fail."); | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<string> OpDesc::GetOpInferDepends() const { | |||
| vector<string> depend_names; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OP_INFER_DEPENDS, depend_names); | |||
| return depend_names; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetDstIndex(const vector<int64_t> &dst_index) { | |||
| auto proto_msg = op_def_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| @@ -1199,20 +1254,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &O | |||
| return subgraph_instance_names_; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { | |||
| subgraph_instance_names_.emplace_back(std::move(name)); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | |||
| for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | |||
| if (*iter == name) { | |||
| subgraph_instance_names_.erase(iter); | |||
| *iter = ""; | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | |||
| GELOGI("Add subgraph name is %s", name.c_str()); | |||
| auto iter = subgraph_names_to_index_.find(name); | |||
| if (iter != subgraph_names_to_index_.end()) { | |||
| GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | |||
| @@ -1220,6 +1272,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphNa | |||
| } | |||
| auto size = subgraph_names_to_index_.size(); | |||
| subgraph_names_to_index_[name] = size; | |||
| subgraph_instance_names_.resize(size + 1); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -1227,4 +1280,34 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint3 | |||
| const { | |||
| return subgraph_names_to_index_; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::SetSubgraphInstanceName(uint32_t index, | |||
| const std::string &name) { | |||
| GELOGI("Add sub graph instans name is %s, index is %u", name.c_str(), index); | |||
| if (index >= subgraph_instance_names_.size()) { | |||
| GE_LOGE("The index %u exceeds the max instance coutn %zu", index, subgraph_instance_names_.size()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| subgraph_instance_names_[index] = name; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RegisterSubgraphIrName(const string &name, | |||
| SubgraphType type) { | |||
| subgraph_ir_names_to_type_[name] = type; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, SubgraphType> &OpDesc::GetSubgraphIrNames() | |||
| const { | |||
| return subgraph_ir_names_to_type_; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY SubgraphType | |||
| OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { | |||
| auto iter = subgraph_ir_names_to_type_.find(name); | |||
| if (iter == subgraph_ir_names_to_type_.end()) { | |||
| return kSubgraphTypeEnd; | |||
| } | |||
| return iter->second; | |||
| } | |||
| } // namespace ge | |||
| @@ -15,6 +15,7 @@ | |||
| */ | |||
| #include "external/graph/operator.h" | |||
| #include "external/graph/operator_factory.h" | |||
| #include <stdint.h> | |||
| #include <algorithm> | |||
| #include <mutex> | |||
| @@ -38,6 +39,11 @@ | |||
| #include "utils/tensor_adapter.h" | |||
| #include "utils/tensor_utils.h" | |||
| #include "utils/type_utils.h" | |||
| #include <algorithm> | |||
| #include <mutex> | |||
| #include <queue> | |||
| #include <set> | |||
| #include <stdint.h> | |||
| using std::enable_shared_from_this; | |||
| using std::make_pair; | |||
| @@ -343,15 +349,71 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
| InferenceContextPtr GetInferenceContext() const { return inference_context_; } | |||
| void SubgraphRegister(const std::string &name, bool dynamic) { | |||
| op_desc_->RegisterSubgraphIrName(name, dynamic ? kDynamic : kStatic); | |||
| } | |||
| void SubgraphCountRegister(const std::string &name, uint32_t count) { | |||
| if (op_desc_->GetSubgraphTypeByIrName(name) == kStatic) { | |||
| op_desc_->AddSubgraphName(name); | |||
| } else { | |||
| for (uint32_t i = 0; i < count; ++i) { | |||
| op_desc_->AddSubgraphName(name + std::to_string(i)); | |||
| } | |||
| } | |||
| subgraph_names_to_builders_[name].resize(count, nullptr); | |||
| } | |||
| void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||
| auto iter = subgraph_names_to_builders_.find(name); | |||
| if (iter == subgraph_names_to_builders_.end()) { | |||
| GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||
| return; | |||
| } | |||
| if (iter->second.size() <= index) { | |||
| GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, excceds the max size %zu", | |||
| name.c_str(), index, iter->second.size()); | |||
| return; | |||
| } | |||
| iter->second[index] = builder; | |||
| } | |||
| SubgraphBuilder GetSubgraphBuilder(const std::string &name, uint32_t index) const { | |||
| auto iter = subgraph_names_to_builders_.find(name); | |||
| if (iter == subgraph_names_to_builders_.end()) { | |||
| GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||
| return nullptr; | |||
| } | |||
| if (iter->second.size() <= index) { | |||
| GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, excceds the max size %zu", | |||
| name.c_str(), index, iter->second.size()); | |||
| return nullptr; | |||
| } | |||
| return iter->second[index]; | |||
| } | |||
| std::vector<std::string> GetSubgraphNames() const { | |||
| std::vector<std::string> names; | |||
| for (const auto &subgraph_name_to_type : op_desc_->GetSubgraphIrNames()) { | |||
| names.emplace_back(subgraph_name_to_type.first); | |||
| } | |||
| return names; | |||
| } | |||
| size_t GetSubgraphNamesCount() const { return op_desc_->GetSubgraphIrNames().size(); } | |||
| OpDescPtr op_desc_ = nullptr; | |||
| private: | |||
| ge::ConstNodePtr node_{nullptr}; | |||
| ge::InferenceContextPtr inference_context_; | |||
| GraphBuilderCallback graph_builder_callback_; | |||
| std::map<string, std::vector<OpIO>> output_links_{}; | |||
| std::map<string, OpIO> input_link_{}; | |||
| std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{}; | |||
| std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{}; | |||
| std::map<std::string, std::vector<SubgraphBuilder>> subgraph_names_to_builders_; | |||
| }; | |||
| // Used to manage OperatorImpl instances created by ge api. | |||
| @@ -559,7 +621,6 @@ InferenceContextPtr Operator::GetInferenceContext() const { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return nullptr, "operator impl is nullptr."); | |||
| return operator_impl_->GetInferenceContext(); | |||
| } | |||
| TensorDesc Operator::GetInputDesc(uint32_t index) const { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return TensorDesc(), "operator impl is nullptr."); | |||
| return TensorAdapter::GeTensorDesc2TensorDesc(operator_impl_->GetInputDesc(index)); | |||
| @@ -698,7 +759,7 @@ const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() con | |||
| void Operator::InputRegister(const string &name) { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
| operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||
| (void)operator_impl_->GetOpDescImpl()->AddInputDesc(name, GeTensorDesc()); | |||
| } | |||
| void Operator::OptionalInputRegister(const string &name) { | |||
| @@ -745,6 +806,12 @@ void Operator::DynamicInputRegister(const string &name, const unsigned int num, | |||
| (void)operator_impl_->GetOpDescImpl()->AddDynamicInputDesc(name, num, is_push_back); | |||
| } | |||
| void Operator::DynamicInputRegisterByIndex(const string &name, const unsigned int num, size_t index) { | |||
| GE_CHK_BOOL_EXEC(!!operator_impl_, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(nullptr != operator_impl_->GetOpDescImpl(), return, "GetOpDescImpl is nullptr."); | |||
| operator_impl_->GetOpDescImpl()->AddDynamicInputDescByIndex(name, num, index); | |||
| } | |||
| int Operator::GetDynamicInputNum(const string &name) const { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | |||
| @@ -896,6 +963,11 @@ OP_ATTR_GET_IMP(string &, Str) | |||
| OP_ATTR_SET_IMP(const vector<string> &, ListStr) | |||
| OP_ATTR_GET_IMP(vector<string> &, ListStr) | |||
| OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
| OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
| OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
| OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
| OP_ATTR_REG_IMP(int64_t, Int) | |||
| OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt) | |||
| OP_ATTR_REG_IMP(float, Float) | |||
| @@ -905,6 +977,8 @@ OP_ATTR_REG_IMP(const vector<string> &, ListStr) | |||
| OP_ATTR_REG_IMP(bool, Bool) | |||
| OP_ATTR_REG_IMP(const vector<bool> &, ListBool) | |||
| OP_ATTR_REG_IMP(const vector<vector<int64_t>> &, ListListInt) | |||
| OP_ATTR_REG_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
| OP_ATTR_REG_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
| #undef OP_ATTR_SET_IMP | |||
| #undef OP_ATTR_GET_IMP | |||
| @@ -1114,6 +1188,95 @@ void Operator::AttrRegister(const string &name, const OpBytes &attr_value) { | |||
| } | |||
| } | |||
| void Operator::SubgraphRegister(const std::string &name, bool dynamic) { | |||
| if (operator_impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||
| return; | |||
| } | |||
| operator_impl_->SubgraphRegister(name, dynamic ? kDynamic : kStatic); | |||
| } | |||
| void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { | |||
| if (operator_impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||
| return; | |||
| } | |||
| operator_impl_->SubgraphCountRegister(name, count); | |||
| } | |||
| void Operator::SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||
| if (operator_impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||
| return; | |||
| } | |||
| operator_impl_->SetSubgraphBuilder(name, index, builder); | |||
| } | |||
| std::vector<std::string> Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } | |||
| SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &name, uint32_t index) const { | |||
| if (operator_impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | |||
| return nullptr; | |||
| } | |||
| return operator_impl_->GetSubgraphBuilder(name, index); | |||
| } | |||
| SubgraphBuilder Operator::GetSubgraphBuilder(const string &name) const { return GetDynamicSubgraphBuilder(name, 0); } | |||
| Graph Operator::GetSubgraph(const string &name) const { | |||
| if (operator_impl_ == nullptr) { | |||
| GE_LOGE("Failed to get subgraph %s, the operator impl is null", name.c_str()); | |||
| return Graph(""); | |||
| } | |||
| auto op_desc = OpDescUtils::GetOpDescFromOperator(*this); | |||
| if (op_desc == nullptr) { | |||
| GE_LOGE("Failed to get subgraph %s, the op_desc is null", name.c_str()); | |||
| return Graph(""); | |||
| } | |||
| const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); | |||
| auto iter = subgraph_names_to_index.find(name); | |||
| if (iter == subgraph_names_to_index.end()) { | |||
| GE_LOGE("Failed to get subgraph %s, the name may be invalid", name.c_str()); | |||
| return Graph(""); | |||
| } | |||
| auto subgraph_instance_name = op_desc->GetSubgraphInstanceName(iter->second); | |||
| if (subgraph_instance_name.empty()) { | |||
| GE_LOGE("Failed to get subgraph %s index %u, the subgraph may not be added", name.c_str(), iter->second); | |||
| return Graph(""); | |||
| } | |||
| auto node = operator_impl_->GetNode(); | |||
| if (node == nullptr) { | |||
| GE_LOGE("Failed to get subgraph %s, the node is null", name.c_str()); | |||
| return Graph(""); | |||
| } | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| if (root_graph == nullptr) { | |||
| GE_LOGE("Failed to get subgraph %s, can not find the root graph", name.c_str()); | |||
| return Graph(""); | |||
| } | |||
| auto subgraph = root_graph->GetSubgraph(subgraph_instance_name); | |||
| if (subgraph == nullptr) { | |||
| GE_LOGE("Failed to get subgraph %s index %u, can not find the instance %s from the root graph", name.c_str(), | |||
| iter->second, subgraph_instance_name.c_str()); | |||
| return Graph(""); | |||
| } | |||
| return GraphUtils::CreateGraphFromComputeGraph(subgraph); | |||
| } | |||
| Graph Operator::GetDynamicSubgraph(const string &name, uint32_t index) const { | |||
| return GetSubgraph(name + std::to_string(index)); | |||
| } | |||
| size_t Operator::GetSubgraphNamesCount() const { | |||
| if (operator_impl_ == nullptr) { | |||
| GE_LOGE("Failed to get subgraph names count, the operator impl is null"); | |||
| return 0; | |||
| } | |||
| return operator_impl_->GetSubgraphNamesCount(); | |||
| } | |||
| class GraphBuilderImpl { | |||
| public: | |||
| explicit GraphBuilderImpl(const string &name) : graph_(ComGraphMakeShared<ComputeGraph>(name)) { | |||
| @@ -96,7 +96,6 @@ VerifyFunc OperatorFactoryImpl::GetVerifyFunc(const std::string &operator_type) | |||
| graphStatus OperatorFactoryImpl::RegisterOperatorCreator(const string &operator_type, OpCreator const &op_creator) { | |||
| if (operator_creators_ == nullptr) { | |||
| GELOGI("operator_creators_ init"); | |||
| operator_creators_.reset(new (std::nothrow) std::map<string, OpCreator>()); | |||
| } | |||
| auto it = operator_creators_->find(operator_type); | |||
| @@ -0,0 +1,422 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/ref_relation.h" | |||
| #include <unordered_set> | |||
| #include <unordered_map> | |||
| #include "utils/mem_utils.h" | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "debug/ge_util.h" | |||
| #include "debug/ge_attr_define.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| using namespace std; | |||
| using namespace ge; | |||
| namespace ge { | |||
| namespace { | |||
| const char *kRefIndex = "_parent_node_index"; | |||
| const string kWhile = "While"; | |||
| const string kIf = "If"; | |||
| const string kCase = "Case"; | |||
| const int kMaxElementNum = 100; | |||
| std::unordered_set<string> function_op = {kWhile, kIf, kCase}; | |||
| } // namespace | |||
| /* Impl */ | |||
| class RefRelations::Impl { | |||
| public: | |||
| graphStatus LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||
| unsigned long number = static_cast<unsigned long>(reinterpret_cast<uintptr_t>(key.node.get())); | |||
| std::string lookup_key = | |||
| key.node_name + std::to_string(key.in_out) + std::to_string(key.in_out_idx) + std::to_string(number); | |||
| auto iter = look_up_table_.find(lookup_key); | |||
| if (iter != look_up_table_.end()) { | |||
| for (auto &c : iter->second) { | |||
| result.insert(c); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GELOGW("can not find any relations! key value is %s", lookup_key.c_str()); | |||
| return GRAPH_SUCCESS; | |||
| }; | |||
| graphStatus BuildRefRelations(ge::ComputeGraph &root_graph); | |||
| graphStatus Clear() { | |||
| GELOGD("Start clear boundary reflections between main graph and sub graph!"); | |||
| look_up_table_.clear(); | |||
| values_.clear(); | |||
| return GRAPH_SUCCESS; | |||
| }; | |||
| private: | |||
| graphStatus BuildLookUpTables(); | |||
| graphStatus BuildRefRelationsForBranch(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||
| vector<vector<RefCell>> &node_refs); | |||
| graphStatus BuildRefRelationsForWhile(const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||
| vector<vector<RefCell>> &node_refs); | |||
| graphStatus BuildRelationsWithFuncNodeType(const NodePtr &root_node, | |||
| const vector<vector<NodePtr>> &classed_data_nodes, | |||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, | |||
| vector<vector<RefCell>> &node_refs); | |||
| void GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||
| vector<NodePtr> &netoutput_nodes, const std::vector<std::string> &sub_graph_names, | |||
| const std::string &node_type); | |||
| graphStatus GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph); | |||
| graphStatus ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, vector<vector<NodePtr>> &classed_data_nodes); | |||
| graphStatus ProcessSubgraphNetoutput(const vector<NodePtr> &netoutput_nodes, | |||
| vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes); | |||
| std::unordered_map<string, vector<RefCell>> look_up_table_; | |||
| std::vector<vector<vector<RefCell>>> values_; | |||
| }; | |||
| // Node Level | |||
| graphStatus RefRelations::Impl::BuildRefRelationsForBranch( | |||
| const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||
| GELOGD("Enter BuildRefRelationsForBranch!"); | |||
| size_t ref_i = 0; | |||
| for (const auto &ref_i_data_nodes : classed_data_nodes) { | |||
| vector<RefCell> in_ref_i_all_refs; | |||
| RefCell cell_root; | |||
| cell_root.node_name = root_node->GetName(); | |||
| cell_root.node = root_node; | |||
| cell_root.in_out = NODE_IN; | |||
| cell_root.in_out_idx = ref_i; | |||
| in_ref_i_all_refs.emplace_back(cell_root); | |||
| for (const auto &data : ref_i_data_nodes) { | |||
| RefCell cell_in; | |||
| RefCell cell_out; | |||
| cell_in.node_name = data->GetName(); | |||
| cell_in.node = data; | |||
| cell_in.in_out = NODE_IN; | |||
| cell_in.in_out_idx = 0; | |||
| cell_out.node_name = data->GetName(); | |||
| cell_out.node = data; | |||
| cell_out.in_out = NODE_OUT; | |||
| cell_out.in_out_idx = 0; | |||
| in_ref_i_all_refs.emplace_back(cell_in); | |||
| in_ref_i_all_refs.emplace_back(cell_out); | |||
| } | |||
| node_refs.emplace_back(in_ref_i_all_refs); | |||
| ref_i++; | |||
| } | |||
| size_t ref_o = 0; | |||
| for (const auto &ref_o_net_nodes : classed_netoutput_nodes) { | |||
| vector<RefCell> out_ref_i_all_refs; | |||
| RefCell cell_root; | |||
| cell_root.node_name = root_node->GetName(); | |||
| cell_root.node = root_node; | |||
| cell_root.in_out = NODE_OUT; | |||
| cell_root.in_out_idx = ref_o; | |||
| out_ref_i_all_refs.emplace_back(cell_root); | |||
| for (const auto &ele : ref_o_net_nodes) { | |||
| RefCell cell_netoutput_in; | |||
| RefCell cell_netoutput_out; | |||
| cell_netoutput_in.node_name = (ele.first)->GetName(); | |||
| cell_netoutput_in.node = ele.first; | |||
| cell_netoutput_in.in_out = NODE_IN; | |||
| cell_netoutput_in.in_out_idx = ele.second; | |||
| cell_netoutput_out.node_name = (ele.first)->GetName(); | |||
| cell_netoutput_out.node = ele.first; | |||
| cell_netoutput_out.in_out = NODE_OUT; | |||
| cell_netoutput_out.in_out_idx = ele.second; | |||
| out_ref_i_all_refs.emplace_back(cell_netoutput_in); | |||
| out_ref_i_all_refs.emplace_back(cell_netoutput_out); | |||
| } | |||
| node_refs.emplace_back(out_ref_i_all_refs); | |||
| ref_o++; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus RefRelations::Impl::BuildLookUpTables() { | |||
| for (size_t i = 0; i < values_.size(); i++) { | |||
| vector<vector<RefCell>> &val = values_[i]; | |||
| for (const auto &ele : val) { | |||
| for (const auto &ref_cell : ele) { | |||
| string key = ref_cell.node_name + std::to_string(ref_cell.in_out) + std::to_string(ref_cell.in_out_idx) + | |||
| std::to_string(static_cast<unsigned long>(reinterpret_cast<uintptr_t>(ref_cell.node.get()))); | |||
| look_up_table_[key] = ele; | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus RefRelations::Impl::BuildRefRelationsForWhile( | |||
| const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||
| GELOGD("Enter BuildRefRelations for while op!"); | |||
| // data_nodes has been sorted | |||
| // for while, input num must be same as output num | |||
| auto input_num = root_node->GetAllInDataAnchorsSize(); | |||
| size_t ref_i = 0; | |||
| while (ref_i < input_num) { | |||
| auto &ref_i_data_nodes = classed_data_nodes[ref_i]; | |||
| auto &ref_i_net_nodes = classed_netoutput_nodes[ref_i]; | |||
| vector<RefCell> ref_i_all_refs; | |||
| RefCell cell_root_i; | |||
| RefCell cell_root_o; | |||
| cell_root_i.node_name = root_node->GetName(); | |||
| cell_root_i.node = root_node; | |||
| cell_root_i.in_out = NODE_IN; | |||
| cell_root_i.in_out_idx = ref_i; | |||
| ref_i_all_refs.emplace_back(cell_root_i); | |||
| cell_root_o.node_name = root_node->GetName(); | |||
| cell_root_o.node = root_node; | |||
| cell_root_o.in_out = NODE_OUT; | |||
| cell_root_o.in_out_idx = ref_i; | |||
| ref_i_all_refs.emplace_back(cell_root_o); | |||
| for (const auto &data : ref_i_data_nodes) { | |||
| RefCell cell_in; | |||
| RefCell cell_out; | |||
| cell_in.node_name = data->GetName(); | |||
| cell_in.node = data; | |||
| cell_in.in_out = NODE_IN; | |||
| cell_in.in_out_idx = 0; | |||
| cell_out.node_name = data->GetName(); | |||
| cell_out.node = data; | |||
| cell_out.in_out = NODE_OUT; | |||
| cell_out.in_out_idx = 0; | |||
| ref_i_all_refs.emplace_back(cell_in); | |||
| ref_i_all_refs.emplace_back(cell_out); | |||
| } | |||
| for (const auto &ele : ref_i_net_nodes) { | |||
| RefCell cell_netoutput_in; | |||
| RefCell cell_netoutput_out; | |||
| cell_netoutput_in.node_name = (ele.first)->GetName(); | |||
| cell_netoutput_in.node = ele.first; | |||
| cell_netoutput_in.in_out = NODE_IN; | |||
| cell_netoutput_in.in_out_idx = ele.second; | |||
| cell_netoutput_out.node_name = (ele.first)->GetName(); | |||
| cell_netoutput_out.node = ele.first; | |||
| cell_netoutput_out.in_out = NODE_OUT; | |||
| cell_netoutput_out.in_out_idx = ele.second; | |||
| ref_i_all_refs.emplace_back(cell_netoutput_in); | |||
| ref_i_all_refs.emplace_back(cell_netoutput_out); | |||
| } | |||
| node_refs.emplace_back(ref_i_all_refs); | |||
| ref_i++; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| // build ref relations according to diff func op type | |||
| graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( | |||
| const NodePtr &root_node, const vector<vector<NodePtr>> &classed_data_nodes, | |||
| const vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes, vector<vector<RefCell>> &node_refs) { | |||
| // data_nodes has been sorted | |||
| auto node_type = root_node->GetType(); | |||
| auto status = GRAPH_SUCCESS; | |||
| if (node_type == kIf || node_type == kCase) { | |||
| status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||
| } else if (node_type == kWhile) { | |||
| status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||
| } else { | |||
| GELOGE(GRAPH_PARAM_INVALID, "Node type [%s] is not supported for build ref relations!", node_type.c_str()); | |||
| status = GRAPH_PARAM_INVALID; | |||
| } | |||
| return status; | |||
| } | |||
| void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &root_graph, vector<NodePtr> &data_nodes, | |||
| vector<NodePtr> &netoutput_nodes, | |||
| const std::vector<std::string> &sub_graph_names, | |||
| const std::string &node_type) { | |||
| int sub_graph_idx = 0; | |||
| for (const auto &name : sub_graph_names) { | |||
| auto sub_graph = root_graph.GetSubgraph(name); | |||
| for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { | |||
| auto sub_graph_node_type = sub_graph_node->GetType(); | |||
| if (sub_graph_node_type == DATA) { | |||
| data_nodes.emplace_back(sub_graph_node); | |||
| } else if (sub_graph_node_type == NETOUTPUT) { | |||
| // if while, the first subgraph must be cond subgraph. | |||
| // There is no meaning for refs ,so continue | |||
| if (node_type == kWhile && sub_graph_idx == 0) { | |||
| continue; | |||
| } | |||
| netoutput_nodes.emplace_back(sub_graph_node); | |||
| } | |||
| continue; | |||
| } | |||
| sub_graph_idx++; | |||
| } | |||
| } | |||
| graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::ComputeGraph &root_graph) { | |||
| auto parent_graph_ptr = graph.GetParentGraph(); | |||
| if (parent_graph_ptr == nullptr) { | |||
| root_graph = graph; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| auto root_graph_ptr = GraphUtils::FindRootGraph(parent_graph_ptr); | |||
| if (root_graph_ptr == nullptr) { | |||
| GE_LOGE("Get null root graph"); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| root_graph = *root_graph_ptr; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, | |||
| vector<vector<NodePtr>> &classed_data_nodes) { | |||
| int max_ref_idx = 0; | |||
| for (const auto &e : data_nodes) { | |||
| int i; | |||
| bool is_exist = true; | |||
| is_exist = AttrUtils::GetInt(e->GetOpDesc(), kRefIndex, i); | |||
| if (!is_exist) { | |||
| GELOGE(GRAPH_FAILED, "Invalid SubGraph NetOutput node[%s].no attr %s", e->GetName().c_str(), kRefIndex); | |||
| return GRAPH_FAILED; | |||
| } | |||
| max_ref_idx = (i > max_ref_idx) ? i : max_ref_idx; | |||
| } | |||
| while (!data_nodes.empty()) { | |||
| auto data = data_nodes.back(); | |||
| data_nodes.pop_back(); | |||
| int ref_idx = 0; | |||
| (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); | |||
| classed_data_nodes[ref_idx].emplace_back(data); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | |||
| const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) { | |||
| for (const auto &sub_netoutput_node : netoutput_nodes) { | |||
| auto op_desc = sub_netoutput_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| for (const auto &in_data_anchor : sub_netoutput_node->GetAllInDataAnchors()) { | |||
| auto in_desc = op_desc->MutableInputDesc(in_data_anchor->GetIdx()); | |||
| if (in_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Invalid NetOutput node [%s] idx [%lu], no tensor on it", | |||
| sub_netoutput_node->GetName().c_str(), in_data_anchor->GetIdx()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| int ref_o; | |||
| if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { | |||
| if (ref_o >= kMaxElementNum) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| classed_netoutput_nodes[ref_o].emplace_back( | |||
| std::pair<NodePtr, size_t>({sub_netoutput_node, static_cast<size_t>(in_data_anchor->GetIdx())})); | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | |||
| /* First Step: Get root graph */ | |||
| ge::ComputeGraph &root_graph = graph; | |||
| auto status = GetRootGraph(graph, root_graph); | |||
| if (status != GRAPH_SUCCESS) { | |||
| return status; | |||
| } | |||
| for (const auto &node : graph.GetAllNodes()) { | |||
| auto node_type = node->GetType(); | |||
| if (function_op.find(node_type) == function_op.end()) { | |||
| continue; | |||
| } | |||
| std::vector<NodePtr> ref_nodes; | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| vector<NodePtr> data_nodes; | |||
| vector<NodePtr> netoutput_nodes; | |||
| // Get data and netoutput of sub_graph | |||
| GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); | |||
| vector<vector<NodePtr>> classed_data_nodes(kMaxElementNum); // according to ref_idx | |||
| vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(kMaxElementNum); // according to ref_idx | |||
| status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); | |||
| return status; | |||
| } | |||
| // for netoutput | |||
| // check netoutput | |||
| // here main graph output number must be the same as every sub_graph netoutput node | |||
| // key: netoutput node_ptr ,<ref_idx, net_in_idx> | |||
| status = ProcessSubgraphNetoutput(netoutput_nodes, classed_netoutput_nodes); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "process netoutput failed!"); | |||
| return status; | |||
| } | |||
| vector<vector<RefCell>> node_refs; | |||
| status = BuildRelationsWithFuncNodeType(node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(status, "BuildRelationsWithFuncNodeType Failed! Node is [%s]!", node->GetName().c_str()); | |||
| return status; | |||
| } | |||
| if (!node_refs.empty()) { | |||
| values_.push_back(node_refs); | |||
| } | |||
| } | |||
| /* Seconde Step: generate map */ | |||
| status = BuildLookUpTables(); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(status, "Build look up tables failed!"); | |||
| return status; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /* Ref Relations Interface */ | |||
| RefRelations::RefRelations() { | |||
| impl_ = MakeShared<Impl>(); | |||
| if (impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "MakeShared failed!"); | |||
| return; | |||
| } | |||
| } | |||
| graphStatus RefRelations::LookUpRefRelations(const RefCell &key, unordered_set<RefCell, RefCellHash> &result) { | |||
| GE_CHECK_NOTNULL(impl_); | |||
| return impl_->LookUpRefRelations(key, result); | |||
| } | |||
| graphStatus RefRelations::BuildRefRelations(ge::ComputeGraph &root_graph) { | |||
| GE_CHECK_NOTNULL(impl_); | |||
| return impl_->BuildRefRelations(root_graph); | |||
| } | |||
| graphStatus RefRelations::Clear() { | |||
| GE_CHECK_NOTNULL(impl_); | |||
| return impl_->Clear(); | |||
| } | |||
| } // namespace ge | |||
| @@ -21,7 +21,7 @@ | |||
| #include <unordered_map> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "framework/common/types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "debug/ge_log.h" | |||
| @@ -37,7 +37,6 @@ | |||
| namespace ge { | |||
| namespace { | |||
| constexpr const char *kRefIndex = "parent_node_index"; | |||
| graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| @@ -47,6 +46,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| if (name.empty()) { | |||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| auto sub_graph = root_graph->GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| @@ -63,7 +66,7 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { | |||
| if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| @@ -76,7 +79,10 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGI("Ref index is %d, input_desc dtype is %d, node name is %s", ref_i, input_desc->GetDataType(), | |||
| node->GetName().c_str()); | |||
| auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | |||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
| @@ -101,6 +107,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| if (name.empty()) { | |||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| auto sub_graph = root_graph->GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| @@ -132,11 +142,14 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| node->GetName().c_str(), edge_anchor->GetIdx()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGI("Netoutput in anchor index is %zu, input tensor dim is %zu", edge_anchor->GetIdx(), | |||
| edge_desc->GetShape().GetDimNum()); | |||
| int ref_i; | |||
| if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { | |||
| if (!AttrUtils::GetInt(edge_desc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||
| continue; | |||
| } | |||
| GELOGI("Parent node index of edge desc is %d", ref_i); | |||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||
| if (output_desc == nullptr) { | |||
| GE_LOGE( | |||
| @@ -29,6 +29,7 @@ namespace { | |||
| /// Extra 1 byte store '\0' | |||
| const int EXTRA_STORE_POINTER_FOR_STRING = 8; | |||
| const int EXTRA_STORE_POINTER_FOR_STRING_AND_END_SYMBOL = 9; | |||
| const int64_t UNKNOWN_DIM_SIZE = -1; | |||
| } // namespace | |||
| namespace ge { | |||
| @@ -65,6 +66,7 @@ class TensorDescImpl { | |||
| TensorDescImpl(const Shape &shape, Format format, DataType dt) : shape_(shape), format_(format), data_type_(dt) {} | |||
| Shape shape_; | |||
| std::vector<std::pair<int64_t, int64_t>> range_; | |||
| Format format_ = FORMAT_ND; | |||
| Format origin_format_ = FORMAT_ND; | |||
| DataType data_type_ = DT_FLOAT; | |||
| @@ -94,7 +96,16 @@ class ShapeImpl { | |||
| public: | |||
| ShapeImpl() = default; | |||
| ~ShapeImpl() = default; | |||
| explicit ShapeImpl(const std::vector<int64_t> &dims) : dims_(dims) {} | |||
| explicit ShapeImpl(const std::vector<int64_t> &dims) { | |||
| bool is_unknown_dim_num = false; | |||
| for (const auto &dim : dims) { | |||
| if (dim == UNKNOWN_DIM_NUM) { | |||
| is_unknown_dim_num = true; | |||
| break; | |||
| } | |||
| } | |||
| dims_ = is_unknown_dim_num ? std::vector<int64_t>({UNKNOWN_DIM_NUM}) : dims; | |||
| } | |||
| std::vector<int64_t> dims_; | |||
| }; | |||
| @@ -105,6 +116,11 @@ Shape::Shape(const std::vector<int64_t> &dims) { impl_ = ComGraphMakeShared<Shap | |||
| size_t Shape::GetDimNum() const { | |||
| if (impl_ != nullptr) { | |||
| for (auto i : impl_->dims_) { | |||
| if (i == UNKNOWN_DIM_NUM) { | |||
| return 0; | |||
| } | |||
| } | |||
| return impl_->dims_.size(); | |||
| } | |||
| return 0; | |||
| @@ -146,6 +162,10 @@ int64_t Shape::GetShapeSize() const { | |||
| } | |||
| int64_t size = 1; | |||
| for (auto i : impl_->dims_) { | |||
| if (i == UNKNOWN_DIM_NUM || i == UNKNOWN_DIM) { | |||
| return UNKNOWN_DIM_SIZE; | |||
| } | |||
| if (!Int64MulNotOverflow(size, i)) { | |||
| GELOGE(GRAPH_FAILED, "mul overflow: %ld, %ld", size, i); | |||
| size = 0; | |||
| @@ -217,6 +237,34 @@ void TensorDesc::SetShape(const Shape &shape) { | |||
| } | |||
| } | |||
| // set shape with -2, it stand for unknown shape | |||
| graphStatus TensorDesc::SetUnknownDimNumShape() { | |||
| if (impl != nullptr) { | |||
| impl->shape_ = Shape({UNKNOWN_DIM_NUM}); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GELOGE(GRAPH_FAILED, "Set unknown shape failed,because no impl class!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| // for unknown shape | |||
| graphStatus TensorDesc::SetShapeRange(const std::vector<std::pair<int64_t, int64_t>> &range) { | |||
| if (impl != nullptr) { | |||
| impl->range_ = range; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GELOGE(GRAPH_FAILED, "SetShapeRange failed!impl is nullptr!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| graphStatus TensorDesc::GetShapeRange(std::vector<std::pair<int64_t, int64_t>> &range) const { | |||
| if (impl != nullptr) { | |||
| range = impl->range_; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GELOGE(GRAPH_FAILED, "impl is nullptr!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| Shape TensorDesc::GetOriginShape() const { | |||
| if (impl != nullptr) { | |||
| return impl->origin_shape_; | |||
| @@ -541,6 +589,17 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des | |||
| tensor_desc.GetDataType()); | |||
| ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | |||
| ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | |||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
| auto status = tensor_desc.GetShapeRange(shape_range); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||
| return ge_tensor_desc; | |||
| } | |||
| status = ge_tensor_desc.SetShapeRange(shape_range); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||
| return ge_tensor_desc; | |||
| } | |||
| auto size = tensor_desc.GetSize(); | |||
| TensorUtils::SetSize(ge_tensor_desc, size); | |||
| @@ -554,6 +613,17 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||
| ge_tensor_desc.GetDataType()); | |||
| tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | |||
| tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | |||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
| auto status = ge_tensor_desc.GetShapeRange(shape_range); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Get shape range failed!"); | |||
| return tensor_desc; | |||
| } | |||
| status = tensor_desc.SetShapeRange(shape_range); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Set shape range failed!"); | |||
| return tensor_desc; | |||
| } | |||
| int64_t size = 0; | |||
| (void)TensorUtils::GetSize(ge_tensor_desc, size); | |||
| tensor_desc.SetSize(size); | |||
| @@ -28,6 +28,7 @@ | |||
| #include <cstring> | |||
| #include <fstream> | |||
| #include <iomanip> | |||
| #include <queue> | |||
| #include "./ge_context.h" | |||
| #include "debug/ge_util.h" | |||
| @@ -390,8 +391,8 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDa | |||
| return GRAPH_FAILED; | |||
| } | |||
| if ((RemoveEdge(src, dst) != GRAPH_SUCCESS) || | |||
| (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { | |||
| (void)RemoveEdge(src, dst); | |||
| if (AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | |||
| dst_node->GetName().c_str(), insert_node->GetName().c_str(), dst_node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| @@ -399,7 +400,7 @@ GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDa | |||
| OutControlAnchorPtr new_out_ctrl_anchor = insert_node->GetOutControlAnchor(); | |||
| GE_CHECK_NOTNULL(new_out_ctrl_anchor); | |||
| for (InControlAnchorPtr peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
| for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | |||
| if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || | |||
| (AddEdge(new_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { | |||
| GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | |||
| @@ -706,7 +707,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||
| GELOGE(GRAPH_FAILED, "File name is too longer!"); | |||
| return; | |||
| } | |||
| std::unique_ptr<char> real_path(new (std::nothrow) char[PATH_MAX]{0}); | |||
| std::unique_ptr<char[]> real_path(new (std::nothrow) char[PATH_MAX]{0}); | |||
| if (real_path == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "New real_path failed."); | |||
| return; | |||
| @@ -1275,6 +1276,423 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR | |||
| return result; | |||
| } | |||
| /// | |||
| /// Get reference-mapping of all data_anchors in graph | |||
| /// @param [in] graph | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| for (auto &node : graph->GetAllNodes()) { | |||
| // in_data_anchor | |||
| if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| // out_data_anchor | |||
| if (HandleOutAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Find ref_mapping for out_data_anchors of node %s failed.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Get reference-mapping for in_data_anchors of node | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(node); | |||
| if (NodeUtils::IsSubgraphOutput(node)) { | |||
| return HandleSubgraphOutput(node, symbol_to_anchors, anchor_to_symbol); | |||
| } | |||
| if (NodeUtils::IsSubgraphInput(node)) { | |||
| return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); | |||
| } | |||
| std::string type = node->GetType(); | |||
| if ((type == MERGE) || (type == STREAMMERGE)) { | |||
| return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); | |||
| } | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { | |||
| std::string symbol = cur_node_info.ToString(); | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| symbol_to_anchors[symbol] = {cur_node_info}; | |||
| anchor_to_symbol[symbol] = symbol; | |||
| } else { | |||
| NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||
| if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Update symbol mapping failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Get reference-mapping for out_data_anchors of node | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(node); | |||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
| NodeIndexIO cur_node_info = NodeIndexIO(node, out_data_anchor->GetIdx(), kOut); | |||
| if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { | |||
| continue; | |||
| } | |||
| int32_t reuse_in_index = -1; | |||
| if (IsRefFromInput(out_data_anchor, reuse_in_index)) { | |||
| NodeIndexIO exist_node_info = NodeIndexIO(node, reuse_in_index, kIn); | |||
| if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Update symbol mapping failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } else { | |||
| std::string symbol = cur_node_info.ToString(); | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| symbol_to_anchors.emplace(std::make_pair(symbol, std::vector<NodeIndexIO>{cur_node_info})); | |||
| anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Handle input of subgraph | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::HandleSubgraphInput(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(node); | |||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||
| // Data in subgraph | |||
| uint32_t index = 0; | |||
| if (!ge::AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||
| GE_LOGE("Get attr ATTR_NAME_PARENT_NODE_INDEX failed, node:%s.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| NodePtr parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||
| GE_CHECK_NOTNULL(parent_node); | |||
| InDataAnchorPtr parent_in_anchor = parent_node->GetInDataAnchor(index); | |||
| GE_CHECK_NOTNULL(parent_in_anchor); | |||
| OutDataAnchorPtr peer_out_anchor = parent_in_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor != nullptr) { | |||
| // Data has and only has one input | |||
| NodeIndexIO cur_node_info = NodeIndexIO(node, 0, kIn); | |||
| NodeIndexIO exist_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||
| if (UpdateRefMapping(cur_node_info, exist_node_info, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Update symbol mapping failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Handle input of Merge op | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(node); | |||
| std::vector<NodeIndexIO> exist_node_infos; | |||
| std::vector<NodeIndexIO> cur_node_infos; | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { | |||
| std::string next_name; | |||
| if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { | |||
| ComputeGraphPtr graph = node->GetOwnerComputeGraph(); | |||
| GE_CHECK_NOTNULL(graph); | |||
| ge::NodePtr next_node = graph->FindNode(next_name); | |||
| GE_CHECK_NOTNULL(next_node); | |||
| // NextIteration has and only has one output | |||
| peer_out_anchor = next_node->GetOutDataAnchor(0); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); | |||
| cur_node_infos.emplace_back(NodeIndexIO(next_node, peer_out_anchor->GetIdx(), kOut)); | |||
| } | |||
| } else { | |||
| cur_node_infos.emplace_back(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn)); | |||
| exist_node_infos.emplace_back(NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut)); | |||
| } | |||
| } | |||
| size_t anchor_nums = 0; | |||
| NodeIndexIO max_node_index_io(nullptr, 0, kOut); | |||
| for (auto &temp_node_info : exist_node_infos) { | |||
| auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); | |||
| if (iter1 != anchor_to_symbol.end()) { | |||
| std::string temp_symbol = iter1->second; | |||
| auto iter2 = symbol_to_anchors.find(temp_symbol); | |||
| if (iter2 != symbol_to_anchors.end()) { | |||
| if (iter2->second.size() > anchor_nums) { | |||
| max_node_index_io = temp_node_info; | |||
| anchor_nums = iter2->second.size(); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| std::string symbol; | |||
| for (auto &temp_node_info : exist_node_infos) { | |||
| if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||
| GRAPH_SUCCESS) || | |||
| symbol.empty()) { | |||
| GE_LOGE("Union symbol map anchor1:%s & anchor2:%s.", max_node_index_io.ToString().c_str(), | |||
| temp_node_info.ToString().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| auto iter = symbol_to_anchors.find(symbol); | |||
| if (iter != symbol_to_anchors.end()) { | |||
| for (auto &temp_node_info : cur_node_infos) { | |||
| GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); | |||
| iter->second.emplace_back(temp_node_info); | |||
| anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Handle output of subgraph | |||
| /// @param [in] node | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(node); | |||
| ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); | |||
| GE_CHECK_NOTNULL(owner_graph); | |||
| NodePtr parent_node = owner_graph->GetParentNode(); | |||
| GE_CHECK_NOTNULL(parent_node); | |||
| OpDescPtr op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| GeTensorDesc in_tensor = op_desc->GetInputDesc(in_data_anchor->GetIdx()); | |||
| uint32_t index = 0; | |||
| if (!ge::AttrUtils::GetInt(in_tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||
| continue; | |||
| } | |||
| GE_CHECK_NOTNULL(parent_node->GetOutDataAnchor(index)); | |||
| // Union symbol of peer_out_anchor & parent_out_anchor | |||
| NodeIndexIO peer_node_info = NodeIndexIO(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx(), kOut); | |||
| NodeIndexIO parent_node_info = NodeIndexIO(parent_node, index, kOut); | |||
| std::string symbol; | |||
| if ((UnionSymbolMapping(peer_node_info, parent_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||
| GRAPH_SUCCESS) || | |||
| symbol.empty()) { | |||
| GE_LOGE("Union symbol map anchor1:%s, anchor2:%s.", peer_node_info.ToString().c_str(), | |||
| parent_node_info.ToString().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| NodeIndexIO cur_node_info = NodeIndexIO(node, in_data_anchor->GetIdx(), kIn); | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| symbol_to_anchors[symbol].emplace_back(cur_node_info); | |||
| anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Union ref-mapping | |||
| /// @param [in] exist_node_info1 | |||
| /// @param [in] exist_node_info2 | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @param [out] symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) { | |||
| std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
| std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
| if (symbol1 == symbol2) { | |||
| symbol = symbol1; | |||
| GELOGI("no need to union."); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| auto iter1 = symbol_to_anchors.find(symbol1); | |||
| auto iter2 = symbol_to_anchors.find(symbol2); | |||
| if ((iter1 == symbol_to_anchors.end()) || (iter2 == symbol_to_anchors.end())) { | |||
| GE_LOGE("symbol %s or %s not exist.", symbol1.c_str(), symbol2.c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto &max_iter = (iter1->second.size() > iter2->second.size() ? iter1 : iter2); | |||
| auto &min_iter = (iter1->second.size() > iter2->second.size() ? iter2 : iter1); | |||
| symbol = (iter1->second.size() > iter2->second.size() ? symbol1 : symbol2); | |||
| std::string min_symbol = (iter1->second.size() > iter2->second.size() ? symbol2 : symbol1); | |||
| for (auto &node_index_io : min_iter->second) { | |||
| GELOGD("Update anchor %s, symbol %s.", node_index_io.ToString().c_str(), symbol.c_str()); | |||
| max_iter->second.emplace_back(node_index_io); | |||
| auto iter = anchor_to_symbol.find(node_index_io.ToString()); | |||
| if (iter == anchor_to_symbol.end()) { | |||
| GE_LOGE("anchor %s not exist.", node_index_io.ToString().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (iter->second != min_symbol) { | |||
| GELOGW("not expected symbol of anchor %s, expect %s but %s exactly.", iter->first.c_str(), min_symbol.c_str(), | |||
| iter->second.c_str()); | |||
| } | |||
| iter->second = symbol; | |||
| } | |||
| GELOGI("Union symbol %s and %s succ.", symbol.c_str(), min_symbol.c_str()); | |||
| symbol_to_anchors.erase(min_iter); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Update symbol mapping with a new reference pair | |||
| /// @param [in] cur_node_info | |||
| /// @param [in] exist_node_info | |||
| /// @param [out] symbol_to_anchors | |||
| /// @param [out] anchor_to_symbol | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info, | |||
| std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| auto iter1 = anchor_to_symbol.find(exist_node_info.ToString()); | |||
| if (iter1 == anchor_to_symbol.end()) { | |||
| GE_LOGE("data_anchor %s is not visible before data_anchor %s, maybe TopoSorting is missing.", | |||
| exist_node_info.ToString().c_str(), cur_node_info.ToString().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::string symbol = iter1->second; | |||
| auto iter2 = symbol_to_anchors.find(symbol); | |||
| if (iter2 == symbol_to_anchors.end()) { | |||
| GE_LOGE("symbol %s not found.", symbol.c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| iter2->second.emplace_back(cur_node_info); | |||
| anchor_to_symbol.emplace(std::make_pair(cur_node_info.ToString(), symbol)); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Check if out_data_anchor is reference of input | |||
| /// @param [in] out_data_anchor | |||
| /// @param [out] reuse_in_index | |||
| /// @return bool | |||
| /// | |||
| bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index) { | |||
| if (out_data_anchor == nullptr) { | |||
| GELOGW("out_data_anchor is NULL."); | |||
| return false; | |||
| } | |||
| int32_t output_index = out_data_anchor->GetIdx(); | |||
| // pass-through op | |||
| NodePtr node = out_data_anchor->GetOwnerNode(); | |||
| std::string type = node->GetType(); | |||
| const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; | |||
| if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { | |||
| reuse_in_index = output_index; | |||
| GELOGI("Pass-Through node name[%s] index[%u].", node->GetName().c_str(), reuse_in_index); | |||
| return true; | |||
| } | |||
| // Merge op 0th output | |||
| if ((type == MERGE) && (output_index == 0)) { | |||
| reuse_in_index = 0; | |||
| GELOGI("Merge name[%s] output_index[0].", node->GetName().c_str()); | |||
| return true; | |||
| } | |||
| // ref op | |||
| OpDescPtr op_desc = node->GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| GELOGW("op_desc is NULL."); | |||
| return false; | |||
| } | |||
| bool is_ref = false; | |||
| (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); | |||
| if (is_ref) { | |||
| const string &output_name = op_desc->GetOutputNameByIndex(output_index); | |||
| for (const auto &input_name : op_desc->GetAllInputNames()) { | |||
| if (!input_name.empty() && (output_name == input_name)) { | |||
| reuse_in_index = op_desc->GetInputIndexByName(input_name); | |||
| GELOGI("Reference name[%s] output[%s][%u] ref to input[%s][%d].", op_desc->GetName().c_str(), | |||
| output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| // reuse input | |||
| auto output_op_desc = op_desc->GetOutputDescPtr(output_index); | |||
| bool reuse_input = false; | |||
| if (output_op_desc != nullptr) { | |||
| if ((TensorUtils::GetReuseInput(*output_op_desc, reuse_input) == GRAPH_SUCCESS) && reuse_input) { | |||
| uint32_t reuse_input_index = 0; | |||
| if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { | |||
| reuse_in_index = static_cast<int32_t>(reuse_input_index); | |||
| GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
| reuse_in_index); | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| /// | |||
| /// @brief Add node to graph | |||
| /// @param [in] op_desc | |||
| @@ -1561,13 +1979,14 @@ CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map<uint | |||
| /// | |||
| ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { | |||
| owner_graph_ = shared_ptr<ComputeGraph>(new (std::nothrow) ComputeGraph(name_)); | |||
| if (owner_graph_ == nullptr) { | |||
| if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "graph is NULL."; | |||
| error_msg = "graph / parent_node is NULL."; | |||
| return nullptr; | |||
| } | |||
| owner_graph_->SetParentNode(parent_node_); | |||
| owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); | |||
| BuildNodes(error_code, error_msg); | |||
| if (error_code != GRAPH_SUCCESS) { | |||
| @@ -1584,41 +2003,58 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string | |||
| return nullptr; | |||
| } | |||
| BuildInputs(error_code, error_msg); | |||
| AddDataNodes(error_code, error_msg); | |||
| if (error_code != GRAPH_SUCCESS) { | |||
| return nullptr; | |||
| } | |||
| BuildOutputs(error_code, error_msg); | |||
| AddRetValNodes(error_code, error_msg); | |||
| if (error_code != GRAPH_SUCCESS) { | |||
| return nullptr; | |||
| } | |||
| if (AddNetOutputNode(error_code, error_msg) == nullptr) { | |||
| // ATTR_NAME_SESSION_GRAPH_ID | |||
| std::string graph_id; | |||
| if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "Get attr session_graph_id failed."; | |||
| return nullptr; | |||
| } | |||
| if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "Set attr session_graph_id failed."; | |||
| return nullptr; | |||
| } | |||
| // refresh node name | |||
| for (const NodePtr &node : owner_graph_->GetDirectNode()) { | |||
| if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { | |||
| continue; | |||
| } | |||
| node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); | |||
| } | |||
| return owner_graph_; | |||
| } | |||
| /// | |||
| /// @brief Build inputs | |||
| /// @brief Add data nodes | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &error_msg) { | |||
| void CompleteGraphBuilder::AddDataNodes(graphStatus &error_code, std::string &error_msg) { | |||
| for (auto &input : graph_inputs_) { | |||
| NodePtr data_node = AddDateNode(input.first, error_code, error_msg); | |||
| NodePtr data_node = AddDataNode(input.first, error_code, error_msg); | |||
| if (data_node == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: add node Data:" + std::to_string(input.first) + +" failed."; | |||
| error_msg = "AddDataNodes failed: add node Data:" + std::to_string(input.first) + +" failed."; | |||
| return; | |||
| } | |||
| if (owner_graph_->AddInputNode(data_node) == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: add input node Data:" + std::to_string(input.first) + +" failed."; | |||
| error_msg = "AddDataNodes failed: add input node Data:" + std::to_string(input.first) + +" failed."; | |||
| return; | |||
| } | |||
| @@ -1627,7 +2063,7 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||
| std::vector<uint32_t> anchor_indes = input.second.second; | |||
| if (input_names.size() != anchor_indes.size()) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: num of input_names and indexs not equal."; | |||
| error_msg = "AddDataNodes failed: num of input_names and indexs not equal."; | |||
| return; | |||
| } | |||
| if (input_names.empty()) { | |||
| @@ -1641,29 +2077,29 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||
| auto iter = node_names_.find(input_name); | |||
| if (iter == node_names_.end()) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: node " + input_name + " not exist in graph."; | |||
| error_msg = "AddDataNodes failed: node " + input_name + " not exist in graph."; | |||
| return; | |||
| } | |||
| NodePtr in_node = node_names_[input_name]; | |||
| if (in_node == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: node " + input_name + " is NULL."; | |||
| error_msg = "AddDataNodes failed: node " + input_name + " is NULL."; | |||
| return; | |||
| } | |||
| if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), in_node->GetInDataAnchor(ind)) != GRAPH_SUCCESS) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + | |||
| error_msg = "AddDataNodes failed: add data-edge Data:" + std::to_string(input.first) + ":0->" + input_name + | |||
| ":" + std::to_string(ind) + " failed."; | |||
| return; | |||
| } | |||
| } | |||
| GELOGD("BuildInputs : Add %u input succ.", input.first); | |||
| GELOGD("AddDataNodes : Add %u input succ.", input.first); | |||
| } | |||
| GELOGD("BuildInputs succ."); | |||
| GELOGD("AddDataNodes succ."); | |||
| } | |||
| /// | |||
| @@ -1673,13 +2109,13 @@ void CompleteGraphBuilder::BuildInputs(graphStatus &error_code, std::string &err | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { | |||
| NodePtr CompleteGraphBuilder::AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg) { | |||
| std::string data_name = "Data_" + std::to_string(index); | |||
| OpDescBuilder op_desc_builder(data_name, "Data"); | |||
| OpDescPtr op_desc = op_desc_builder.AddInput("x").AddOutput("y").Build(); | |||
| if (op_desc == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: create op_desc " + data_name + " failed."; | |||
| error_msg = "AddDataNode failed: create op_desc " + data_name + " failed."; | |||
| return nullptr; | |||
| } | |||
| @@ -1687,7 +2123,7 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod | |||
| if (index_iter != input_mapping_.end()) { | |||
| if (!ge::AttrUtils::SetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, index_iter->second)) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; | |||
| error_msg = "AddDataNode failed: set attr ATTR_NAME_PARENT_NODE_INDEX for " + data_name + " failed."; | |||
| return nullptr; | |||
| } | |||
| } | |||
| @@ -1695,189 +2131,83 @@ NodePtr CompleteGraphBuilder::AddDateNode(uint32_t index, graphStatus &error_cod | |||
| NodePtr data_node = owner_graph_->AddNode(op_desc); | |||
| if (data_node == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildInputs failed: add node " + data_name + " failed."; | |||
| error_msg = "AddDataNode failed: add node " + data_name + " failed."; | |||
| return nullptr; | |||
| } | |||
| node_names_[data_name] = data_node; | |||
| return data_node; | |||
| } | |||
| /// | |||
| /// @brief Build outputs | |||
| /// @brief Add RetVal nodes | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void CompleteGraphBuilder::BuildOutputs(graphStatus &error_code, std::string &error_msg) { | |||
| std::map<std::string, std::vector<int32_t>> out_nodes_map; | |||
| std::vector<std::pair<NodePtr, int32_t>> out_nodes_info; | |||
| for (auto &pair : graph_outputs_) { | |||
| std::string output = pair.first; | |||
| int32_t ind = pair.second; | |||
| auto out_iter = node_names_.find(output); | |||
| void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string &error_msg) { | |||
| size_t output_num = graph_outputs_.size(); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| int32_t index = graph_outputs_[i].second; | |||
| auto out_iter = node_names_.find(graph_outputs_[i].first); | |||
| if (out_iter == node_names_.end()) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildOutputs failed: node " + output + " not exist in graph."; | |||
| error_msg = "AddRetValNode failed: node " + graph_outputs_[i].first + " not exist in graph."; | |||
| return; | |||
| } | |||
| NodePtr out_node = node_names_[output]; | |||
| if (out_node == nullptr) { | |||
| NodePtr node = out_iter->second; | |||
| if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildOutputs failed: node " + output + " is NULL."; | |||
| error_msg = "AddRetValNode failed: node is NULL."; | |||
| return; | |||
| } | |||
| OutDataAnchorPtr out_anchor = out_node->GetOutDataAnchor(ind); | |||
| if (out_anchor == nullptr) { | |||
| std::string name = node->GetName() + "_RetVal"; | |||
| OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); | |||
| if (ret_val_desc == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "BuildOutputs failed: anchor " + output + ":" + std::to_string(ind) + " is NULL."; | |||
| error_msg = "AddRetValNode " + name + " failed: op_desc is NULL."; | |||
| return; | |||
| } | |||
| auto iter = out_nodes_map.find(output); | |||
| if (iter == out_nodes_map.end()) { | |||
| std::vector<int32_t> vec = {ind}; | |||
| out_nodes_map[output] = vec; | |||
| } else { | |||
| out_nodes_map[output].emplace_back(ind); | |||
| ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); | |||
| if ((ret_val_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) || | |||
| (ret_val_desc->AddOutputDesc(tensor) != GRAPH_SUCCESS)) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "AddRetValNode " + name + " failed: add input_desc / output_desc failed."; | |||
| return; | |||
| } | |||
| out_nodes_info.emplace_back(std::make_pair(out_node, ind)); | |||
| GELOGD("BuildOutputs : AddOutputAnchor %s:%u succ.", output.c_str(), ind); | |||
| } | |||
| owner_graph_->SetGraphOutNodes(out_nodes_map); | |||
| owner_graph_->SetGraphOutNodesInfo(out_nodes_info); | |||
| GELOGD("BuildOutputs succ."); | |||
| } | |||
| /// | |||
| /// @brief Add NetOutput node | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return NodePtr | |||
| /// | |||
| NodePtr CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) { | |||
| std::string log_msg = "AddNetOutputNode name:" + std::string(kNodeNameNetOutput) + ", type:" + NETOUTPUT; | |||
| OpDescPtr net_output_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(kNodeNameNetOutput, NETOUTPUT)); | |||
| if (net_output_desc == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = log_msg + " failed: op_desc is NULL."; | |||
| return nullptr; | |||
| } | |||
| std::vector<std::pair<NodePtr, int32_t>> out_nodes_info = owner_graph_->GetGraphOutNodesInfo(); | |||
| error_code = BuildInOutForNetOutput(out_nodes_info, net_output_desc); | |||
| if (error_code != GRAPH_SUCCESS) { | |||
| error_msg = log_msg + " failed: add input/output tensor failed."; | |||
| return nullptr; | |||
| } | |||
| NodePtr net_output_node = owner_graph_->AddNode(net_output_desc); | |||
| if (net_output_node == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = log_msg + " failed: add node failed."; | |||
| return nullptr; | |||
| } | |||
| error_code = AddEdgeForNetOutput(out_nodes_info, net_output_node); | |||
| if (error_code != GRAPH_SUCCESS) { | |||
| error_msg = log_msg + " failed: link edge failed."; | |||
| return nullptr; | |||
| } | |||
| GELOGD("%s succ.", log_msg.c_str()); | |||
| return net_output_node; | |||
| } | |||
| /// | |||
| /// @brief Add input/output tensor for NetOutput node | |||
| /// @param [in] out_nodes_info | |||
| /// @param [out] net_output_desc | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus CompleteGraphBuilder::BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
| OpDescPtr &net_output_desc) { | |||
| size_t output_num = out_nodes_info.size(); | |||
| for (size_t i = 0; i < output_num; i++) { | |||
| NodePtr src_node = out_nodes_info[i].first; | |||
| uint32_t src_index = out_nodes_info[i].second; | |||
| if ((src_node == nullptr) || (src_node->GetOpDesc() == nullptr)) { | |||
| GE_LOGE("AddInOutForNetOutputOp failed: src_node is NULL."); | |||
| return GRAPH_FAILED; | |||
| if (!(ge::AttrUtils::SetStr(ret_val_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_RetVal") && | |||
| ge::AttrUtils::SetInt(ret_val_desc, RETVAL_ATTR_NAME_INDEX, i))) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "AddRetValNode " + name + " failed: set FRAMEWORK_ORIGINAL_TYPE / RETVAL_ATTR_NAME_INDEX failed."; | |||
| return; | |||
| } | |||
| ge::GeTensorDesc in_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); | |||
| auto iter = output_mapping_.find(i); | |||
| if (iter != output_mapping_.end()) { | |||
| if (!ge::AttrUtils::SetInt(in_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
| GE_LOGE("AddInOutForNetOutputOp failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
| return GRAPH_FAILED; | |||
| if (!ge::AttrUtils::SetInt(ret_val_desc, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "AddRetValNode " + name + " failed: set attr PARENT_NODE_INDEX failed."; | |||
| return; | |||
| } | |||
| } | |||
| if (net_output_desc->AddInputDesc(in_desc) != SUCCESS) { | |||
| GE_LOGE("AddInOutForNetOutputOp failed: add input_desc failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); | |||
| TensorUtils::SetOutputTensor(out_desc, true); | |||
| if (net_output_desc->AddOutputDesc(out_desc) != SUCCESS) { | |||
| GE_LOGE("AddInOutForNetOutputOp failed: add output_desc failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| GELOGD("Add input/output tensor for NetOutput node succ."); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// @brief Add edge for NetOutput node | |||
| /// @param [in] out_nodes_info | |||
| /// @param [out] net_output_node | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus CompleteGraphBuilder::AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
| const NodePtr &net_output_node) { | |||
| if (net_output_node == nullptr) { | |||
| GE_LOGE("AddEdgeForNetOutputOp failed: NetOutput is NULL."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| size_t out_num = out_nodes_info.size(); | |||
| for (size_t i = 0; i < out_num; i++) { | |||
| NodePtr src_node = out_nodes_info[i].first; | |||
| uint32_t ind = out_nodes_info[i].second; | |||
| if (src_node == nullptr) { | |||
| GE_LOGE("AddEdgeForNetOutputOp failed: src_node is NULL."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (GraphUtils::AddEdge(src_node->GetOutDataAnchor(ind), net_output_node->GetInDataAnchor(i)) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Add data-edge %s:%u->%s:%zu failed.", src_node->GetName().c_str(), ind, | |||
| net_output_node->GetName().c_str(), i); | |||
| return GRAPH_FAILED; | |||
| NodePtr ret_val_node = owner_graph_->AddNode(ret_val_desc); | |||
| if (ret_val_node == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "AddRetValNode " + name + " failed: add node failed."; | |||
| return; | |||
| } | |||
| } | |||
| std::vector<NodePtr> leaf_nodes; | |||
| for (auto &node : owner_graph_->GetDirectNode()) { | |||
| if (node->GetOutNodes().empty()) { | |||
| leaf_nodes.emplace_back(node); | |||
| } | |||
| } | |||
| for (auto &node : leaf_nodes) { | |||
| if (GraphUtils::AddEdge(node->GetOutControlAnchor(), net_output_node->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Add ctrl-edge %s->%s failed.", node->GetName().c_str(), net_output_node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| if (GraphUtils::AddEdge(node->GetOutDataAnchor(index), ret_val_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { | |||
| error_code = GRAPH_FAILED; | |||
| error_msg = "AddRetValNode " + name + " failed: add data-edge " + node->GetName() + ":" + std::to_string(index) + | |||
| "->" + ret_val_node->GetName() + ":0 failed."; | |||
| return; | |||
| } | |||
| } | |||
| GELOGD("Add edge for NetOutput node succ."); | |||
| return GRAPH_SUCCESS; | |||
| GELOGD("AddRetValNodes succ."); | |||
| } | |||
| /// | |||
| @@ -1999,4 +2329,60 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & | |||
| GELOGD("Build exist nodes succ."); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec) { | |||
| std::vector<NodePtr> stack_input; | |||
| std::map<NodePtr, uint32_t> map_in_edge_num; | |||
| graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Sort nodes failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; | |||
| std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, | |||
| [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); | |||
| std::queue<NodePtr> stack; | |||
| NodePtr cur_node = nullptr; | |||
| std::map<string, NodePtr> name_node_map; | |||
| vector<string> nodes_name; | |||
| while (!stack_input.empty() || !stack.empty()) { | |||
| if (!stack.empty()) { | |||
| cur_node = stack.front(); | |||
| stack.pop(); | |||
| } else { | |||
| cur_node = stack_input.back(); | |||
| stack_input.pop_back(); | |||
| } | |||
| node_vec.emplace_back(cur_node); | |||
| compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); | |||
| for (const auto &iter : name_node_map) { | |||
| nodes_name.emplace_back(iter.first); | |||
| } | |||
| std::sort(nodes_name.begin(), nodes_name.end()); | |||
| for (const auto &iter : nodes_name) { | |||
| stack.push(name_node_map[iter]); | |||
| } | |||
| name_node_map.clear(); | |||
| nodes_name.clear(); | |||
| } | |||
| // If they are not equal, there is a closed loop | |||
| if (node_vec.size() != compute_graph->nodes_.size()) { | |||
| std::set<Node *> itered_nodes_set; | |||
| for (auto &node : node_vec) { | |||
| itered_nodes_set.insert(node.get()); | |||
| } | |||
| GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", | |||
| compute_graph->nodes_.size(), node_vec.size()); | |||
| for (auto &node : compute_graph->nodes_) { | |||
| if (itered_nodes_set.count(node.get()) == 0) { | |||
| GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); | |||
| } | |||
| } | |||
| return GRAPH_FAILED; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -21,6 +21,7 @@ | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/anchor.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/types.h" | |||
| #include "utils/tensor_utils.h" | |||
| #include "utils/type_utils.h" | |||
| @@ -28,6 +29,26 @@ namespace ge { | |||
| std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{}; | |||
| std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{}; | |||
| bool OpShapeIsUnknown(const OpDescPtr &desc) { | |||
| for (const auto &ptr : desc->GetAllInputsDescPtr()) { | |||
| auto ge_shape = ptr->GetShape(); | |||
| for (const auto &dim : ge_shape.GetDims()) { | |||
| if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| for (const auto &ptr : desc->GetAllOutputsDescPtr()) { | |||
| auto ge_shape = ptr->GetShape(); | |||
| for (const auto &dim : ge_shape.GetDims()) { | |||
| if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) { | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AddSendEventId(const NodePtr &node, | |||
| const uint32_t &event_id) { | |||
| GE_CHECK_NOTNULL(node); | |||
| @@ -282,18 +303,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
| GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | |||
| continue; | |||
| } | |||
| auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->GetInputDescPtr(peer_anchor->GetIdx()); | |||
| auto peer_input_desc = peer_anchor->GetOwnerNode()->GetOpDesc()->MutableInputDesc(peer_anchor->GetIdx()); | |||
| if (peer_input_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "peer_input_desc is nullptr"); | |||
| continue; | |||
| } | |||
| output_tensor.SetOriginFormat(peer_input_desc->GetOriginFormat()); | |||
| output_tensor.SetFormat(peer_input_desc->GetFormat()); | |||
| auto peer_op_desc = peer_anchor->GetOwnerNode()->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(peer_op_desc == nullptr, GELOGE(GRAPH_FAILED, "peer opdesc is null"); continue); | |||
| GE_IF_BOOL_EXEC(peer_op_desc->UpdateInputDesc(peer_anchor->GetIdx(), output_tensor) != GRAPH_SUCCESS, | |||
| GELOGE(GRAPH_FAILED, "peer opdesc is null"); | |||
| continue); | |||
| GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), | |||
| output_tensor.GetDataType(), output_tensor.GetOriginDataType()); | |||
| peer_input_desc->SetShape(output_tensor.GetShape()); | |||
| peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | |||
| peer_input_desc->SetDataType(output_tensor.GetDataType()); | |||
| peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | |||
| ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | |||
| static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
| GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), | |||
| peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| @@ -361,6 +387,41 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||
| input_desc->SetShape(shape); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { | |||
| auto desc = node.GetOpDesc(); | |||
| GE_CHECK_NOTNULL(desc); | |||
| auto sub_graph_names = desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| is_unknow = OpShapeIsUnknown(desc); | |||
| return GRAPH_SUCCESS; | |||
| } else { | |||
| auto owner_graph = node.GetOwnerComputeGraph(); | |||
| GE_CHECK_NOTNULL(owner_graph); | |||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||
| if (root_graph == nullptr) { | |||
| GE_LOGE("Node %s gets null root graph", node.GetName().c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| for (auto &sub_graph_name : sub_graph_names) { | |||
| auto sub_graph = root_graph->GetSubgraph(sub_graph_name); | |||
| GE_CHECK_NOTNULL(sub_graph); | |||
| for (const auto &node_ptr : sub_graph->GetDirectNode()) { | |||
| auto status = GetNodeUnknownShapeStatus(*node_ptr, is_unknow); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GE_LOGE("get node unknown shape status failed!"); | |||
| return status; | |||
| } | |||
| if (is_unknow) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::string NodeUtils::GetNodeType(const Node &node) { | |||
| if (node.GetType() != FRAMEWORKOP) { | |||
| return node.GetType(); | |||
| @@ -381,9 +442,9 @@ ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||
| return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | |||
| } | |||
| graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { | |||
| graphStatus NodeUtils::SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph) { | |||
| if (subgraph == nullptr) { | |||
| GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); | |||
| GE_LOGE("Failed to set subgraph to node %s index %u, null subgraph", node.GetName().c_str(), index); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto op_desc = node.GetOpDesc(); | |||
| @@ -395,11 +456,105 @@ graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) | |||
| GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| op_desc->AddSubgraphInstanceName(subgraph->GetName()); | |||
| auto ret = op_desc->SetSubgraphInstanceName(index, subgraph->GetName()); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GE_LOGE("Failed to set subgraph to node %s index %u", node.GetName().c_str(), index); | |||
| return ret; | |||
| } | |||
| subgraph->SetParentNode(node.shared_from_this()); | |||
| subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | |||
| root_graph->AddSubgraph(subgraph); | |||
| return root_graph->AddSubgraph(subgraph); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| /// | |||
| /// Check if node is input of subgraph | |||
| /// @param [in] node | |||
| /// @return bool | |||
| /// | |||
| bool NodeUtils::IsSubgraphInput(const NodePtr &node) { | |||
| if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||
| (node->GetOwnerComputeGraph()->GetParentNode() == nullptr)) { | |||
| return false; | |||
| } | |||
| return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | |||
| } | |||
| /// | |||
| /// Check if node is output of subgraph | |||
| /// @param [in] node | |||
| /// @return bool | |||
| /// | |||
| bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { | |||
| if ((node == nullptr) || (node->GetOpDesc() == nullptr) || | |||
| (node->GetOwnerComputeGraph()->GetParentNode() == nullptr) || (node->GetType() != NETOUTPUT)) { | |||
| return false; | |||
| } | |||
| for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { | |||
| if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| /// | |||
| /// @brief Get subgraph original input node. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| NodePtr NodeUtils::GetParentInput(const NodePtr &node) { | |||
| GE_CHECK_NOTNULL_EXEC(node, return nullptr); | |||
| uint32_t parent_index = 0; | |||
| if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||
| return nullptr; | |||
| } | |||
| // Subgraph Data Node, check for constant input. | |||
| const ComputeGraphPtr &graph = node->GetOwnerComputeGraph(); | |||
| GE_CHECK_NOTNULL_EXEC(graph, return nullptr); | |||
| const NodePtr &parent_node = graph->GetParentNode(); | |||
| GE_CHECK_NOTNULL_EXEC(parent_node, return nullptr); | |||
| const InDataAnchorPtr &in_anchor = parent_node->GetInDataAnchor(parent_index); | |||
| GE_CHECK_NOTNULL_EXEC(in_anchor, return nullptr); | |||
| const OutDataAnchorPtr &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL_EXEC(peer_out_anchor, return nullptr); | |||
| return peer_out_anchor->GetOwnerNode(); | |||
| } | |||
| /// | |||
| /// @brief Get subgraph input is constant. | |||
| /// @param [in] node | |||
| /// @param [out] string | |||
| /// @return bool | |||
| /// | |||
| bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { | |||
| GE_CHECK_NOTNULL_EXEC(in_node, return false); | |||
| if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
| op_type = in_node->GetType(); | |||
| return true; | |||
| } | |||
| if (in_node->GetType() == DATA) { | |||
| std::string const_type; | |||
| if (!AttrUtils::GetStr(in_node->GetOpDesc(), ATTR_NAME_PARENT_CONST_TYPE, const_type)) { | |||
| return false; | |||
| } | |||
| if ((const_type == CONSTANT) || (const_type == CONSTANTOP)) { | |||
| op_type = const_type; | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| } // namespace ge | |||
| @@ -469,7 +469,7 @@ OpDescUtils::SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights) | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| ge::GeAttrValue::NamedAttrs named_attrs; | |||
| ge::GeAttrValue::NAMED_ATTRS named_attrs; | |||
| (void)ge::AttrUtils::SetListTensor(named_attrs, "key", weights); | |||
| vector<ge::GeTensorPtr> copy_weights; | |||
| (void)ge::AttrUtils::MutableListTensor(named_attrs, "key", copy_weights); | |||
| @@ -578,7 +578,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | |||
| inputs_.emplace_back(name); | |||
| inputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Add input | |||
| /// @param [in] name | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name, | |||
| const GeTensorDesc &tensor) { | |||
| inputs_.emplace_back(std::make_pair(name, tensor)); | |||
| return *this; | |||
| } | |||
| @@ -591,7 +603,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | |||
| uint32_t num) { | |||
| for (uint32_t i = 0; i < num; i++) { | |||
| inputs_.emplace_back(name + std::to_string(i)); | |||
| inputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||
| } | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Add dynamic input | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput( | |||
| const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||
| for (uint32_t i = 0; i < num; i++) { | |||
| inputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||
| } | |||
| return *this; | |||
| } | |||
| @@ -602,7 +629,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | |||
| outputs_.emplace_back(name); | |||
| outputs_.emplace_back(std::make_pair(name, GeTensorDesc())); | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Add output | |||
| /// @param [in] name | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name, | |||
| const GeTensorDesc &tensor) { | |||
| outputs_.emplace_back(std::make_pair(name, tensor)); | |||
| return *this; | |||
| } | |||
| @@ -615,7 +654,22 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::Add | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | |||
| uint32_t num) { | |||
| for (uint32_t i = 0; i < num; i++) { | |||
| outputs_.emplace_back(name + std::to_string(i)); | |||
| outputs_.emplace_back(std::make_pair(name + std::to_string(i), GeTensorDesc())); | |||
| } | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Add dynamic output | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @param [in] tensor | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput( | |||
| const std::string &name, uint32_t num, const GeTensorDesc &tensor) { | |||
| for (uint32_t i = 0; i < num; i++) { | |||
| outputs_.emplace_back(std::make_pair(name + std::to_string(i), tensor)); | |||
| } | |||
| return *this; | |||
| } | |||
| @@ -632,14 +686,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() | |||
| } | |||
| for (auto &input : inputs_) { | |||
| if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
| if (op_desc->AddInputDesc(input.first, input.second) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Add input_desc failed."); | |||
| return nullptr; | |||
| } | |||
| } | |||
| for (auto &output : outputs_) { | |||
| if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
| if (op_desc->AddOutputDesc(output.first, output.second) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Add output_desc failed."); | |||
| return nullptr; | |||
| } | |||
| @@ -647,4 +701,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() | |||
| return op_desc; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgraphInstanceName( | |||
| const std::string &subgraph_name, const std::string &subgraph_instance_name, OpDescPtr &op_desc) { | |||
| const auto &subgraph_names_to_index = op_desc->GetSubgraphNameIndexes(); | |||
| auto iter = subgraph_names_to_index.find(subgraph_name); | |||
| if (iter == subgraph_names_to_index.end()) { | |||
| GELOGE(GRAPH_PARAM_INVALID, | |||
| "Failed to set subgraph instance %s for node %s type %s, the subgraph name %s does not exists", | |||
| subgraph_instance_name.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||
| subgraph_name.c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); | |||
| } | |||
| } // namespace ge | |||
| @@ -282,6 +282,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
| case FORMAT_FRACTAL_Z_3D: | |||
| case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | |||
| case FORMAT_NDC1HWC0: | |||
| case FORMAT_FRACTAL_Z_C04: | |||
| graph_status = CalcElementCntByDims(dims, element_cnt); | |||
| break; | |||
| default: | |||
| @@ -56,6 +56,7 @@ static const std::map<Format, std::string> kFormatToStringMap = { | |||
| {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||
| {FORMAT_CN, "CN"}, | |||
| {FORMAT_NC, "NC"}, | |||
| {FORMAT_FRACTAL_ZN_LSTM, "FRACTAL_ZN_LSTM"}, | |||
| {FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||
| {FORMAT_ALL, "ALL"}}; | |||
| @@ -76,7 +77,8 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||
| "FRACTAL_NZ", | |||
| "NDC1HWC0", | |||
| "FORMAT_FRACTAL_Z_3D", | |||
| "FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; | |||
| "FORMAT_FRACTAL_Z_3D_TRANSPOSE" | |||
| "FORMAT_FRACTAL_ZN_LSTM"}; | |||
| static const std::map<std::string, Format> kDataFormatMap = { | |||
| {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | |||
| @@ -119,6 +121,7 @@ static const std::map<std::string, Format> kStringToFormatMap = { | |||
| {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | |||
| {"CN", FORMAT_CN}, | |||
| {"NC", FORMAT_NC}, | |||
| {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, | |||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
| {"ALL", FORMAT_ALL}}; | |||
| @@ -13,15 +13,18 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # libge_compiler.so & libge_train.so | |||
| # libge_compiler.so & libge_runner.so | |||
| # will later be integrated into libgraph_runner.so, works for both training and inference | |||
| # compiling proto files generates some warnings, use no-unused-variable to suppress them | |||
| set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | |||
| file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "../proto/fusion_model.proto" | |||
| "../proto/optimizer_priority.proto" | |||
| ) | |||
| file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB PROTO_CLIENT_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "../proto/ge_api.proto" | |||
| ) | |||
| file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "../proto/om.proto" | |||
| "../proto/task.proto" | |||
| "../proto/insert_op.proto" | |||
| @@ -30,57 +33,46 @@ file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "../proto/op_mapping_info.proto" | |||
| ) | |||
| ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
| ge_protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | |||
| ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST}) | |||
| # include directories | |||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||
| include_directories(${GE_SOURCE_DIR}) | |||
| include_directories(${GE_SOURCE_DIR}/src) | |||
| include_directories(${GE_SOURCE_DIR}/inc) | |||
| include_directories(${GE_SOURCE_DIR}/inc/common/util) | |||
| include_directories(${GE_SOURCE_DIR}/inc/external) | |||
| include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
| include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
| include_directories(${GE_SOURCE_DIR}/inc/framework/common) | |||
| include_directories(${GE_SOURCE_DIR}/inc/runtime) | |||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) | |||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||
| include_directories(${CMAKE_BINARY_DIR}) | |||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
| ######### libge_train.so ############# | |||
| ######### libge_runner.so ############# | |||
| # need to remove dependencies on pb files later | |||
| file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "client/ge_api.cc" | |||
| "common/formats/format_transfers/*.cc" | |||
| "common/formats/formats.cc" | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| "common/fp16_t.cc" | |||
| "common/ge/plugin_manager.cc" | |||
| "common/helper/model_cache_helper.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| "engine_manager/dnnengine_manager.cc" | |||
| "ge_local_engine/engine/host_cpu_engine.cc" | |||
| "generator/ge_generator.cc" | |||
| "generator/generator_api.cc" | |||
| "graph/build/graph_builder.cc" | |||
| "graph/build/label_allocator.cc" | |||
| "graph/build/logical_stream_allocator.cc" | |||
| "graph/build/model_builder.cc" | |||
| "graph/build/run_context.cc" | |||
| "graph/build/stream_allocator.cc" | |||
| "graph/build/stream_graph_optimizer.cc" | |||
| "graph/build/task_generator.cc" | |||
| "graph/common/bcast.cc" | |||
| "graph/common/omg_util.cc" | |||
| "graph/common/transop_util.cc" | |||
| "graph/build/*.cc" | |||
| "graph/common/*.cc" | |||
| "graph/execute/graph_execute.cc" | |||
| "graph/label/*.cc" | |||
| "graph/load/graph_loader.cc" | |||
| "graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
| "graph/load/new_model_manager/data_dumper.cc" | |||
| "graph/load/new_model_manager/data_inputer.cc" | |||
| "graph/load/new_model_manager/davinci_model.cc" | |||
| "graph/load/new_model_manager/davinci_model_parser.cc" | |||
| "graph/load/new_model_manager/model_manager.cc" | |||
| "graph/load/new_model_manager/model_output.cc" | |||
| "graph/load/new_model_manager/model_utils.cc" | |||
| "graph/load/new_model_manager/*.cc" | |||
| "graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/event_record_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/event_wait_task_info.cc" | |||
| @@ -89,8 +81,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/hccl_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/kernel_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_set_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
| @@ -99,15 +93,9 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/new_model_manager/tbe_handle_store.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/graph_context.cc" | |||
| "graph/manager/graph_manager.cc" | |||
| "graph/manager/graph_manager_utils.cc" | |||
| "graph/manager/graph_mem_allocator.cc" | |||
| "graph/manager/graph_var_manager.cc" | |||
| "graph/manager/*.cc" | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/trans_var_data_utils.cc" | |||
| "graph/manager/util/debug.cc" | |||
| "graph/manager/util/hcom_util.cc" | |||
| "graph/manager/util/rt_context_util.cc" | |||
| @@ -115,27 +103,10 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/optimize/graph_optimize.cc" | |||
| "graph/optimize/optimizer/allreduce_fusion_pass.cc" | |||
| "graph/optimize/summary_optimize.cc" | |||
| "graph/partition/dynamic_shape_partition.cc" | |||
| "graph/partition/engine_place.cc" | |||
| "graph/partition/graph_partition.cc" | |||
| "graph/passes/addn_pass.cc" | |||
| "graph/passes/aicpu_constant_folding_pass.cc" | |||
| "graph/passes/assert_pass.cc" | |||
| "graph/passes/atomic_addr_clean_pass.cc" | |||
| "graph/passes/base_pass.cc" | |||
| "graph/passes/cast_remove_pass.cc" | |||
| "graph/passes/cast_translate_pass.cc" | |||
| "graph/passes/common_subexpression_elimination_pass.cc" | |||
| "graph/passes/compile_nodes_pass.cc" | |||
| "graph/passes/constant_folding_pass.cc" | |||
| "graph/passes/constant_fuse_same_pass.cc" | |||
| "graph/passes/control_op_attr_pass.cc" | |||
| "graph/passes/control_trigger_pass.cc" | |||
| "graph/passes/dimension_adjust_pass.cc" | |||
| "graph/passes/dimension_compute_pass.cc" | |||
| "graph/passes/dropout_pass.cc" | |||
| "graph/passes/end_graph_pass.cc" | |||
| "graph/passes/enter_pass.cc" | |||
| "graph/passes/flow_ctrl_pass.cc" | |||
| "graph/passes/*.cc" | |||
| "graph/passes/folding_kernel/add_kernel.cc" | |||
| "graph/passes/folding_kernel/broadcast_args_kernel.cc" | |||
| "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | |||
| @@ -171,51 +142,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/folding_kernel/sub_kernel.cc" | |||
| "graph/passes/folding_kernel/transdata_kernel.cc" | |||
| "graph/passes/folding_kernel/unpack_kernel.cc" | |||
| "graph/passes/folding_pass.cc" | |||
| "graph/passes/get_original_format_pass.cc" | |||
| "graph/passes/guarantee_const_pass.cc" | |||
| "graph/passes/hccl_memcpy_pass.cc" | |||
| "graph/passes/identify_reference_pass.cc" | |||
| "graph/passes/identity_pass.cc" | |||
| "graph/passes/infershape_pass.cc" | |||
| "graph/passes/isolated_op_remove_pass.cc" | |||
| "graph/passes/iterator_op_pass.cc" | |||
| "graph/passes/link_gen_mask_nodes_pass.cc" | |||
| "graph/passes/merge_pass.cc" | |||
| "graph/passes/multi_batch_pass.cc" | |||
| "graph/passes/net_output_pass.cc" | |||
| "graph/passes/next_iteration_pass.cc" | |||
| "graph/passes/no_use_reshape_remove_pass.cc" | |||
| "graph/passes/pass_manager.cc" | |||
| "graph/passes/pass_utils.cc" | |||
| "graph/passes/permute_pass.cc" | |||
| "graph/passes/placeholder_with_default_pass.cc" | |||
| "graph/passes/prevent_gradient_pass.cc" | |||
| "graph/passes/print_op_pass.cc" | |||
| "graph/passes/prune_pass.cc" | |||
| "graph/passes/reshape_remove_pass.cc" | |||
| "graph/passes/resource_pair_add_control_pass.cc" | |||
| "graph/passes/resource_pair_remove_control_pass.cc" | |||
| "graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
| "graph/passes/save_pass.cc" | |||
| "graph/passes/shape_operate_op_remove_pass.cc" | |||
| "graph/passes/snapshot_pass.cc" | |||
| "graph/passes/stop_gradient_pass.cc" | |||
| "graph/passes/switch_logic_remove_pass.cc" | |||
| "graph/passes/switch_op_pass.cc" | |||
| "graph/passes/switch_pass.cc" | |||
| "graph/passes/transop_breadth_fusion_pass.cc" | |||
| "graph/passes/transop_depth_fusion_pass.cc" | |||
| "graph/passes/transop_nearby_allreduce_fusion_pass.cc" | |||
| "graph/passes/transop_without_reshape_fusion_pass.cc" | |||
| "graph/passes/transpose_transdata_pass.cc" | |||
| "graph/passes/unused_const_pass.cc" | |||
| "graph/passes/unused_op_remove_pass.cc" | |||
| "graph/passes/var_is_initialized_op_pass.cc" | |||
| "graph/passes/variable_format_pass.cc" | |||
| "graph/passes/variable_op_pass.cc" | |||
| "graph/passes/variable_prepare_op_pass.cc" | |||
| "graph/passes/variable_ref_delete_op_pass.cc" | |||
| "graph/preprocess/graph_preprocess.cc" | |||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | |||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
| @@ -231,22 +157,17 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| ) | |||
| ######### libge_train.so ############# | |||
| add_library(ge_train SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
| target_compile_definitions(ge_train PRIVATE | |||
| ######### libge_runner.so ############# | |||
| add_library(ge_runner SHARED ${TRAIN_SRC_LIST} ${PROTO_SRCS} ${PROTO_CLIENT_SRCS} ${PROTO_HEADER_HDRS}) | |||
| target_compile_definitions(ge_runner PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| DAVINCI_SUPPORT_PROFILING | |||
| REUSE_MEMORY=1 | |||
| DAVINCI_TRAIN | |||
| DAVINCI_CLOUD | |||
| FMK_SUPPORT_DEBUG | |||
| PLATFORM_CLOUD) | |||
| target_link_libraries(ge_train | |||
| DAVINCI_CLOUD) | |||
| target_link_libraries(ge_runner | |||
| graph | |||
| ge_common | |||
| "-Wl,--whole-archive" | |||
| ge_memory | |||
| "-Wl,--no-whole-archive" | |||
| ${PROTOBUF_LIBRARY} | |||
| ${register} | |||
| ${c_sec} | |||
| @@ -267,33 +188,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| "common/fp16_t.cc" | |||
| "common/ge/plugin_manager.cc" | |||
| "common/helper/model_cache_helper.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| "engine_manager/dnnengine_manager.cc" | |||
| "ge_local_engine/engine/host_cpu_engine.cc" | |||
| "generator/ge_generator.cc" | |||
| "generator/generator_api.cc" | |||
| "graph/build/graph_builder.cc" | |||
| "graph/build/label_allocator.cc" | |||
| "graph/build/logical_stream_allocator.cc" | |||
| "graph/build/model_builder.cc" | |||
| "graph/build/run_context.cc" | |||
| "graph/build/stream_allocator.cc" | |||
| "graph/build/stream_graph_optimizer.cc" | |||
| "graph/build/task_generator.cc" | |||
| "graph/common/bcast.cc" | |||
| "graph/common/omg_util.cc" | |||
| "graph/common/transop_util.cc" | |||
| "graph/build/*.cc" | |||
| "graph/common/*.cc" | |||
| "graph/execute/graph_execute.cc" | |||
| "graph/label/*.cc" | |||
| "graph/load/graph_loader.cc" | |||
| "graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
| "graph/load/new_model_manager/data_dumper.cc" | |||
| "graph/load/new_model_manager/data_inputer.cc" | |||
| "graph/load/new_model_manager/davinci_model.cc" | |||
| "graph/load/new_model_manager/davinci_model_parser.cc" | |||
| "graph/load/new_model_manager/model_manager.cc" | |||
| "graph/load/new_model_manager/model_output.cc" | |||
| "graph/load/new_model_manager/model_utils.cc" | |||
| "graph/load/new_model_manager/*.cc" | |||
| "graph/load/new_model_manager/task_info/end_graph_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/event_record_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/event_wait_task_info.cc" | |||
| @@ -301,8 +207,10 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/fusion_stop_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/kernel_ex_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/kernel_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_set_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
| @@ -311,41 +219,18 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/new_model_manager/tbe_handle_store.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/graph_context.cc" | |||
| "graph/manager/graph_manager.cc" | |||
| "graph/manager/graph_manager_utils.cc" | |||
| "graph/manager/graph_mem_allocator.cc" | |||
| "graph/manager/graph_var_manager.cc" | |||
| "graph/manager/*.cc" | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/trans_var_data_utils.cc" | |||
| "graph/manager/util/debug.cc" | |||
| "graph/manager/util/rt_context_util.cc" | |||
| "graph/manager/util/variable_accelerate_ctrl.cc" | |||
| "graph/optimize/graph_optimize.cc" | |||
| "graph/optimize/summary_optimize.cc" | |||
| "graph/partition/dynamic_shape_partition.cc" | |||
| "graph/partition/engine_place.cc" | |||
| "graph/partition/graph_partition.cc" | |||
| "graph/passes/addn_pass.cc" | |||
| "graph/passes/aicpu_constant_folding_pass.cc" | |||
| "graph/passes/assert_pass.cc" | |||
| "graph/passes/atomic_addr_clean_pass.cc" | |||
| "graph/passes/base_pass.cc" | |||
| "graph/passes/cast_remove_pass.cc" | |||
| "graph/passes/cast_translate_pass.cc" | |||
| "graph/passes/common_subexpression_elimination_pass.cc" | |||
| "graph/passes/compile_nodes_pass.cc" | |||
| "graph/passes/constant_folding_pass.cc" | |||
| "graph/passes/constant_fuse_same_pass.cc" | |||
| "graph/passes/control_op_attr_pass.cc" | |||
| "graph/passes/control_trigger_pass.cc" | |||
| "graph/passes/dimension_adjust_pass.cc" | |||
| "graph/passes/dimension_compute_pass.cc" | |||
| "graph/passes/dropout_pass.cc" | |||
| "graph/passes/end_graph_pass.cc" | |||
| "graph/passes/enter_pass.cc" | |||
| "graph/passes/flow_ctrl_pass.cc" | |||
| "graph/passes/*.cc" | |||
| "graph/passes/folding_kernel/add_kernel.cc" | |||
| "graph/passes/folding_kernel/broadcast_args_kernel.cc" | |||
| "graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | |||
| @@ -380,87 +265,33 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/folding_kernel/strided_slice_kernel.cc" | |||
| "graph/passes/folding_kernel/sub_kernel.cc" | |||
| "graph/passes/folding_kernel/transdata_kernel.cc" | |||
| "graph/passes/folding_kernel/transpose_kernel.cc" | |||
| "graph/passes/folding_kernel/unpack_kernel.cc" | |||
| "graph/passes/folding_pass.cc" | |||
| "graph/passes/get_original_format_pass.cc" | |||
| "graph/passes/guarantee_const_pass.cc" | |||
| "graph/passes/hccl_memcpy_pass.cc" | |||
| "graph/passes/identify_reference_pass.cc" | |||
| "graph/passes/identity_pass.cc" | |||
| "graph/passes/infershape_pass.cc" | |||
| "graph/passes/isolated_op_remove_pass.cc" | |||
| "graph/passes/iterator_op_pass.cc" | |||
| "graph/passes/link_gen_mask_nodes_pass.cc" | |||
| "graph/passes/merge_pass.cc" | |||
| "graph/passes/multi_batch_pass.cc" | |||
| "graph/passes/net_output_pass.cc" | |||
| "graph/passes/next_iteration_pass.cc" | |||
| "graph/passes/no_use_reshape_remove_pass.cc" | |||
| "graph/passes/pass_manager.cc" | |||
| "graph/passes/pass_utils.cc" | |||
| "graph/passes/permute_pass.cc" | |||
| "graph/passes/placeholder_with_default_pass.cc" | |||
| "graph/passes/prevent_gradient_pass.cc" | |||
| "graph/passes/print_op_pass.cc" | |||
| "graph/passes/prune_pass.cc" | |||
| "graph/passes/reshape_remove_pass.cc" | |||
| "graph/passes/resource_pair_add_control_pass.cc" | |||
| "graph/passes/resource_pair_remove_control_pass.cc" | |||
| "graph/passes/same_transdata_breadth_fusion_pass.cc" | |||
| "graph/passes/save_pass.cc" | |||
| "graph/passes/shape_operate_op_remove_pass.cc" | |||
| "graph/passes/snapshot_pass.cc" | |||
| "graph/passes/stop_gradient_pass.cc" | |||
| "graph/passes/switch_logic_remove_pass.cc" | |||
| "graph/passes/switch_op_pass.cc" | |||
| "graph/passes/switch_pass.cc" | |||
| "graph/passes/transop_breadth_fusion_pass.cc" | |||
| "graph/passes/transop_depth_fusion_pass.cc" | |||
| "graph/passes/transop_nearby_allreduce_fusion_pass.cc" | |||
| "graph/passes/transop_without_reshape_fusion_pass.cc" | |||
| "graph/passes/transpose_transdata_pass.cc" | |||
| "graph/passes/unused_const_pass.cc" | |||
| "graph/passes/unused_op_remove_pass.cc" | |||
| "graph/passes/var_is_initialized_op_pass.cc" | |||
| "graph/passes/variable_format_pass.cc" | |||
| "graph/passes/variable_op_pass.cc" | |||
| "graph/passes/variable_prepare_op_pass.cc" | |||
| "graph/passes/variable_ref_delete_op_pass.cc" | |||
| "graph/preprocess/graph_preprocess.cc" | |||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | |||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
| "graph/preprocess/multi_batch_copy_graph.cc" | |||
| "init/gelib.cc" | |||
| "ir_build/atc_ir_common.cc" | |||
| "ir_build/ge_ir_build.cc" | |||
| "model/ge_model.cc" | |||
| "omm/csa_interact.cc" | |||
| "opskernel_manager/ops_kernel_manager.cc" | |||
| "session/inner_session.cc" | |||
| "session/session_manager.cc" | |||
| "single_op/single_op.cc" | |||
| "single_op/single_op_manager.cc" | |||
| "single_op/single_op_model.cc" | |||
| "single_op/stream_resource.cc" | |||
| "single_op/task/build_task_utils.cc" | |||
| "single_op/task/op_task.cc" | |||
| "single_op/task/tbe_task_builder.cc" | |||
| ########################################## | |||
| # "ir_build/ge_ir_build.cc" | |||
| # "offline/atc_ir_common.cc" | |||
| "single_op/*.cc" | |||
| "single_op/task/*.cc" | |||
| ) | |||
| add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
| target_compile_definitions(ge_compiler PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| DAVINCI_SUPPORT_PROFILING | |||
| REUSE_MEMORY=1 | |||
| FMK_HOST_INFER | |||
| PLATFORM_CLOUD) | |||
| FMK_HOST_INFER) | |||
| target_link_libraries(ge_compiler | |||
| graph | |||
| ge_common | |||
| "-Wl,--whole-archive" | |||
| ge_memory | |||
| "-Wl,--no-whole-archive" | |||
| ${PROTOBUF_LIBRARY} | |||
| ${register} | |||
| ${c_sec} | |||
| @@ -469,5 +300,6 @@ target_link_libraries(ge_compiler | |||
| ${msprof} | |||
| ${runtime} | |||
| ${resouce} | |||
| ${error_manager} | |||
| rt | |||
| dl) | |||
| @@ -13,21 +13,21 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # libge_client.so & libge_client_train.so | |||
| # libge_client.so | |||
| # add all proto files, generate corresponding .h and .cc files | |||
| set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | |||
| file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "../../proto/ge_api.proto" | |||
| ) | |||
| file(GLOB_RECURSE PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "../../proto/ge_ir.proto" | |||
| "../../proto/task.proto" | |||
| "../../proto/om.proto" | |||
| "../../proto/insert_op.proto" | |||
| ) | |||
| file(GLOB_RECURSE SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "ge_api.cc" | |||
| ) | |||
| @@ -49,30 +49,9 @@ include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | |||
| include_directories(${CMAKE_BINARY_DIR}) | |||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
| ######### libge_client_train.so ############# | |||
| add_library(ge_client_train SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
| target_compile_definitions(ge_client_train PRIVATE | |||
| Werror | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| REUSE_MEMORY=1 | |||
| PLATFORM_CLOUD | |||
| DAVINCI_CLOUD) | |||
| target_link_libraries(ge_client_train | |||
| graph | |||
| ge_train | |||
| ge_common | |||
| ${PROTOBUF_LIBRARY} | |||
| ${register} | |||
| ${c_sec} | |||
| ${slog} | |||
| ${mmpa} | |||
| ${runtime} | |||
| rt | |||
| dl) | |||
| ############ libge_client.so ################ | |||
| add_library(ge_client SHARED ${SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
| target_compile_definitions(ge_client_train PRIVATE | |||
| target_compile_definitions(ge_client PRIVATE | |||
| Werror | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| REUSE_MEMORY=1 | |||
| @@ -32,17 +32,18 @@ | |||
| using domi::GetContext; | |||
| using domi::OpRegistry; | |||
| using domi::RealPath; | |||
| using domi::StringUtils; | |||
| using std::map; | |||
| using std::string; | |||
| using std::vector; | |||
| namespace ge { | |||
| static const int32_t kMaxStrLen = 128; | |||
| namespace { | |||
| const int32_t kMaxStrLen = 128; | |||
| } | |||
| static bool kGeInitialized = false; | |||
| static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use | |||
| namespace ge { | |||
| void GetOpsProtoPath(std::string &opsproto_path) { | |||
| GELOGI("Enter get ops proto path schedule"); | |||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||
| @@ -394,8 +395,8 @@ Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc | |||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | |||
| } | |||
| Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<TensorInfo> &inputs, | |||
| std::vector<TensorInfo> &outputs, std::function<void(Status)> callback) { | |||
| Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<InputTensorInfo> &inputs, | |||
| RunAsyncCallback callback) { | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "SessionConstructor failed"); | |||
| @@ -405,8 +406,7 @@ Status Session::RunGraphAsync(uint32_t graph_id, const std::vector<TensorInfo> & | |||
| GELOGW( | |||
| "The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | |||
| Status ret = | |||
| ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, outputs, callback); | |||
| Status ret = ge::GELib::GetInstance()->SessionManagerObj().RunGraphAsync(sessionId_, graph_id, inputs, callback); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "SessionManager RunGraphAsync failed"); | |||
| return FAILED; | |||
| @@ -28,7 +28,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "debug/memory_dumper.cc" | |||
| "fmk_error_codes.cc" | |||
| "formats/format_transfers/datatype_transfer.cc" | |||
| "formats/format_transfers/format_transfer.cc" | |||
| "formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc" | |||
| "formats/format_transfers/format_transfer_dhwcn_fracz3D.cc" | |||
| "formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc" | |||
| @@ -41,6 +40,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc" | |||
| "formats/format_transfers/format_transfer_nc1hwc0_nchw.cc" | |||
| "formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc" | |||
| "formats/format_transfers/format_transfer_nchw_fz_c04.cc" | |||
| "formats/format_transfers/format_transfer_nchw_nc1hwc0.cc" | |||
| "formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc" | |||
| "formats/format_transfers/format_transfer_transpose.cc" | |||
| @@ -54,6 +54,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "helper/om_file_helper.cc" | |||
| "math/fp16_math.cc" | |||
| "model_parser/base.cc" | |||
| "model_saver.cc" | |||
| "op/attr_value_util.cc" | |||
| "op/ge_op_utils.cc" | |||
| "properties_manager.cc" | |||
| @@ -61,9 +62,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "thread_pool.cc" | |||
| "types.cc" | |||
| "util.cc" | |||
| "model_saver.cc" | |||
| ############################### | |||
| "op/attr_define.cc" | |||
| ) | |||
| ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||
| @@ -73,6 +71,7 @@ include_directories(${CMAKE_CURRENT_LIST_DIR}) | |||
| include_directories(${CMAKE_CURRENT_LIST_DIR}/op) | |||
| include_directories(${GE_SOURCE_DIR}/src/ge) | |||
| include_directories(${GE_SOURCE_DIR}/inc) | |||
| include_directories(${GE_SOURCE_DIR}/inc/common/util) | |||
| include_directories(${GE_SOURCE_DIR}/inc/external) | |||
| include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
| include_directories(${GE_SOURCE_DIR}/inc/framework) | |||
| @@ -96,5 +95,6 @@ target_link_libraries(ge_common | |||
| ${slog} | |||
| ${mmpa} | |||
| ${resource} | |||
| ${error_manager} | |||
| rt | |||
| dl) | |||
| @@ -17,7 +17,6 @@ | |||
| #include "common/auth/file_saver.h" | |||
| #include <fcntl.h> | |||
| #include <securec.h> | |||
| #include <unistd.h> | |||
| #include <cstdlib> | |||
| @@ -29,10 +28,6 @@ | |||
| #include "framework/common/debug/log.h" | |||
| #include "framework/common/util.h" | |||
| using domi::CreateDirectory; | |||
| using domi::ModelEncryptType; | |||
| using ge::ModelBufferData; | |||
| namespace { | |||
| const int kFileOpSuccess = 0; | |||
| } // namespace | |||
| @@ -270,4 +265,4 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status FileSaver::SaveToFile(co | |||
| } | |||
| return ret; | |||
| } | |||
| } // namespace ge | |||
| } // namespace ge | |||
| @@ -26,30 +26,26 @@ | |||
| #include "graph/buffer.h" | |||
| #include "mmpa/mmpa_api.h" | |||
| using domi::ModelFileHeader; | |||
| using domi::ModelPartition; | |||
| using domi::ModelPartitionTable; | |||
| struct PROC_PARAM { | |||
| uint8_t *model_name; | |||
| /* ISV Ek buffer */ | |||
| // ISV Ek buffer | |||
| uint8_t *model_key; | |||
| uint32_t model_key_len; | |||
| /* ISV root certificate buffer */ | |||
| // ISV root certificate buffer | |||
| uint8_t *root_cert; | |||
| uint32_t root_cert_len; | |||
| /* ISV private key buffer */ | |||
| // ISV private key buffer | |||
| uint8_t *pri_key; | |||
| uint32_t pri_key_len; | |||
| /* Raw AI Module Image buffer */ | |||
| // Raw AI Module Image buffer | |||
| uint8_t *ai_image; | |||
| uint32_t ai_image_len; | |||
| /* ISV HW key buffer */ | |||
| // ISV HW key buffer | |||
| uint8_t *hw_key; | |||
| uint32_t hw_key_len; | |||
| }; | |||
| @@ -66,11 +62,11 @@ using std::string; | |||
| class FileSaver { | |||
| public: | |||
| /** | |||
| * @ingroup domi_common | |||
| * @brief save model, no encryption | |||
| * @return Status result | |||
| */ | |||
| /// | |||
| /// @ingroup domi_common | |||
| /// @brief save model, no encryption | |||
| /// @return Status result | |||
| /// | |||
| static Status SaveToFile(const string &file_path, const ge::ModelData &model, | |||
| const ModelFileHeader *model_file_header = nullptr); | |||
| @@ -84,26 +80,26 @@ class FileSaver { | |||
| static Status SaveToFile(const string &file_path, const void *data, int len); | |||
| protected: | |||
| /** | |||
| * @ingroup domi_common | |||
| * @brief Check validity of the file path | |||
| * @return Status result | |||
| */ | |||
| /// | |||
| /// @ingroup domi_common | |||
| /// @brief Check validity of the file path | |||
| /// @return Status result | |||
| /// | |||
| static Status CheckPath(const string &file_path); | |||
| static Status WriteData(const void *data, uint32_t size, int32_t fd); | |||
| static Status OpenFile(int32_t &fd, const std::string &file_path); | |||
| /** | |||
| * @ingroup domi_common | |||
| * @brief save model to file | |||
| * @param [in] file_path file output path | |||
| * @param [in] file_header file header info | |||
| * @param [in] data model data | |||
| * @param [in] len model length | |||
| * @return Status result | |||
| */ | |||
| /// | |||
| /// @ingroup domi_common | |||
| /// @brief save model to file | |||
| /// @param [in] file_path file output path | |||
| /// @param [in] file_header file header info | |||
| /// @param [in] data model data | |||
| /// @param [in] len model length | |||
| /// @return Status result | |||
| /// | |||
| static Status SaveWithFileHeader(const string &file_path, const ModelFileHeader &file_header, const void *data, | |||
| int len); | |||
| @@ -16,6 +16,7 @@ | |||
| #include "framework/omg/omg_inner_types.h" | |||
| using ge::OmgContext; | |||
| namespace domi { | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OmgContext &GetContext() { | |||
| static OmgContext context; | |||
| @@ -155,7 +155,7 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||
| void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
| const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||
| bool enum2str) { | |||
| if (nullptr == field || nullptr == reflection) { | |||
| if ((field == nullptr) || (reflection == nullptr)) { | |||
| Message2Json(message, black_fields, json, enum2str); | |||
| return; | |||
| } | |||
| @@ -28,7 +28,9 @@ | |||
| using std::string; | |||
| static const int kInvalidFd = (-1); | |||
| namespace { | |||
| const int kInvalidFd = (-1); | |||
| } // namespace | |||
| namespace ge { | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY MemoryDumper::MemoryDumper() : fd_(kInvalidFd) {} | |||
| @@ -16,7 +16,7 @@ | |||
| #include "common/formats/format_transfers/datatype_transfer.h" | |||
| #include <stdint.h> | |||
| #include <cstdint> | |||
| #include <map> | |||
| #include <utility> | |||
| @@ -27,8 +27,6 @@ | |||
| #include "graph/utils/type_utils.h" | |||
| #include "securec.h" | |||
| using ge::fp16_t; | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -134,10 +132,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
| } | |||
| auto trans_mode = iter->second; | |||
| if (args.src_data_size == 0) { | |||
| GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); | |||
| return PARAM_INVALID; | |||
| } | |||
| int size = GetSizeByDataType(args.dst_data_type); | |||
| if (size <= 0) { | |||
| GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | |||
| @@ -149,6 +143,12 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
| return PARAM_INVALID; | |||
| } | |||
| size_t total_size = static_cast<size_t>(args.src_data_size * size); | |||
| result.length = total_size; | |||
| if (total_size == 0) { | |||
| GELOGI("In TransDataType, total_size is zero, has no data."); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | |||
| @@ -162,7 +162,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||
| return INTERNAL_ERROR; | |||
| } | |||
| result.data = dst; | |||
| result.length = total_size; | |||
| return SUCCESS; | |||
| } | |||
| @@ -21,7 +21,7 @@ | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| #include "external/graph/types.h" | |||
| #include "framework/common/ge_inner_error_codes.h" | |||
| @@ -1,69 +0,0 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include <map> | |||
| #include <utility> | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/utils/type_utils.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| namespace { | |||
| struct FormatTransferRegistry { | |||
| Status RegisterBuilder(Format src, Format dst, FormatTransferBuilder builder) { | |||
| src_dst_builder[src][dst] = std::move(builder); | |||
| return SUCCESS; | |||
| } | |||
| std::map<Format, std::map<Format, FormatTransferBuilder>> src_dst_builder; | |||
| }; | |||
| FormatTransferRegistry &GetFormatTransferRegistry() { | |||
| static FormatTransferRegistry registry; | |||
| return registry; | |||
| } | |||
| } // namespace | |||
| std::shared_ptr<FormatTransfer> BuildFormatTransfer(const TransArgs &args) { | |||
| auto registry = GetFormatTransferRegistry(); | |||
| auto dst_builder = registry.src_dst_builder.find(args.src_format); | |||
| if (dst_builder == registry.src_dst_builder.end()) { | |||
| return nullptr; | |||
| } | |||
| auto builder_iter = dst_builder->second.find(args.dst_format); | |||
| if (builder_iter == dst_builder->second.end()) { | |||
| return nullptr; | |||
| } | |||
| return builder_iter->second(); | |||
| } | |||
| bool FormatTransferExists(const TransArgs &args) { | |||
| auto registry = GetFormatTransferRegistry(); | |||
| auto dst_builder = registry.src_dst_builder.find(args.src_format); | |||
| if (dst_builder == registry.src_dst_builder.end()) { | |||
| return false; | |||
| } | |||
| return dst_builder->second.count(args.dst_format) > 0; | |||
| } | |||
| FormatTransferRegister::FormatTransferRegister(FormatTransferBuilder builder, Format src, Format dst) { | |||
| (void)GetFormatTransferRegistry().RegisterBuilder(src, dst, std::move(builder)); | |||
| // RegisterBuilder() always return success, no need to check value | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -27,7 +27,9 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| namespace { | |||
| bool CheckDataTypeSupported(const DataType &data_type) { return (data_type == DT_FLOAT || data_type == DT_FLOAT16); } | |||
| bool CheckDataTypeSupported(const DataType &data_type) { | |||
| return (data_type == DT_FLOAT || data_type == DT_FLOAT16 || data_type == DT_INT8); | |||
| } | |||
| Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||
| auto src_shape = args.src_shape; | |||
| @@ -51,10 +53,11 @@ Status CheckArgsForC1hwncoc0ToHwcn(const TransArgs &args) { | |||
| GELOGE(PARAM_INVALID, "Failed to check dst shape %s", ShapeToString(dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / kCubeSize + 1 || | |||
| auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||
| if (src_shape.at(kC1hwncoc0C1) != (dst_shape.at(kHwcnC) - 1) / cube_size + 1 || | |||
| src_shape.at(kC1hwncoc0H) != dst_shape.at(kHwcnH) || src_shape.at(kC1hwncoc0W) != dst_shape.at(kHwcnW) || | |||
| src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != kCubeSize || | |||
| src_shape.at(kC1hwncoc0C0) != kCubeSize) { | |||
| src_shape.at(kC1hwncoc0N) != dst_shape.at(kHwcnN) || src_shape.at(kC1hwncoc0Co) != cube_size || | |||
| src_shape.at(kC1hwncoc0C0) != cube_size) { | |||
| GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | |||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| @@ -78,6 +81,7 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||
| auto c0 = args.src_shape.at(kC1hwncoc0C0); | |||
| auto co = args.src_shape.at(kC1hwncoc0Co); | |||
| auto c = args.dst_shape.at(kHwcnC); | |||
| auto cube_size = GetCubeSizeByDataType(args.src_data_type); | |||
| int64_t cn = c * n; | |||
| int64_t wcn = w * cn; | |||
| int64_t coc0 = co * c0; | |||
| @@ -93,8 +97,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, int size | |||
| int64_t c_head_addr = w_head_addr + c_idx * n; | |||
| for (int64_t n_idx = 0; n_idx < n; n_idx++) { | |||
| int64_t dst_idx = c_head_addr + n_idx; | |||
| int64_t c1_idx = c_idx / kCubeSize; | |||
| int64_t c0_idx = c_idx % kCubeSize; | |||
| int64_t c1_idx = c_idx / cube_size; | |||
| int64_t c0_idx = c_idx % cube_size; | |||
| int64_t co_idx = c0_idx; | |||
| int64_t src_idx = c1_idx * hwncoc0 + h_idx * wncoc0 + w_idx * ncoc0 + n_idx * coc0 + co_idx * c0 + c0_idx; | |||
| auto src_offset = src_idx * size; | |||
| @@ -130,6 +134,11 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| if (total_size <= 0) { | |||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||
| if (total_size == 0 && src_size == 0) { | |||
| result.length = static_cast<size_t>(total_size); | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -88,6 +88,11 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -18,7 +18,7 @@ | |||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWCN_FRACTAL_Z_3D_H_ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -89,6 +89,11 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -18,7 +18,7 @@ | |||
| #define GE_COMMON_FORMATS_FORMAT_TRANSFERS_FORMAT_TRANSFER_DHWNC_FRACTAL_Z_3D_TRANSPOSE_H_ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -116,6 +116,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||
| Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -184,6 +189,11 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||
| Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -119,6 +119,11 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
| int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t dst_size = total_ele_cnt * size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -194,6 +199,11 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -259,6 +269,11 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -117,6 +117,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||
| Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -189,6 +194,11 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||
| Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -133,6 +133,12 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| if (total_size <= 0) { | |||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||
| if (total_size == 0 && src_size == 0) { | |||
| result.length = static_cast<size_t>(total_size); | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||
| @@ -133,6 +133,12 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||
| if (total_size <= 0) { | |||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||
| if (total_size == 0 && src_size == 0) { | |||
| result.length = static_cast<size_t>(total_size); | |||
| return SUCCESS; | |||
| } | |||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||
| return PARAM_INVALID; | |||
| @@ -140,6 +146,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||
| GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||
| @@ -19,7 +19,7 @@ | |||
| #include <vector> | |||
| #include "common/formats/format_transfers/format_transfer.h" | |||
| #include "register/register_format_transfer.h" | |||
| namespace ge { | |||
| namespace formats { | |||