Merge pull request !22 from yanghaoran/mastertags/v0.3.0-alpha
| @@ -18,7 +18,6 @@ | |||||
| #define INC_COMMON_BLOCKING_QUEUE_H_ | #define INC_COMMON_BLOCKING_QUEUE_H_ | ||||
| #include <stdint.h> | #include <stdint.h> | ||||
| #include <condition_variable> | #include <condition_variable> | ||||
| #include <list> | #include <list> | ||||
| #include <mutex> | #include <mutex> | ||||
| @@ -87,7 +86,7 @@ class BlockingQueue { | |||||
| is_stoped_ = false; | is_stoped_ = false; | ||||
| } | } | ||||
| // if the queue stop , the function to release the unprocessed items will be call | |||||
| // if the queue is stoped ,need call this function to release the unprocessed items | |||||
| std::list<T> GetRemainItems() { | std::list<T> GetRemainItems() { | ||||
| std::unique_lock<std::mutex> lock(mutex_); | std::unique_lock<std::mutex> lock(mutex_); | ||||
| @@ -19,10 +19,10 @@ | |||||
| #include <stdint.h> | #include <stdint.h> | ||||
| /// | |||||
| /// @ingroup dnn | |||||
| /// @brief struct define of dynamic aipp batch parameter. | |||||
| /// | |||||
| /** | |||||
| * @ingroup dnn | |||||
| * @brief struct define of dynamic aipp batch parameter. | |||||
| */ | |||||
| typedef struct tagAippDynamicBatchPara { | typedef struct tagAippDynamicBatchPara { | ||||
| int8_t cropSwitch; // crop switch | int8_t cropSwitch; // crop switch | ||||
| int8_t scfSwitch; // resize switch | int8_t scfSwitch; // resize switch | ||||
| @@ -66,10 +66,10 @@ typedef struct tagAippDynamicBatchPara { | |||||
| int8_t reserve1[16]; // 32B assign, for ub copy | int8_t reserve1[16]; // 32B assign, for ub copy | ||||
| } kAippDynamicBatchPara; | } kAippDynamicBatchPara; | ||||
| /// | |||||
| /// @ingroup dnn | |||||
| /// @brief struct definition of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||||
| /// | |||||
| /** | |||||
| * @ingroup dnn | |||||
| * @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||||
| */ | |||||
| typedef struct tagAippDynamicPara { | typedef struct tagAippDynamicPara { | ||||
| uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 | uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 | ||||
| int8_t cscSwitch; // csc switch | int8_t cscSwitch; // csc switch | ||||
| @@ -61,19 +61,19 @@ typedef enum tagHiAiNpuModuleId { | |||||
| HIAI_DP = 23, | HIAI_DP = 23, | ||||
| } HiAiNpuModuleId; | } HiAiNpuModuleId; | ||||
| // bit 31-bit30 to be hiai local | |||||
| /* bit 31-bit30 to be hiai local */ | |||||
| #define HIAI_NPULOCAL_MASK 0xC0000000 | #define HIAI_NPULOCAL_MASK 0xC0000000 | ||||
| #define SHIFT_LOCAL_MASK 30 | #define SHIFT_LOCAL_MASK 30 | ||||
| #define HIAI_NPULOCAL_VAL_MASK 0x3 | #define HIAI_NPULOCAL_VAL_MASK 0x3 | ||||
| // bit 29 -bit28 to be hiai aicpu code type | |||||
| /* bit 29 -bit28 to be hiai aicpu code type */ | |||||
| #define HIAI_CODE_TYPE_MASK 0x30000000 | #define HIAI_CODE_TYPE_MASK 0x30000000 | ||||
| #define SHIFT_CODE_MASK 28 | #define SHIFT_CODE_MASK 28 | ||||
| #define HIAI_CODE_TYPE_VAL_MASK 0x3 | #define HIAI_CODE_TYPE_VAL_MASK 0x3 | ||||
| // bit 27 -bit25 to be hiai error level | |||||
| /* bit 27 -bit25 to be hiai error level */ | |||||
| #define HIAI_ERROR_LEVEL_MASK 0x0E000000 | #define HIAI_ERROR_LEVEL_MASK 0x0E000000 | ||||
| #define SHIFT_ERROR_LVL_MASK 25 | #define SHIFT_ERROR_LVL_MASK 25 | ||||
| #define HIAI_ERROR_LEVEL_VAL_MASK 0x7 | #define HIAI_ERROR_LEVEL_VAL_MASK 0x7 | ||||
| // bit 24 -bit17 to be hiai mod | |||||
| /* bit 24 -bit17 to be hiai mod */ | |||||
| #define HIAI_MODE_ID_MASK 0x01FE0000 | #define HIAI_MODE_ID_MASK 0x01FE0000 | ||||
| #define SHIFT_MODE_MASK 17 | #define SHIFT_MODE_MASK 17 | ||||
| #define HIAI_MODE_ID_VAL_MASK 0xFF | #define HIAI_MODE_ID_VAL_MASK 0xFF | ||||
| @@ -19,13 +19,12 @@ | |||||
| #include <runtime/rt.h> | #include <runtime/rt.h> | ||||
| #include <stdint.h> | #include <stdint.h> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| using std::string; | using std::string; | ||||
| namespace ge { | namespace ge { | ||||
| // DAVINCI_TRAIN/DAVINCI_CLOUD is not needed when GETaskKernelHcclInfo needed | |||||
| // when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD | |||||
| struct GETaskKernelHcclInfo { | struct GETaskKernelHcclInfo { | ||||
| string hccl_type; | string hccl_type; | ||||
| void *inputDataAddr; | void *inputDataAddr; | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "./ge_task_info.h" | #include "./ge_task_info.h" | ||||
| #include "./ops_kernel_info_types.h" | #include "./ops_kernel_info_types.h" | ||||
| #include "cce/aicpu_engine_struct.h" | #include "cce/aicpu_engine_struct.h" | ||||
| @@ -29,7 +28,6 @@ | |||||
| #include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "proto/task.pb.h" | #include "proto/task.pb.h" | ||||
| using std::map; | using std::map; | ||||
| using std::string; | using std::string; | ||||
| using std::to_string; | using std::to_string; | ||||
| @@ -47,7 +45,7 @@ class OpsKernelInfoStore { | |||||
| // initialize opsKernelInfoStore | // initialize opsKernelInfoStore | ||||
| virtual Status Initialize(const map<string, string> &options) = 0; | virtual Status Initialize(const map<string, string> &options) = 0; | ||||
| // finalize opsKernelInfoStore | |||||
| // close opsKernelInfoStore | |||||
| virtual Status Finalize() = 0; | virtual Status Finalize() = 0; | ||||
| virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | ||||
| @@ -57,18 +55,20 @@ class OpsKernelInfoStore { | |||||
| // get all opsKernelInfo | // get all opsKernelInfo | ||||
| virtual void GetAllOpsKernelInfo(map<string, OpInfo> &infos) const = 0; | virtual void GetAllOpsKernelInfo(map<string, OpInfo> &infos) const = 0; | ||||
| // check whether opsKernelInfoStore is supported based on the operator attribute | |||||
| // whether the opsKernelInfoStore is supported based on the operator attribute | |||||
| virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; | virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; | ||||
| virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, | virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, | ||||
| bool realQuery = false) const { | bool realQuery = false) const { | ||||
| return CheckSupported(opDescPtr, un_supported_reason); | return CheckSupported(opDescPtr, un_supported_reason); | ||||
| } | } | ||||
| // opsFlag opsFlag[0] indicates constant folding is supported or not | |||||
| virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; | |||||
| // requirement of memory allocation | |||||
| // memory allocation requirement | |||||
| virtual Status CalcOpRunningParam(Node &node) = 0; | virtual Status CalcOpRunningParam(Node &node) = 0; | ||||
| // generate task for op | |||||
| // generate task for op。 | |||||
| virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | ||||
| // only call fe engine interface to compile single op | // only call fe engine interface to compile single op | ||||
| @@ -77,10 +77,10 @@ class OpsKernelInfoStore { | |||||
| // load task for op | // load task for op | ||||
| virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | ||||
| // only to call aicpu interface for generating task struct | |||||
| // only call aicpu interface to generate task struct | |||||
| virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | ||||
| // only to call aicpu interface for generating task struct | |||||
| // only call aicpu interface to generate task struct | |||||
| virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -37,6 +37,7 @@ struct RunContext { | |||||
| ge::Buffer weightsBuffer; | ge::Buffer weightsBuffer; | ||||
| std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...) | std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...) | ||||
| std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...) | std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...) | ||||
| std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...) | |||||
| }; | }; | ||||
| struct Task { | struct Task { | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include "./graph_optimizer_types.h" | #include "./graph_optimizer_types.h" | ||||
| #include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
| #include "common/opskernel/ops_kernel_info_types.h" | #include "common/opskernel/ops_kernel_info_types.h" | ||||
| @@ -39,19 +38,19 @@ class GraphOptimizer { | |||||
| // close graphOptimizer | // close graphOptimizer | ||||
| virtual Status Finalize() = 0; | virtual Status Finalize() = 0; | ||||
| // optimize original graph for FE quant optimization | |||||
| // optimize original graph for FE quant optimize | |||||
| virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } | virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } | ||||
| // optimize original graph used in the graph preparation stage | |||||
| // optimize original graph, using in graph preparation stage | |||||
| virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | ||||
| // optimize fused graph | // optimize fused graph | ||||
| virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | ||||
| // optimize the whole graph which will be used after graph merged | |||||
| // optimize whole graph, using after graph merged stage | |||||
| virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; | virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; | ||||
| // get attributes of graph optimizer | |||||
| // get attribute of graph optimizer | |||||
| virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; | virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; | ||||
| // optimize streamed Graph | // optimize streamed Graph | ||||
| @@ -19,8 +19,6 @@ | |||||
| #include <stdint.h> | #include <stdint.h> | ||||
| #include <string> | #include <string> | ||||
| using std::string; | |||||
| namespace ge { | namespace ge { | ||||
| enum OPTIMIZER_SCOPE { | enum OPTIMIZER_SCOPE { | ||||
| UNIT = 0, | UNIT = 0, | ||||
| @@ -28,7 +26,7 @@ enum OPTIMIZER_SCOPE { | |||||
| }; | }; | ||||
| struct GraphOptimizerAttribute { | struct GraphOptimizerAttribute { | ||||
| string engineName; | |||||
| std::string engineName; | |||||
| OPTIMIZER_SCOPE scope; | OPTIMIZER_SCOPE scope; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <cstdint> | #include <cstdint> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <set> | |||||
| namespace ge { | namespace ge { | ||||
| // Option key: graph run mode | // Option key: graph run mode | ||||
| @@ -38,9 +39,11 @@ const char *const GE_AICPU_FLAG = "ge.aicpuFlag"; | |||||
| const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | ||||
| const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | ||||
| const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | ||||
| const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||||
| // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | ||||
| const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | ||||
| const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | ||||
| const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; | |||||
| // Option key: memory init | // Option key: memory init | ||||
| const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | ||||
| @@ -141,19 +144,43 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||||
| // congigure outputDatatype to setting net output type | // congigure outputDatatype to setting net output type | ||||
| const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | ||||
| // congigure opSelectImplmode to setting op select implmode | |||||
| const std::string kOpSelectImplmode = "ge.opSelectImplmode"; | |||||
| // configure whether to enable hcom parallel by session constructor options param, | // configure whether to enable hcom parallel by session constructor options param, | ||||
| // its value should be "0" or "1", default value is "0" | // its value should be "0" or "1", default value is "0" | ||||
| const std::string HCOM_PARALLEL = "ge.hcomParallel"; | const std::string HCOM_PARALLEL = "ge.hcomParallel"; | ||||
| // configure whether to use dynamic batch size | |||||
| const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | |||||
| // configure whether to use dynamic image size | |||||
| const char *const kDynamicImageSize = "ge.dynamicImageSize"; | |||||
| // Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, | // Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, | ||||
| // example: GA|RL, support configure multiple, split by | | // example: GA|RL, support configure multiple, split by | | ||||
| const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | ||||
| // Configure soc version , example: "Ascend310" | |||||
| const std::string SOC_VERSION = "ge.socVersion"; | |||||
| // Configure core type "VectorEngine", default value is "AIcoreEngine" | // Configure core type "VectorEngine", default value is "AIcoreEngine" | ||||
| const std::string CORE_TYPE = "ge.engineType"; | const std::string CORE_TYPE = "ge.engineType"; | ||||
| // Configure soc version , example: "Ascend310" | |||||
| const std::string SOC_VERSION = "ge.socVersion"; | |||||
| // Configure AICORE NUM | |||||
| const std::string AICORE_NUM = "ge.aicoreNum"; | |||||
| // Configure L1FUSION | |||||
| const std::string L1_FUSION = "ge.l1Fusion"; | |||||
| // Configure Small Channel flag | |||||
| const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||||
| // Configure Compress Weight flag | |||||
| const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; | |||||
| // Configure fusion switch file path | |||||
| const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; | |||||
| // Save original model | // Save original model | ||||
| const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | ||||
| @@ -194,6 +221,28 @@ struct TensorInfo { | |||||
| DataDesc data; // tensor data | DataDesc data; // tensor data | ||||
| ShapeDesc shapeInfo; // tensor shape | ShapeDesc shapeInfo; // tensor shape | ||||
| }; | }; | ||||
| // for ir build | |||||
| namespace ir_option { | |||||
| static const char *const INPUT_FORMAT = "input_format"; | |||||
| static const char *const INPUT_SHAPE = "input_shape"; | |||||
| static const char *const OP_NAME_MAP = "op_name_map"; | |||||
| static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | |||||
| static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | |||||
| static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | |||||
| static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); | |||||
| static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; | |||||
| static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); | |||||
| static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | |||||
| static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | |||||
| static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | |||||
| // for interface: aclgrphBuildModel | |||||
| const std::set<std::string> ir_builder_suppported_options = { | |||||
| INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||||
| DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||||
| AUTO_TUNE_MODE}; | |||||
| // for interface: aclgrphBuildInitialize | |||||
| const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; | |||||
| } // namespace ir_option | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ | #endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ | ||||
| @@ -0,0 +1,75 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GE_IR_BUILD_H_ | |||||
| #define INC_EXTERNAL_GE_IR_BUILD_H_ | |||||
| #include <string> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include "graph/graph.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| namespace ge { | |||||
| struct ModelBufferData { | |||||
| std::shared_ptr<uint8_t> data = nullptr; | |||||
| uint64_t length; | |||||
| }; | |||||
| /** | |||||
| * @ingroup AscendCL | |||||
| * @brief build model.Notice the model is stored in buffer | |||||
| * | |||||
| * @param global_options[IN] global init params for build | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||||
| * @retval OtherValues Failure | |||||
| */ | |||||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | |||||
| /** | |||||
| * @ingroup AscendCL | |||||
| * @brief build model.Notice the model is stored in buffer | |||||
| * | |||||
| */ | |||||
| void aclgrphBuildFinalize(); | |||||
| /** | |||||
| * @ingroup AscendCL | |||||
| * @brief build model.Notice the model is stored in buffer | |||||
| * | |||||
| * @param graph[IN] the graph ready to build | |||||
| * @param options[IN] options used for build | |||||
| * @param model[OUT] builded model | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||||
| * @retval OtherValues Failure | |||||
| */ | |||||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | |||||
| ModelBufferData &model); | |||||
| /** | |||||
| * @ingroup AscendCL | |||||
| * @brief save model buffer to file | |||||
| * | |||||
| * @param output_file[IN] the file path to be saved | |||||
| * @param model[IN] model buffer data | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||||
| * @retval OtherValues Failure | |||||
| */ | |||||
| graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | |||||
| }; // namespace ge | |||||
| #endif | |||||
| @@ -22,7 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "./ge_error_codes.h" | |||||
| using std::make_shared; | using std::make_shared; | ||||
| using std::map; | using std::map; | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/operator.h" | |||||
| #include "./operator.h" | |||||
| namespace ge { | namespace ge { | ||||
| class GraphImpl; | class GraphImpl; | ||||
| @@ -21,8 +21,8 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/tensor.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "./tensor.h" | |||||
| #include "./types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class InferenceContext; | class InferenceContext; | ||||
| @@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||||
| static std::unique_ptr<InferenceContext> Create(); | static std::unique_ptr<InferenceContext> Create(); | ||||
| private: | private: | ||||
| InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
| explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||||
| std::shared_ptr<InferenceContextImpl> inference_context_impl_; | std::shared_ptr<InferenceContextImpl> inference_context_impl_; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -23,9 +23,9 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "external/graph/inference_context.h" | |||||
| #include "external/graph/tensor.h" | |||||
| #include "./ge_error_codes.h" | |||||
| #include "./inference_context.h" | |||||
| #include "./tensor.h" | |||||
| #ifndef USER_GE_LOGI | #ifndef USER_GE_LOGI | ||||
| #define USER_GE_LOGI(...) | #define USER_GE_LOGI(...) | ||||
| @@ -22,8 +22,8 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph//operator.h" | |||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "./operator.h" | |||||
| #include "./ge_error_codes.h" | |||||
| namespace ge { | namespace ge { | ||||
| using OpCreator = std::function<Operator(const std::string &)>; | using OpCreator = std::function<Operator(const std::string &)>; | ||||
| @@ -22,10 +22,10 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/operator.h" | |||||
| #include "external/graph/operator_factory.h" | |||||
| #include "external/graph/tensor.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "./operator.h" | |||||
| #include "./operator_factory.h" | |||||
| #include "./tensor.h" | |||||
| #include "./types.h" | |||||
| namespace ge { | namespace ge { | ||||
| using std::function; | using std::function; | ||||
| @@ -60,7 +60,7 @@ class OpReg { | |||||
| \ | \ | ||||
| private: \ | private: \ | ||||
| void __##x() { \ | void __##x() { \ | ||||
| OpReg() | |||||
| OpReg() | |||||
| #define ATTR(x, Type, ...) \ | #define ATTR(x, Type, ...) \ | ||||
| N(); \ | N(); \ | ||||
| @@ -86,7 +86,7 @@ class OpReg { | |||||
| void __attr_##x() { \ | void __attr_##x() { \ | ||||
| Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ | Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ | ||||
| string attr_name(#x); \ | string attr_name(#x); \ | ||||
| (void)OpReg() | |||||
| (void)OpReg() | |||||
| #define REQUIRED_ATTR(x, Type) \ | #define REQUIRED_ATTR(x, Type) \ | ||||
| N(); \ | N(); \ | ||||
| @@ -112,7 +112,7 @@ class OpReg { | |||||
| void __required_attr_##x() { \ | void __required_attr_##x() { \ | ||||
| Operator::RequiredAttrRegister(#x); \ | Operator::RequiredAttrRegister(#x); \ | ||||
| string attr_name(#x); \ | string attr_name(#x); \ | ||||
| (void)OpReg() | |||||
| (void)OpReg() | |||||
| #define INPUT(x, t) \ | #define INPUT(x, t) \ | ||||
| N(); \ | N(); \ | ||||
| @@ -137,7 +137,7 @@ class OpReg { | |||||
| private: \ | private: \ | ||||
| void __input_##x() { \ | void __input_##x() { \ | ||||
| Operator::InputRegister(#x); \ | Operator::InputRegister(#x); \ | ||||
| (void)OpReg() | |||||
| (void)OpReg() | |||||
| #define OPTIONAL_INPUT(x, t) \ | #define OPTIONAL_INPUT(x, t) \ | ||||
| N(); \ | N(); \ | ||||
| @@ -162,7 +162,7 @@ class OpReg { | |||||
| private: \ | private: \ | ||||
| void __optional_input_##x() { \ | void __optional_input_##x() { \ | ||||
| Operator::OptionalInputRegister(#x); \ | Operator::OptionalInputRegister(#x); \ | ||||
| (void)OpReg() | |||||
| (void)OpReg() | |||||
| #define OUTPUT(x, t) \ | #define OUTPUT(x, t) \ | ||||
| N(); \ | N(); \ | ||||
| @@ -179,7 +179,7 @@ class OpReg { | |||||
| private: \ | private: \ | ||||
| void __out_##x() { \ | void __out_##x() { \ | ||||
| Operator::OutputRegister(#x); \ | Operator::OutputRegister(#x); \ | ||||
| (void)OpReg() | |||||
| (void)OpReg() | |||||
| #define DYNAMIC_INPUT(x, t) \ | #define DYNAMIC_INPUT(x, t) \ | ||||
| N(); \ | N(); \ | ||||
| @@ -206,7 +206,7 @@ class OpReg { | |||||
| \ | \ | ||||
| private: \ | private: \ | ||||
| void __dy_input_##x() { \ | void __dy_input_##x() { \ | ||||
| (void)OpReg() | |||||
| (void)OpReg() | |||||
| #define DYNAMIC_OUTPUT(x, t) \ | #define DYNAMIC_OUTPUT(x, t) \ | ||||
| N(); \ | N(); \ | ||||
| @@ -227,18 +227,18 @@ class OpReg { | |||||
| \ | \ | ||||
| private: \ | private: \ | ||||
| void __dy_output_##x() { \ | void __dy_output_##x() { \ | ||||
| (void)OpReg() | |||||
| (void)OpReg() | |||||
| #define PASTE(g_register, y) g_register##y | #define PASTE(g_register, y) g_register##y | ||||
| #define __OP_END_IMPL__(x, y) \ | |||||
| N(); \ | |||||
| } \ | |||||
| static_assert( \ | |||||
| std::is_same<x, _THIS_TYPE>::value, \ | |||||
| "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||||
| } \ | |||||
| ; \ | |||||
| static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||||
| #define __OP_END_IMPL__(x, y) \ | |||||
| N(); \ | |||||
| } \ | |||||
| static_assert( \ | |||||
| std::is_same<x, _THIS_TYPE>::value, \ | |||||
| "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||||
| } \ | |||||
| ; \ | |||||
| static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||||
| } | } | ||||
| #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) | #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) | ||||
| @@ -286,7 +286,7 @@ class OpReg { | |||||
| // Common shape inferencer | // Common shape inferencer | ||||
| #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ | #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ | ||||
| [](Operator op)->graphStatus { \ | |||||
| [](Operator op) -> graphStatus { \ | |||||
| auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ | auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ | ||||
| auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | ||||
| TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | ||||
| @@ -300,7 +300,7 @@ graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape, | |||||
| const function<void(const vector<int64_t> &y_shape)> &set_out_shape); | const function<void(const vector<int64_t> &y_shape)> &set_out_shape); | ||||
| #define BROADCAST_INFER(in1_name, in2_name, out_name) \ | #define BROADCAST_INFER(in1_name, in2_name, out_name) \ | ||||
| [](Operator op)->graphStatus { \ | |||||
| [](Operator op) -> graphStatus { \ | |||||
| return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ | return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ | ||||
| [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ | [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ | ||||
| [&](const vector<int64_t> &y_shape) { \ | [&](const vector<int64_t> &y_shape) { \ | ||||
| @@ -22,8 +22,8 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/graph/ge_error_codes.h" | |||||
| #include "external/graph/types.h" | |||||
| #include "./ge_error_codes.h" | |||||
| #include "./types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class ShapeImpl; | class ShapeImpl; | ||||
| @@ -133,11 +133,13 @@ enum Format { | |||||
| FORMAT_FRACTAL_ZZ, | FORMAT_FRACTAL_ZZ, | ||||
| FORMAT_FRACTAL_NZ, | FORMAT_FRACTAL_NZ, | ||||
| FORMAT_NCDHW, | FORMAT_NCDHW, | ||||
| FORMAT_DHWCK, // 3D filter input tensor format | |||||
| FORMAT_DHWCN, // 3D filter input tensor format | |||||
| FORMAT_NDC1HWC0, | FORMAT_NDC1HWC0, | ||||
| FORMAT_FRACTAL_Z_3D, | FORMAT_FRACTAL_Z_3D, | ||||
| FORMAT_CN, | FORMAT_CN, | ||||
| FORMAT_NC, | FORMAT_NC, | ||||
| FORMAT_DHWNC, | |||||
| FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | |||||
| FORMAT_RESERVED, | FORMAT_RESERVED, | ||||
| FORMAT_ALL | FORMAT_ALL | ||||
| }; | }; | ||||
| @@ -47,6 +47,12 @@ class Tensor; | |||||
| class TBEPluginManager; | class TBEPluginManager; | ||||
| } // namespace ge | } // namespace ge | ||||
| namespace google { | |||||
| namespace protobuf { | |||||
| class Message; | |||||
| } | |||||
| } // namespace google | |||||
| namespace domi { | namespace domi { | ||||
| Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | ||||
| Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | ||||
| @@ -56,6 +62,8 @@ using google::protobuf::Message; | |||||
| class OpRegistrationDataImpl; | class OpRegistrationDataImpl; | ||||
| using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | ||||
| using FusionParseParamFunc = | |||||
| std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | ||||
| public: | public: | ||||
| @@ -71,15 +79,20 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
| OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | ||||
| OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | |||||
| OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | ||||
| OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | ||||
| OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); | |||||
| domi::ImplyType GetImplyType() const; | domi::ImplyType GetImplyType() const; | ||||
| std::string GetOmOptype() const; | std::string GetOmOptype() const; | ||||
| std::set<std::string> GetOriginOpTypeSet() const; | std::set<std::string> GetOriginOpTypeSet() const; | ||||
| domi::FrameworkType GetFrameworkType() const; | domi::FrameworkType GetFrameworkType() const; | ||||
| ParseParamFunc GetParseParamFn() const; | ParseParamFunc GetParseParamFn() const; | ||||
| FusionParseParamFunc GetFusionParseParamFn() const; | |||||
| private: | private: | ||||
| std::shared_ptr<OpRegistrationDataImpl> impl_; | std::shared_ptr<OpRegistrationDataImpl> impl_; | ||||
| @@ -103,5 +116,27 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
| namespace ge { | namespace ge { | ||||
| using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
| using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||||
| public: | |||||
| HostCpuOp() = default; | |||||
| virtual ~HostCpuOp() = default; | |||||
| virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||||
| std::map<std::string, Tensor> &outputs) = 0; | |||||
| }; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||||
| public: | |||||
| HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||||
| }; | |||||
| #define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||||
| static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||||
| ::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | ||||
| @@ -22,7 +22,7 @@ | |||||
| #define DECLARE_ERRORNO(sysid, modid, name, value) \ | #define DECLARE_ERRORNO(sysid, modid, name, value) \ | ||||
| const domi::Status name = \ | const domi::Status name = \ | ||||
| ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||||
| ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||||
| #define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) | #define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) | ||||
| @@ -33,6 +33,7 @@ using Status = uint32_t; | |||||
| DECLARE_ERRORNO(0, 0, SUCCESS, 0); | DECLARE_ERRORNO(0, 0, SUCCESS, 0); | ||||
| DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); | DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); | ||||
| DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 | DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 | ||||
| DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); | |||||
| } // namespace domi | } // namespace domi | ||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ | ||||
| @@ -48,6 +48,10 @@ typedef enum tagDomiTensorFormat { | |||||
| DOMI_TENSOR_BN_WEIGHT, | DOMI_TENSOR_BN_WEIGHT, | ||||
| DOMI_TENSOR_CHWN, // Android NN Depth CONV | DOMI_TENSOR_CHWN, // Android NN Depth CONV | ||||
| DOMI_TENSOR_FILTER_HWCK, // filter input tensor format | DOMI_TENSOR_FILTER_HWCK, // filter input tensor format | ||||
| DOMI_TENSOR_NDHWC, | |||||
| DOMI_TENSOR_NCDHW, | |||||
| DOMI_TENSOR_DHWCN, // 3D filter input tensor format | |||||
| DOMI_TENSOR_DHWNC, | |||||
| DOMI_TENSOR_RESERVED | DOMI_TENSOR_RESERVED | ||||
| } domiTensorFormat_t; | } domiTensorFormat_t; | ||||
| } // namespace domi | } // namespace domi | ||||
| @@ -18,11 +18,13 @@ | |||||
| #define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | #define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | ||||
| #include <cstdint> | #include <cstdint> | ||||
| #include <unistd.h> | |||||
| #include <sys/syscall.h> | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "toolchain/slog.h" | #include "toolchain/slog.h" | ||||
| #define GE_MODULE_NAME GE | |||||
| #define GE_MODULE_NAME static_cast<int>(GE) | |||||
| // trace status of log | // trace status of log | ||||
| enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | ||||
| @@ -35,15 +37,20 @@ enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | |||||
| #define GELOGO(...) GE_LOG_OPLOG(GE_MODULE_NAME, __VA_ARGS__) | #define GELOGO(...) GE_LOG_OPLOG(GE_MODULE_NAME, __VA_ARGS__) | ||||
| #define GELOGT(VALUE, ...) GE_LOG_TRACE(GE_MODULE_NAME, VALUE, __VA_ARGS__) | #define GELOGT(VALUE, ...) GE_LOG_TRACE(GE_MODULE_NAME, VALUE, __VA_ARGS__) | ||||
| inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||||
| int32_t enable_event = 0; | |||||
| int32_t dlog_level = dlog_getlevel(module_name, &enable_event); | |||||
| if (dlog_level <= log_level) { | |||||
| inline bool IsLogEnable(int module_name, int log_level) { | |||||
| int32_t enable = CheckLogLevel(module_name, log_level); | |||||
| // 1:enable, 0:disable | |||||
| if (enable == 1) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| return false; | return false; | ||||
| } | } | ||||
| inline pid_t GetTid() { | |||||
| thread_local static pid_t tid = syscall(__NR_gettid); | |||||
| return tid; | |||||
| } | |||||
| #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | ||||
| #define GE_TIMESTAMP_END(stage, stage_name) \ | #define GE_TIMESTAMP_END(stage, stage_name) \ | ||||
| @@ -68,29 +75,35 @@ inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | ||||
| call_num_of##stage) | call_num_of##stage) | ||||
| #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||||
| dlog_error(static_cast<int>(MOD_NAME), "%s: ErrorNo: %d(%s) " fmt, __FUNCTION__, ERROR_CODE, \ | |||||
| #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||||
| dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ | |||||
| ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | ||||
| #define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||||
| if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_WARN)) \ | |||||
| dlog_warn(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||||
| if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_INFO)) \ | |||||
| dlog_info(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||||
| if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_DEBUG)) \ | |||||
| dlog_debug(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||||
| if (IsLogEnable(MOD_NAME, DLOG_WARN)) dlog_warn(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||||
| if (IsLogEnable(MOD_NAME, DLOG_INFO)) dlog_info(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||||
| if (IsLogEnable(MOD_NAME, DLOG_DEBUG)) dlog_debug(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \ | #define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \ | ||||
| Dlog(static_cast<int>(MOD_NAME), DLOG_OPLOG, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||||
| do { \ | |||||
| TraceStatus stat = value; \ | |||||
| const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||||
| int idx = static_cast<int>(stat); \ | |||||
| char *k = const_cast<char *>("status"); \ | |||||
| char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||||
| KeyValue kv = {k, v}; \ | |||||
| DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__); \ | |||||
| Dlog(MOD_NAME, DLOG_OPLOG, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||||
| do { \ | |||||
| TraceStatus stat = value; \ | |||||
| const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||||
| int idx = static_cast<int>(stat); \ | |||||
| char *k = const_cast<char *>("status"); \ | |||||
| char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||||
| KeyValue kv = {k, v}; \ | |||||
| DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__); \ | |||||
| } while (0) | } while (0) | ||||
| // print memory when it is greater than 1KB. | |||||
| #define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \ | |||||
| do { \ | |||||
| if ((SIZE) > 1024) { \ | |||||
| GELOGI("MallocMemory, func=%s, size=%zu, purpose=%s", (#FUNC), static_cast<size_t>(SIZE), (PURPOSE)); \ | |||||
| } \ | |||||
| } while (0); | |||||
| #endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | ||||
| @@ -29,7 +29,18 @@ | |||||
| using cce::CC_STATUS_SUCCESS; | using cce::CC_STATUS_SUCCESS; | ||||
| using cce::ccStatus_t; | using cce::ccStatus_t; | ||||
| #define GE_LOGE(...) DAV_LOGE("GE", __VA_ARGS__) | |||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||||
| #define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||||
| #else | |||||
| #include <android/log.h> | |||||
| #if defined(BUILD_VERSION_PERF) | |||||
| #define DOMI_LOGE(fmt, ...) | |||||
| #else | |||||
| // The Android system has strict log control. Do not modify the log. | |||||
| #define DOMI_LOGE(fmt, ...) \ | |||||
| __android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
| #endif | |||||
| #endif | |||||
| // ge marco | // ge marco | ||||
| #define GE_LOGI_IF(condition, ...) \ | #define GE_LOGI_IF(condition, ...) \ | ||||
| @@ -44,7 +55,7 @@ using cce::ccStatus_t; | |||||
| #define GE_LOGE_IF(condition, ...) \ | #define GE_LOGE_IF(condition, ...) \ | ||||
| if ((condition)) { \ | if ((condition)) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| } | } | ||||
| // If expr is not SUCCESS, print the log and return the same value | // If expr is not SUCCESS, print the log and return the same value | ||||
| @@ -52,7 +63,7 @@ using cce::ccStatus_t; | |||||
| do { \ | do { \ | ||||
| const ge::Status _status = (expr); \ | const ge::Status _status = (expr); \ | ||||
| if (_status != ge::SUCCESS) { \ | if (_status != ge::SUCCESS) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| return _status; \ | return _status; \ | ||||
| } \ | } \ | ||||
| } while (0); | } while (0); | ||||
| @@ -62,7 +73,7 @@ using cce::ccStatus_t; | |||||
| do { \ | do { \ | ||||
| const ge::Status _status = (expr); \ | const ge::Status _status = (expr); \ | ||||
| if (_status != ge::SUCCESS) { \ | if (_status != ge::SUCCESS) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| } \ | } \ | ||||
| } while (0); | } while (0); | ||||
| @@ -75,6 +86,15 @@ using cce::ccStatus_t; | |||||
| } \ | } \ | ||||
| } while (0); | } while (0); | ||||
| // If expr is not GRAPH_SUCCESS, print the log and return FAILED | |||||
| #define GE_CHK_GRAPH_STATUS_RET(expr, ...) \ | |||||
| do { \ | |||||
| if ((expr) != ge::GRAPH_SUCCESS) { \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| return FAILED; \ | |||||
| } \ | |||||
| } while (0); | |||||
| // If expr is not SUCCESS, print the log and execute a custom statement | // If expr is not SUCCESS, print the log and execute a custom statement | ||||
| #define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ | #define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ | ||||
| do { \ | do { \ | ||||
| @@ -91,25 +111,11 @@ using cce::ccStatus_t; | |||||
| (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | ||||
| (void)msg.append( \ | (void)msg.append( \ | ||||
| ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | ||||
| GE_LOGE("%s", msg.c_str()); \ | |||||
| DOMI_LOGE("%s", msg.c_str()); \ | |||||
| return _status; \ | return _status; \ | ||||
| } \ | } \ | ||||
| } while (0); | } while (0); | ||||
| // If expr is not true, print the Info log and return the specified status | |||||
| #define GE_CHK_BOOL_RET_STATUS_LOGI(expr, _status, ...) \ | |||||
| do { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| std::string msg; \ | |||||
| (void)msg.append(StringUtils::FormatString(__VA_ARGS__)); \ | |||||
| (void)msg.append( \ | |||||
| StringUtils::FormatString(" Check result false, status: 0x%X %s", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||||
| GELOGI("%s", msg.c_str()); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0); | |||||
| // If expr is not true, print the log and return the specified status | // If expr is not true, print the log and return the specified status | ||||
| #define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | #define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | ||||
| do { \ | do { \ | ||||
| @@ -124,7 +130,7 @@ using cce::ccStatus_t; | |||||
| { \ | { \ | ||||
| bool b = (expr); \ | bool b = (expr); \ | ||||
| if (!b) { \ | if (!b) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| exec_expr; \ | exec_expr; \ | ||||
| } \ | } \ | ||||
| }; | }; | ||||
| @@ -163,7 +169,7 @@ using cce::ccStatus_t; | |||||
| { \ | { \ | ||||
| bool b = (expr); \ | bool b = (expr); \ | ||||
| if (b) { \ | if (b) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| exec_expr; \ | exec_expr; \ | ||||
| } \ | } \ | ||||
| }; | }; | ||||
| @@ -182,7 +188,7 @@ using cce::ccStatus_t; | |||||
| { \ | { \ | ||||
| bool b = (expr); \ | bool b = (expr); \ | ||||
| if (b) { \ | if (b) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| exec_expr; \ | exec_expr; \ | ||||
| return; \ | return; \ | ||||
| } \ | } \ | ||||
| @@ -193,7 +199,7 @@ using cce::ccStatus_t; | |||||
| { \ | { \ | ||||
| bool b = (expr); \ | bool b = (expr); \ | ||||
| if (b) { \ | if (b) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| exec_expr; \ | exec_expr; \ | ||||
| return _status; \ | return _status; \ | ||||
| } \ | } \ | ||||
| @@ -210,62 +216,42 @@ using cce::ccStatus_t; | |||||
| // -----------------runtime related macro definitions------------------------------- | // -----------------runtime related macro definitions------------------------------- | ||||
| // If expr is not RT_ERROR_NONE, print the log | // If expr is not RT_ERROR_NONE, print the log | ||||
| #define GE_CHK_RT(expr) \ | |||||
| do { \ | |||||
| rtError_t _rt_ret = (expr); \ | |||||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||||
| GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
| } \ | |||||
| #define GE_CHK_RT(expr) \ | |||||
| do { \ | |||||
| rtError_t _rt_ret = (expr); \ | |||||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
| } \ | |||||
| } while (0); | } while (0); | ||||
| // If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | // If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | ||||
| #define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||||
| { \ | |||||
| rtError_t _rt_ret = (expr); \ | |||||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||||
| GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| #define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||||
| { \ | |||||
| rtError_t _rt_ret = (expr); \ | |||||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | } | ||||
| // If expr is not RT_ERROR_NONE, print the log and return | // If expr is not RT_ERROR_NONE, print the log and return | ||||
| #define GE_CHK_RT_RET(expr) \ | |||||
| do { \ | |||||
| rtError_t _rt_ret = (expr); \ | |||||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||||
| GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
| return ge::RT_FAILED; \ | |||||
| } \ | |||||
| #define GE_CHK_RT_RET(expr) \ | |||||
| do { \ | |||||
| rtError_t _rt_ret = (expr); \ | |||||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||||
| return ge::RT_FAILED; \ | |||||
| } \ | |||||
| } while (0); | } while (0); | ||||
| // ------------------------cce related macro definitions---------------------------- | // ------------------------cce related macro definitions---------------------------- | ||||
| // If expr is not CC_STATUS_SUCCESS, print the log | // If expr is not CC_STATUS_SUCCESS, print the log | ||||
| #define GE_CHK_CCE(expr) \ | |||||
| do { \ | |||||
| ccStatus_t _cc_ret = (expr); \ | |||||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
| GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
| } \ | |||||
| } while (0); | |||||
| // If expr is not CC_STATUS_SUCCESS, print the log and execute the exec_expr expression | |||||
| #define GE_CHK_CCE_EXEC(expr, exec_expr) \ | |||||
| do { \ | |||||
| ccStatus_t _cc_ret = (expr); \ | |||||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
| GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } while (0); | |||||
| // If expr is not CC_STATUS_SUCCESS, print the log and return | |||||
| #define GE_CHK_CCE_RET(expr) \ | |||||
| do { \ | |||||
| ccStatus_t _cc_ret = (expr); \ | |||||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
| GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
| return ge::CCE_FAILED; \ | |||||
| } \ | |||||
| #define GE_CHK_CCE(expr) \ | |||||
| do { \ | |||||
| ccStatus_t _cc_ret = (expr); \ | |||||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||||
| DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
| } \ | |||||
| } while (0); | } while (0); | ||||
| // If expr is true, execute exec_expr without printing logs | // If expr is true, execute exec_expr without printing logs | ||||
| @@ -281,37 +267,8 @@ using cce::ccStatus_t; | |||||
| try { \ | try { \ | ||||
| exec_expr0; \ | exec_expr0; \ | ||||
| } catch (const std::bad_alloc &) { \ | } catch (const std::bad_alloc &) { \ | ||||
| GE_LOGE("Make shared failed"); \ | |||||
| DOMI_LOGE("Make shared failed"); \ | |||||
| exec_expr1; \ | exec_expr1; \ | ||||
| } | } | ||||
| #define GE_CHECK_INT32_MUL_OVERFLOW(a, b, ...) \ | |||||
| do { \ | |||||
| if ((a) > 0) { \ | |||||
| if ((b) > 0) { \ | |||||
| if ((a) > (INT32_MAX / (b))) { \ | |||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| return ge::FAILED; \ | |||||
| } \ | |||||
| } else { \ | |||||
| if ((b) < (INT32_MIN / (a))) { \ | |||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| return ge::FAILED; \ | |||||
| } \ | |||||
| } \ | |||||
| } else { \ | |||||
| if ((b) > 0) { \ | |||||
| if ((a) < (INT32_MAX / (b))) { \ | |||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| return ge::FAILED; \ | |||||
| } \ | |||||
| } else { \ | |||||
| if (((a) != 0) && ((b) < (INT32_MAX / (a)))) { \ | |||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| return ge::FAILED; \ | |||||
| } \ | |||||
| } \ | |||||
| } \ | |||||
| } while (0); | |||||
| #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | ||||
| @@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | ||||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | ||||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | ||||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | ||||
| GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | ||||
| GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | ||||
| @@ -204,15 +203,16 @@ GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60, | |||||
| GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | ||||
| "Failed set graph finish rebuild in node searcher."); // 1343242301 | "Failed set graph finish rebuild in node searcher."); // 1343242301 | ||||
| GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | ||||
| // Optimize errocode | |||||
| GE_ERRORNO_GRAPH(TO_BE_DELETED, 200, "The node of the graph to be deleted."); | |||||
| GE_ERRORNO_GRAPH(NOT_CHANGED, 201, "NThe node of the graph not changed."); | |||||
| // Engine_manager module error code definition | // Engine_manager module error code definition | ||||
| GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | ||||
| GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 | GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 | ||||
| GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | ||||
| // Optimize errocode | |||||
| GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 | |||||
| GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 | |||||
| // Ops module error code definition | // Ops module error code definition | ||||
| GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | ||||
| GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 | GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 | ||||
| @@ -24,8 +24,7 @@ | |||||
| #include "common/fmk_error_codes.h" | #include "common/fmk_error_codes.h" | ||||
| #include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
| using std::string; | |||||
| #include "external/graph/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| enum RuntimeType { HOST = 0, DEVICE = 1 }; | enum RuntimeType { HOST = 0, DEVICE = 1 }; | ||||
| @@ -56,7 +55,7 @@ struct DataBuffer { | |||||
| /// | /// | ||||
| /// @ingroup domi_ome | /// @ingroup domi_ome | ||||
| /// @brief External inputdata | |||||
| /// @brief External input data | |||||
| /// | /// | ||||
| struct InputData { | struct InputData { | ||||
| uint32_t index; // Index of input data | uint32_t index; // Index of input data | ||||
| @@ -65,13 +64,14 @@ struct InputData { | |||||
| uint32_t model_id; // Model ID required for data processing | uint32_t model_id; // Model ID required for data processing | ||||
| uint64_t request_id = 0; // Request ID | uint64_t request_id = 0; // Request ID | ||||
| std::vector<DataBuffer> blobs; // Actual input data, currently only supports one input | std::vector<DataBuffer> blobs; // Actual input data, currently only supports one input | ||||
| bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false | |||||
| std::string batch_label; // Gear used for current inference in dynamic batch scene | |||||
| }; | }; | ||||
| // The definition of output result structure | |||||
| /// Output result structure definition | |||||
| struct OutputData { | struct OutputData { | ||||
| uint32_t index; // Index of input data | uint32_t index; // Index of input data | ||||
| uint32_t model_id; // The model ID corresponding to the processing result | uint32_t model_id; // The model ID corresponding to the processing result | ||||
| /// Output data cache, arranged in sequence of output operators. | /// Output data cache, arranged in sequence of output operators. | ||||
| /// If the operator has multiple outputs, | /// If the operator has multiple outputs, | ||||
| /// the data buffer order of the operator is the same as that defined in the | /// the data buffer order of the operator is the same as that defined in the | ||||
| @@ -142,11 +142,31 @@ struct Options { | |||||
| bool deployMode; | bool deployMode; | ||||
| bool isAICPUMode; | bool isAICPUMode; | ||||
| bool enable_atomic; | bool enable_atomic; | ||||
| string podName; | |||||
| std::string podName; | |||||
| int64_t rankId; | int64_t rankId; | ||||
| string rankTableFile; | |||||
| std::string rankTableFile; | |||||
| int32_t ge_hccl_flag = 0; | int32_t ge_hccl_flag = 0; | ||||
| int32_t physical_device_id; | int32_t physical_device_id; | ||||
| }; | }; | ||||
| // Profiling info of task | |||||
| struct TaskDescInfo { | |||||
| std::string op_name; | |||||
| uint32_t block_dim; | |||||
| uint32_t task_id; | |||||
| uint32_t stream_id; | |||||
| }; | |||||
| // Profiling info of graph | |||||
| struct ComputeGraphDescInfo { | |||||
| std::string op_name; | |||||
| std::string op_type; | |||||
| std::vector<Format> input_format; | |||||
| std::vector<std::vector<int64_t>> input_shape; | |||||
| std::vector<DataType> input_data_type; | |||||
| std::vector<Format> output_format; | |||||
| std::vector<std::vector<int64_t>> output_shape; | |||||
| std::vector<DataType> output_data_type; | |||||
| }; | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ | #endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <memory> | |||||
| #include "common/fmk_types.h" | #include "common/fmk_types.h" | ||||
| #include "common/helper/om_file_helper.h" | #include "common/helper/om_file_helper.h" | ||||
| @@ -33,36 +32,41 @@ class ModelHelper { | |||||
| ModelHelper() = default; | ModelHelper() = default; | ||||
| ~ModelHelper(); | ~ModelHelper(); | ||||
| Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, const std::string &output_file); | |||||
| Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file); | |||||
| Status LoadModel(const ge::ModelData &model_data); | |||||
| Status SaveToOmModel(const GeModelPtr& ge_model, const SaveParam& save_param, const std::string& output_file, | |||||
| ge::ModelBufferData& model); | |||||
| Status SaveOriginalGraphToOmModel(const ge::Graph& graph, const std::string& output_file); | |||||
| Status LoadModel(const ge::ModelData& model_data); | |||||
| Status GetModelBufferData(ge::ModelBufferData& model); | |||||
| ModelFileHeader *GetFileHeader() { return file_header_; } | |||||
| ModelFileHeader* GetFileHeader() { return file_header_; } | |||||
| GeModelPtr GetGeModel(); | GeModelPtr GetGeModel(); | ||||
| void SetSaveMode(bool val) { is_offline_ = val; } | |||||
| bool GetSaveMode(void) const { return is_offline_; } | |||||
| static Status TransModelToGeModel(const ModelPtr &model, GeModelPtr &ge_model); | |||||
| static Status TransGeModelToModel(const GeModelPtr &geModelPtr, ModelPtr &modelPtr); | |||||
| static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); | |||||
| static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); | |||||
| private: | private: | ||||
| bool is_assign_model_ = false; | bool is_assign_model_ = false; | ||||
| ModelFileHeader *file_header_ = nullptr; | |||||
| bool is_offline_ = true; | |||||
| ModelFileHeader* file_header_ = nullptr; | |||||
| // Encrypted model need delete temp model and unencrypted model need not delete model | // Encrypted model need delete temp model and unencrypted model need not delete model | ||||
| uint8_t *model_addr_tmp_ = nullptr; | |||||
| uint8_t* model_addr_tmp_ = nullptr; | |||||
| uint32_t model_len_tmp_ = 0; | uint32_t model_len_tmp_ = 0; | ||||
| GeModelPtr model_; | GeModelPtr model_; | ||||
| ModelHelper(const ModelHelper &); | |||||
| ModelHelper &operator=(const ModelHelper &); | |||||
| Status GenerateGeModel(OmFileLoadHelper &om_load_helper); | |||||
| Status LoadModelData(OmFileLoadHelper &om_load_helper); | |||||
| void SetModelToGeModel(ge::Model &model); | |||||
| Status LoadWeights(OmFileLoadHelper &om_load_helper); | |||||
| Status LoadTask(OmFileLoadHelper &om_load_helper); | |||||
| Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper); | |||||
| ModelHelper(const ModelHelper&); | |||||
| ModelHelper& operator=(const ModelHelper&); | |||||
| Status GenerateGeModel(OmFileLoadHelper& om_load_helper); | |||||
| Status LoadModelData(OmFileLoadHelper& om_load_helper); | |||||
| void SetModelToGeModel(ge::Model& model); | |||||
| Status LoadWeights(OmFileLoadHelper& om_load_helper); | |||||
| Status LoadTask(OmFileLoadHelper& om_load_helper); | |||||
| Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | |||||
| Status ReleaseLocalModelData() noexcept; | Status ReleaseLocalModelData() noexcept; | ||||
| Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | |||||
| const uint8_t *data, size_t size); | |||||
| Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | |||||
| const uint8_t* data, size_t size); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ | #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ | ||||
| @@ -20,10 +20,12 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "external/ge/ge_ir_build.h" | |||||
| #include "framework/common/fmk_types.h" | #include "framework/common/fmk_types.h" | ||||
| #include "framework/common/ge_types.h" | |||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "framework/common/ge_types.h" | |||||
| using ProcParam = struct PROC_PARAM; | |||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| @@ -80,9 +82,10 @@ class OmFileSaveHelper { | |||||
| const std::vector<ModelPartition> &GetModelPartitions() const; | const std::vector<ModelPartition> &GetModelPartitions() const; | ||||
| Status SaveModel(const SaveParam &save_param, const char *target_file); | |||||
| Status SaveModel(const SaveParam &save_param, const char *target_file, ge::ModelBufferData &model, | |||||
| bool is_offline = true); | |||||
| Status SaveModelToFile(const char *output_file); | |||||
| Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true); | |||||
| ModelFileHeader model_header_; | ModelFileHeader model_header_; | ||||
| OmFileContext context_; | OmFileContext context_; | ||||
| @@ -120,4 +120,4 @@ class L2CacheOptimize { | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | |||||
| #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | |||||
| @@ -649,6 +649,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_M | |||||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | ||||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; | |||||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | ||||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | ||||
| @@ -801,6 +803,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_N | |||||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | ||||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | ||||
| // For constant folding | |||||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||||
| } // namespace domi | } // namespace domi | ||||
| #endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | #endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | ||||
| @@ -17,11 +17,12 @@ | |||||
| #ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | ||||
| #define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | #define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | ||||
| #include <google/protobuf/map.h> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <string> | #include <string> | ||||
| #include <google/protobuf/map.h> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "common/types.h" | #include "common/types.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
| using domi::AttrDef; | using domi::AttrDef; | ||||
| @@ -18,7 +18,6 @@ | |||||
| #define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ | #define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ | ||||
| #include <cce/dnn.h> | #include <cce/dnn.h> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| @@ -56,6 +55,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_TR | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | ||||
| // FunctionOp | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t IF_COND_INPUT; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_START_INPUT; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT_INPUT; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||||
| class OpUtils { | class OpUtils { | ||||
| public: | public: | ||||
| /// | /// | ||||
| @@ -164,15 +172,23 @@ class OpUtils { | |||||
| /// | /// | ||||
| static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params); | static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params); | ||||
| static Status TransferDim(const std::vector<int64_t> &dim, std::vector<int64_t> &dim_vector); | static Status TransferDim(const std::vector<int64_t> &dim, std::vector<int64_t> &dim_vector); | ||||
| static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin, | |||||
| int64_t out_dim, int64_t stride); | |||||
| template <typename T> | |||||
| static void SliceData(const std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, | |||||
| int64_t begin, int64_t out_dim, int64_t stride); | |||||
| template <typename T> | |||||
| static Status SetDataByDataType(size_t out_size, const std::vector<char *> &chunk_input, | |||||
| const std::vector<char *> &chunk_output, GeTensor *output); | |||||
| template <typename T> | |||||
| static Status SetOutputSliceDataByDataType(void *data, int64_t data_size, const std::vector<int64_t> &input_dims, | |||||
| const std::vector<int64_t> &begin, const std::vector<int64_t> &output_dims, | |||||
| ge::GeTensor *output, const std::vector<int64_t> &stride); | |||||
| static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int64_t> &input_dims, | static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int64_t> &input_dims, | ||||
| std::vector<int64_t> &begin, std::vector<int64_t> &output_dims, ge::GeTensor *output, | std::vector<int64_t> &begin, std::vector<int64_t> &output_dims, ge::GeTensor *output, | ||||
| std::vector<int64_t> &stride); | std::vector<int64_t> &stride); | ||||
| /// | /// | ||||
| /// @ingroup domi_omg | /// @ingroup domi_omg | ||||
| /// @brief Convert the convolution weight data from [h, w, c, k] to [k, c, h, w] | |||||
| /// @brief Convert the convolutional weight data from [h, w, c, k] to [k, c, h, w] | |||||
| /// @param [in] input Weight data in HWCK format | /// @param [in] input Weight data in HWCK format | ||||
| /// @param [in] H value of H dimension | /// @param [in] H value of H dimension | ||||
| /// @param [in] W value of W dimension | /// @param [in] W value of W dimension | ||||
| @@ -183,7 +199,7 @@ class OpUtils { | |||||
| static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | ||||
| /// | /// | ||||
| /// @ingroup domi_omg | /// @ingroup domi_omg | ||||
| /// @brief Converts the convolution weight data from [k, c, h, w] to [h, w, c, k]. | |||||
| /// @brief Converts the convolutional weight data from [k, c, h, w] to [h, w, c, k]. | |||||
| /// @param [in] input Weight data in HWCK format | /// @param [in] input Weight data in HWCK format | ||||
| /// @param [in] K value of K dimension | /// @param [in] K value of K dimension | ||||
| /// @param [in] C value of C dimension | /// @param [in] C value of C dimension | ||||
| @@ -222,7 +238,6 @@ using CceTensorDescriptorPtr = std::shared_ptr<CceTensorDescriptor>; | |||||
| class CceTensorDescriptor { | class CceTensorDescriptor { | ||||
| public: | public: | ||||
| explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); | explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); | ||||
| CceTensorDescriptor(const CceTensorDescriptor &) = delete; | CceTensorDescriptor(const CceTensorDescriptor &) = delete; | ||||
| CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; | CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <math.h> | #include <math.h> | ||||
| #include <stdint.h> | #include <stdint.h> | ||||
| namespace domi { | |||||
| namespace ge { | |||||
| // general | // general | ||||
| const float DEFAULT_ALPHA_VALUE = 1.0; | const float DEFAULT_ALPHA_VALUE = 1.0; | ||||
| const float DEFAULT_BETA_VALUE = 0.0; | const float DEFAULT_BETA_VALUE = 0.0; | ||||
| @@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||||
| // Shufflechannel | // Shufflechannel | ||||
| const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | ||||
| } // namespace domi | |||||
| } // namespace ge | |||||
| #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ | #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ | ||||
| @@ -25,7 +25,7 @@ | |||||
| /// MAKE_GUARD([&] { Release Resource 1 }) | /// MAKE_GUARD([&] { Release Resource 1 }) | ||||
| /// Acquire Resource 2 | /// Acquire Resource 2 | ||||
| // MAKE_GUARD([&] { Release Resource 2 }) | // MAKE_GUARD([&] { Release Resource 2 }) | ||||
| #define GE_MAKE_GUARD(var, callback) ge::ScopeGuard make_guard_##var(callback) | |||||
| #define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||||
| #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | ||||
| namespace ge { | namespace ge { | ||||
| @@ -156,6 +156,7 @@ REGISTER_OPTYPE_DECLARE(GATHER, "Gather"); | |||||
| REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); | REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); | ||||
| REGISTER_OPTYPE_DECLARE(PACK, "Pack"); | REGISTER_OPTYPE_DECLARE(PACK, "Pack"); | ||||
| REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); | REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); | ||||
| REGISTER_OPTYPE_DECLARE(SLICED, "SliceD"); | |||||
| REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); | REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); | ||||
| REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); | REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); | ||||
| REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); | REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); | ||||
| @@ -188,6 +189,19 @@ REGISTER_OPTYPE_DECLARE(REFNEXTITERATION, "RefNextIteration"); | |||||
| REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); | REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); | ||||
| REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); | REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); | ||||
| REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); | REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); | ||||
| REGISTER_OPTYPE_DECLARE(SYMBOLICGRADIENT, "SymbolicGradient"); | |||||
| REGISTER_OPTYPE_DECLARE(REMOTECALL, "RemoteCall"); | |||||
| REGISTER_OPTYPE_DECLARE(_IF, "_If"); | |||||
| REGISTER_OPTYPE_DECLARE(STATELESSIF, "StatelessIf"); | |||||
| REGISTER_OPTYPE_DECLARE(IF, "If"); | |||||
| REGISTER_OPTYPE_DECLARE(CASE, "Case"); | |||||
| REGISTER_OPTYPE_DECLARE(_WHILE, "_While"); | |||||
| REGISTER_OPTYPE_DECLARE(WHILE, "While"); | |||||
| REGISTER_OPTYPE_DECLARE(STATELESSWHILE, "StatelessWhile"); | |||||
| REGISTER_OPTYPE_DECLARE(FOR, "For"); | |||||
| REGISTER_OPTYPE_DECLARE(PARTITIONEDCALL, "PartitionedCall"); | |||||
| REGISTER_OPTYPE_DECLARE(STATEFULPARTITIONEDCALL, "StatefulPartitionedCall"); | |||||
| REGISTER_OPTYPE_DECLARE(FAKEPARAM, "FakeParam"); | |||||
| REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); | REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); | ||||
| REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); | REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); | ||||
| REGISTER_OPTYPE_DECLARE(CAST, "Cast"); | REGISTER_OPTYPE_DECLARE(CAST, "Cast"); | ||||
| @@ -424,6 +438,12 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||||
| REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
| REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
| REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | ||||
| REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | |||||
| REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | |||||
| REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | |||||
| REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | |||||
| REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); | REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); | ||||
| REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); | REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); | ||||
| @@ -1032,14 +1052,11 @@ struct BasicInfo { | |||||
| uint32_t workspace_size; // workspace | uint32_t workspace_size; // workspace | ||||
| uint32_t total_size; // total memory size | uint32_t total_size; // total memory size | ||||
| }; | }; | ||||
| #pragma pack() // Cancels single-byte alignment | #pragma pack() // Cancels single-byte alignment | ||||
| } // namespace ge | } // namespace ge | ||||
| namespace domi { | namespace domi { | ||||
| /// @brief Data structure definition related to task sinking | /// @brief Data structure definition related to task sinking | ||||
| /// Build model | |||||
| enum BuildMode { | enum BuildMode { | ||||
| GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | ||||
| GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | ||||
| @@ -30,6 +30,14 @@ | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "mmpa/mmpa_api.h" | #include "mmpa/mmpa_api.h" | ||||
| #define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
| do { \ | |||||
| if (size <= 0) { \ | |||||
| DOMI_LOGE(param[#size] is not a positive number); \ | |||||
| return PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | ||||
| { \ | { \ | ||||
| bool b = (expr); \ | bool b = (expr); \ | ||||
| @@ -50,21 +58,6 @@ | |||||
| if (var) GE_CHK_RT(rtStreamDestroy(var)); \ | if (var) GE_CHK_RT(rtStreamDestroy(var)); \ | ||||
| }); | }); | ||||
| #define GE_MAKE_GUARD_RTEVENT(var) \ | |||||
| GE_MAKE_GUARD(var, [&] { \ | |||||
| if (var) GE_CHK_RT(rtEventDestroy(var)); \ | |||||
| }); | |||||
| #define GE_MAKE_GUARD_TENSOR(var) \ | |||||
| GE_MAKE_GUARD(var, [&] { \ | |||||
| if (var) GE_CHK_CCE(ccDestroyTensorDescriptor(&var)); \ | |||||
| }); | |||||
| #define GE_MAKE_GUARD_FILTER_DESC(var) \ | |||||
| GE_MAKE_GUARD(var, [&] { \ | |||||
| if (var) GE_CHK_CCE(ccDestroyFilterDescriptor(&var)); \ | |||||
| }); | |||||
| // For propagating errors when calling a function. | // For propagating errors when calling a function. | ||||
| #define GE_RETURN_IF_ERROR(expr) \ | #define GE_RETURN_IF_ERROR(expr) \ | ||||
| do { \ | do { \ | ||||
| @@ -76,7 +69,7 @@ | |||||
| do { \ | do { \ | ||||
| const ::ge::Status _status = (expr); \ | const ::ge::Status _status = (expr); \ | ||||
| if (_status) { \ | if (_status) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| return _status; \ | return _status; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -85,7 +78,7 @@ | |||||
| #define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | #define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | ||||
| do { \ | do { \ | ||||
| if (condition) { \ | if (condition) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| return ge::FAILED; \ | return ge::FAILED; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -95,7 +88,7 @@ | |||||
| do { \ | do { \ | ||||
| bool _condition = (condition); \ | bool _condition = (condition); \ | ||||
| if (!_condition) { \ | if (!_condition) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| return ge::FAILED; \ | return ge::FAILED; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -104,7 +97,7 @@ | |||||
| #define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | #define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | ||||
| do { \ | do { \ | ||||
| if (condition) { \ | if (condition) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| return ge::PARAM_INVALID; \ | return ge::PARAM_INVALID; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -114,111 +107,90 @@ | |||||
| do { \ | do { \ | ||||
| bool _condition = (condition); \ | bool _condition = (condition); \ | ||||
| if (!_condition) { \ | if (!_condition) { \ | ||||
| GE_LOGE(__VA_ARGS__); \ | |||||
| DOMI_LOGE(__VA_ARGS__); \ | |||||
| return ge::PARAM_INVALID; \ | return ge::PARAM_INVALID; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | ||||
| #define GE_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GE_LOGE(param[#val] must not be null.); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| #define GE_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| DOMI_LOGE(param[#val] must not be null.); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||||
| #define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GE_LOGE(param[#val] must not be null.); \ | |||||
| return; \ | |||||
| } \ | |||||
| // Check if the parameter is null. If yes, just return and record the error | |||||
| #define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| DOMI_LOGE(param[#val] must not be null.); \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | ||||
| #define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GE_LOGE(param[#val] must not be null.); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| #define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| DOMI_LOGE(param[#val] must not be null.); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Check whether the parameter is null. If yes, return directly and record the error log | // Check whether the parameter is null. If yes, return directly and record the error log | ||||
| #define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GE_LOGE(param[#val] must not be null.); \ | |||||
| return; \ | |||||
| } \ | |||||
| #define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| DOMI_LOGE(param[#val] must not be null.); \ | |||||
| return; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Check if the parameter is null. If yes, return false and record the error log | // Check if the parameter is null. If yes, return false and record the error log | ||||
| #define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| GE_LOGE(param[#val] must not be null.); \ | |||||
| return false; \ | |||||
| } \ | |||||
| #define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||||
| do { \ | |||||
| if (val == nullptr) { \ | |||||
| DOMI_LOGE(param[#val] must not be null.); \ | |||||
| return false; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Check if the parameter is out of bounds | // Check if the parameter is out of bounds | ||||
| #define GE_CHECK_SIZE(size) \ | |||||
| do { \ | |||||
| if (size == 0) { \ | |||||
| GE_LOGE(param[#size] is out of range); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| #define GE_CHECK_SIZE(size) \ | |||||
| do { \ | |||||
| if (size == 0) { \ | |||||
| DOMI_LOGE(param[#size] is out of range); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Macros that define the size variable | |||||
| #define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||||
| uint32_t _var_name; \ | |||||
| do { \ | |||||
| uint32_t _expr_size = (_expr); \ | |||||
| uint32_t _sizeof_size = (_sizeof); \ | |||||
| if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||||
| GE_LOGE(byte size : #_var_name is out of range); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| _var_name = _sizeof_size * _expr_size; \ | |||||
| } while (0); | |||||
| // Check if the container is empty | // Check if the container is empty | ||||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
| do { \ | |||||
| if (vector.empty()) { \ | |||||
| GE_LOGE(param[#vector] is empty !); \ | |||||
| return ge::FAILED; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||||
| do { \ | |||||
| if (size <= 0) { \ | |||||
| GE_LOGE(param[#size] is not a positive number); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||||
| do { \ | |||||
| if (vector.empty()) { \ | |||||
| DOMI_LOGE(param[#vector] is empty !); \ | |||||
| return ge::FAILED; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Check if the value on the left is greater than or equal to the value on the right | // Check if the value on the left is greater than or equal to the value on the right | ||||
| #define GE_CHECK_GE(lhs, rhs) \ | |||||
| do { \ | |||||
| if (lhs < rhs) { \ | |||||
| GE_LOGE(param[#lhs] is less than[#rhs]); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| #define GE_CHECK_GE(lhs, rhs) \ | |||||
| do { \ | |||||
| if (lhs < rhs) { \ | |||||
| DOMI_LOGE(param[#lhs] is less than[#rhs]); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| // Check if the value on the left is less than or equal to the value on the right | // Check if the value on the left is less than or equal to the value on the right | ||||
| #define GE_CHECK_LE(lhs, rhs) \ | |||||
| do { \ | |||||
| if (lhs > rhs) { \ | |||||
| GE_LOGE(param[#lhs] is greater than[#rhs]); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| #define GE_CHECK_LE(lhs, rhs) \ | |||||
| do { \ | |||||
| if (lhs > rhs) { \ | |||||
| DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ | |||||
| return ge::PARAM_INVALID; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| #define GE_DELETE_NEW_SINGLE(var) \ | #define GE_DELETE_NEW_SINGLE(var) \ | ||||
| @@ -52,10 +52,10 @@ | |||||
| #define DLOG_DECLARE(level) \ | #define DLOG_DECLARE(level) \ | ||||
| void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) | void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) | ||||
| namespace ge { | |||||
| namespace domi { | |||||
| DLOG_DECLARE(INFO); | DLOG_DECLARE(INFO); | ||||
| DLOG_DECLARE(WARNING); | DLOG_DECLARE(WARNING); | ||||
| DLOG_DECLARE(ERROR); | DLOG_DECLARE(ERROR); | ||||
| } // namespace ge | |||||
| } // namespace domi | |||||
| #endif // INC_FRAMEWORK_DLOG_LOG_H_ | #endif // INC_FRAMEWORK_DLOG_LOG_H_ | ||||
| @@ -38,7 +38,7 @@ struct DNNEngineAttribute { | |||||
| std::vector<std::string> mem_type; | std::vector<std::string> mem_type; | ||||
| uint32_t compute_cost; | uint32_t compute_cost; | ||||
| enum RuntimeType runtime_type; // HOST, DEVICE | enum RuntimeType runtime_type; // HOST, DEVICE | ||||
| // set this attribute if the inputformat of engine must be specific, otherwise set FORMAT_RESERVED | |||||
| // If engine input format must be specific, set this attribute, else set FORMAT_RESERVED | |||||
| Format engine_input_format; | Format engine_input_format; | ||||
| Format engine_output_format; | Format engine_output_format; | ||||
| }; | }; | ||||
| @@ -26,6 +26,7 @@ | |||||
| #include "common/types.h" | #include "common/types.h" | ||||
| #include "graph/tensor.h" | #include "graph/tensor.h" | ||||
| #include "runtime/base.h" | #include "runtime/base.h" | ||||
| #include "common/dynamic_aipp.h" | |||||
| namespace ge { | namespace ge { | ||||
| class ModelListenerAdapter; | class ModelListenerAdapter; | ||||
| @@ -33,12 +34,15 @@ class ModelListenerAdapter; | |||||
| class SingleOp; | class SingleOp; | ||||
| struct RunModelData { | struct RunModelData { | ||||
| uint32_t index; // Data index | |||||
| uint32_t model_id; // Model id | |||||
| std::vector<DataBuffer> blobs; // All input/output data buffer | |||||
| uint32_t timestamp; // Data creation time | |||||
| uint32_t timeout; // Processing timeout | |||||
| uint64_t request_id = 0; // Request ID | |||||
| uint32_t index; // Data index | |||||
| uint32_t modelId; | |||||
| std::vector<DataBuffer> blobs; // All input/output data buffer | |||||
| uint32_t timestamp; // Data creation time | |||||
| uint32_t timeout; // Processing timeout | |||||
| uint64_t request_id = 0; // Request ID | |||||
| uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 | |||||
| uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 | |||||
| uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 | |||||
| }; | }; | ||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | ||||
| @@ -46,12 +50,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| GeExecutor(); | GeExecutor(); | ||||
| ~GeExecutor() = default; | ~GeExecutor() = default; | ||||
| ge::Status Initialize(); | ge::Status Initialize(); | ||||
| ge::Status Finalize(); | |||||
| // Load model | // Load model | ||||
| ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority, | ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority, | ||||
| std::shared_ptr<ge::ModelListener> listener); | std::shared_ptr<ge::ModelListener> listener); | ||||
| ge::Status UnloadModel(uint32_t model_id); | |||||
| ge::Status UnloadModel(uint32_t modelId); | |||||
| ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data); | ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data); | ||||
| @@ -59,6 +64,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ||||
| std::vector<ge::TensorDesc> &output_desc); | std::vector<ge::TensorDesc> &output_desc); | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Set dynamic batch size | |||||
| /// @param [in] model_id: model id allocate from manager | |||||
| /// @param [in] dynamic_input_addr: dynamic input addr created by user | |||||
| /// @param [in] length: length of dynamic input addr | |||||
| /// @param [in] batch_size: batch size entered by user in dynamic multi-batch scenario | |||||
| /// @return execute result | |||||
| /// | |||||
| ge::Status SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size); | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Set dynamic image info | |||||
| /// @param [in] model_id: model id allocate from manager | |||||
| /// @param [in] dynamic_input_addr: dynamic input addr created by user | |||||
| /// @param [in] length: length of dynamic input addr | |||||
| /// @param [in] image_height: image height entered by user in dynamic multi-resolution scenario | |||||
| /// @param [in] image_width: image width entered by user in dynamic multi-resolution scenario | |||||
| /// @return execute result | |||||
| /// | |||||
| ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, | |||||
| uint64_t image_width); | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Get dynamic batch_info | |||||
| /// @param [in] model_id | |||||
| /// @param [out] batch_info | |||||
| /// @return execute result | |||||
| /// | |||||
| ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info); | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Set dynamic image info | |||||
| /// @param [in] model_id: model id allocate from manager | |||||
| /// @param [in] dynamic_input_addr: dynamic input addr created by user | |||||
| /// @param [in] length: length of dynamic input addr | |||||
| /// @param [in] aippBatchPara: kAippDynamicBatchPara vector by user in dynamic aipp | |||||
| /// @param [in] aippParms: kAippDynamicPara by user in dynamic aipp | |||||
| /// @return execute result | |||||
| /// | |||||
| ge::Status SetDynamicAippData(uint32_t model_id, void *dynamic_input_addr, uint64_t length, | |||||
| const std::vector<kAippDynamicBatchPara> &aippBatchPara, | |||||
| const kAippDynamicPara &aippParms); | |||||
| ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | ||||
| std::vector<ge::TensorDesc> &output_desc); | std::vector<ge::TensorDesc> &output_desc); | ||||
| @@ -147,7 +198,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| /// | /// | ||||
| ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); | ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); | ||||
| static ge::Status LoadSingleOp(const std::string &model_name, const ge::ModelData &model_data, void *stream, | |||||
| static ge::Status LoadSingleOp(const std::string &modelName, const ge::ModelData &modelData, void *stream, | |||||
| SingleOp **single_op); | SingleOp **single_op); | ||||
| static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs, | static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs, | ||||
| @@ -156,8 +207,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| static ge::Status ReleaseSingleOpResource(void *stream); | static ge::Status ReleaseSingleOpResource(void *stream); | ||||
| private: | private: | ||||
| static bool is_init_; | |||||
| std::vector<std::shared_ptr<ModelListenerAdapter>> listener_adapters_; | |||||
| static bool isInit_; | |||||
| }; | }; | ||||
| ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info); | ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info); | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "ge/ge_ir_build.h" | |||||
| #include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| @@ -45,6 +45,8 @@ class GeGenerator { | |||||
| Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, | Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, | ||||
| const std::vector<GeTensor> &inputs = std::vector<GeTensor>()); | const std::vector<GeTensor> &inputs = std::vector<GeTensor>()); | ||||
| Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief: Build single OP in Model. | /// @brief: Build single OP in Model. | ||||
| @@ -58,6 +60,8 @@ class GeGenerator { | |||||
| const std::vector<GeTensor> &outputs, const std::string &model_file_name); | const std::vector<GeTensor> &outputs, const std::string &model_file_name); | ||||
| private: | private: | ||||
| Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | |||||
| ge::ModelBufferData &model, bool is_offline = true); | |||||
| class Impl; | class Impl; | ||||
| std::shared_ptr<Impl> impl_; | std::shared_ptr<Impl> impl_; | ||||
| @@ -24,7 +24,6 @@ extern "C" { | |||||
| #endif | #endif | ||||
| typedef uint32_t Status_t; | typedef uint32_t Status_t; | ||||
| using Status_t = uint32_t; | |||||
| typedef void *OpAttr_t; | typedef void *OpAttr_t; | ||||
| typedef void *OpTensor_t; | typedef void *OpTensor_t; | ||||
| @@ -23,7 +23,7 @@ | |||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| namespace ge { | namespace ge { | ||||
| const int64_t kMemAlignSize = 512; | |||||
| const int64_t MEM_ALIGN_SIZE = 512; | |||||
| class MemoryAssigner { | class MemoryAssigner { | ||||
| public: | public: | ||||
| explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} | explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} | ||||
| @@ -39,4 +39,4 @@ class MemoryAssigner { | |||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ | |||||
| #endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ | |||||
| @@ -31,7 +31,6 @@ | |||||
| using domi::DOMI_TENSOR_ND; | using domi::DOMI_TENSOR_ND; | ||||
| using domi::DOMI_TENSOR_RESERVED; | using domi::DOMI_TENSOR_RESERVED; | ||||
| using domi::domiTensorFormat_t; | using domi::domiTensorFormat_t; | ||||
| using domi::FMK_TYPE_RESERVED; | |||||
| using domi::FrameworkType; | using domi::FrameworkType; | ||||
| using std::map; | using std::map; | ||||
| using std::string; | using std::string; | ||||
| @@ -44,10 +43,10 @@ namespace ge { | |||||
| * @brief run model | * @brief run model | ||||
| */ | */ | ||||
| enum RunMode { | enum RunMode { | ||||
| kGeOmModel = 0, // generate offline model file | |||||
| kModelToJson = 1, // convert to JSON file | |||||
| kOnlyPreCheck = 3, // only for pre-check | |||||
| kPbtxtToJson = 5 // pbtxt to json | |||||
| GEN_OM_MODEL = 0, // generate offline model file | |||||
| MODEL_TO_JSON = 1, // convert to JSON file | |||||
| ONLY_PRE_CHECK = 3, // only for pre-check | |||||
| PBTXT_TO_JSON = 5 // pbtxt to json | |||||
| }; | }; | ||||
| /// | /// | ||||
| @@ -56,10 +55,10 @@ enum RunMode { | |||||
| /// | /// | ||||
| enum HighPrecisionMode { | enum HighPrecisionMode { | ||||
| // the FP16 high-precision function is disabled in common mode | // the FP16 high-precision function is disabled in common mode | ||||
| kHighPrecisonDefault = 0, | |||||
| HIGH_PRECISION_DEFAULT = 0, | |||||
| // high-precision mode, in which FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) is enable | |||||
| kHighPrecisionFP16 = 1 | |||||
| // high-precision mode, enabling FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) | |||||
| HIGH_PRECISION_FP16 = 1 | |||||
| }; | }; | ||||
| /// | /// | ||||
| @@ -99,21 +98,23 @@ struct OmgContext { | |||||
| // preferential format used by the entire network | // preferential format used by the entire network | ||||
| domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | ||||
| domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | ||||
| RunMode run_mode = kOnlyPreCheck; | |||||
| RunMode run_mode = ONLY_PRE_CHECK; | |||||
| bool train_flag = false; | bool train_flag = false; | ||||
| // whether to use FP16 high precision | // whether to use FP16 high precision | ||||
| int32_t fp16_high_precision = kHighPrecisonDefault; | |||||
| int32_t fp16_high_precision = HIGH_PRECISION_DEFAULT; | |||||
| std::string output_type; | std::string output_type; | ||||
| // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | ||||
| // network require special processing based on the specific network. | |||||
| // e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||||
| // fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||||
| // convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||||
| // network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||||
| // is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||||
| // FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||||
| // operator needs to be deleted for the convolution kernel. | |||||
| std::string net_name; | std::string net_name; | ||||
| // whether to enable dynamic batch | |||||
| bool enable_l2dynamic = false; | |||||
| // Whether to use dynamic batch size or dynamic image size | |||||
| bool is_dynamic_input = false; | |||||
| std::string dynamic_batch_size; | |||||
| std::string dynamic_image_size; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -32,15 +32,7 @@ class PlatformVersionManager { | |||||
| PlatformVersionManager() = delete; | PlatformVersionManager() = delete; | ||||
| ~PlatformVersionManager() = delete; | ~PlatformVersionManager() = delete; | ||||
| static Status GetPlatformVersion(std::string &ver) { | static Status GetPlatformVersion(std::string &ver) { | ||||
| #if defined PLATFORM_PHOENIX | |||||
| ver = "3.51.z"; | |||||
| #elif defined PLATFORM_ORLANDO | |||||
| ver = "3.31.z"; | |||||
| #elif defined PLATFORM_MINI | |||||
| ver = "1.11.z"; | ver = "1.11.z"; | ||||
| #elif defined PLATFORM_CLOUD | |||||
| ver = "1.61.z"; | |||||
| #endif | |||||
| std::vector<std::string> version_splits = StringUtils::Split(ver, '.'); | std::vector<std::string> version_splits = StringUtils::Split(ver, '.'); | ||||
| GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;); | GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;); | ||||
| @@ -20,13 +20,17 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "graph/range_vistor.h" | #include "graph/range_vistor.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| namespace ge { | namespace ge { | ||||
| enum AnchorStatus { ANCHOR_SUSPEND = 0, ANCHOR_CONST = 1, ANCHOR_DATA = 2, ANCHOR_RESERVED = 3 }; | |||||
| enum AnchorStatus { | |||||
| ANCHOR_SUSPEND = 0, // dat null | |||||
| ANCHOR_CONST = 1, | |||||
| ANCHOR_DATA = 2, // Effective | |||||
| ANCHOR_RESERVED = 3 | |||||
| }; | |||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| @@ -81,17 +85,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||||
| virtual ~Anchor() = default; | virtual ~Anchor() = default; | ||||
| protected: | protected: | ||||
| // Whether the two anchors are equal | |||||
| // Whether the two anchor is equal | |||||
| virtual bool Equal(AnchorPtr anchor) const = 0; | virtual bool Equal(AnchorPtr anchor) const = 0; | ||||
| virtual bool IsTypeOf(TYPE type) const; | virtual bool IsTypeOf(TYPE type) const; | ||||
| public: | public: | ||||
| // Get all peer anchors connected to current anchor | // Get all peer anchors connected to current anchor | ||||
| Vistor<AnchorPtr> GetPeerAnchors() const; | Vistor<AnchorPtr> GetPeerAnchors() const; | ||||
| // Get the first peer anchor | |||||
| // Get peer anchor size | |||||
| size_t GetPeerAnchorsSize() const; | |||||
| // Get first peer anchor | |||||
| AnchorPtr GetFirstPeerAnchor() const; | AnchorPtr GetFirstPeerAnchor() const; | ||||
| // Get the node which is the owner of the anchor | |||||
| // Get the anchor belong to which node | |||||
| NodePtr GetOwnerNode() const; | NodePtr GetOwnerNode() const; | ||||
| // Remove all links with the anchor | // Remove all links with the anchor | ||||
| @@ -100,22 +106,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||||
| // Remove link with the given anchor | // Remove link with the given anchor | ||||
| graphStatus Unlink(const AnchorPtr &peer); | graphStatus Unlink(const AnchorPtr &peer); | ||||
| // Replace the peeranchor with the new peeranchor | |||||
| // Replace peer with new peers | |||||
| graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); | graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); | ||||
| // Judge if the anchor is linked with the given anchor | // Judge if the anchor is linked with the given anchor | ||||
| bool IsLinkedWith(const AnchorPtr &peer); | bool IsLinkedWith(const AnchorPtr &peer); | ||||
| // Get the anchor index of the node | |||||
| // Get anchor index of the node | |||||
| int GetIdx() const; | int GetIdx() const; | ||||
| // Set the anchor index of the node | |||||
| // set anchor index of the node | |||||
| void SetIdx(int index); | void SetIdx(int index); | ||||
| protected: | protected: | ||||
| // All peer anchors connected to current anchor | // All peer anchors connected to current anchor | ||||
| vector<std::weak_ptr<Anchor>> peer_anchors_; | vector<std::weak_ptr<Anchor>> peer_anchors_; | ||||
| // The owner nodes of the anchor | |||||
| // The owner node of anchor | |||||
| std::weak_ptr<Node> owner_node_; | std::weak_ptr<Node> owner_node_; | ||||
| // The index of current anchor | // The index of current anchor | ||||
| int idx_; | int idx_; | ||||
| @@ -167,7 +173,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataA | |||||
| virtual ~InDataAnchor() = default; | virtual ~InDataAnchor() = default; | ||||
| // Get source out data anchor | |||||
| // Get source out data anchor | |||||
| OutDataAnchorPtr GetPeerOutAnchor() const; | OutDataAnchorPtr GetPeerOutAnchor() const; | ||||
| // Build connection from OutDataAnchor to InDataAnchor | // Build connection from OutDataAnchor to InDataAnchor | ||||
| @@ -19,10 +19,10 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
| namespace ge { | namespace ge { | ||||
| class GeAttrValue; | class GeAttrValue; | ||||
| class _GeSerializable { | class _GeSerializable { | ||||
| public: | public: | ||||
| @@ -107,7 +107,6 @@ class _GeSerializable { | |||||
| static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | ||||
| }; | }; | ||||
| #define _GE_FI(a) #a, a | #define _GE_FI(a) #a, a | ||||
| #define _GE_MAP_FIELDS1(a1) _GE_FI(a1) | #define _GE_MAP_FIELDS1(a1) _GE_FI(a1) | ||||
| #define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) | #define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) | ||||
| @@ -130,23 +129,23 @@ class _GeSerializable { | |||||
| #define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ | #define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ | ||||
| _GE_FI(a1) \ | _GE_FI(a1) \ | ||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
| _GE_FI(a11) | |||||
| _GE_FI(a11) | |||||
| #define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ | #define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ | ||||
| _GE_FI(a1) \ | _GE_FI(a1) \ | ||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
| _GE_FI(a11), _GE_FI(a12) | |||||
| _GE_FI(a11), _GE_FI(a12) | |||||
| #define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ | #define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ | ||||
| _GE_FI(a1) \ | _GE_FI(a1) \ | ||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||||
| #define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ | #define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ | ||||
| _GE_FI(a1) \ | _GE_FI(a1) \ | ||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||||
| #define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ | #define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ | ||||
| _GE_FI(a1) \ | _GE_FI(a1) \ | ||||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | ||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||||
| #define _GE_PRIVATE_ARGS_GLUE(x, y) x y | #define _GE_PRIVATE_ARGS_GLUE(x, y) x y | ||||
| @@ -17,12 +17,11 @@ | |||||
| #ifndef INC_GRAPH_BUFFER_H_ | #ifndef INC_GRAPH_BUFFER_H_ | ||||
| #define INC_GRAPH_BUFFER_H_ | #define INC_GRAPH_BUFFER_H_ | ||||
| #include <graph/types.h> | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
| #include "graph/types.h" | |||||
| namespace ge { | namespace ge { | ||||
| #ifdef HOST_VISIBILITY | #ifdef HOST_VISIBILITY | ||||
| @@ -72,7 +71,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { | |||||
| GeIrProtoHelper<proto::AttrDef> data_; | GeIrProtoHelper<proto::AttrDef> data_; | ||||
| std::string *buffer_ = nullptr; | std::string *buffer_ = nullptr; | ||||
| // Create buffer from protobuf obj | |||||
| // Create from protobuf obj | |||||
| Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); | Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); | ||||
| Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); | Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); | ||||
| @@ -17,7 +17,6 @@ | |||||
| #ifndef INC_GRAPH_COMPUTE_GRAPH_H_ | #ifndef INC_GRAPH_COMPUTE_GRAPH_H_ | ||||
| #define INC_GRAPH_COMPUTE_GRAPH_H_ | #define INC_GRAPH_COMPUTE_GRAPH_H_ | ||||
| #include <deque> | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -63,7 +62,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>; | using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>; | ||||
| explicit ComputeGraph(const std::string &name); | explicit ComputeGraph(const std::string &name); | ||||
| virtual ~ComputeGraph(); | |||||
| ~ComputeGraph() override; | |||||
| std::string GetName() const; | std::string GetName() const; | ||||
| void SetName(const std::string &name); | void SetName(const std::string &name); | ||||
| @@ -81,7 +80,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| Vistor<NodePtr> GetOutputNodes() const; | Vistor<NodePtr> GetOutputNodes() const; | ||||
| NodePtr FindNode(const std::string &name) const; | NodePtr FindNode(const std::string &name) const; | ||||
| // Add node | |||||
| // AddNode with NodePtr | |||||
| NodePtr AddNode(NodePtr node); | NodePtr AddNode(NodePtr node); | ||||
| NodePtr AddNode(OpDescPtr op); | NodePtr AddNode(OpDescPtr op); | ||||
| NodePtr AddNodeFront(NodePtr node); | NodePtr AddNodeFront(NodePtr node); | ||||
| @@ -94,9 +93,40 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| graphStatus RemoveOutputNode(const NodePtr &node); | graphStatus RemoveOutputNode(const NodePtr &node); | ||||
| graphStatus RemoveConstInput(const NodePtr &node); | graphStatus RemoveConstInput(const NodePtr &node); | ||||
| /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, | |||||
| /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph | |||||
| /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() | |||||
| /// must equal to subgraph->GetOwnerGraph(). | |||||
| /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. | |||||
| /// The subgraph's name SHOULD(not must) be the same as the parameter `name` | |||||
| graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph); | |||||
| graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||||
| void RemoveSubgraph(const std::string &name); | |||||
| void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||||
| std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const; | |||||
| std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const; | |||||
| // obsolete | |||||
| std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph); | std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph); | ||||
| // obsolete | |||||
| graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph); | graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph); | ||||
| /// | |||||
| /// @brief Update input-mapping | |||||
| /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||||
| /// @return graphStatus | |||||
| /// | |||||
| graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||||
| /// | |||||
| /// @brief Update output-mapping | |||||
| /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||||
| /// @return graphStatus | |||||
| /// | |||||
| graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||||
| graphStatus TopologicalSorting(); | graphStatus TopologicalSorting(); | ||||
| bool IsValid() const; | bool IsValid() const; | ||||
| void Dump() const; | void Dump() const; | ||||
| @@ -127,6 +157,11 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| } | } | ||||
| } | } | ||||
| shared_ptr<ComputeGraph> GetParentGraph(); | |||||
| void SetParentGraph(const shared_ptr<ComputeGraph> &parent); | |||||
| shared_ptr<Node> GetParentNode(); | |||||
| void SetParentNode(const shared_ptr<Node> &parent); | |||||
| const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; } | const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; } | ||||
| void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } | void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } | ||||
| @@ -138,8 +173,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| uint32_t GetInputSize() const { return input_size_; } | uint32_t GetInputSize() const { return input_size_; } | ||||
| /// | /// | ||||
| /// Set iteration needed. | |||||
| /// If set is true, it means this graph need run iteration some | |||||
| /// Set is need train iteration. | |||||
| /// If set true, it means this graph need to be run iteration some | |||||
| /// times(according variant "npu_runconfig/iterations_per_loop"). | /// times(according variant "npu_runconfig/iterations_per_loop"). | ||||
| /// @param need_iteration is need iteration | /// @param need_iteration is need iteration | ||||
| /// | /// | ||||
| @@ -150,7 +185,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| const std::string GetOutput(); | const std::string GetOutput(); | ||||
| /// | /// | ||||
| /// Get need_iteration. | |||||
| /// Get is need train iteration. | |||||
| /// @return is need iteration | /// @return is need iteration | ||||
| /// | /// | ||||
| bool GetNeedIteration() const { return need_iteration_; } | bool GetNeedIteration() const { return need_iteration_; } | ||||
| @@ -201,6 +236,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| std::deque<NodePtr> &stack); | std::deque<NodePtr> &stack); | ||||
| graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
| std::map<string, NodePtr> &breadth_node_map); | std::map<string, NodePtr> &breadth_node_map); | ||||
| graphStatus TopologicalSortingSubgraph(); | |||||
| graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | ||||
| size_t GetInEdgeSize(const NodePtr &node); | size_t GetInEdgeSize(const NodePtr &node); | ||||
| size_t GetOutEdgeSize(const NodePtr &node); | size_t GetOutEdgeSize(const NodePtr &node); | ||||
| @@ -210,31 +246,38 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | ||||
| const std::vector<NodePtr> &l_node_ptr_vector) const; | const std::vector<NodePtr> &l_node_ptr_vector) const; | ||||
| ProtoAttrMapHelper attrs_; | |||||
| friend class ModelSerializeImp; | friend class ModelSerializeImp; | ||||
| friend class GraphDebugImp; | friend class GraphDebugImp; | ||||
| friend class OnnxUtils; | friend class OnnxUtils; | ||||
| std::string name_; | |||||
| uint32_t graph_id_ = 0; | |||||
| ProtoAttrMapHelper attrs_; | |||||
| std::vector<NodePtr> nodes_; | std::vector<NodePtr> nodes_; | ||||
| std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||||
| std::vector<NodePtr> target_nodes_info_; | |||||
| std::vector<NodePtr> input_nodes_; | std::vector<NodePtr> input_nodes_; | ||||
| std::vector<std::string> inputs_order_; | |||||
| uint32_t input_size_ = 1; | |||||
| std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||||
| uint32_t output_size_ = 1; | |||||
| std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||||
| std::vector<std::shared_ptr<ComputeGraph>> sub_graph_; | std::vector<std::shared_ptr<ComputeGraph>> sub_graph_; | ||||
| std::string name_; | |||||
| std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_; | |||||
| std::weak_ptr<ComputeGraph> parent_graph_; | |||||
| std::weak_ptr<Node> parent_node_; | |||||
| // the members followed should not in the ComputeGraph class | |||||
| bool is_valid_flag_; | bool is_valid_flag_; | ||||
| bool is_summary_graph_ = false; | bool is_summary_graph_ = false; | ||||
| // Indicates whether it is need iteration | // Indicates whether it is need iteration | ||||
| bool need_iteration_ = false; | bool need_iteration_ = false; | ||||
| std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_; | std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_; | ||||
| std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||||
| // TaskIdx -> op_name Map | // TaskIdx -> op_name Map | ||||
| std::map<uint32_t, std::string> op_name_map_; | std::map<uint32_t, std::string> op_name_map_; | ||||
| std::vector<std::string> inputs_order_; | |||||
| uint32_t output_size_ = 1; | |||||
| uint32_t input_size_ = 1; | |||||
| std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||||
| std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||||
| std::vector<NodePtr> target_nodes_info_; | |||||
| uint64_t session_id_ = 0; | uint64_t session_id_ = 0; | ||||
| uint32_t graph_id_ = 0; | |||||
| ge::Format data_format_ = ge::FORMAT_ND; | ge::Format data_format_ = ge::FORMAT_ND; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -18,7 +18,6 @@ | |||||
| #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | ||||
| #include <string> | #include <string> | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -59,6 +58,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | ||||
| @@ -75,8 +76,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | ||||
| // GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||||
| // ATTR_NAME_WEIGHTS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | ||||
| @@ -124,6 +124,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | ||||
| @@ -141,10 +148,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | ||||
| @@ -166,15 +178,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | ||||
| // _Arg | // _Arg | ||||
| @@ -263,7 +275,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||||
| // Huberloss | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||||
| // SSDRealDivTileMul | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||||
| // SSDSumMulRealDivMean | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||||
| /// ConcatFive2Four | |||||
| /// ConcatFour2Five | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||||
| // Scale | // Scale | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | ||||
| @@ -300,7 +334,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | ||||
| // Roipooling | // Roipooling | ||||
| @@ -313,6 +346,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||||
| // DetectionOutput | // DetectionOutput | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | ||||
| @@ -371,6 +405,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||||
| // Permute | // Permute | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||||
| // SSD Normalize | // SSD Normalize | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | ||||
| @@ -411,9 +446,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | ||||
| // Log | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||||
| // Pack | // Pack | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | ||||
| // Dynamic stitch | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||||
| // Unpack | // Unpack | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | ||||
| // Gathernd | // Gathernd | ||||
| @@ -422,8 +463,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||||
| // Argmax | // Argmax | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||||
| // Upsample | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||||
| // Relu | // Relu | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | ||||
| @@ -439,6 +488,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_AT | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; | |||||
| // Squeeze | // Squeeze | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; | ||||
| @@ -461,6 +511,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_TF; | |||||
| // Generate_rpn_proposal | // Generate_rpn_proposal | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | ||||
| @@ -493,6 +544,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | ||||
| static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||||
| // ENTER | // ENTER | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | ||||
| @@ -518,6 +570,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | ||||
| // RetinaNet | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||||
| // MatMul | // MatMul | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | ||||
| @@ -566,10 +621,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||||
| // Upsample | // Upsample | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | ||||
| // PadV2 | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||||
| // MirrorPad | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||||
| // Filler | // Filler | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | ||||
| @@ -590,36 +665,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | ||||
| @@ -634,6 +679,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | ||||
| @@ -642,12 +689,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
| // Public attribute | // Public attribute | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | ||||
| @@ -685,11 +726,178 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | ||||
| // Used for operators that do not generate task | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; | |||||
| // Used for operators that output reuse input | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||||
| // L2_normalize | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||||
| // HCOM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||||
| // Log time stamp | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||||
| // SpaceToDepth/DepthToSpace | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||||
| // SparseSoftmaxCrossEntropyWithLogits | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||||
| // MaxPoolGradWithArgmax | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||||
| // AvgPoolGrad | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||||
| // Varible | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||||
| // Assign | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||||
| // ShapeN | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||||
| // Space2bacth batch2space | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||||
| // Depth_to_space space_to_depth | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||||
| // FakeQuantWithMinMaxVars | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||||
| // Mobilenet_ssd_conv_fusion | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||||
| // Lsh project | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||||
| // Control flow | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||||
| // GatherV2 attr def | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||||
| // Reshape attr def | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||||
| // Axis attr def | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||||
| // The node link with SparseSoftmaxCrossEntropyWithLogits | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||||
| // For constant folding | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||||
| // Used for mark the active label list to find stream of activated node | // Used for mark the active label list to find stream of activated node | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; | |||||
| // Multi batch | // Multi batch | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; | ||||
| @@ -697,7 +905,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| // Control flow | // Control flow | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | ||||
| @@ -709,6 +916,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; | ||||
| // Function Op | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | |||||
| // Used for mark the active node is for loop, type:bool | // Used for mark the active node is for loop, type:bool | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | ||||
| @@ -742,6 +952,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INS | |||||
| // For inserted op | // For inserted op | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | ||||
| // For compress weight | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; | |||||
| // For data dump | // For data dump | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; | ||||
| @@ -752,6 +965,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | ||||
| // used for l1 fusion and other fusion in future | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||||
| // used for label switch | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | |||||
| // Varible | // Varible | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | ||||
| @@ -20,10 +20,8 @@ | |||||
| #include <atomic> | #include <atomic> | ||||
| #include <memory> | #include <memory> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/attr_value_serializable.h" | #include "graph/attr_value_serializable.h" | ||||
| #include "graph/buffer.h" | #include "graph/buffer.h" | ||||
| namespace ge { | namespace ge { | ||||
| #define DEF_TYPE_DEC(type, name) \ | #define DEF_TYPE_DEC(type, name) \ | ||||
| inline void set_##name(const type &value) { name = value; } \ | inline void set_##name(const type &value) { name = value; } \ | ||||
| @@ -49,10 +47,9 @@ namespace ge { | |||||
| inline void add_##name(type value) { name.push_back(value); } \ | inline void add_##name(type value) { name.push_back(value); } \ | ||||
| inline std::vector<type> *mutable_##name() { return &name; } | inline std::vector<type> *mutable_##name() { return &name; } | ||||
| #define DEF_TYPE_BYTES_DEC(name) \ | |||||
| inline void clear_##name() { name.ClearBuffer(); } \ | |||||
| inline void set_##name(const void *value, size_t size) { \ | |||||
| name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||||
| #define DEF_TYPE_BYTES_DEC(name) \ | |||||
| inline void clear_##name() { name.ClearBuffer(); } \ | |||||
| inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||||
| inline Buffer *mutable_##name() { return &name; } | inline Buffer *mutable_##name() { return &name; } | ||||
| struct CompressInfo { | struct CompressInfo { | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/detail/any_map.h" | #include "graph/detail/any_map.h" | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| @@ -96,7 +95,7 @@ class GeIrProtoHelper { | |||||
| } | } | ||||
| } | } | ||||
| // protoMsg_ is part of protoOwner_ and they have the same runtime | |||||
| // protoMsg_ is part of protoOwner_, they have the same runtime | |||||
| ProtoMsgOwner protoOwner_ = nullptr; | ProtoMsgOwner protoOwner_ = nullptr; | ||||
| ProtoType *protoMsg_ = nullptr; | ProtoType *protoMsg_ = nullptr; | ||||
| friend class GeIrProtoHelper<typename std::conditional< | friend class GeIrProtoHelper<typename std::conditional< | ||||
| @@ -21,9 +21,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/model.h" | |||||
| #include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| @@ -48,15 +46,15 @@ struct NodeNameNodeReq { | |||||
| class ModelSerializeImp { | class ModelSerializeImp { | ||||
| public: | public: | ||||
| bool SerializeModel(const Model &model, proto::ModelDef *modeProto); | |||||
| bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); | |||||
| bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto); | |||||
| bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); | |||||
| bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); | bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); | ||||
| bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto); | |||||
| bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||||
| bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto); | |||||
| bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||||
| bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); | bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/buffer.h" | #include "graph/buffer.h" | ||||
| #include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| @@ -139,15 +138,14 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
| template <typename vector_type> | template <typename vector_type> | ||||
| // To cols | // To cols | ||||
| using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, | |||||
| int>::type; | |||||
| using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type; | |||||
| template <typename one_type> | template <typename one_type> | ||||
| using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type; | using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type; | ||||
| template <typename val_type> | template <typename val_type> | ||||
| using enable_if_type_valid_t = | using enable_if_type_valid_t = | ||||
| typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||||
| typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||||
| template <typename seriliable_type> | template <typename seriliable_type> | ||||
| using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; | using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; | ||||
| @@ -18,7 +18,6 @@ | |||||
| #define INC_GRAPH_GE_CONTEXT_H_ | #define INC_GRAPH_GE_CONTEXT_H_ | ||||
| #include <string> | #include <string> | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -42,4 +41,4 @@ class GEContext { | |||||
| GEContext &GetContext(); | GEContext &GetContext(); | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_GRAPH_GE_CONTEXT_H_ | |||||
| #endif // INC_GRAPH_GE_CONTEXT_H_ | |||||
| @@ -20,7 +20,6 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| using std::map; | using std::map; | ||||
| @@ -42,5 +41,4 @@ class GEThreadLocalContext { | |||||
| GEThreadLocalContext &GetThreadLocalContext(); | GEThreadLocalContext &GetThreadLocalContext(); | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ | #endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ | ||||
| @@ -21,12 +21,10 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
| #include "graph/buffer.h" | #include "graph/buffer.h" | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| namespace ge { | namespace ge { | ||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | ||||
| public: | public: | ||||
| @@ -43,6 +41,18 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||||
| int64_t GetShapeSize() const; | int64_t GetShapeSize() const; | ||||
| std::string ToString() const; | std::string ToString() const; | ||||
| /// | |||||
| /// @brief Check is unknown shape | |||||
| /// @return bool | |||||
| /// | |||||
| bool IsUnknownShape() const; | |||||
| /// | |||||
| /// @brief Check is a scalar | |||||
| /// @return bool | |||||
| /// | |||||
| bool IsScalar() const; | |||||
| GeShape(const GeShape &other); | GeShape(const GeShape &other); | ||||
| GeShape(GeShape &&other); | GeShape(GeShape &&other); | ||||
| GeShape &operator=(const GeShape &other); | GeShape &operator=(const GeShape &other); | ||||
| @@ -51,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||||
| private: | private: | ||||
| GeIrProtoHelper<proto::ShapeDef> shape_def_; | GeIrProtoHelper<proto::ShapeDef> shape_def_; | ||||
| friend class GeTensorDesc; | friend class GeTensorDesc; | ||||
| // Create geshape from proto obj | |||||
| // Create from proto obj | |||||
| GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); | GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); | ||||
| void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } | void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } | ||||
| @@ -112,7 +122,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||||
| void Init(); | void Init(); | ||||
| // Create getensordesc from proto obj | |||||
| // Create from proto obj | |||||
| GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); | GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); | ||||
| friend class GeTensor; | friend class GeTensor; | ||||
| friend class GeAttrValueImp; | friend class GeAttrValueImp; | ||||
| @@ -159,10 +169,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { | |||||
| friend class GeAttrValueImp; | friend class GeAttrValueImp; | ||||
| friend class ModelSerializeImp; | friend class ModelSerializeImp; | ||||
| friend class OnnxUtils; | friend class OnnxUtils; | ||||
| // Create getensor from proto obj | |||||
| // Create from proto obj | |||||
| GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); | GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); | ||||
| GeIrProtoHelper<proto::TensorDef> tensor_def_; | GeIrProtoHelper<proto::TensorDef> tensor_def_; | ||||
| // Reference from tensorDef_, cab not use it directly | |||||
| // Reference from tensorDef_, do not direct use | |||||
| mutable GeTensorDesc __desc_; | mutable GeTensorDesc __desc_; | ||||
| GeTensorDesc &DescReference() const; | GeTensorDesc &DescReference() const; | ||||
| }; | }; | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
| #include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| @@ -62,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | |||||
| using AttrHolder::HasAttr; | using AttrHolder::HasAttr; | ||||
| using AttrHolder::SetAttr; | using AttrHolder::SetAttr; | ||||
| graphStatus Save(Buffer &buffer) const; | |||||
| graphStatus Save(Buffer &buffer, bool is_dump = false) const; | |||||
| graphStatus SaveToFile(const string &file_name) const; | graphStatus SaveToFile(const string &file_name) const; | ||||
| // Model will be rewrite | // Model will be rewrite | ||||
| @@ -19,7 +19,6 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include "graph/buffer.h" | #include "graph/buffer.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/model.h" | #include "graph/model.h" | ||||
| @@ -27,7 +26,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class ModelSerialize { | class ModelSerialize { | ||||
| public: | public: | ||||
| Buffer SerializeModel(const Model &model); | |||||
| Buffer SerializeModel(const Model &model, bool is_dump = false); | |||||
| Model UnserializeModel(const uint8_t *data, size_t len); | Model UnserializeModel(const uint8_t *data, size_t len); | ||||
| Model UnserializeModel(ge::proto::ModelDef &model_def); | Model UnserializeModel(ge::proto::ModelDef &model_def); | ||||
| @@ -113,25 +113,25 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
| bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const; | bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const; | ||||
| // All inData nodes | |||||
| // All in Data nodes | |||||
| Vistor<NodePtr> GetInDataNodes() const; | Vistor<NodePtr> GetInDataNodes() const; | ||||
| // All inControl nodes | |||||
| // All in Control nodes | |||||
| Vistor<NodePtr> GetInControlNodes() const; | Vistor<NodePtr> GetInControlNodes() const; | ||||
| // GetInAllNodes = InDataNodes + InControlNodes | // GetInAllNodes = InDataNodes + InControlNodes | ||||
| Vistor<NodePtr> GetInAllNodes() const; | Vistor<NodePtr> GetInAllNodes() const; | ||||
| // All outData nodes | |||||
| // All out Data nodes | |||||
| Vistor<NodePtr> GetOutDataNodes() const; | Vistor<NodePtr> GetOutDataNodes() const; | ||||
| uint32_t GetOutDataNodesSize() const; | uint32_t GetOutDataNodesSize() const; | ||||
| // All outControl nodes | |||||
| // All out Control nodes | |||||
| Vistor<NodePtr> GetOutControlNodes() const; | Vistor<NodePtr> GetOutControlNodes() const; | ||||
| // GetOutAllNodes = OutDataNodes + InControlNodes | // GetOutAllNodes = OutDataNodes + InControlNodes | ||||
| Vistor<NodePtr> GetOutAllNodes() const; | Vistor<NodePtr> GetOutAllNodes() const; | ||||
| // Get all indata nodes and its outanchor | |||||
| // Get all in data nodes and its out-anchor | |||||
| Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const; | Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const; | ||||
| // Get all outdata nodes and its inanchor | |||||
| // Get all out data nodes and its in-anchor | |||||
| Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const; | Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const; | ||||
| graphStatus InferShapeAndType() const; | graphStatus InferShapeAndType() const; | ||||
| @@ -176,7 +176,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
| void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | ||||
| NodePtr GetOrigNode(void) { return orig_node_; } | |||||
| NodePtr GetOrigNode() { return orig_node_; } | |||||
| private: | private: | ||||
| bool NodeMembersAreEqual(const Node &r_node) const; | bool NodeMembersAreEqual(const Node &r_node) const; | ||||
| @@ -23,7 +23,6 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <vector> | #include <vector> | ||||
| #include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
| #include "graph/range_vistor.h" | #include "graph/range_vistor.h" | ||||
| @@ -108,6 +107,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| size_t GetInputsSize() const; | size_t GetInputsSize() const; | ||||
| size_t GetAllInputsSize() const; | |||||
| graphStatus AddOutputDesc(const GeTensorDesc &output_desc); | graphStatus AddOutputDesc(const GeTensorDesc &output_desc); | ||||
| graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); | graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); | ||||
| @@ -122,6 +123,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | ||||
| uint32_t GetAllOutputsDescSize() const; | |||||
| Vistor<GeTensorDesc> GetAllOutputsDesc() const; | Vistor<GeTensorDesc> GetAllOutputsDesc() const; | ||||
| Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const; | Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const; | ||||
| @@ -132,6 +135,10 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; | ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; | ||||
| ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; | |||||
| ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; | |||||
| graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | ||||
| graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | ||||
| @@ -140,7 +147,11 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| bool IsOptionalInput(uint32_t index) const; | bool IsOptionalInput(uint32_t index) const; | ||||
| std::map<string, uint32_t> GetAllInputName(); | |||||
| std::map<string, uint32_t> GetAllInputName() const; | |||||
| void SetAllInputName(const std::map<string, uint32_t> &input_name_idx); | |||||
| std::vector<string> GetAllOptionalInputName() const; | |||||
| std::map<string, uint32_t> GetAllOutputName(); | std::map<string, uint32_t> GetAllOutputName(); | ||||
| @@ -225,6 +236,14 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| std::string GetOpEngineName() const; | std::string GetOpEngineName() const; | ||||
| graphStatus AddSubgraphName(const std::string &name); | |||||
| const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | |||||
| std::string GetSubgraphInstanceName(uint32_t index) const; | |||||
| const std::vector<std::string> &GetSubgraphInstanceNames() const; | |||||
| void AddSubgraphInstanceName(std::string name); | |||||
| void RemoveSubgraphInstanceName(const std::string &name); | |||||
| protected: | protected: | ||||
| ProtoAttrMapHelper MutableAttrMap() override; | ProtoAttrMapHelper MutableAttrMap() override; | ||||
| ConstProtoAttrMapHelper GetAttrMap() const override; | ConstProtoAttrMapHelper GetAttrMap() const override; | ||||
| @@ -236,9 +255,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; | bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; | ||||
| GeIrProtoHelper<ge::proto::OpDef> op_def_; | GeIrProtoHelper<ge::proto::OpDef> op_def_; | ||||
| std::vector<std::string> subgraph_instance_names_; | |||||
| std::map<std::string, uint32_t> subgraph_names_to_index_; | |||||
| vector<GeTensorDescPtr> inputs_desc_{}; | vector<GeTensorDescPtr> inputs_desc_{}; | ||||
| map<string, uint32_t> input_name_idx_{}; | |||||
| std::unordered_set<string> optional_input_names_{}; | |||||
| vector<GeTensorDescPtr> outputs_desc_{}; | vector<GeTensorDescPtr> outputs_desc_{}; | ||||
| map<string, uint32_t> output_name_idx_{}; | map<string, uint32_t> output_name_idx_{}; | ||||
| std::function<graphStatus(Operator &)> infer_func_ = nullptr; | std::function<graphStatus(Operator &)> infer_func_ = nullptr; | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/operator_factory.h" | #include "graph/operator_factory.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -47,7 +46,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { | |||||
| static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); | static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); | ||||
| private: | |||||
| static shared_ptr<std::map<string, OpCreator>> operator_creators_; | static shared_ptr<std::map<string, OpCreator>> operator_creators_; | ||||
| static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_; | static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_; | ||||
| static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_; | static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_; | ||||
| @@ -18,8 +18,8 @@ | |||||
| #define INC_GRAPH_SHAPE_REFINER_H_ | #define INC_GRAPH_SHAPE_REFINER_H_ | ||||
| #include <string> | #include <string> | ||||
| #include "external/graph/inference_context.h" | #include "external/graph/inference_context.h" | ||||
| #include "external/graph/ge_error_codes.h" | #include "external/graph/ge_error_codes.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| @@ -27,8 +27,10 @@ namespace ge { | |||||
| // ShapeRefiner performs shape inference for compute graphs | // ShapeRefiner performs shape inference for compute graphs | ||||
| class ShapeRefiner { | class ShapeRefiner { | ||||
| public: | public: | ||||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); | |||||
| static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); | |||||
| static graphStatus InferShapeAndType(const NodePtr &node); | static graphStatus InferShapeAndType(const NodePtr &node); | ||||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||||
| private: | private: | ||||
| static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | ||||
| @@ -14,8 +14,8 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
| #define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
| #ifndef INC_GRAPH_USR_TYPES_H_ | |||||
| #define INC_GRAPH_USR_TYPES_H_ | |||||
| #include <atomic> | #include <atomic> | ||||
| #include <memory> | #include <memory> | ||||
| @@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||||
| #undef USR_TYPE_BYTES_DEC | #undef USR_TYPE_BYTES_DEC | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||||
| #endif // INC_GRAPH_USR_TYPES_H_ | |||||
| @@ -99,8 +99,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||||
| static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | ||||
| static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); | static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); | ||||
| // Value will be moved | // Value will be moved | ||||
| static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, | |||||
| vector<Buffer> &listBuffer); | |||||
| static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||||
| static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | ||||
| static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value); | static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value); | ||||
| @@ -116,6 +115,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||||
| static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); | static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); | ||||
| static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); | |||||
| class AttrHolderAdapter { | class AttrHolderAdapter { | ||||
| public: | public: | ||||
| AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} | AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} | ||||
| @@ -137,6 +137,18 @@ class GraphUtils { | |||||
| static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | ||||
| const std::vector<OpDescPtr> &vec_op_desc); | const std::vector<OpDescPtr> &vec_op_desc); | ||||
| /// | |||||
| /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst | |||||
| /// @param [in] src | |||||
| /// @param [in] dsts | |||||
| /// @param [in] insert_node | |||||
| /// @param [in] input_index | |||||
| /// @param [in] output_index | |||||
| /// @return graphStatus | |||||
| /// | |||||
| static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||||
| const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||||
| static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | ||||
| static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); | static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); | ||||
| @@ -145,16 +157,12 @@ class GraphUtils { | |||||
| static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node); | static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node); | ||||
| static bool CheckIsTrainGraph(const ge::ComputeGraphPtr &compute_graph); | |||||
| static bool MatchDumpStr(const std::string &suffix); | static bool MatchDumpStr(const std::string &suffix); | ||||
| static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); | static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); | ||||
| static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | ||||
| static bool CheckGlobalStepNode(const ge::NodePtr &node); | |||||
| static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | ||||
| static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | ||||
| @@ -252,6 +260,315 @@ class GraphUtils { | |||||
| /// @return success: GRAPH_SUCESS | /// @return success: GRAPH_SUCESS | ||||
| /// | /// | ||||
| static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | ||||
| static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||||
| }; | |||||
| class ComputeGraphBuilder { | |||||
| public: | |||||
| ComputeGraphBuilder() : owner_graph_(nullptr) {} | |||||
| ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; | |||||
| ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; | |||||
| ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; | |||||
| ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; | |||||
| ~ComputeGraphBuilder() = default; | |||||
| /// | |||||
| /// @brief Add node to graph | |||||
| /// @param [in] op_desc | |||||
| /// @return ComputeGraphBuilder | |||||
| /// | |||||
| virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); | |||||
| /// | |||||
| /// @brief Add data-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] out_anchor_ind | |||||
| /// @param [in] dst_name | |||||
| /// @param [in] in_anchor_ind | |||||
| /// @return ComputeGraphBuilder | |||||
| /// | |||||
| virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, | |||||
| const std::string &dst_name, uint32_t in_anchor_ind); | |||||
| /// | |||||
| /// @brief Add ctrl-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] dst_name | |||||
| /// @return ComputeGraphBuilder | |||||
| /// | |||||
| virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); | |||||
| /// | |||||
| /// @brief Build graph | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return ComputeGraphPtr | |||||
| /// | |||||
| virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; | |||||
| /// @brief Get node with name | |||||
| /// @param [in] name | |||||
| /// @return NodePtr | |||||
| /// | |||||
| NodePtr GetNode(const std::string &name); | |||||
| protected: | |||||
| /// | |||||
| /// @brief Build nodes | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildNodes(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Build data-links | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildDataLinks(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Build ctrl-links | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); | |||||
| ComputeGraphPtr owner_graph_; | |||||
| // node_name -> node | |||||
| std::map<std::string, NodePtr> node_names_; | |||||
| std::vector<OpDescPtr> nodes_; | |||||
| // <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind> | |||||
| std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_; | |||||
| // src_node_name -> dst_node_name | |||||
| std::vector<std::pair<std::string, std::string>> ctrl_links_; | |||||
| }; | |||||
| class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
| public: | |||||
| explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} | |||||
| CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; | |||||
| CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; | |||||
| CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; | |||||
| CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; | |||||
| ~CompleteGraphBuilder() = default; | |||||
| /// | |||||
| /// @brief Add node to graph | |||||
| /// @param [in] op_desc | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||||
| /// | |||||
| /// @brief Add data-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] out_anchor_ind | |||||
| /// @param [in] dst_name | |||||
| /// @param [in] in_anchor_ind | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||||
| uint32_t in_anchor_ind) override; | |||||
| /// | |||||
| /// @brief Add ctrl-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] dst_name | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||||
| /// | |||||
| /// @brief Set index_th input anchor for graph | |||||
| /// @param [in] index | |||||
| /// @param [in] node_names | |||||
| /// @param [in] anchor_inds | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names, | |||||
| const std::vector<uint32_t> &anchor_inds); | |||||
| /// | |||||
| /// @brief Set index_th input of graph as useless | |||||
| /// @param [in] index | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetUselessInput(uint32_t index); | |||||
| /// | |||||
| /// @brief Add output anchor for graph | |||||
| /// @param [in] owner_node_name | |||||
| /// @param [in] anchor_ind | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); | |||||
| /// | |||||
| /// @brief Set parent-node of graph | |||||
| /// @param [in] parent_node | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); | |||||
| /// | |||||
| /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node | |||||
| /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||||
| /// | |||||
| /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind | |||||
| /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node | |||||
| /// @return CompleteGraphBuilder | |||||
| /// | |||||
| CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||||
| /// | |||||
| /// @brief Build graph | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return ComputeGraphPtr | |||||
| /// | |||||
| ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||||
| private: | |||||
| /// | |||||
| /// @brief Build inputs | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildInputs(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Add data node | |||||
| /// @param [in] index | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Build outputs | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildOutputs(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Add NetOutput node | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return NodePtr | |||||
| /// | |||||
| NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Add input/output tensor for NetOutput node | |||||
| /// @param [in] out_nodes_info | |||||
| /// @param [out] net_output_desc | |||||
| /// @return graphStatus | |||||
| /// | |||||
| graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
| OpDescPtr &net_output_desc); | |||||
| /// | |||||
| /// @brief Add edge for NetOutput node | |||||
| /// @param [in] out_nodes_info | |||||
| /// @param [out] net_output_node | |||||
| /// @return graphStatus | |||||
| /// | |||||
| graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||||
| const NodePtr &net_output_node); | |||||
| std::string name_; | |||||
| NodePtr parent_node_; | |||||
| std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | |||||
| std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | |||||
| // index_of_graph_input -> in_anchor_index_of_parent_node | |||||
| std::map<uint32_t, uint32_t> input_mapping_; | |||||
| // index_of_graph_output -> out_anchor_index_of_parent_node | |||||
| std::map<uint32_t, uint32_t> output_mapping_; | |||||
| }; | |||||
| class PartialGraphBuilder : public ComputeGraphBuilder { | |||||
| public: | |||||
| PartialGraphBuilder() = default; | |||||
| PartialGraphBuilder(const PartialGraphBuilder &) = delete; | |||||
| PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; | |||||
| PartialGraphBuilder(const PartialGraphBuilder &&) = delete; | |||||
| PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; | |||||
| ~PartialGraphBuilder() = default; | |||||
| /// | |||||
| /// @brief Add node to graph | |||||
| /// @param [in] op_desc | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||||
| /// | |||||
| /// @brief Add data-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] out_anchor_ind | |||||
| /// @param [in] dst_name | |||||
| /// @param [in] in_anchor_ind | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||||
| uint32_t in_anchor_ind) override; | |||||
| /// | |||||
| /// @brief Add ctrl-link among nodes in graph | |||||
| /// @param [in] src_name | |||||
| /// @param [in] dst_name | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||||
| /// | |||||
| /// @brief Set owner graph | |||||
| /// @param [in] graph | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); | |||||
| /// | |||||
| /// @brief Add exist node | |||||
| /// @param [in] node | |||||
| /// @return PartialGraphBuilder | |||||
| /// | |||||
| PartialGraphBuilder &AddExistNode(const NodePtr &node); | |||||
| /// | |||||
| /// @brief Build multi nodes with links | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return ComputeGraphPtr | |||||
| /// | |||||
| ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||||
| private: | |||||
| /// | |||||
| /// @brief Build exist nodes | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildExistNodes(graphStatus &error_code, std::string &error_msg); | |||||
| std::vector<NodePtr> exist_nodes_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -56,6 +56,11 @@ class NodeUtils { | |||||
| static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | ||||
| static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | ||||
| static std::string GetNodeType(const Node &node); | |||||
| static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | |||||
| static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); | |||||
| private: | private: | ||||
| static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | ||||
| static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | ||||
| @@ -20,7 +20,6 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/def_types.h" | #include "graph/def_types.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
| @@ -29,7 +28,6 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class OpDesc; | class OpDesc; | ||||
| using OpDescPtr = std::shared_ptr<OpDesc>; | using OpDescPtr = std::shared_ptr<OpDesc>; | ||||
| class OpDescUtils { | class OpDescUtils { | ||||
| @@ -39,55 +37,108 @@ class OpDescUtils { | |||||
| OpDescUtils() = default; | OpDescUtils() = default; | ||||
| ~OpDescUtils() = default; | ~OpDescUtils() = default; | ||||
| static bool HasQuantizeFactorParams(const OpDescPtr &op_desc); | |||||
| static bool HasQuantizeFactorParams(const OpDesc &op_desc); | |||||
| static graphStatus GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant); | |||||
| static graphStatus GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant); | |||||
| static graphStatus SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant); | |||||
| static graphStatus SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant); | |||||
| static vector<ge::NodePtr> GetConstInputNode(const ge::Node &node); | |||||
| static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr> &input_nodes); | |||||
| static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node); | |||||
| static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr &node); | |||||
| static vector<GeTensorPtr> MutableWeights(const ge::Node &node); | |||||
| static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); | |||||
| static bool HasQuantizeFactorParams(const OpDesc& op_desc); | |||||
| static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); | |||||
| static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); | |||||
| static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); | |||||
| static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); | |||||
| static vector<ge::NodePtr> GetConstInputNode(const ge::Node& node); | |||||
| static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr>& input_nodes); | |||||
| static vector<ConstGeTensorPtr> GetWeights(const ge::Node& node); | |||||
| static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr& node); | |||||
| static vector<GeTensorPtr> MutableWeights(const ge::Node& node); | |||||
| static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node); | static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node); | ||||
| static graphStatus SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights); | |||||
| static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights); | |||||
| static graphStatus SetWeights(ge::Node& node, const vector<ge::GeTensorPtr>& weights); | |||||
| static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr>& weights); | |||||
| static graphStatus ClearWeights(ge::NodePtr node); | static graphStatus ClearWeights(ge::NodePtr node); | ||||
| static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); | static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); | ||||
| static bool ClearInputDesc(const ge::NodePtr &node); | |||||
| static bool ClearOutputDesc(const ge::OpDescPtr &op_desc, uint32_t index); | |||||
| static bool ClearOutputDesc(const ge::NodePtr &node); | |||||
| static vector<ge::NodePtr> GetConstInputs(const ge::Node &node); | |||||
| static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr &node); | |||||
| static size_t GetNonConstInputsSize(const ge::Node &node); | |||||
| static bool ClearInputDesc(const ge::NodePtr& node); | |||||
| static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); | |||||
| static bool ClearOutputDesc(const ge::NodePtr& node); | |||||
| static vector<ge::NodePtr> GetConstInputs(const ge::Node& node); | |||||
| static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr& node); | |||||
| static size_t GetNonConstInputsSize(const ge::Node& node); | |||||
| static size_t GetNonConstInputsSize(ge::ConstNodePtr node); | static size_t GetNonConstInputsSize(ge::ConstNodePtr node); | ||||
| // Index: Indicate the index of all non const inputs | |||||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const = 0); | |||||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const = 0); | |||||
| static bool GetNonConstInputIndex(const ge::Node &node, size_t index_non_const, size_t &index); | |||||
| static bool GetNonConstInputIndex(const ge::ConstNodePtr &node, size_t index_non_const, size_t &index); | |||||
| // Index: Indicate the index of all inputs | |||||
| static bool IsNonConstInput(const ge::Node &node, size_t index = 0); | |||||
| static bool IsNonConstInput(const ge::ConstNodePtr &node, size_t index = 0); | |||||
| static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr &node); | |||||
| static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr); | |||||
| // Index: Indicates the index of all non const inputs | |||||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); | |||||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); | |||||
| static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); | |||||
| static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); | |||||
| // Index: Indicates the index of all inputs | |||||
| static bool IsNonConstInput(const ge::Node& node, size_t index = 0); | |||||
| static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); | |||||
| static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr& node); | |||||
| static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); | |||||
| static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); | static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); | ||||
| static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); | static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); | ||||
| static OpDescPtr GetOpDescFromOperator(const Operator &oprt); | |||||
| static OpDescPtr GetOpDescFromOperator(const Operator& oprt); | |||||
| static OpDescPtr CreateConstOp(const GeTensorPtr &tensor_ptr); | |||||
| static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | |||||
| private: | private: | ||||
| static GeTensorPtr MutableWeights(ge::OpDesc &op_desc); | |||||
| static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | |||||
| static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | ||||
| static graphStatus SetWeights(ge::OpDesc &op_desc, const GeTensorPtr weight); | |||||
| static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); | |||||
| static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); | static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); | ||||
| }; | }; | ||||
| class OpDescBuilder { | |||||
| public: | |||||
| OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} | |||||
| OpDescBuilder(const OpDescBuilder&) = delete; | |||||
| OpDescBuilder& operator=(const OpDescBuilder&) = delete; | |||||
| OpDescBuilder(const OpDescBuilder&&) = delete; | |||||
| OpDescBuilder& operator=(const OpDescBuilder&&) = delete; | |||||
| ~OpDescBuilder() = default; | |||||
| /// | |||||
| /// @brief Add input | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddInput(const std::string& name); | |||||
| /// | |||||
| /// @brief Add dynamic input | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | |||||
| /// | |||||
| /// @brief Add output | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddOutput(const std::string& name); | |||||
| /// | |||||
| /// @brief Add dynamic output | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | |||||
| /// | |||||
| /// @brief Build op_desc | |||||
| /// @return OpDescPtr | |||||
| /// | |||||
| OpDescPtr Build(); | |||||
| private: | |||||
| std::string name_; | |||||
| std::string type_; | |||||
| std::vector<std::string> inputs_; | |||||
| std::vector<std::string> outputs_; | |||||
| }; | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ | #endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ | ||||
| @@ -18,15 +18,14 @@ | |||||
| #define INC_GRAPH_UTILS_TENSOR_UTILS_H_ | #define INC_GRAPH_UTILS_TENSOR_UTILS_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/def_types.h" | #include "graph/def_types.h" | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| namespace ge { | namespace ge { | ||||
| class TensorUtils { | class TensorUtils { | ||||
| public: | public: | ||||
| static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, uint32_t &size); | |||||
| static void SetSize(GeTensorDesc &tensorDesc, uint32_t size); | |||||
| static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); | |||||
| static void SetSize(GeTensorDesc &tensorDesc, int64_t size); | |||||
| static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); | static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); | ||||
| static uint32_t GetWeightSize(const GeTensor &tensor); | static uint32_t GetWeightSize(const GeTensor &tensor); | ||||
| static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); | static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); | ||||
| @@ -62,16 +61,16 @@ class TensorUtils { | |||||
| static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); | static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); | ||||
| /// | /// | ||||
| /// calculate mem size of the tensor. | |||||
| /// calculate tensor mem size. | |||||
| /// @param shape tensor shape | /// @param shape tensor shape | ||||
| /// @param format tensor format | /// @param format tensor format | ||||
| /// @param data_type tensor data type | /// @param data_type tensor data type | ||||
| /// @param mem_size -1 means unknown shape,others means mem size | |||||
| /// @return GRAPH_SUCCESS:success, others:failed | |||||
| /// @param mem_size -1 means unknown shape,other means mem size | |||||
| /// @return GRAPH_SUCCESS:success, other:failed | |||||
| /// | /// | ||||
| static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); | static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); | ||||
| static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||||
| static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||||
| static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||||
| static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ | #endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ | ||||
| @@ -58,6 +58,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
| include_directories(${GE_SOURCE_DIR}/inc/graph) | include_directories(${GE_SOURCE_DIR}/inc/graph) | ||||
| include_directories(${GE_SOURCE_DIR}/inc/common) | include_directories(${GE_SOURCE_DIR}/inc/common) | ||||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||||
| include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | ||||
| include_directories(${CMAKE_BINARY_DIR}) | include_directories(${CMAKE_BINARY_DIR}) | ||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
| @@ -26,6 +26,8 @@ Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), id | |||||
| bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; } | bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; } | ||||
| size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } | |||||
| Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const { | Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const { | ||||
| vector<AnchorPtr> ret; | vector<AnchorPtr> ret; | ||||
| for (const auto &anchor : peer_anchors_) { | for (const auto &anchor : peer_anchors_) { | ||||
| @@ -32,8 +32,7 @@ Buffer::Buffer(const Buffer &other) { | |||||
| buffer_ = other.buffer_; | buffer_ = other.buffer_; | ||||
| } | } | ||||
| // default | |||||
| Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { | |||||
| Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default | |||||
| auto proto_msg = data_.GetProtoMsg(); | auto proto_msg = data_.GetProtoMsg(); | ||||
| if (proto_msg != nullptr) { | if (proto_msg != nullptr) { | ||||
| try { | try { | ||||
| @@ -15,9 +15,7 @@ | |||||
| */ | */ | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include <deque> | #include <deque> | ||||
| #include "./format_refiner.h" | #include "./format_refiner.h" | ||||
| #include "./ge_context.h" | #include "./ge_context.h" | ||||
| #include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
| @@ -41,7 +39,7 @@ const size_t OUTPUT_PARAM_SIZE = 2; | |||||
| } // namespace | } // namespace | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | ||||
| : nodes_(), input_nodes_(), sub_graph_(), name_(name), is_valid_flag_(false), need_iteration_(false) { | |||||
| : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | |||||
| attrs_.InitDefault(); | attrs_.InitDefault(); | ||||
| } | } | ||||
| ComputeGraph::~ComputeGraph() {} | ComputeGraph::~ComputeGraph() {} | ||||
| @@ -154,7 +152,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNod | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | ||||
| const ComputeGraph &r_graph) const { | const ComputeGraph &r_graph) const { | ||||
| return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.sub_graph_.size()") && | |||||
| return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && | |||||
| IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | ||||
| VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | ||||
| IsEqual(this->name_, r_graph.name_, "graph.name_") && | IsEqual(this->name_, r_graph.name_, "graph.name_") && | ||||
| @@ -398,6 +396,165 @@ graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &su | |||||
| } | } | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) { | |||||
| if (subgraph == nullptr) { | |||||
| GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto parent_graph = subgraph->GetParentGraph(); | |||||
| if (parent_graph == nullptr) { | |||||
| GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto parent_node = subgraph->GetParentNode(); | |||||
| if (parent_node == nullptr) { | |||||
| GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (parent_node->GetOwnerComputeGraph() != parent_graph) { | |||||
| GE_LOGE( | |||||
| "Try to add a subgraph which parent node's parent graph is not equal to " | |||||
| "the subgraph's parent graph, subgraph name %s, parent node name %s", | |||||
| subgraph->GetName().c_str(), parent_graph->GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (!this->parent_graph_.expired()) { | |||||
| GE_LOGE("The subgraphs can only be added to the root graph"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (name != subgraph->GetName()) { | |||||
| GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); | |||||
| } | |||||
| sub_graph_.push_back(subgraph); | |||||
| names_to_subgraph_[name] = subgraph; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) { | |||||
| if (subgraph == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| return AddSubgraph(subgraph->GetName(), subgraph); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { | |||||
| auto iter = names_to_subgraph_.find(name); | |||||
| if (iter == names_to_subgraph_.end()) { | |||||
| return; | |||||
| } | |||||
| for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { | |||||
| if (*vec_iter == iter->second) { | |||||
| sub_graph_.erase(vec_iter); | |||||
| break; | |||||
| } | |||||
| } | |||||
| names_to_subgraph_.erase(iter); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( | |||||
| const std::shared_ptr<ComputeGraph> &subgraph) { | |||||
| if (subgraph != nullptr) { | |||||
| RemoveSubgraph(subgraph->GetName()); | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | |||||
| const std::string &name) const { | |||||
| auto iter = names_to_subgraph_.find(name); | |||||
| return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | |||||
| ComputeGraph::GetAllSubgraphs() const { | |||||
| return sub_graph_; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() { | |||||
| return parent_graph_.lock(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( | |||||
| const shared_ptr<ComputeGraph> &parent) { | |||||
| parent_graph_ = parent; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() { | |||||
| return parent_node_.lock(); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) { | |||||
| parent_node_ = parent; | |||||
| } | |||||
| /// | |||||
| /// @brief Update input-mapping | |||||
| /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||||
| /// @return graphStatus | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | |||||
| for (auto &input : input_nodes_) { | |||||
| uint32_t cur_index = 0; | |||||
| if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||||
| continue; | |||||
| } | |||||
| auto iter = input_mapping.find(cur_index); | |||||
| if (iter == input_mapping.end()) { | |||||
| continue; | |||||
| } | |||||
| if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
| GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @brief Update output-mapping | |||||
| /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||||
| /// @return graphStatus | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | |||||
| NodePtr net_output = FindNode(kNodeNameNetOutput); | |||||
| if (net_output == nullptr) { | |||||
| GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = net_output->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| size_t num = op_desc->GetInputsSize(); | |||||
| for (size_t i = 0; i < num; i++) { | |||||
| GeTensorDesc tensor = op_desc->GetInputDesc(i); | |||||
| uint32_t cur_index = 0; | |||||
| if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||||
| continue; | |||||
| } | |||||
| auto iter = output_mapping.find(cur_index); | |||||
| if (iter == output_mapping.end()) { | |||||
| continue; | |||||
| } | |||||
| if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
| GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { | |||||
| GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | ||||
| std::vector<NodePtr> node_vec = nodes_; | std::vector<NodePtr> node_vec = nodes_; | ||||
| for (const auto &node : GetAllNodes()) { | for (const auto &node : GetAllNodes()) { | ||||
| @@ -551,6 +708,23 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | ||||
| auto ret = TopologicalSortingSubgraph(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Sub graph partition Failed"); | |||||
| return ret; | |||||
| } | |||||
| // partition sub graph | |||||
| for (const auto &sub_graph : GetAllSubgraphs()) { | |||||
| ret = sub_graph->TopologicalSortingSubgraph(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Sub graph topological sort Failed"); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { | |||||
| std::vector<NodePtr> node_vec; | std::vector<NodePtr> node_vec; | ||||
| std::map<NodePtr, uint32_t> map_in_edge_num; | std::map<NodePtr, uint32_t> map_in_edge_num; | ||||
| bool use_BFS = false; | bool use_BFS = false; | ||||
| @@ -598,6 +772,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||||
| node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] | node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] | ||||
| nodes_.push_back(node); | nodes_.push_back(node); | ||||
| } | } | ||||
| is_valid_flag_ = true; | is_valid_flag_ = true; | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -614,7 +789,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||||
| verify_isolated = true; | verify_isolated = true; | ||||
| } | } | ||||
| } | } | ||||
| for (const auto &node : GetAllNodes()) { | |||||
| for (const auto &node : GetDirectNode()) { | |||||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
| map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | ||||
| if (map_in_edge_num[node] == 0) { | if (map_in_edge_num[node] == 0) { | ||||
| @@ -640,16 +815,16 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||||
| /// 2. Compare two indices, if not match, swap the positions of two inputs | /// 2. Compare two indices, if not match, swap the positions of two inputs | ||||
| /// *: Remind: stack is reverse-order | /// *: Remind: stack is reverse-order | ||||
| for (size_t i = 0; i < stack.size(); ++i) { | for (size_t i = 0; i < stack.size(); ++i) { | ||||
| // [stack: should not be null] | |||||
| // If not found in 'inputs_order_', skip it | |||||
| auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||||
| GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||||
| auto inx_i = it_i - inputs_order_.begin(); | |||||
| for (size_t j = i + 1; j < stack.size(); ++j) { | for (size_t j = i + 1; j < stack.size(); ++j) { | ||||
| // If not found in 'inputs_order_', skip it | // If not found in 'inputs_order_', skip it | ||||
| auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||||
| GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||||
| auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); | auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); | ||||
| GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); | GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); | ||||
| // Compare index, swap them if it should be | // Compare index, swap them if it should be | ||||
| auto inx_i = it_i - inputs_order_.begin(); | |||||
| auto inx_j = it_j - inputs_order_.begin(); | auto inx_j = it_j - inputs_order_.begin(); | ||||
| GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); | GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); | ||||
| } | } | ||||
| @@ -663,7 +838,7 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||||
| return in_edge_size; | return in_edge_size; | ||||
| } | } | ||||
| for (const auto &anchor : node->GetAllInDataAnchors()) { | for (const auto &anchor : node->GetAllInDataAnchors()) { | ||||
| in_edge_size = in_edge_size + anchor->GetPeerAnchors().size(); | |||||
| in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); | |||||
| // Break flow control data loop. | // Break flow control data loop. | ||||
| OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); | OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); | ||||
| if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { | if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { | ||||
| @@ -680,10 +855,11 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| if (node->GetInControlAnchor() != nullptr) { | if (node->GetInControlAnchor() != nullptr) { | ||||
| in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchors().size(); | |||||
| in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); | |||||
| } | } | ||||
| return in_edge_size; | return in_edge_size; | ||||
| } | } | ||||
| size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | ||||
| size_t out_edge_size = 0; | size_t out_edge_size = 0; | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| @@ -699,7 +875,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||||
| } | } | ||||
| } | } | ||||
| if (node->GetOutControlAnchor() != nullptr) { | if (node->GetOutControlAnchor() != nullptr) { | ||||
| if (out_edge_size > (UINT32_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||||
| if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); | out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); | ||||
| @@ -724,17 +900,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | peer_in_anchor->GetOwnerNode()->GetName().c_str())); | ||||
| } | } | ||||
| } | } | ||||
| GE_IF_BOOL_EXEC(node->GetOutControlAnchor() == nullptr, GELOGE(GRAPH_FAILED, "Out control anchor is null"); | |||||
| return ); | |||||
| for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { | |||||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
| } | |||||
| for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInDataAnchors()) { | |||||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
| auto out_control_anchor = node->GetOutControlAnchor(); | |||||
| if (out_control_anchor != nullptr) { | |||||
| for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { | |||||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
| } | |||||
| for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { | |||||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -18,21 +18,9 @@ | |||||
| #define COMMON_GRAPH_DEBUG_GE_LOG_H_ | #define COMMON_GRAPH_DEBUG_GE_LOG_H_ | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "toolchain/slog.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #define GE_MOD_ID GE | |||||
| #ifdef _MSC_VER | |||||
| #define FUNC_NAME __FUNCTION__ | |||||
| #else | |||||
| #define FUNC_NAME __PRETTY_FUNCTION__ | |||||
| #endif | |||||
| #define D_GE_LOGE(fmt, ...) \ | |||||
| dlog_error(static_cast<int>(GE_MOD_ID), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
| #define GE_LOGE(...) D_GE_LOGE(__VA_ARGS__) | |||||
| #define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||||
| #define GE_LOGI_IF(condition, ...) \ | #define GE_LOGI_IF(condition, ...) \ | ||||
| if ((condition)) { \ | if ((condition)) { \ | ||||
| @@ -44,15 +32,15 @@ | |||||
| GELOGW(__VA_ARGS__); \ | GELOGW(__VA_ARGS__); \ | ||||
| } | } | ||||
| #define GE_LOGE_IF(condition, ...) \ | |||||
| if ((condition)) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| #define GE_LOGE_IF(condition, ...) \ | |||||
| if ((condition)) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| } | } | ||||
| #define GE_CHK_STATUS_RET_NOLOG(expr) \ | #define GE_CHK_STATUS_RET_NOLOG(expr) \ | ||||
| do { \ | do { \ | ||||
| const ge::graphStatus _status = (expr); \ | const ge::graphStatus _status = (expr); \ | ||||
| if (_status != ge::GRAPH_SUCCESS) { \ | |||||
| if (ge::SUCCESS != _status) { \ | |||||
| return _status; \ | return _status; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -61,7 +49,7 @@ | |||||
| do { \ | do { \ | ||||
| bool b = (expr); \ | bool b = (expr); \ | ||||
| if (!b) { \ | if (!b) { \ | ||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| return _status; \ | return _status; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -85,7 +73,7 @@ | |||||
| do { \ | do { \ | ||||
| const ge::graphStatus _status = (expr); \ | const ge::graphStatus _status = (expr); \ | ||||
| if (_status) { \ | if (_status) { \ | ||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| return _status; \ | return _status; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -95,7 +83,7 @@ | |||||
| { \ | { \ | ||||
| bool b = (expr); \ | bool b = (expr); \ | ||||
| if (b) { \ | if (b) { \ | ||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| exec_expr; \ | exec_expr; \ | ||||
| } \ | } \ | ||||
| } | } | ||||
| @@ -119,63 +107,41 @@ | |||||
| } while (0) | } while (0) | ||||
| // If expr is not true, the log is printed and a custom statement is executed | // If expr is not true, the log is printed and a custom statement is executed | ||||
| #define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | |||||
| // If expr is not true, the log is printed and a custom statement is executed | |||||
| #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGI(__VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| #define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | } | ||||
| // If expr is not true, the log is printed and a custom statement is executed | // If expr is not true, the log is printed and a custom statement is executed | ||||
| #define GE_CHK_BOOL_EXEC_DEBUG(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGD(__VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||||
| { \ | |||||
| bool b = (expr); \ | |||||
| if (!b) { \ | |||||
| GELOGI(__VA_ARGS__); \ | |||||
| exec_expr; \ | |||||
| } \ | |||||
| } | } | ||||
| // If expr is not GRAPH_SUCCESS, print the log and return the same value | // If expr is not GRAPH_SUCCESS, print the log and return the same value | ||||
| #define GE_CHK_STATUS_RET(expr, ...) \ | |||||
| do { \ | |||||
| const ge::graphStatus _status = (expr); \ | |||||
| if (_status != ge::GRAPH_SUCCESS) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| #define GE_CHK_STATUS_RET(expr, ...) \ | |||||
| do { \ | |||||
| const ge::graphStatus _status = (expr); \ | |||||
| if (ge::SUCCESS != _status) { \ | |||||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||||
| return _status; \ | |||||
| } \ | |||||
| } while (0) | } while (0) | ||||
| #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||||
| try { \ | |||||
| exec_expr0; \ | |||||
| } catch (...) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, "Make shared failed"); \ | |||||
| exec_expr1; \ | |||||
| #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||||
| try { \ | |||||
| exec_expr0; \ | |||||
| } catch (...) { \ | |||||
| GELOGE(ge::FAILED, "Make shared failed"); \ | |||||
| exec_expr1; \ | |||||
| } | } | ||||
| /// CCE related macro definition | |||||
| /// If expr is not CC_STATUS_GRAPH_SUCCESS, print the log and return | |||||
| #define GE_CHK_CCE_RET(expr) \ | |||||
| do { \ | |||||
| ccgraphStatus_t _cc_ret = (expr); \ | |||||
| if (_cc_ret != CC_STATUS_GRAPH_SUCCESS) { \ | |||||
| GELOGE(ge::GRAPH_FAILED, "Call cce api failed, ret: 0x%X", _cc_ret); \ | |||||
| return ge::GRAPH_FAILED; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ | #endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/debug/ge_log.h" | #include "graph/debug/ge_log.h" | ||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| @@ -15,12 +15,10 @@ | |||||
| */ | */ | ||||
| #include "graph/debug/graph_debug.h" | #include "graph/debug/graph_debug.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include <vector> | #include <vector> | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #define TAB " " | #define TAB " " | ||||
| @@ -16,13 +16,11 @@ | |||||
| #ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | #ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | ||||
| #define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | #define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | ||||
| #include <cstdint> | #include <cstdint> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include <string> | #include <string> | ||||
| #include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
| #include "./ge_error_codes.h" | #include "./ge_error_codes.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| @@ -15,9 +15,7 @@ | |||||
| */ | */ | ||||
| #include "detail/attributes_holder.h" | #include "detail/attributes_holder.h" | ||||
| #include <map> | #include <map> | ||||
| #include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| @@ -14,14 +14,12 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "graph/format_refiner.h" | |||||
| #include "format_refiner.h" | |||||
| #include <deque> | #include <deque> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <set> | #include <set> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <unordered_set> | #include <unordered_set> | ||||
| #include "./compute_graph.h" | #include "./compute_graph.h" | ||||
| #include "./ge_error_codes.h" | #include "./ge_error_codes.h" | ||||
| #include "./graph/ge_tensor.h" | #include "./graph/ge_tensor.h" | ||||
| @@ -57,6 +55,7 @@ graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||||
| } | } | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | ||||
| std::vector<ge::NodePtr> &data_nodes, | std::vector<ge::NodePtr> &data_nodes, | ||||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | std::unordered_map<ge::NodePtr, bool> &node_status) { | ||||
| @@ -82,10 +81,10 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
| // consider special node save process | // consider special node save process | ||||
| // get all input desc format | // get all input desc format | ||||
| bool node_is_all_nd = false; | bool node_is_all_nd = false; | ||||
| for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetInputsSize()); i++) { | |||||
| auto input_desc = op_desc->GetInputDesc(i); | |||||
| auto input_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||||
| for (uint32_t i = 0; i < input_size; i++) { | |||||
| // Operator pre-set format but not origin format | // Operator pre-set format but not origin format | ||||
| auto input_format = input_desc.GetFormat(); | |||||
| auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | |||||
| // Pre-save data node and default infer fail | // Pre-save data node and default infer fail | ||||
| if (node_ptr->GetType() == DATA) { | if (node_ptr->GetType() == DATA) { | ||||
| data_nodes.push_back(node_ptr); | data_nodes.push_back(node_ptr); | ||||
| @@ -95,9 +94,9 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||||
| } | } | ||||
| } | } | ||||
| // Get all output desc format | // Get all output desc format | ||||
| for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetOutputsSize()); i++) { | |||||
| GeTensorDesc output_desc = op_desc->GetOutputDesc(i); | |||||
| auto output_format = output_desc.GetFormat(); | |||||
| auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||||
| for (uint32_t i = 0; i < output_size; i++) { | |||||
| auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); | |||||
| if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { | if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { | ||||
| node_is_all_nd = true; | node_is_all_nd = true; | ||||
| } | } | ||||
| @@ -145,7 +144,8 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_anchor : node->GetAllInDataAnchors()) { | ||||
| GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); | GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); | ||||
| auto in_data_anchor_idx = in_anchor->GetIdx(); | auto in_data_anchor_idx = in_anchor->GetIdx(); | ||||
| auto to_be_set_format = (node->GetOpDesc()->GetInputDesc(in_data_anchor_idx)).GetOriginFormat(); | |||||
| auto to_be_set_format = | |||||
| node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx))->GetOriginFormat(); | |||||
| if (to_be_set_format == FORMAT_ND) { | if (to_be_set_format == FORMAT_ND) { | ||||
| GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); | GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); | ||||
| continue; | continue; | ||||
| @@ -162,7 +162,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
| } | } | ||||
| // Check format whether have been set | // Check format whether have been set | ||||
| int idx = peer_out_data_anchor->GetIdx(); | int idx = peer_out_data_anchor->GetIdx(); | ||||
| auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(idx); | |||||
| auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | |||||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | ||||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | ||||
| if (dim_num == 0) { | if (dim_num == 0) { | ||||
| @@ -182,7 +182,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | ge_tensor_desc.SetOriginFormat(to_be_set_format); | ||||
| ge_tensor_desc.SetFormat(to_be_set_format); | ge_tensor_desc.SetFormat(to_be_set_format); | ||||
| (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(idx, ge_tensor_desc); | |||||
| (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||||
| // Call operator infer format api (forward) to get out format | // Call operator infer format api (forward) to get out format | ||||
| GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | ||||
| @@ -205,7 +205,8 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||||
| GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); | GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); | ||||
| GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | ||||
| auto out_data_anchor_idx = out_data_anchor->GetIdx(); | auto out_data_anchor_idx = out_data_anchor->GetIdx(); | ||||
| auto to_be_set_format = (node->GetOpDesc()->GetOutputDesc(out_data_anchor_idx)).GetOriginFormat(); | |||||
| auto to_be_set_format = | |||||
| node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat(); | |||||
| if (to_be_set_format == FORMAT_ND) { | if (to_be_set_format == FORMAT_ND) { | ||||
| GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); | GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); | ||||
| continue; | continue; | ||||
| @@ -222,7 +223,7 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||||
| } | } | ||||
| // Check format whether have been set | // Check format whether have been set | ||||
| int idx = peer_in_data_anchor->GetIdx(); | int idx = peer_in_data_anchor->GetIdx(); | ||||
| auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(idx); | |||||
| auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | |||||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | ||||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | ||||
| if (dim_num == 0) { | if (dim_num == 0) { | ||||
| @@ -285,9 +286,9 @@ void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = | |||||
| graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | ||||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | std::unordered_map<ge::NodePtr, bool> &node_status) { | ||||
| bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | ||||
| bool need_process = ((!is_first_infer) && (is_internal_format == false) && (data_format != FORMAT_ND)); | |||||
| bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); | |||||
| if (!need_process) { | if (!need_process) { | ||||
| GELOGI("no necessary to do DataNodeFormatProcess.IsFirstInfer: %d, data_format:%s", is_first_infer, | |||||
| GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, | |||||
| TypeUtils::FormatToSerialString(data_format).c_str()); | TypeUtils::FormatToSerialString(data_format).c_str()); | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -378,9 +379,9 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||||
| /// Notice: ignore 5D formats | /// Notice: ignore 5D formats | ||||
| auto data_format = graph->GetDataFormat(); | auto data_format = graph->GetDataFormat(); | ||||
| status = DataNodeFormatProcess(data_nodes, data_format, node_status); | status = DataNodeFormatProcess(data_nodes, data_format, node_status); | ||||
| // Set infer flag to false | // Set infer flag to false | ||||
| SetInferOrigineFormatFlag(false); | SetInferOrigineFormatFlag(false); | ||||
| return status; | return status; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -42,6 +42,8 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||||
| const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | ||||
| const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||||
| const std::string ATTR_NAME_PAD = "pad"; | const std::string ATTR_NAME_PAD = "pad"; | ||||
| const std::string ATTR_NAME_PADS = "pad"; | const std::string ATTR_NAME_PADS = "pad"; | ||||
| @@ -83,6 +85,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||||
| const std::string ATTR_NAME_AXIS = "axis"; | const std::string ATTR_NAME_AXIS = "axis"; | ||||
| const std::string ATTR_NAME_BROADCAST = "broadcast"; | const std::string ATTR_NAME_BROADCAST = "broadcast"; | ||||
| const std::string ATTR_NAME_OUTPUT = "output"; | |||||
| const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | ||||
| const std::string ATTR_NAME_TIDX = "t_idx"; | const std::string ATTR_NAME_TIDX = "t_idx"; | ||||
| @@ -103,6 +106,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||||
| const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | ||||
| const std::string ATTR_NAME_AIPP = "aipp"; | const std::string ATTR_NAME_AIPP = "aipp"; | ||||
| const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||||
| const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
| const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||||
| const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||||
| const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||||
| const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | ||||
| const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | ||||
| @@ -111,6 +121,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||||
| const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | ||||
| const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | ||||
| const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | ||||
| const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
| const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | ||||
| const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | ||||
| @@ -122,15 +133,11 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||||
| const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | ||||
| const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | ||||
| const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | ||||
| const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||||
| const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||||
| const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||||
| const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||||
| const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||||
| const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | ||||
| const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | ||||
| const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||||
| const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||||
| const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||||
| // To be deleted | // To be deleted | ||||
| const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | ||||
| @@ -144,15 +151,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||||
| const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
| const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | ||||
| const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
| // Refinedet | // Refinedet | ||||
| const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | ||||
| const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
| const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | ||||
| const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | ||||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
| const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||||
| const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||||
| // _Arg | // _Arg | ||||
| const std::string ATTR_NAME_INDEX = "index"; | const std::string ATTR_NAME_INDEX = "index"; | ||||
| @@ -242,6 +247,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||||
| const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | ||||
| const std::string BATCHNORM_ATTR_SCALE = "scale"; | const std::string BATCHNORM_ATTR_SCALE = "scale"; | ||||
| const std::string BATCHNORM_ATTR_BIAS = "bias"; | const std::string BATCHNORM_ATTR_BIAS = "bias"; | ||||
| const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||||
| const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||||
| const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||||
| // huberloss | |||||
| const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||||
| // SSDRealDivTileMul | |||||
| const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||||
| // SSDSumMulRealDivMean | |||||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||||
| // ConcatFive2Four | |||||
| // ConcatFour2Five | |||||
| const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||||
| const std::string SSD_CLASS_NUM = "class_num"; | |||||
| const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||||
| const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||||
| const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||||
| const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||||
| // Scale | // Scale | ||||
| const std::string SCALE_ATTR_SCALE = "scale"; | const std::string SCALE_ATTR_SCALE = "scale"; | ||||
| @@ -346,6 +375,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||||
| // Permute | // Permute | ||||
| const std::string PERMUTE_ATTR_ORDER = "order"; | const std::string PERMUTE_ATTR_ORDER = "order"; | ||||
| const std::string PERMUTE_ATTR_PERM = "perm"; | |||||
| // SSD Normalize | // SSD Normalize | ||||
| const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | ||||
| @@ -373,6 +403,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||||
| const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | ||||
| const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | ||||
| // RefinedetDetectionOutput | |||||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||||
| // PRelu | // PRelu | ||||
| const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | ||||
| @@ -386,11 +420,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||||
| const std::string POWER_ATTR_NAME_SCALE = "scale"; | const std::string POWER_ATTR_NAME_SCALE = "scale"; | ||||
| const std::string POWER_ATTR_NAME_SHIFT = "shift"; | const std::string POWER_ATTR_NAME_SHIFT = "shift"; | ||||
| // log | |||||
| const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||||
| const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||||
| const std::string LOG_ATTR_NAME_BASE = "base"; | |||||
| // Pack | // Pack | ||||
| const std::string PACK_ATTR_NAME_NUM = "N"; | const std::string PACK_ATTR_NAME_NUM = "N"; | ||||
| // Unpack | // Unpack | ||||
| const std::string UNPACK_ATTR_NAME_NUM = "num"; | const std::string UNPACK_ATTR_NAME_NUM = "num"; | ||||
| const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||||
| // Gathernd | // Gathernd | ||||
| const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | ||||
| const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | ||||
| @@ -400,6 +439,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||||
| const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | ||||
| const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | ||||
| const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | ||||
| const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||||
| const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||||
| const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||||
| // upsample | |||||
| const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||||
| const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||||
| // Relu | // Relu | ||||
| const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | ||||
| @@ -416,6 +462,7 @@ const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; | |||||
| const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; | const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; | ||||
| const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; | const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; | ||||
| const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; | const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; | ||||
| const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; | |||||
| // Squeeze | // Squeeze | ||||
| const std::string SQUEEZE_ATTR_AXIS = "axis"; | const std::string SQUEEZE_ATTR_AXIS = "axis"; | ||||
| @@ -438,6 +485,7 @@ const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; | |||||
| const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | ||||
| const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | ||||
| const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | ||||
| const std::string ROIALIGN_ATTR_NAME_TF = "roialign_tf"; | |||||
| // Generate_rpn_proposal | // Generate_rpn_proposal | ||||
| const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | ||||
| @@ -536,19 +584,42 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||||
| const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | ||||
| // Rnn | // Rnn | ||||
| const std::string RNN_MODE_ = "rnn_"; | |||||
| const std::string CNN_RNN = "cnn_rnn"; | |||||
| const std::string RNN_TENSORFLOW = "rnn_tensorflow"; | |||||
| const std::string RNN_MODE_STATIC = "rnn_static"; | |||||
| const std::string MUTI_RNN = "multi_rnn"; | const std::string MUTI_RNN = "multi_rnn"; | ||||
| const std::string CNN_RNN = "cnn_rnn"; | |||||
| const std::string RNN_MODE_ = "rnn_"; | |||||
| const std::string CELL_MODE = "mode"; | const std::string CELL_MODE = "mode"; | ||||
| const std::string LSTM_CELL = "lstm_cell"; | const std::string LSTM_CELL = "lstm_cell"; | ||||
| const std::string GRU_CELL = "gru_cell"; | const std::string GRU_CELL = "gru_cell"; | ||||
| const std::string RNN_HT = "ht"; | const std::string RNN_HT = "ht"; | ||||
| const std::string RNN_XT_HT = "xt_ht"; | const std::string RNN_XT_HT = "xt_ht"; | ||||
| const std::string RNN_BATCH_SIZE = "batch_size"; | const std::string RNN_BATCH_SIZE = "batch_size"; | ||||
| const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||||
| const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||||
| const std::string LSTM_ACTIVATE = "lstm_activate"; | |||||
| const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||||
| const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||||
| const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||||
| const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||||
| const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||||
| // Upsample | // Upsample | ||||
| const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | ||||
| // PadV2 | |||||
| const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||||
| const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||||
| const std::string PADV2_ATTR_NAME_T = "T"; | |||||
| const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
| const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
| // MirrorPad | |||||
| const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||||
| const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||||
| const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||||
| const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||||
| // Filler | // Filler | ||||
| const std::string FILLER_TYPE = "filler_type"; | const std::string FILLER_TYPE = "filler_type"; | ||||
| const std::string FILLER_VALUE = "filler_value"; | const std::string FILLER_VALUE = "filler_value"; | ||||
| @@ -559,9 +630,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||||
| // TopKV2 | // TopKV2 | ||||
| const std::string TOPKV2_ATTR_K = "k"; | const std::string TOPKV2_ATTR_K = "k"; | ||||
| const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
| const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
| // Calibaration | // Calibaration | ||||
| const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | ||||
| const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | ||||
| @@ -616,6 +684,8 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; | |||||
| const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | ||||
| const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | |||||
| const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | ||||
| const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | ||||
| @@ -630,6 +700,8 @@ const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; | |||||
| const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | ||||
| const std::string ATTR_MODEL_CORE_TYPE = "core_type"; | |||||
| // Public attribute | // Public attribute | ||||
| const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | ||||
| @@ -661,17 +733,145 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||||
| const std::string TARGET_TYPE_LITE = "LITE"; | const std::string TARGET_TYPE_LITE = "LITE"; | ||||
| // l2_normalize | |||||
| const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||||
| const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||||
| const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||||
| const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||||
| const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||||
| const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||||
| const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||||
| const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||||
| // HCOM | |||||
| const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||||
| const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||||
| const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||||
| const std::string HCOM_ATTR_GROUP = "group"; | |||||
| const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||||
| const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||||
| const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||||
| const std::string HCOM_ATTR_FUSION = "fusion"; | |||||
| const std::string HCOM_ATTR_SHAPE = "shape"; | |||||
| const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||||
| // SpaceToDepth/DepthToSpace | |||||
| const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||||
| // SparseSoftmaxCrossEntropyWithLogits | |||||
| const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||||
| // MaxPoolGradWithArgmax | |||||
| const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||||
| // AvgPoolGrad | |||||
| const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||||
| // Pad | |||||
| const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||||
| // Varible | |||||
| const std::string VAR_ATTR_FORMAT = "_var_format"; | |||||
| const std::string VAR_ATTR_NAME = "var_name"; | |||||
| const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||||
| const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||||
| const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||||
| const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||||
| const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||||
| const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||||
| const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||||
| const std::string VAR_ATTR_SHAPE = "shape"; | |||||
| const std::string HALF_VAR_NAME_END = "_fp16"; | |||||
| const std::string VAR_ATTR_INITED = "var_is_inited"; | |||||
| const std::string VAR_ATTR_CONTAINER = "container"; | |||||
| const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||||
| const std::string VAR_ATTR_DTYPE = "dtype"; | |||||
| const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||||
| const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||||
| const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||||
| const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||||
| const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||||
| const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||||
| // Assign | |||||
| const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||||
| // space2bacth batch2space | |||||
| const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||||
| const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||||
| // depth_to_space space_to_depth | |||||
| const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||||
| // FakeQuantWithMinMaxVars | |||||
| const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||||
| const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||||
| // mobilenet_ssd_conv_fusion | |||||
| const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||||
| const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||||
| const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||||
| // lsh project | |||||
| const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||||
| // log time stamp | |||||
| const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||||
| const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||||
| // ShapeN | |||||
| const std::string SHAPEN_ATTR_N = "N"; | |||||
| const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||||
| const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||||
| // GatherV2 attr def | |||||
| const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||||
| const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||||
| const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||||
| // Reshape attr def | |||||
| const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||||
| const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||||
| // axis attr def | |||||
| const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||||
| const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||||
| const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||||
| const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||||
| // For constant folding | |||||
| const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||||
| const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | ||||
| const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | ||||
| const std::string ATTR_NAME_REFERENCE = "reference"; | const std::string ATTR_NAME_REFERENCE = "reference"; | ||||
| const std::string ATTR_NAME_NOTASK = "_no_task"; | |||||
| const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; | |||||
| const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; | |||||
| const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; | |||||
| const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; | |||||
| const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; | const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; | ||||
| // Used for mark the active label list stream of activated node | // Used for mark the active label list stream of activated node | ||||
| const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; | const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; | ||||
| // Used for l2cache, true: the memory of all inputs is used for the last time. | |||||
| const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; | |||||
| // Multi batch | // Multi batch | ||||
| const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; | const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; | ||||
| const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; | const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; | ||||
| @@ -682,6 +882,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||||
| const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | ||||
| const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | ||||
| const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | ||||
| const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||||
| const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||||
| const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | ||||
| const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | ||||
| @@ -691,6 +893,9 @@ const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; | |||||
| const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | ||||
| // Function Op | |||||
| const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | |||||
| // Used for mark the active node is for loop, type:bool | // Used for mark the active node is for loop, type:bool | ||||
| const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | ||||
| @@ -702,6 +907,20 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; | |||||
| const std::string MODEL_ATTR_SESSION_ID = "session_id"; | const std::string MODEL_ATTR_SESSION_ID = "session_id"; | ||||
| // l1 fusion and other fusion in future | |||||
| const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||||
| const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||||
| const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | |||||
| const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | |||||
| const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; | |||||
| const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; | |||||
| const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; | |||||
| const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; | |||||
| const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; | |||||
| const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||||
| const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||||
| const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||||
| // Atomic addr clean attrs | // Atomic addr clean attrs | ||||
| const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | ||||
| const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; | const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; | ||||
| @@ -722,6 +941,9 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | |||||
| // For inserted op | // For inserted op | ||||
| const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | ||||
| // For compress weight | |||||
| const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; | |||||
| // For data dump | // For data dump | ||||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; | ||||
| const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; | const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; | ||||
| @@ -732,24 +954,17 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | ||||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | ||||
| // Variable | |||||
| const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||||
| const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||||
| const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||||
| const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||||
| const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||||
| // HCOM | |||||
| const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||||
| const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||||
| const std::string HCOM_ATTR_SHAPE = "shape"; | |||||
| const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||||
| // functional ops attr | |||||
| const std::string ATTR_NAME_TCOND = "Tcond"; | |||||
| const std::string ATTR_NAME_TIN = "Tin"; | |||||
| const std::string ATTR_NAME_TOUT = "Tout"; | |||||
| const std::string ATTR_NAME_THEN_BRANCH = "then_branch"; | |||||
| const std::string ATTR_NAME_ELSE_BRANCH = "else_branch"; | |||||
| const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||||
| // used for label switch | |||||
| const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | |||||
| const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | |||||
| const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | ||||
| const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | ||||
| // Dynamic stitch | |||||
| const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
| #include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
| #include "detail/model_serialize_imp.h" | #include "detail/model_serialize_imp.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "debug/ge_attr_define.h" | |||||
| #include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| @@ -53,7 +53,7 @@ string GeAttrValue::NamedAttrs::GetName() const { | |||||
| GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | ||||
| GeAttrValue value; | GeAttrValue value; | ||||
| (void)GetAttr(key, value); | |||||
| GetAttr(key, value); | |||||
| return value; | return value; | ||||
| } | } | ||||
| @@ -1081,6 +1081,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||||
| if (!GetListInt(std::move(obj), name, int64_list)) { | if (!GetListInt(std::move(obj), name, int64_list)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| for (size_t i = 0; i < int64_list.size(); ++i) { | for (size_t i = 0; i < int64_list.size(); ++i) { | ||||
| if (int64_list[i] > INT32_MAX) { | if (int64_list[i] > INT32_MAX) { | ||||
| GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); | GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); | ||||
| @@ -1098,6 +1099,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||||
| if (!GetListInt(std::move(obj), name, int64_list)) { | if (!GetListInt(std::move(obj), name, int64_list)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| for (size_t i = 0; i < int64_list.size(); ++i) { | for (size_t i = 0; i < int64_list.size(); ++i) { | ||||
| if (int64_list[i] > UINT32_MAX) { | if (int64_list[i] > UINT32_MAX) { | ||||
| GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); | GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); | ||||
| @@ -1215,6 +1217,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||||
| GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); | GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); | ||||
| op_desc->extAttrs_ = org_op_desc->extAttrs_; | op_desc->extAttrs_ = org_op_desc->extAttrs_; | ||||
| if (op_desc->HasAttr("_input_name_idx_key")) { | |||||
| if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed."); | |||||
| } | |||||
| } | |||||
| if (op_desc->HasAttr("_input_name_idx_value")) { | |||||
| if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed."); | |||||
| } | |||||
| } | |||||
| if (op_desc->HasAttr("_opt_input")) { | |||||
| if (op_desc->DelAttr("_opt_input") != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed."); | |||||
| } | |||||
| } | |||||
| return op_desc; | return op_desc; | ||||
| } | } | ||||
| @@ -1237,11 +1256,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||||
| op_desc->extAttrs_ = org_op_desc->extAttrs_; | op_desc->extAttrs_ = org_op_desc->extAttrs_; | ||||
| op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); | |||||
| op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), | |||||
| org_op_desc->optional_input_names_.end()); | |||||
| op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | |||||
| op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | ||||
| op_desc->infer_func_ = org_op_desc->infer_func_; | op_desc->infer_func_ = org_op_desc->infer_func_; | ||||
| @@ -1250,4 +1264,25 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||||
| return op_desc; | return op_desc; | ||||
| } | } | ||||
| std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { | |||||
| auto holder = obj.get(); | |||||
| if (holder == nullptr) { | |||||
| return ""; | |||||
| } | |||||
| auto attrs_map = holder->GetAttrMap(); | |||||
| if (attrs_map.GetProtoMsg() == nullptr) { | |||||
| return ""; | |||||
| } | |||||
| std::map<std::string, std::string> ordered_attrs; | |||||
| for (auto &attr : *(attrs_map.GetProtoMsg())) { | |||||
| ordered_attrs[attr.first] = attr.second.SerializeAsString(); | |||||
| } | |||||
| std::stringstream ss; | |||||
| for (auto &attr : ordered_attrs) { | |||||
| ss << attr.first << ":" << attr.second << ";"; | |||||
| } | |||||
| return ss.str(); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -163,6 +163,34 @@ int64_t GeShape::GetShapeSize() const { | |||||
| return res; | return res; | ||||
| } | } | ||||
| /// | |||||
| /// @brief Check is unknown shape | |||||
| /// @return bool | |||||
| /// /// | |||||
| bool GeShape::IsUnknownShape() const { | |||||
| auto proto_msg = shape_def_.GetProtoMsg(); | |||||
| if (proto_msg != nullptr) { | |||||
| for (auto i : proto_msg->dim()) { | |||||
| if (i < 0) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| /// | |||||
| /// @brief Check is a scalar | |||||
| /// @return bool | |||||
| /// | |||||
| bool GeShape::IsScalar() const { | |||||
| auto proto_msg = shape_def_.GetProtoMsg(); | |||||
| if (proto_msg != nullptr) { | |||||
| return proto_msg->dim().empty(); | |||||
| } | |||||
| return false; | |||||
| } | |||||
| const string TENSOR_UTILS_SIZE = "size"; | const string TENSOR_UTILS_SIZE = "size"; | ||||
| const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; | const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; | ||||
| const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; | const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; | ||||
| @@ -639,14 +667,14 @@ GeTensor &GeTensor::operator=(const GeTensor &other) { | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, | ||||
| uint32_t &size) { | |||||
| int64_t &size) { | |||||
| auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | ||||
| GE_CHECK_NOTNULL(tensor_descriptor_msg); | GE_CHECK_NOTNULL(tensor_descriptor_msg); | ||||
| size = static_cast<uint32_t>(tensor_descriptor_msg->size()); | |||||
| size = static_cast<int64_t>(tensor_descriptor_msg->size()); | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, uint32_t size) { | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { | |||||
| auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | ||||
| if (tensor_descriptor_msg != nullptr) { | if (tensor_descriptor_msg != nullptr) { | ||||
| tensor_descriptor_msg->set_size(size); | tensor_descriptor_msg->set_size(size); | ||||
| @@ -49,6 +49,7 @@ void Model::Init() { | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | ||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | ||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | ||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); | ||||
| (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); | (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); | ||||
| version_ = 0; | version_ = 0; | ||||
| @@ -77,9 +78,9 @@ void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } | |||||
| Graph Model::GetGraph() const { return graph_; } | Graph Model::GetGraph() const { return graph_; } | ||||
| graphStatus Model::Save(Buffer &buffer) const { | |||||
| graphStatus Model::Save(Buffer &buffer, bool is_dump) const { | |||||
| ModelSerialize serialize; | ModelSerialize serialize; | ||||
| buffer = serialize.SerializeModel(*this); | |||||
| buffer = serialize.SerializeModel(*this, is_dump); | |||||
| return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; | return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; | ||||
| } | } | ||||
| @@ -113,7 +114,7 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||||
| } | } | ||||
| int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); | int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); | ||||
| if (fd < 0) { | if (fd < 0) { | ||||
| GELOGE(GRAPH_FAILED, "open file failed, file path [%s] ", real_path); | |||||
| GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| bool ret = ge_proto.SerializeToFileDescriptor(fd); | bool ret = ge_proto.SerializeToFileDescriptor(fd); | ||||
| @@ -129,6 +130,10 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | GELOGE(GRAPH_FAILED, "close file descriptor fail."); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| if (!ret) { | |||||
| GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | } | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -152,7 +157,7 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||||
| } | } | ||||
| int fd = open(real_path, O_RDONLY); | int fd = open(real_path, O_RDONLY); | ||||
| if (fd < 0) { | if (fd < 0) { | ||||
| GELOGE(GRAPH_FAILED, "open file failed"); | |||||
| GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| @@ -170,6 +175,10 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | GELOGE(GRAPH_FAILED, "close file descriptor fail."); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| if (!ret) { | |||||
| GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return Load(model_def); | return Load(model_def); | ||||
| } | } | ||||
| @@ -15,10 +15,8 @@ | |||||
| */ | */ | ||||
| #include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
| #include <google/protobuf/text_format.h> | #include <google/protobuf/text_format.h> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
| #include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| @@ -26,6 +24,7 @@ | |||||
| #include "graph/detail/model_serialize_imp.h" | #include "graph/detail/model_serialize_imp.h" | ||||
| #include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
| #include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
| #include "debug/ge_op_types.h" | |||||
| using std::string; | using std::string; | ||||
| @@ -84,20 +83,29 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { | |||||
| bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||||
| if (op_desc == nullptr || op_def_proto == nullptr) { | if (op_desc == nullptr || op_def_proto == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "Input Para Invalid"); | GELOGE(GRAPH_FAILED, "Input Para Invalid"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (op_desc->op_def_.GetProtoMsg() != nullptr) { | if (op_desc->op_def_.GetProtoMsg() != nullptr) { | ||||
| *op_def_proto = *op_desc->op_def_.GetProtoMsg(); | *op_def_proto = *op_desc->op_def_.GetProtoMsg(); | ||||
| // Delete unnecessary attr | |||||
| if (is_dump) { | |||||
| auto attr = op_def_proto->mutable_attr(); | |||||
| attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); | |||||
| attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); | |||||
| attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); | |||||
| GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), | |||||
| attr->erase(ATTR_NAME_WEIGHTS)); | |||||
| } | |||||
| op_def_proto->clear_input_desc(); | op_def_proto->clear_input_desc(); | ||||
| op_def_proto->clear_output_desc(); | op_def_proto->clear_output_desc(); | ||||
| // Input descs | // Input descs | ||||
| if (op_desc->GetInputsSize() > 0) { | |||||
| auto size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||||
| if (op_desc->GetAllInputsSize() > 0) { | |||||
| auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize()); | |||||
| for (uint32_t i = 0; i < size; i++) { | for (uint32_t i = 0; i < size; i++) { | ||||
| auto tensor_desc = op_desc->GetInputDescPtr(i); | |||||
| auto tensor_desc = op_desc->GetInputDescPtrDfault(i); | |||||
| if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | ||||
| *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | ||||
| } | } | ||||
| @@ -117,12 +125,12 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto) { | |||||
| bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { | |||||
| if (node == nullptr || op_def_proto == nullptr) { | if (node == nullptr || op_def_proto == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); | GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto)) { | |||||
| if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { | |||||
| GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); | GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -134,7 +142,8 @@ bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_ | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, | ||||
| proto::GraphDef *graph_proto) { | |||||
| proto::GraphDef *graph_proto, | |||||
| bool is_dump) { | |||||
| if (graph == nullptr || graph_proto == nullptr) { | if (graph == nullptr || graph_proto == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "Input para Invalid"); | GELOGE(GRAPH_FAILED, "Input para Invalid"); | ||||
| return false; | return false; | ||||
| @@ -156,7 +165,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||||
| *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); | *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); | ||||
| } | } | ||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| if (!SerializeNode(node, graph_proto->add_op())) { | |||||
| if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { | |||||
| if (node->GetOpDesc() != nullptr) { | if (node->GetOpDesc() != nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); | GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); | ||||
| } | } | ||||
| @@ -166,7 +175,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||||
| return true; | return true; | ||||
| } | } | ||||
| bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto) { | |||||
| bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { | |||||
| if (model_proto == nullptr) { | if (model_proto == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "model_proto para Invalid"); | GELOGE(GRAPH_FAILED, "model_proto para Invalid"); | ||||
| return false; | return false; | ||||
| @@ -183,7 +192,7 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||||
| GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); | GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| if (!SerializeGraph(compute_graph, model_proto->add_graph())) { | |||||
| if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { | |||||
| GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -390,10 +399,10 @@ bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf:: | |||||
| return true; | return true; | ||||
| } | } | ||||
| Buffer ModelSerialize::SerializeModel(const Model &model) { | |||||
| Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { | |||||
| proto::ModelDef model_def; | proto::ModelDef model_def; | ||||
| ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
| if (!imp.SerializeModel(model, &model_def)) { | |||||
| if (!imp.SerializeModel(model, &model_def, is_dump)) { | |||||
| return Buffer(); | return Buffer(); | ||||
| } | } | ||||
| #if !defined(__ANDROID__) && !defined(ANDROID) | #if !defined(__ANDROID__) && !defined(ANDROID) | ||||
| @@ -401,7 +401,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get | |||||
| vec.push_back(in_anchor); | vec.push_back(in_anchor); | ||||
| } | } | ||||
| } | } | ||||
| // Push back in_control_anchor_ | |||||
| // Push back in_control_anchor_ | |||||
| if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || | if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || | ||||
| (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { | (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { | ||||
| auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | ||||
| @@ -512,7 +512,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||||
| auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); | auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); | ||||
| for (const auto &out_anchor : peer_out_anchors) { | for (const auto &out_anchor : peer_out_anchors) { | ||||
| GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, " in_control_anchor_ peer out data anchors is nullptr"); | |||||
| GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); | |||||
| auto node = out_anchor->GetOwnerNode(); | auto node = out_anchor->GetOwnerNode(); | ||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | ||||
| vec.push_back(node); | vec.push_back(node); | ||||
| @@ -521,7 +521,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||||
| auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | ||||
| for (const auto &out_control_anchor : peer_out_control_anchors) { | for (const auto &out_control_anchor : peer_out_control_anchors) { | ||||
| GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, | GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, | ||||
| " in_control_anchor_ peer out control anchors is nullptr"); | |||||
| "in_control_anchor_ peer out control anchors is nullptr"); | |||||
| auto node = out_control_anchor->GetOwnerNode(); | auto node = out_control_anchor->GetOwnerNode(); | ||||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | ||||
| vec.push_back(node); | vec.push_back(node); | ||||
| @@ -785,6 +785,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(co | |||||
| GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, | GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, | ||||
| "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), | "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), | ||||
| op_desc->GetInputsSize()); | op_desc->GetInputsSize()); | ||||
| GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, | GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, | ||||
| "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), | "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), | ||||
| op_desc->GetOutputsSize()); | op_desc->GetOutputsSize()); | ||||
| @@ -61,6 +61,12 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; | |||||
| const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | ||||
| const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | |||||
| const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | |||||
| const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value"; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { | ||||
| op_def_.InitDefault(); | op_def_.InitDefault(); | ||||
| if (op_def_.GetProtoMsg() != nullptr) { | if (op_def_.GetProtoMsg() != nullptr) { | ||||
| @@ -202,7 +208,8 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d | |||||
| } | } | ||||
| graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | ||||
| if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| if (input_name_idx.find(name) != input_name_idx.end()) { | |||||
| GELOGI("input %s is exist, update it", name.c_str()); | GELOGI("input %s is exist, update it", name.c_str()); | ||||
| graphStatus ret = UpdateInputDesc(name, input_desc); | graphStatus ret = UpdateInputDesc(name, input_desc); | ||||
| return ret; | return ret; | ||||
| @@ -214,15 +221,17 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| inputs_desc_.push_back(in_desc); | inputs_desc_.push_back(in_desc); | ||||
| (void)input_name_idx_.insert(make_pair(name, index)); | |||||
| (void)input_name_idx.insert(make_pair(name, index)); | |||||
| SetAllInputName(input_name_idx); | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| } | } | ||||
| graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | ||||
| auto input_name_idx = GetAllInputName(); | |||||
| for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
| string input_name = name + std::to_string(i); | string input_name = name + std::to_string(i); | ||||
| GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, | |||||
| GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||||
| "Add input tensor_desc is existed. name[%s]", input_name.c_str()); | "Add input tensor_desc is existed. name[%s]", input_name.c_str()); | ||||
| std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | ||||
| @@ -234,12 +243,13 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||||
| (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | ||||
| // Update index in input_name_idx | // Update index in input_name_idx | ||||
| for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { | |||||
| for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||||
| it->second += 1; | it->second += 1; | ||||
| } | } | ||||
| (void)input_name_idx_.insert(make_pair(input_name, 0)); | |||||
| (void)input_name_idx.insert(make_pair(input_name, 0)); | |||||
| } | } | ||||
| SetAllInputName(input_name_idx); | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -270,10 +280,19 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int | |||||
| graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | ||||
| if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; | if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; | ||||
| (void)optional_input_names_.insert(name); | |||||
| vector<string> optional_input_names; | |||||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
| optional_input_names.push_back(name); | |||||
| (void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| std::vector<string> OpDesc::GetAllOptionalInputName() const { | |||||
| vector<string> optional_input_names; | |||||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
| return optional_input_names; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
| OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | ||||
| GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); | GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); | ||||
| @@ -288,11 +307,12 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { | ||||
| return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && | |||||
| IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||||
| IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && | |||||
| IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||||
| IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||||
| return ( | |||||
| IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") && | |||||
| IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||||
| IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") && | |||||
| IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||||
| IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { | ||||
| @@ -366,8 +386,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD | |||||
| } | } | ||||
| graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { | graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { | ||||
| auto it = input_name_idx_.find(name); | |||||
| if (it == input_name_idx_.end()) { | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| auto it = input_name_idx.find(name); | |||||
| if (it == input_name_idx.end()) { | |||||
| GELOGW("Cann't find the input desc. name[%s]", name.c_str()); | GELOGW("Cann't find the input desc. name[%s]", name.c_str()); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| @@ -387,8 +408,9 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc & | |||||
| } | } | ||||
| bool OpDesc::InputIsSet(const string &name) const { | bool OpDesc::InputIsSet(const string &name) const { | ||||
| auto it = input_name_idx_.find(name); | |||||
| if (it != input_name_idx_.end()) { | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| auto it = input_name_idx.find(name); | |||||
| if (it != input_name_idx.end()) { | |||||
| GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); | GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); | ||||
| auto tensor_desc = inputs_desc_[it->second]; | auto tensor_desc = inputs_desc_[it->second]; | ||||
| GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); | GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); | ||||
| @@ -406,18 +428,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc | |||||
| } | } | ||||
| GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | ||||
| auto it = input_name_idx_.find(name); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| auto it = input_name_idx.find(name); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc()); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); | GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); | ||||
| return *(inputs_desc_[it->second].get()); | return *(inputs_desc_[it->second].get()); | ||||
| } | } | ||||
| GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | ||||
| auto input_name_idx = GetAllInputName(); | |||||
| vector<string> names; | vector<string> names; | ||||
| if (input_name_idx_.empty()) { | |||||
| if (input_name_idx.empty()) { | |||||
| return OpDesc::Vistor<string>(shared_from_this(), names); | return OpDesc::Vistor<string>(shared_from_this(), names); | ||||
| } | } | ||||
| for (std::pair<string, uint32_t> input : input_name_idx_) { | |||||
| for (std::pair<string, uint32_t> input : input_name_idx) { | |||||
| names.push_back(input.first); | names.push_back(input.first); | ||||
| } | } | ||||
| return OpDesc::Vistor<string>(shared_from_this(), names); | return OpDesc::Vistor<string>(shared_from_this(), names); | ||||
| @@ -483,6 +507,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() co | |||||
| return size; | return size; | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { | ||||
| int index = static_cast<int>(outputs_desc_.size()); | int index = static_cast<int>(outputs_desc_.size()); | ||||
| return AddOutputDesc("__output" + std::to_string(index), output_desc); | return AddOutputDesc("__output" + std::to_string(index), output_desc); | ||||
| @@ -548,6 +574,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu | |||||
| return outputs_desc_[index]; | return outputs_desc_[index]; | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { | |||||
| return static_cast<uint32_t>(outputs_desc_.size()); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const { | ||||
| vector<GeTensorDesc> temp{}; | vector<GeTensorDesc> temp{}; | ||||
| for (const auto &it : outputs_desc_) { | for (const auto &it : outputs_desc_) { | ||||
| @@ -580,6 +610,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetI | |||||
| } | } | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr | |||||
| OpDesc::GetInputDescPtrDfault(uint32_t index) const { | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); | |||||
| return inputs_desc_[(int32_t)index]; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| auto it = input_name_idx.find(name); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr<const GeTensorDesc>()); | |||||
| return inputs_desc_[it->second]; | |||||
| } | |||||
| graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | ||||
| if (is_push_back) { | if (is_push_back) { | ||||
| for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
| @@ -603,12 +646,45 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int | |||||
| } | } | ||||
| bool OpDesc::IsOptionalInput(const string &name) const { | bool OpDesc::IsOptionalInput(const string &name) const { | ||||
| return optional_input_names_.find(name) != optional_input_names_.end(); | |||||
| vector<string> optional_input_names; | |||||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||||
| for (auto &item : optional_input_names) { | |||||
| if (item == name) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | } | ||||
| bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | ||||
| std::map<string, uint32_t> OpDesc::GetAllInputName() { return input_name_idx_; } | |||||
| std::map<string, uint32_t> OpDesc::GetAllInputName() const { | |||||
| std::map<string, uint32_t> input_name_idx; | |||||
| std::vector<string> key; | |||||
| std::vector<uint32_t> value; | |||||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||||
| (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||||
| if (key.size() != value.size()) { | |||||
| GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); | |||||
| } else { | |||||
| for (uint32_t i = 0; i < key.size(); ++i) { | |||||
| input_name_idx.insert(std::pair<string, uint32_t>(key.at(i), value.at(i))); | |||||
| } | |||||
| } | |||||
| return input_name_idx; | |||||
| } | |||||
| void OpDesc::SetAllInputName(const std::map<string, uint32_t> &input_name_idx) { | |||||
| std::vector<string> key; | |||||
| std::vector<uint32_t> value; | |||||
| for (auto &item : input_name_idx) { | |||||
| key.emplace_back(item.first); | |||||
| value.emplace_back(item.second); | |||||
| } | |||||
| (void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||||
| (void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||||
| } | |||||
| std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; } | std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; } | ||||
| @@ -619,6 +695,7 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||||
| auto factory_map_size = input_name_idx.size(); | auto factory_map_size = input_name_idx.size(); | ||||
| // It indicates that some inputs have no optionalname. | // It indicates that some inputs have no optionalname. | ||||
| // The redundant optionalname of factory needs to be deleted and then assigned | // The redundant optionalname of factory needs to be deleted and then assigned | ||||
| auto all_input_name_idx = GetAllInputName(); | |||||
| if (input_map_size < factory_map_size) { | if (input_map_size < factory_map_size) { | ||||
| GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, | GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, | ||||
| factory_map_size); | factory_map_size); | ||||
| @@ -631,22 +708,23 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||||
| } | } | ||||
| if (input_name_idx.size() == input_map_size) { | if (input_name_idx.size() == input_map_size) { | ||||
| GELOGI("UpdateInputName"); | GELOGI("UpdateInputName"); | ||||
| input_name_idx_ = input_name_idx; | |||||
| all_input_name_idx = input_name_idx; | |||||
| } else { | } else { | ||||
| ret = false; | ret = false; | ||||
| GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); | GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); | ||||
| } | } | ||||
| } else if (input_map_size == factory_map_size) { | } else if (input_map_size == factory_map_size) { | ||||
| input_name_idx_ = input_name_idx; | |||||
| all_input_name_idx = input_name_idx; | |||||
| } else { | } else { | ||||
| ret = false; | ret = false; | ||||
| GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); | GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); | ||||
| } | } | ||||
| SetAllInputName(all_input_name_idx); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) { | bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) { | ||||
| size_t output_map_size = GetAllOutputsDesc().size(); | |||||
| size_t output_map_size = GetAllOutputsDescSize(); | |||||
| size_t factory_map_size = output_name_idx.size(); | size_t factory_map_size = output_name_idx.size(); | ||||
| if (output_map_size < factory_map_size) { | if (output_map_size < factory_map_size) { | ||||
| GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, | GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, | ||||
| @@ -754,17 +832,17 @@ graphStatus OpDesc::OpVerify() { | |||||
| } | } | ||||
| graphStatus OpDesc::CommonVerify() const { | graphStatus OpDesc::CommonVerify() const { | ||||
| for (string iname : GetAllInputNames()) { | |||||
| for (const string &iname : GetAllInputNames()) { | |||||
| // Checking shape of all inputs | // Checking shape of all inputs | ||||
| vector<int64_t> ishape = GetInputDesc(iname).GetShape().GetDims(); | |||||
| vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | |||||
| for (int64_t dim : ishape) { | for (int64_t dim : ishape) { | ||||
| GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | ||||
| iname.c_str()); | iname.c_str()); | ||||
| } | } | ||||
| } | } | ||||
| // Check all attributes defined | // Check all attributes defined | ||||
| const auto all_attributes = GetAllAttrs(); | |||||
| for (const auto name : GetAllAttrNames()) { | |||||
| const auto &all_attributes = GetAllAttrs(); | |||||
| for (const auto &name : GetAllAttrNames()) { | |||||
| GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | ||||
| "operator attribute %s is empty.", name.c_str()); | "operator attribute %s is empty.", name.c_str()); | ||||
| } | } | ||||
| @@ -773,19 +851,21 @@ graphStatus OpDesc::CommonVerify() const { | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { | ||||
| auto it = input_name_idx_.begin(); | |||||
| for (; it != input_name_idx_.end(); ++it) { | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| auto it = input_name_idx.begin(); | |||||
| for (; it != input_name_idx.end(); ++it) { | |||||
| if (it->second == index) { | if (it->second == index) { | ||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), ""); | |||||
| return it->first; | return it->first; | ||||
| } | } | ||||
| int OpDesc::GetInputIndexByName(const string &name) const { | int OpDesc::GetInputIndexByName(const string &name) const { | ||||
| auto it_find = input_name_idx_.find(name); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| auto it_find = input_name_idx.find(name); | |||||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1); | |||||
| return static_cast<int>(it_find->second); | return static_cast<int>(it_find->second); | ||||
| } | } | ||||
| @@ -1065,10 +1145,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputCo | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, | ||||
| const int &index) { | const int &index) { | ||||
| if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||||
| auto input_name_idx = GetAllInputName(); | |||||
| if (input_name_idx.find(name) != input_name_idx.end()) { | |||||
| GELOGI("Restore input name index is existed. name[%s]", name.c_str()); | GELOGI("Restore input name index is existed. name[%s]", name.c_str()); | ||||
| } | } | ||||
| (void)input_name_idx_.insert(make_pair(name, index)); | |||||
| (void)input_name_idx.insert(make_pair(name, index)); | |||||
| SetAllInputName(input_name_idx); | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -1104,4 +1186,45 @@ graphStatus OpDesc::CallInferFormatFunc(Operator &op) { | |||||
| } | } | ||||
| return (graphStatus)infer_format_func_(op); | return (graphStatus)infer_format_func_(op); | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { | |||||
| if (static_cast<size_t>(index) >= subgraph_instance_names_.size()) { | |||||
| return ""; | |||||
| } | |||||
| return subgraph_instance_names_.at(index); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &OpDesc::GetSubgraphInstanceNames() | |||||
| const { | |||||
| return subgraph_instance_names_; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { | |||||
| subgraph_instance_names_.emplace_back(std::move(name)); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | |||||
| for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | |||||
| if (*iter == name) { | |||||
| subgraph_instance_names_.erase(iter); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | |||||
| auto iter = subgraph_names_to_index_.find(name); | |||||
| if (iter != subgraph_names_to_index_.end()) { | |||||
| GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto size = subgraph_names_to_index_.size(); | |||||
| subgraph_names_to_index_[name] = size; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint32_t> &OpDesc::GetSubgraphNameIndexes() | |||||
| const { | |||||
| return subgraph_names_to_index_; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -20,8 +20,7 @@ | |||||
| #include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| using std::function; | |||||
| using std::vector; | |||||
| using namespace std; | |||||
| namespace ge { | namespace ge { | ||||
| @@ -15,13 +15,12 @@ | |||||
| */ | */ | ||||
| #include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
| #include <stdint.h> | #include <stdint.h> | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <mutex> | #include <mutex> | ||||
| #include <queue> | #include <queue> | ||||
| #include <set> | #include <set> | ||||
| #include "array_ops.h" | |||||
| #include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
| #include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| @@ -33,7 +32,6 @@ | |||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
| #include "graph/operator_factory.h" | |||||
| #include "graph/usr_types.h" | #include "graph/usr_types.h" | ||||
| #include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
| #include "utils/op_desc_utils.h" | #include "utils/op_desc_utils.h" | ||||
| @@ -48,10 +46,6 @@ using std::string; | |||||
| using std::to_string; | using std::to_string; | ||||
| using std::vector; | using std::vector; | ||||
| namespace { | |||||
| const char *const kValue = "value"; | |||||
| } // namespace | |||||
| namespace ge { | namespace ge { | ||||
| class OpIO { | class OpIO { | ||||
| public: | public: | ||||
| @@ -148,6 +142,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
| for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) { | for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) { | ||||
| is_input_const.push_back(false); | is_input_const.push_back(false); | ||||
| } | } | ||||
| is_input_const[dst_index] = is_const; | is_input_const[dst_index] = is_const; | ||||
| op_desc_->SetIsInputConst(is_input_const); | op_desc_->SetIsInputConst(is_input_const); | ||||
| @@ -179,8 +174,8 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
| GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | ||||
| op_desc_->GetName().c_str()); | op_desc_->GetName().c_str()); | ||||
| auto out_op_impl = out_handler->GetOwner(); | auto out_op_impl = out_handler->GetOwner(); | ||||
| GE_CHK_BOOL_EXEC(out_op_impl && out_op_impl->GetOpDescImpl(), return, "out_handler invalid. name[%s]", | |||||
| dst_name.c_str()); | |||||
| GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, | |||||
| "out_handler invalid. name[%s]", dst_name.c_str()); | |||||
| bool is_const = false; | bool is_const = false; | ||||
| if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { | if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { | ||||
| is_const = true; | is_const = true; | ||||
| @@ -193,7 +188,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
| op_desc_->SetIsInputConst(is_input_const); | op_desc_->SetIsInputConst(is_input_const); | ||||
| OpIO in_handler(dst_name, dst_index, shared_from_this()); | OpIO in_handler(dst_name, dst_index, shared_from_this()); | ||||
| GE_CHK_BOOL_EXEC(!!out_op_impl, return, "Get out_handler's impl failed."); | |||||
| GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); | |||||
| out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | ||||
| auto src_output_desc = out_op_impl->GetOutputDesc(src_name); | auto src_output_desc = out_op_impl->GetOutputDesc(src_name); | ||||
| @@ -210,7 +205,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
| void AddControlInputImp(const ge::Operator &src_oprt) { | void AddControlInputImp(const ge::Operator &src_oprt) { | ||||
| if (src_oprt.operator_impl_ == nullptr) { | if (src_oprt.operator_impl_ == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "Src operator impl is nullptr"); | |||||
| GELOGE(FAILED, "Src operator impl is nullptr"); | |||||
| return; | return; | ||||
| } | } | ||||
| for (auto &input : control_input_link_) { | for (auto &input : control_input_link_) { | ||||
| @@ -520,9 +515,9 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||||
| if (peer_node_ptr->GetOpDesc() != nullptr) { | if (peer_node_ptr->GetOpDesc() != nullptr) { | ||||
| const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | ||||
| if (op_descType == CONSTANTOP) { | if (op_descType == CONSTANTOP) { | ||||
| return const_op.GetAttr(kValue, data); | |||||
| return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||||
| } else if (op_descType == CONSTANT) { | } else if (op_descType == CONSTANT) { | ||||
| return const_op.GetAttr(kValue, data); | |||||
| return const_op.GetAttr(op::Const::name_attr_value(), data); | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -542,9 +537,9 @@ graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) | |||||
| Operator const_op(out_handle.GetOwner()); | Operator const_op(out_handle.GetOwner()); | ||||
| const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); | const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); | ||||
| if (op_desc_impl_type == CONSTANTOP) { | if (op_desc_impl_type == CONSTANTOP) { | ||||
| return const_op.GetAttr(kValue, data); | |||||
| return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||||
| } else if (op_desc_impl_type == CONSTANT) { | } else if (op_desc_impl_type == CONSTANT) { | ||||
| return const_op.GetAttr(kValue, data); | |||||
| return const_op.GetAttr(op::Const::name_attr_value(), data); | |||||
| } | } | ||||
| } | } | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| @@ -709,6 +704,7 @@ void Operator::InputRegister(const string &name) { | |||||
| void Operator::OptionalInputRegister(const string &name) { | void Operator::OptionalInputRegister(const string &name) { | ||||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
| // [No need to verify return value] | |||||
| (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | ||||
| GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | ||||
| } | } | ||||
| @@ -716,24 +712,28 @@ void Operator::OptionalInputRegister(const string &name) { | |||||
| void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
| // [No need to verify return value] | |||||
| (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); | (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); | ||||
| } | } | ||||
| void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
| // [No need to verify return value] | |||||
| (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | ||||
| } | } | ||||
| void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | ||||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
| // [No need to verify return value] | |||||
| (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | ||||
| } | } | ||||
| void Operator::OutputRegister(const string &name) { | void Operator::OutputRegister(const string &name) { | ||||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
| // [No need to verify return value] | |||||
| (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); | (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); | ||||
| } | } | ||||
| @@ -757,7 +757,8 @@ int Operator::GetDynamicInputNum(const string &name) const { | |||||
| void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { | void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { | ||||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | ||||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | ||||
| (void)AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, | |||||
| "Set %s int failed", name.c_str()); | |||||
| (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); | (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); | ||||
| } | } | ||||
| @@ -765,7 +766,8 @@ int Operator::GetDynamicOutputNum(const string &name) const { | |||||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | ||||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | ||||
| int num = 0; | int num = 0; | ||||
| (void)AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num); | |||||
| GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, | |||||
| "Get %s int failed", name.c_str()); | |||||
| return num; | return num; | ||||
| } | } | ||||
| @@ -1141,7 +1143,9 @@ class GraphBuilderImpl { | |||||
| GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); | GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); | ||||
| } | } | ||||
| } | } | ||||
| GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, | |||||
| "User Input do not include operator such as \ | |||||
| Data, Variable operator or operator that has output but no input."); | |||||
| auto ret = WalkAllOperators(vec_inputs); | auto ret = WalkAllOperators(vec_inputs); | ||||
| GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | ||||
| @@ -1163,7 +1167,8 @@ class GraphBuilderImpl { | |||||
| que.pop(); | que.pop(); | ||||
| for (const auto &op_impl : vec_tem) { | for (const auto &op_impl : vec_tem) { | ||||
| GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") | GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") | ||||
| GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue) | |||||
| GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, | |||||
| "This node %s has created.", op_impl->GetName().c_str()) | |||||
| auto node_ptr = graph_->AddNode(op_impl->op_desc_); | auto node_ptr = graph_->AddNode(op_impl->op_desc_); | ||||
| GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); | GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); | ||||
| all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); | all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); | ||||
| @@ -1202,10 +1207,13 @@ class GraphBuilderImpl { | |||||
| for (const auto &node_info : all_nodes_info_) { | for (const auto &node_info : all_nodes_info_) { | ||||
| auto src_op_impl_ptr = node_info.first; | auto src_op_impl_ptr = node_info.first; | ||||
| auto src_node_ptr = node_info.second; | auto src_node_ptr = node_info.second; | ||||
| GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); | GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); | ||||
| auto out_links = src_op_impl_ptr->output_links_; | auto out_links = src_op_impl_ptr->output_links_; | ||||
| GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, | |||||
| "Src operator impl's op_desc is null."); | |||||
| auto &op_desc = src_op_impl_ptr->op_desc_; | auto &op_desc = src_op_impl_ptr->op_desc_; | ||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | |||||
| for (const auto &out : out_links) { | for (const auto &out : out_links) { | ||||
| auto src_idx = op_desc->GetOutputIndexByName(out.first); | auto src_idx = op_desc->GetOutputIndexByName(out.first); | ||||
| GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); | GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); | ||||
| @@ -1216,7 +1224,9 @@ class GraphBuilderImpl { | |||||
| for (const auto &dst_opio : out.second) { | for (const auto &dst_opio : out.second) { | ||||
| auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); | auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); | ||||
| GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); | GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); | ||||
| GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); | GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); | ||||
| auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); | auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); | ||||
| GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); | GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); | ||||
| @@ -1260,8 +1270,7 @@ inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { | |||||
| ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | ||||
| auto graph_builder_impl = GraphBuilderImpl(name); | auto graph_builder_impl = GraphBuilderImpl(name); | ||||
| ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); | ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); | ||||
| GE_IF_BOOL_EXEC(compute_graph == nullptr, return compute_graph); | |||||
| GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); | |||||
| compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); | compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); | ||||
| if (HasSameNameNode(compute_graph)) { | if (HasSameNameNode(compute_graph)) { | ||||
| GELOGW("Compute do not allow has same name nodes."); | GELOGW("Compute do not allow has same name nodes."); | ||||
| @@ -15,13 +15,11 @@ | |||||
| */ | */ | ||||
| #include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
| #include <algorithm> | |||||
| #include <cstdlib> | #include <cstdlib> | ||||
| #include <algorithm> | |||||
| #include <functional> | #include <functional> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <sstream> | #include <sstream> | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/debug/ge_log.h" | #include "graph/debug/ge_log.h" | ||||
| @@ -155,7 +153,7 @@ void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { | |||||
| // Load .so file | // Load .so file | ||||
| for (auto elem : file_list) { | for (auto elem : file_list) { | ||||
| void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); | |||||
| void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||||
| if (handle == nullptr) { | if (handle == nullptr) { | ||||
| GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | ||||
| continue; | continue; | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "./ge_context.h" | #include "./ge_context.h" | ||||
| #include "./ge_global_options.h" | #include "./ge_global_options.h" | ||||
| #include "./ge_local_context.h" | #include "./ge_local_context.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| @@ -87,4 +86,5 @@ uint32_t GEContext::DeviceId() { return device_id_; } | |||||
| uint64_t GEContext::TraceId() { return trace_id_; } | uint64_t GEContext::TraceId() { return trace_id_; } | ||||
| void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
| #include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
| #include "external/graph/operator.h" | #include "external/graph/operator.h" | ||||
| @@ -34,6 +35,122 @@ | |||||
| #include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| constexpr const char *kRefIndex = "parent_node_index"; | |||||
| graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
| for (const auto &name : sub_graph_names) { | |||||
| auto sub_graph = root_graph->GetSubgraph(name); | |||||
| if (sub_graph == nullptr) { | |||||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||||
| if (node_sub->GetType() != DATA) { | |||||
| continue; | |||||
| } | |||||
| int ref_i; | |||||
| auto data_opdesc = node_sub->GetOpDesc(); | |||||
| if (data_opdesc == nullptr) { | |||||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { | |||||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto input_desc = op_desc->MutableInputDesc(ref_i); | |||||
| if (input_desc == nullptr) { | |||||
| GE_LOGE( | |||||
| "The ref index(%d) on the data %s on the sub graph %s " | |||||
| "parent node %s are incompatible, inputs num %u", | |||||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret = data_opdesc->UpdateOutputDesc(0, *input_desc); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", | |||||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| if (sub_graph_names.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||||
| for (const auto &name : sub_graph_names) { | |||||
| auto sub_graph = root_graph->GetSubgraph(name); | |||||
| if (sub_graph == nullptr) { | |||||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr netoutput = nullptr; | |||||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||||
| auto sub_node = sub_nodes.at(i - 1); | |||||
| if (sub_node->GetType() == NETOUTPUT) { | |||||
| netoutput = sub_node; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (netoutput == nullptr) { | |||||
| GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto netoutput_opdesc = netoutput->GetOpDesc(); | |||||
| if (netoutput_opdesc == nullptr) { | |||||
| GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), | |||||
| node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { | |||||
| auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); | |||||
| if (edge_desc == nullptr) { | |||||
| GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), | |||||
| node->GetName().c_str(), edge_anchor->GetIdx()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| int ref_i; | |||||
| if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { | |||||
| // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||||
| continue; | |||||
| } | |||||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||||
| if (output_desc == nullptr) { | |||||
| GE_LOGE( | |||||
| "The ref index(%d) on the input %d of netoutput %s on the sub graph %s " | |||||
| "parent node %s are incompatible, outputs num %u", | |||||
| ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||||
| node->GetAllOutDataAnchorsSize()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace | |||||
| void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "node is null"); | GELOGE(GRAPH_FAILED, "node is null"); | ||||
| @@ -42,7 +159,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
| ge::OpDescPtr op_desc = node->GetOpDesc(); | ge::OpDescPtr op_desc = node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | ||||
| std::string str; | std::string str; | ||||
| if (!op_desc->GetAllInputsDescPtr().empty()) { | |||||
| if (op_desc->GetInputsSize() != 0) { | |||||
| std::string input_desc_str = "input shape: "; | std::string input_desc_str = "input shape: "; | ||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | ||||
| input_desc_str += "["; | input_desc_str += "["; | ||||
| @@ -56,7 +173,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
| str += input_desc_str; | str += input_desc_str; | ||||
| } | } | ||||
| if (!op_desc->GetAllOutputsDescPtr().empty()) { | |||||
| if (op_desc->GetAllOutputsDescSize() != 0) { | |||||
| std::string output_desc_str = "output shape: "; | std::string output_desc_str = "output shape: "; | ||||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | ||||
| if (output_desc == nullptr) { | if (output_desc == nullptr) { | ||||
| @@ -76,13 +193,24 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
| } | } | ||||
| graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { | graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { | ||||
| return InferShapeAndType(node, op, true); | |||||
| } | |||||
| graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { | |||||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | ||||
| const auto &op_type = op_desc->GetType(); | const auto &op_type = op_desc->GetType(); | ||||
| graphStatus ret; | |||||
| if (before_subgraph) { | |||||
| ret = UpdateSubGraphDataNodes(node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| // Get infer func and execute | // Get infer func and execute | ||||
| graphStatus ret = op_desc->CallInferFunc(op); | |||||
| ret = op_desc->CallInferFunc(op); | |||||
| if (ret == GRAPH_PARAM_INVALID) { | if (ret == GRAPH_PARAM_INVALID) { | ||||
| // Op ir no infer func, try to get infer func from operator factory | // Op ir no infer func, try to get infer func from operator factory | ||||
| auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | ||||
| @@ -113,7 +241,14 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & | |||||
| ret = op_desc->CallInferFunc(op); | ret = op_desc->CallInferFunc(op); | ||||
| GELOGI("op CallInferFunc second. ret: %u", ret); | GELOGI("op CallInferFunc second. ret: %u", ret); | ||||
| } | } | ||||
| return ret; | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| if (!before_subgraph) { | |||||
| return UpdateParentNodeOutTensor(node); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | } | ||||
| InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | ||||
| @@ -179,8 +314,11 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf | |||||
| namespace { | namespace { | ||||
| std::unordered_map<NodePtr, InferenceContextPtr> context_map; | std::unordered_map<NodePtr, InferenceContextPtr> context_map; | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | ||||
| return InferShapeAndType(node, true); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, | |||||
| bool before_subgraph) { | |||||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | ||||
| if (node->Verify() != GRAPH_SUCCESS) { | if (node->Verify() != GRAPH_SUCCESS) { | ||||
| GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | ||||
| @@ -199,7 +337,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | Operator op = OpDescUtils::CreateOperatorFromNode(node); | ||||
| op.SetInferenceContext(inference_context); | op.SetInferenceContext(inference_context); | ||||
| graphStatus status = InferShapeAndType(node, op); | |||||
| graphStatus status = InferShapeAndType(node, op, before_subgraph); | |||||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | ||||
| (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); | (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); | ||||
| } else { | } else { | ||||
| @@ -353,6 +353,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | ||||
| } | } | ||||
| @@ -516,13 +517,14 @@ graphStatus Tensor::IsValid() { | |||||
| GELOGW("mul overflow: %lu, %u", shape_size, type_length); | GELOGW("mul overflow: %lu, %u", shape_size, type_length); | ||||
| } else { | } else { | ||||
| if (shape_size * type_length != data_size) { | if (shape_size * type_length != data_size) { | ||||
| // [Just log] Constructor | |||||
| GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | ||||
| data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return GRAPH_FAILED; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -539,7 +541,7 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des | |||||
| tensor_desc.GetDataType()); | tensor_desc.GetDataType()); | ||||
| ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | ||||
| ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | ||||
| auto size = static_cast<uint32_t>(tensor_desc.GetSize()); | |||||
| auto size = tensor_desc.GetSize(); | |||||
| TensorUtils::SetSize(ge_tensor_desc, size); | TensorUtils::SetSize(ge_tensor_desc, size); | ||||
| auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt()); | auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt()); | ||||
| @@ -552,7 +554,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||||
| ge_tensor_desc.GetDataType()); | ge_tensor_desc.GetDataType()); | ||||
| tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | ||||
| tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | ||||
| uint32_t size = 0; | |||||
| int64_t size = 0; | |||||
| (void)TensorUtils::GetSize(ge_tensor_desc, size); | (void)TensorUtils::GetSize(ge_tensor_desc, size); | ||||
| tensor_desc.SetSize(size); | tensor_desc.SetSize(size); | ||||
| @@ -15,18 +15,21 @@ | |||||
| */ | */ | ||||
| #include "graph/utils/ge_ir_utils.h" | #include "graph/utils/ge_ir_utils.h" | ||||
| #include <utility> | #include <utility> | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| namespace { | namespace { | ||||
| const char *const kControlAnchorIndex = ":-1"; | const char *const kControlAnchorIndex = ":-1"; | ||||
| const char *const kNodeTypeForSubgraph = "subgraph"; | const char *const kNodeTypeForSubgraph = "subgraph"; | ||||
| const char *const kPrefixForInputDesc = "input_desc_attr_"; | |||||
| const char *const kPrefixForOutputDesc = "output_desc_attr_"; | |||||
| const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; | const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; | ||||
| const int8_t kMaxRecursionDepth = 10; | const int8_t kMaxRecursionDepth = 10; | ||||
| const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); | const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); | ||||
| const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; | const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; | ||||
| const int64_t kInputPrefixLength = 5; | |||||
| const int64_t kOutputPrefixLength = 6; | |||||
| using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -198,7 +201,7 @@ void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_A | |||||
| void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, | void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, | ||||
| ::google::protobuf::RepeatedField<bool> data) { | ::google::protobuf::RepeatedField<bool> data) { | ||||
| if (node_proto == nullptr) { | if (node_proto == nullptr) { | ||||
| GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); | |||||
| GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); | |||||
| return; | return; | ||||
| } | } | ||||
| if (!data.empty()) { | if (!data.empty()) { | ||||
| @@ -320,7 +323,16 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||||
| auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); | auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | ||||
| "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); | "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); | ||||
| const auto &tensor_desc_map = tensor_descriptor->attr(); | |||||
| std::string suffix = ":" + std::to_string(i); | |||||
| AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); | |||||
| } else { | |||||
| GELOGW("Tensor descriptor is nullptr"); | |||||
| continue; | |||||
| } | } | ||||
| } else { | |||||
| GELOGW("Input desc is nullptr"); | |||||
| continue; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -360,16 +372,25 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||||
| auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); | auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | ||||
| "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); | "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); | ||||
| const auto &tensor_desc_map = tensor_descriptor->attr(); | |||||
| std::string suffix = ":" + std::to_string(i); | |||||
| AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); | |||||
| } else { | |||||
| GELOGW("Tensor descriptor is nullptr"); | |||||
| continue; | |||||
| } | } | ||||
| } else { | |||||
| GELOGW("Output desc is nullptr"); | |||||
| continue; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto) { | |||||
| GE_CHK_BOOL_EXEC(op_def != nullptr, return, "Opdef is nullptr"); | |||||
| const auto &op_def_attr_map = op_def->attr(); | |||||
| for (const auto &item : op_def_attr_map) { | |||||
| void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( | |||||
| const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto, | |||||
| const std::string &prefix, const std::string &suffix) { | |||||
| for (const auto &item : attr_map) { | |||||
| auto attr_name = item.first; | auto attr_name = item.first; | ||||
| auto attr_def = item.second; | auto attr_def = item.second; | ||||
| auto attr_type = attr_def.value_case(); | auto attr_type = attr_def.value_case(); | ||||
| @@ -377,36 +398,40 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||||
| const auto &tensor_def = attr_def.t(); | const auto &tensor_def = attr_def.t(); | ||||
| const auto &tensor_desc = tensor_def.desc(); | const auto &tensor_desc = tensor_def.desc(); | ||||
| auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); | auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_dtype:", &data_type); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, | |||||
| &data_type); | |||||
| auto dims = tensor_desc.shape().dim(); | auto dims = tensor_desc.shape().dim(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name + "_desc_shape:", dims); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, | |||||
| dims); | |||||
| auto layout = tensor_desc.layout(); | auto layout = tensor_desc.layout(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_layout:", &layout); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, | |||||
| &layout); | |||||
| auto device_type = tensor_desc.device_type(); | auto device_type = tensor_desc.device_type(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | ||||
| attr_name + "_desc_device_type:", &device_type); | |||||
| prefix + attr_name + "_desc_device_type" + suffix, &device_type); | |||||
| if (kDumpLevel == DUMP_ALL) { | if (kDumpLevel == DUMP_ALL) { | ||||
| auto data = tensor_def.data(); | auto data = tensor_def.data(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_data", &data); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, | |||||
| &data); | |||||
| } | } | ||||
| } | } | ||||
| if (attr_type == ge::proto::AttrDef::kS) { | if (attr_type == ge::proto::AttrDef::kS) { | ||||
| if (kDumpLevel == DUMP_ALL) { | if (kDumpLevel == DUMP_ALL) { | ||||
| auto str_value = attr_def.s(); | auto str_value = attr_def.s(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name, &str_value); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); | |||||
| } | } | ||||
| } | } | ||||
| if (attr_type == ge::proto::AttrDef::kI) { | if (attr_type == ge::proto::AttrDef::kI) { | ||||
| auto int_value = attr_def.i(); | auto int_value = attr_def.i(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||||
| } | } | ||||
| if (attr_type == ge::proto::AttrDef::kF) { | if (attr_type == ge::proto::AttrDef::kF) { | ||||
| auto float_value = attr_def.f(); | auto float_value = attr_def.f(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, attr_name, &float_value); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); | |||||
| } | } | ||||
| if (attr_type == ge::proto::AttrDef::kB) { | if (attr_type == ge::proto::AttrDef::kB) { | ||||
| auto int_value = static_cast<int64_t>(attr_def.b()); | auto int_value = static_cast<int64_t>(attr_def.b()); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||||
| } | } | ||||
| if (attr_type == ge::proto::AttrDef::kList) { | if (attr_type == ge::proto::AttrDef::kList) { | ||||
| const auto &list_value = attr_def.list(); | const auto &list_value = attr_def.list(); | ||||
| @@ -415,21 +440,21 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||||
| ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { | ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { | ||||
| if (kDumpLevel == DUMP_ALL) { | if (kDumpLevel == DUMP_ALL) { | ||||
| const auto &strings = list_value.s(); | const auto &strings = list_value.s(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, attr_name, strings); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); | |||||
| } | } | ||||
| } | } | ||||
| if (list_value_type == | if (list_value_type == | ||||
| ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { | ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { | ||||
| const auto &floats = list_value.f(); | const auto &floats = list_value.f(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, attr_name, floats); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); | |||||
| } | } | ||||
| if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { | if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { | ||||
| const auto &ints = list_value.i(); | const auto &ints = list_value.i(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, ints); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); | |||||
| } | } | ||||
| if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { | if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { | ||||
| const auto &bools = list_value.b(); | const auto &bools = list_value.b(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, bools); | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -481,8 +506,15 @@ void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto | |||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); | ||||
| const auto &is_input_const = op_def->is_input_const(); | const auto &is_input_const = op_def->is_input_const(); | ||||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); | ||||
| AddAttrProtoForAttrsFromOpDef(op_def, node_proto); | |||||
| const auto &op_def_attr_map = op_def->attr(); | |||||
| AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); | |||||
| } else { | |||||
| GELOGE(FAILED, "Opdef is nullptr"); | |||||
| return; | |||||
| } | } | ||||
| } else { | |||||
| GELOGE(FAILED, "Opdesc is nullptr"); | |||||
| return; | |||||
| } | } | ||||
| } | } | ||||
| @@ -526,15 +558,13 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||||
| node_proto->clear_input(); | node_proto->clear_input(); | ||||
| // 1. Add input by in data edge | // 1. Add input by in data edge | ||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| if (in_data_anchor != nullptr) { | |||||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||||
| node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||||
| std::to_string(peer_out_anchor->GetIdx())); | |||||
| } else { | |||||
| // Add "" input | |||||
| node_proto->add_input(""); | |||||
| } | |||||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||||
| node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||||
| std::to_string(peer_out_anchor->GetIdx())); | |||||
| } else { | |||||
| // Add "" input | |||||
| node_proto->add_input(""); | |||||
| } | } | ||||
| } | } | ||||
| @@ -547,6 +577,9 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||||
| node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); | node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); | ||||
| } | } | ||||
| } | } | ||||
| } else { | |||||
| GELOGE(FAILED, "Incontrol anchor is nullptr"); | |||||
| return false; | |||||
| } | } | ||||
| // 3. Add output for Netron visual support | // 3. Add output for Netron visual support | ||||
| @@ -584,7 +617,7 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||||
| } | } | ||||
| const auto &op_desc = node->GetOpDesc(); | const auto &op_desc = node->GetOpDesc(); | ||||
| if (op_desc != nullptr) { | if (op_desc != nullptr) { | ||||
| auto size_out = op_desc->GetOutputsSize(); | |||||
| uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||||
| if (size_out > 0) { | if (size_out > 0) { | ||||
| for (uint32_t i = 0; i < size_out; i++) { | for (uint32_t i = 0; i < size_out; i++) { | ||||
| const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); | const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); | ||||
| @@ -598,7 +631,13 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||||
| auto dim = shape->add_dim(); | auto dim = shape->add_dim(); | ||||
| dim->set_dim_value(d); | dim->set_dim_value(d); | ||||
| } | } | ||||
| } else { | |||||
| GELOGW("Shape is nullptr"); | |||||
| continue; | |||||
| } | } | ||||
| } else { | |||||
| GELOGW("Ge tensor is nullptr"); | |||||
| continue; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -666,7 +705,7 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||||
| } | } | ||||
| // For subgraphs: a subgraph is represented by a node | // For subgraphs: a subgraph is represented by a node | ||||
| for (const auto &sub_compute_graph : compute_graph->sub_graph_) { | |||||
| for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { | |||||
| if (sub_compute_graph != nullptr) { | if (sub_compute_graph != nullptr) { | ||||
| auto node_proto = graph_proto->add_node(); | auto node_proto = graph_proto->add_node(); | ||||
| if (node_proto == nullptr) { | if (node_proto == nullptr) { | ||||
| @@ -679,6 +718,10 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||||
| attr->set_name("graph"); | attr->set_name("graph"); | ||||
| attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); | attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); | ||||
| auto sub_graph_proto = attr->mutable_g(); | auto sub_graph_proto = attr->mutable_g(); | ||||
| if (sub_graph_proto == nullptr) { | |||||
| GELOGW("Sub graph proto is nullptr"); | |||||
| continue; | |||||
| } | |||||
| if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { | if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { | ||||
| GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); | GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); | ||||
| continue; | continue; | ||||
| @@ -831,56 +874,116 @@ void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t | |||||
| value = attr_proto.i(); | value = attr_proto.i(); | ||||
| } | } | ||||
| void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_input_output_desc, int32_t index, | |||||
| OpDescPtr &op_desc) { | |||||
| if (op_desc == nullptr || op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "op_desc or op_desc->MutableInputDesc(index) is nullptr"); | |||||
| void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_input_desc, int32_t index, | |||||
| OpDescPtr &op_desc) { | |||||
| if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr", | |||||
| op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); | |||||
| return; | return; | ||||
| } | } | ||||
| if (attr_name_for_input_output_desc == "input_desc_dtype") { | |||||
| if (attr_name_for_input_desc == "input_desc_dtype") { | |||||
| auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | ||||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | ||||
| } else if (attr_name_for_input_output_desc == "input_desc_shape") { | |||||
| } else if (attr_name_for_input_desc == "input_desc_shape") { | |||||
| std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
| DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
| GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | ||||
| } else if (attr_name_for_input_output_desc == "input_desc_layout") { | |||||
| } else if (attr_name_for_input_desc == "input_desc_layout") { | |||||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | ||||
| } else if (attr_name_for_input_output_desc == "input_desc_origin_shape") { | |||||
| } else if (attr_name_for_input_desc == "input_desc_origin_shape") { | |||||
| std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
| DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
| GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | ||||
| } else if (attr_name_for_input_output_desc == "input_desc_origin_layout") { | |||||
| } else if (attr_name_for_input_desc == "input_desc_origin_layout") { | |||||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | ||||
| } else if (attr_name_for_input_output_desc == "output_desc_dtype") { | |||||
| } else if (attr_name_for_input_desc == "input_desc_size") { | |||||
| int64_t input_size = 0; | |||||
| auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
| DecodeAttribute(attr_proto, input_size); | |||||
| tensor_descriptor->set_size(input_size); | |||||
| } else if (attr_name_for_input_desc == "input_desc_data_offset") { | |||||
| auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
| int64_t offset = 0; | |||||
| DecodeAttribute(attr_proto, offset); | |||||
| tensor_descriptor->set_data_offset(offset); | |||||
| } else { | |||||
| return; | |||||
| } | |||||
| } | |||||
| void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_output_desc, int32_t index, | |||||
| OpDescPtr &op_desc) { | |||||
| if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr", | |||||
| op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); | |||||
| return; | |||||
| } | |||||
| if (attr_name_for_output_desc == "output_desc_dtype") { | |||||
| auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | ||||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | ||||
| } else if (attr_name_for_input_output_desc == "output_desc_shape") { | |||||
| } else if (attr_name_for_output_desc == "output_desc_shape") { | |||||
| std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
| DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
| GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | ||||
| } else if (attr_name_for_input_output_desc == "output_desc_layout") { | |||||
| } else if (attr_name_for_output_desc == "output_desc_layout") { | |||||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | ||||
| } else if (attr_name_for_input_output_desc == "output_desc_origin_shape") { | |||||
| } else if (attr_name_for_output_desc == "output_desc_origin_shape") { | |||||
| std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
| DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
| GeShape ge_shape(ints); | GeShape ge_shape(ints); | ||||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | ||||
| } else if (attr_name_for_input_output_desc == "output_desc_origin_layout") { | |||||
| } else if (attr_name_for_output_desc == "output_desc_origin_layout") { | |||||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | ||||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | ||||
| } else if (attr_name_for_output_desc == "output_desc_size") { | |||||
| int64_t output_size = 0; | |||||
| auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
| DecodeAttribute(attr_proto, output_size); | |||||
| tensor_descriptor->set_size(output_size); | |||||
| } else if (attr_name_for_output_desc == "output_desc_data_offset") { | |||||
| auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||||
| int64_t offset = 0; | |||||
| DecodeAttribute(attr_proto, offset); | |||||
| tensor_descriptor->set_data_offset(offset); | |||||
| } else { | |||||
| return; | |||||
| } | |||||
| } | |||||
| void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_input_output_desc, int32_t index, | |||||
| OpDescPtr &op_desc) { | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "op_desc is nullptr"); | |||||
| return; | |||||
| } | |||||
| if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { | |||||
| DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||||
| } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { | |||||
| DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||||
| } else { | } else { | ||||
| return; | return; | ||||
| } | } | ||||
| } | } | ||||
| void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { | |||||
| auto attr_map = op_def.mutable_attr(); | |||||
| const auto &attr_name = attr_proto.name(); | |||||
| ge::proto::AttrDef op_attr; | |||||
| int64_t value = 0; | |||||
| DecodeAttribute(attr_proto, value); | |||||
| op_attr.set_i(value); | |||||
| attr_map->insert(AttrDefPair(attr_name, op_attr)); | |||||
| } | |||||
| void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { | void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { | ||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); | GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); | ||||
| @@ -910,6 +1013,16 @@ void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_pr | |||||
| std::vector<std::int64_t> ints; | std::vector<std::int64_t> ints; | ||||
| DecodeAttribute(attr_proto, ints); | DecodeAttribute(attr_proto, ints); | ||||
| op_desc->SetDstIndex(ints); | op_desc->SetDstIndex(ints); | ||||
| } else if (attr_name == "fusion_scope") { | |||||
| DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); | |||||
| } else if (attr_name == "input_i") { | |||||
| std::vector<std::int64_t> ints; | |||||
| DecodeAttribute(attr_proto, ints); | |||||
| op_desc->SetInputOffset(ints); | |||||
| } else if (attr_name == "output_i") { | |||||
| std::vector<std::int64_t> ints; | |||||
| DecodeAttribute(attr_proto, ints); | |||||
| op_desc->SetOutputOffset(ints); | |||||
| } else { | } else { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -939,20 +1052,14 @@ bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_ | |||||
| auto size_in = attr.i(); | auto size_in = attr.i(); | ||||
| for (int64_t i = 0; i < size_in; i++) { | for (int64_t i = 0; i < size_in; i++) { | ||||
| GeTensorDesc ge_tensor_desc; | GeTensorDesc ge_tensor_desc; | ||||
| if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
| GELOGW("Add inputdesc failed"); | |||||
| continue; | |||||
| } | |||||
| GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); | |||||
| } | } | ||||
| } | } | ||||
| if (attr.name() == "output_desc_nums") { | if (attr.name() == "output_desc_nums") { | ||||
| auto size_out = attr.i(); | auto size_out = attr.i(); | ||||
| for (int64_t i = 0; i < size_out; i++) { | for (int64_t i = 0; i < size_out; i++) { | ||||
| GeTensorDesc ge_tensor_desc; | GeTensorDesc ge_tensor_desc; | ||||
| if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
| GELOGW("add inputdesc failed"); | |||||
| continue; | |||||
| } | |||||
| GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -970,10 +1077,7 @@ bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_p | |||||
| } | } | ||||
| graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name()); | graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name()); | ||||
| if (graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); | |||||
| return false; | |||||
| } | |||||
| GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); | |||||
| /// 1. Decode all nodes first, node should include input | /// 1. Decode all nodes first, node should include input | ||||
| /// and output nodes and nodes which represent sub graphs | /// and output nodes and nodes which represent sub graphs | ||||
| std::map<std::string, NodePtr> node_map; | std::map<std::string, NodePtr> node_map; | ||||
| @@ -131,6 +131,10 @@ class OnnxUtils { | |||||
| static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); | static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); | ||||
| static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map, | |||||
| onnx::NodeProto *node_proto, const std::string &prefix = "", | |||||
| const std::string &suffix = ""); | |||||
| static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); | static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); | ||||
| static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); | static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); | ||||
| @@ -172,10 +176,20 @@ class OnnxUtils { | |||||
| static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); | static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); | ||||
| static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_output_desc, int32_t index, | |||||
| OpDescPtr &op_desc); | |||||
| static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||||
| const std::string &attr_name_for_input_desc, int32_t index, | |||||
| OpDescPtr &op_desc); | |||||
| static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | ||||
| const std::string &attr_name_for_input_output_desc, int32_t index, | const std::string &attr_name_for_input_output_desc, int32_t index, | ||||
| OpDescPtr &op_desc); | OpDescPtr &op_desc); | ||||
| static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); | |||||
| static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); | static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); | ||||
| static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); | static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); | ||||
| @@ -15,10 +15,12 @@ | |||||
| */ | */ | ||||
| #include "utils/node_utils.h" | #include "utils/node_utils.h" | ||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
| #include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
| @@ -109,6 +111,7 @@ graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_pt | |||||
| graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { | graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { | ||||
| GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, | GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, | ||||
| "node or in_data_anchor is nullptr"); | "node or in_data_anchor is nullptr"); | ||||
| bool find_flag = false; | bool find_flag = false; | ||||
| uint32_t index = 0; | uint32_t index = 0; | ||||
| vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end(); | vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end(); | ||||
| @@ -358,4 +361,45 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||||
| input_desc->SetShape(shape); | input_desc->SetShape(shape); | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| std::string NodeUtils::GetNodeType(const Node &node) { | |||||
| if (node.GetType() != FRAMEWORKOP) { | |||||
| return node.GetType(); | |||||
| } | |||||
| std::string type; | |||||
| (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||||
| return type; | |||||
| } | |||||
| ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
| if (root_graph == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | |||||
| } | |||||
| graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { | |||||
| if (subgraph == nullptr) { | |||||
| GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
| if (root_graph == nullptr) { | |||||
| GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| op_desc->AddSubgraphInstanceName(subgraph->GetName()); | |||||
| subgraph->SetParentNode(node.shared_from_this()); | |||||
| subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | |||||
| root_graph->AddSubgraph(subgraph); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -15,9 +15,7 @@ | |||||
| */ | */ | ||||
| #include "utils/op_desc_utils.h" | #include "utils/op_desc_utils.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include "debug/ge_attr_define.h" | #include "debug/ge_attr_define.h" | ||||
| #include "debug/ge_op_types.h" | #include "debug/ge_op_types.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| @@ -209,6 +207,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | ||||
| const vector<ge::NodePtr> &input_nodes) { | const vector<ge::NodePtr> &input_nodes) { | ||||
| vector<ConstGeTensorPtr> ret; | vector<ConstGeTensorPtr> ret; | ||||
| for (const auto &input_node : input_nodes) { | for (const auto &input_node : input_nodes) { | ||||
| auto temp_weight = MutableWeights(input_node->GetOpDesc()); | auto temp_weight = MutableWeights(input_node->GetOpDesc()); | ||||
| if (temp_weight == nullptr) { | if (temp_weight == nullptr) { | ||||
| @@ -379,7 +378,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||||
| if (NodeUtils::IsAnchorStatusSet(*node)) { | if (NodeUtils::IsAnchorStatusSet(*node)) { | ||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_anchor : node->GetAllInDataAnchors()) { | ||||
| if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | ||||
| (void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
| ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
| } | } | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -389,7 +388,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | ||||
| (void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
| ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -572,4 +571,80 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||||
| } | } | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| /// | |||||
| /// @brief Add input | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | |||||
| inputs_.emplace_back(name); | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add dynamic input | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | |||||
| uint32_t num) { | |||||
| for (uint32_t i = 0; i < num; i++) { | |||||
| inputs_.emplace_back(name + std::to_string(i)); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add output | |||||
| /// @param [in] name | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | |||||
| outputs_.emplace_back(name); | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Add dynamic output | |||||
| /// @param [in] name | |||||
| /// @param [in] num | |||||
| /// @return OpDescBuilder | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | |||||
| uint32_t num) { | |||||
| for (uint32_t i = 0; i < num; i++) { | |||||
| outputs_.emplace_back(name + std::to_string(i)); | |||||
| } | |||||
| return *this; | |||||
| } | |||||
| /// | |||||
| /// @brief Build op_desc | |||||
| /// @return OpDescPtr | |||||
| /// | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { | |||||
| OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_)); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); | |||||
| return nullptr; | |||||
| } | |||||
| for (auto &input : inputs_) { | |||||
| if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Add input_desc failed."); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| for (auto &output : outputs_) { | |||||
| if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Add output_desc failed."); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return op_desc; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -15,7 +15,6 @@ | |||||
| */ | */ | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include <cmath> | #include <cmath> | ||||
| #include "debug/ge_log.h" | #include "debug/ge_log.h" | ||||
| @@ -276,6 +275,14 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||||
| break; | break; | ||||
| case FORMAT_FRACTAL_NZ: | case FORMAT_FRACTAL_NZ: | ||||
| case FORMAT_FRACTAL_ZZ: | case FORMAT_FRACTAL_ZZ: | ||||
| case FORMAT_NDHWC: | |||||
| case FORMAT_NCDHW: | |||||
| case FORMAT_DHWCN: | |||||
| case FORMAT_DHWNC: | |||||
| case FORMAT_FRACTAL_Z_3D: | |||||
| case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | |||||
| case FORMAT_NDC1HWC0: | |||||
| case FORMAT_FRACTAL_Z_C04: | |||||
| graph_status = CalcElementCntByDims(dims, element_cnt); | graph_status = CalcElementCntByDims(dims, element_cnt); | ||||
| break; | break; | ||||
| default: | default: | ||||
| @@ -351,21 +358,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTens | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
| TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||||
| TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||||
| graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); | graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); | ||||
| if (graph_status != GRAPH_SUCCESS) { | if (graph_status != GRAPH_SUCCESS) { | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| // 64-byte alignment, if size is 0, align to 32 bytes | // 64-byte alignment, if size is 0, align to 32 bytes | ||||
| if (size_temp > (UINT32_MAX - kNum2 * kDataMemAlignSize)) { | |||||
| GELOGW("The updated mem size %u is bigger than UINT32_MAX", size_temp); | |||||
| if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { | |||||
| GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); | |||||
| } else { | } else { | ||||
| size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | ||||
| } | } | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
| TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||||
| TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||||
| GeShape output_shape = desc_temp.GetShape(); | GeShape output_shape = desc_temp.GetShape(); | ||||
| Format format = desc_temp.GetFormat(); | Format format = desc_temp.GetFormat(); | ||||
| DataType data_type = desc_temp.GetDataType(); | DataType data_type = desc_temp.GetDataType(); | ||||
| @@ -376,13 +383,13 @@ TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_ | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| if ((output_mem_size > UINT32_MAX) || (output_mem_size < 0)) { | |||||
| GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %u]", | |||||
| output_mem_size, UINT32_MAX); | |||||
| if (output_mem_size < 0) { | |||||
| GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", | |||||
| output_mem_size, INT64_MAX); | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| size_temp = static_cast<uint32_t>(output_mem_size); | |||||
| size_temp = output_mem_size; | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -19,43 +19,45 @@ | |||||
| namespace ge { | namespace ge { | ||||
| static const std::map<Format, std::string> kFormatToStringMap = { | static const std::map<Format, std::string> kFormatToStringMap = { | ||||
| {FORMAT_NCHW, "NCHW"}, | |||||
| {FORMAT_NHWC, "NHWC"}, | |||||
| {FORMAT_ND, "ND"}, | |||||
| {FORMAT_NC1HWC0, "NC1HWC0"}, | |||||
| {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||||
| {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||||
| {FORMAT_NHWC1C0, "NHWC1C0"}, | |||||
| {FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||||
| {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||||
| {FORMAT_C1HWNC0, "C1HWNC0"}, | |||||
| {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||||
| {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||||
| {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||||
| {FORMAT_CHWN, "CHWN"}, | |||||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||||
| {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||||
| {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||||
| {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||||
| {FORMAT_HWCN, "HWCN"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||||
| {FORMAT_MD, "MD"}, | |||||
| {FORMAT_NDHWC, "NDHWC"}, | |||||
| {FORMAT_NCDHW, "NCDHW"}, | |||||
| {FORMAT_DHWCK, "DHWCK"}, | |||||
| {FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||||
| {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||||
| {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||||
| {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||||
| {FORMAT_CN, "CN"}, | |||||
| {FORMAT_NC, "NC"}, | |||||
| {FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||||
| {FORMAT_ALL, "ALL"}}; | |||||
| {FORMAT_NCHW, "NCHW"}, | |||||
| {FORMAT_NHWC, "NHWC"}, | |||||
| {FORMAT_ND, "ND"}, | |||||
| {FORMAT_NC1HWC0, "NC1HWC0"}, | |||||
| {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||||
| {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||||
| {FORMAT_NHWC1C0, "NHWC1C0"}, | |||||
| {FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||||
| {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||||
| {FORMAT_C1HWNC0, "C1HWNC0"}, | |||||
| {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||||
| {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||||
| {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||||
| {FORMAT_CHWN, "CHWN"}, | |||||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||||
| {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||||
| {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||||
| {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||||
| {FORMAT_HWCN, "HWCN"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||||
| {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||||
| {FORMAT_MD, "MD"}, | |||||
| {FORMAT_NDHWC, "NDHWC"}, | |||||
| {FORMAT_NCDHW, "NCDHW"}, | |||||
| {FORMAT_DHWCN, "DHWCN"}, | |||||
| {FORMAT_DHWNC, "DHWNC"}, | |||||
| {FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||||
| {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||||
| {FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, | |||||
| {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||||
| {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||||
| {FORMAT_CN, "CN"}, | |||||
| {FORMAT_NC, "NC"}, | |||||
| {FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||||
| {FORMAT_ALL, "ALL"}}; | |||||
| static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | ||||
| "FRACTAL_Z", | "FRACTAL_Z", | ||||
| @@ -73,137 +75,140 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||||
| "FRACTAL_ZZ", | "FRACTAL_ZZ", | ||||
| "FRACTAL_NZ", | "FRACTAL_NZ", | ||||
| "NDC1HWC0", | "NDC1HWC0", | ||||
| "FORMAT_FRACTAL_Z_3D"}; | |||||
| "FORMAT_FRACTAL_Z_3D", | |||||
| "FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; | |||||
| static const std::map<std::string, Format> kDataFormatMap = { | static const std::map<std::string, Format> kDataFormatMap = { | ||||
| {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"ND", FORMAT_ND}}; | |||||
| {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | |||||
| static const std::map<std::string, Format> kStringToFormatMap = { | static const std::map<std::string, Format> kStringToFormatMap = { | ||||
| {"NCHW", FORMAT_NCHW}, | |||||
| {"NHWC", FORMAT_NHWC}, | |||||
| {"ND", FORMAT_ND}, | |||||
| {"NC1HWC0", FORMAT_NC1HWC0}, | |||||
| {"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||||
| {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||||
| {"NHWC1C0", FORMAT_NHWC1C0}, | |||||
| {"FSR_NCHW", FORMAT_FSR_NCHW}, | |||||
| {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||||
| {"C1HWNC0", FORMAT_C1HWNC0}, | |||||
| {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||||
| {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||||
| {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||||
| {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||||
| {"CHWN", FORMAT_CHWN}, | |||||
| {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||||
| {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||||
| {"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||||
| {"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||||
| {"HWCN", FORMAT_HWCN}, | |||||
| {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||||
| {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||||
| {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||||
| {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||||
| {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||||
| {"MD", FORMAT_MD}, | |||||
| {"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||||
| {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||||
| {"NDHWC", FORMAT_NDHWC}, | |||||
| {"NCDHW", FORMAT_NCDHW}, | |||||
| {"DHWCK", FORMAT_DHWCK}, | |||||
| {"NDC1HWC0", FORMAT_NDC1HWC0}, | |||||
| {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||||
| {"CN", FORMAT_CN}, | |||||
| {"NC", FORMAT_NC}, | |||||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||||
| {"ALL", FORMAT_ALL}}; | |||||
| {"NCHW", FORMAT_NCHW}, | |||||
| {"NHWC", FORMAT_NHWC}, | |||||
| {"ND", FORMAT_ND}, | |||||
| {"NC1HWC0", FORMAT_NC1HWC0}, | |||||
| {"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||||
| {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||||
| {"NHWC1C0", FORMAT_NHWC1C0}, | |||||
| {"FSR_NCHW", FORMAT_FSR_NCHW}, | |||||
| {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||||
| {"C1HWNC0", FORMAT_C1HWNC0}, | |||||
| {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||||
| {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||||
| {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||||
| {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||||
| {"CHWN", FORMAT_CHWN}, | |||||
| {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||||
| {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||||
| {"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||||
| {"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||||
| {"HWCN", FORMAT_HWCN}, | |||||
| {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||||
| {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||||
| {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||||
| {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||||
| {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||||
| {"MD", FORMAT_MD}, | |||||
| {"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||||
| {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||||
| {"NDHWC", FORMAT_NDHWC}, | |||||
| {"NCDHW", FORMAT_NCDHW}, | |||||
| {"DHWCN", FORMAT_DHWCN}, | |||||
| {"DHWNC", FORMAT_DHWNC}, | |||||
| {"NDC1HWC0", FORMAT_NDC1HWC0}, | |||||
| {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||||
| {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | |||||
| {"CN", FORMAT_CN}, | |||||
| {"NC", FORMAT_NC}, | |||||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||||
| {"ALL", FORMAT_ALL}}; | |||||
| static const std::map<DataType, std::string> kDataTypeToStringMap = { | static const std::map<DataType, std::string> kDataTypeToStringMap = { | ||||
| {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||||
| {DT_FLOAT, "DT_FLOAT"}, // float type | |||||
| {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||||
| {DT_INT8, "DT_INT8"}, // int8 type | |||||
| {DT_INT16, "DT_INT16"}, // int16 type | |||||
| {DT_UINT16, "DT_UINT16"}, // uint16 type | |||||
| {DT_UINT8, "DT_UINT8"}, // uint8 type | |||||
| {DT_INT32, "DT_INT32"}, // uint32 type | |||||
| {DT_INT64, "DT_INT64"}, // int64 type | |||||
| {DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||||
| {DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||||
| {DT_BOOL, "DT_BOOL"}, // bool type | |||||
| {DT_DOUBLE, "DT_DOUBLE"}, // double type | |||||
| {DT_DUAL, "DT_DUAL"}, // dual output type | |||||
| {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||||
| {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||||
| {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||||
| {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||||
| {DT_QINT8, "DT_QINT8"}, // qint8 type | |||||
| {DT_QINT16, "DT_QINT16"}, // qint16 type | |||||
| {DT_QINT32, "DT_QINT32"}, // qint32 type | |||||
| {DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||||
| {DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||||
| {DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||||
| {DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||||
| {DT_STRING, "DT_STRING"}, // string type | |||||
| {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||||
| {DT_FLOAT, "DT_FLOAT"}, // float type | |||||
| {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||||
| {DT_INT8, "DT_INT8"}, // int8 type | |||||
| {DT_INT16, "DT_INT16"}, // int16 type | |||||
| {DT_UINT16, "DT_UINT16"}, // uint16 type | |||||
| {DT_UINT8, "DT_UINT8"}, // uint8 type | |||||
| {DT_INT32, "DT_INT32"}, // uint32 type | |||||
| {DT_INT64, "DT_INT64"}, // int64 type | |||||
| {DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||||
| {DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||||
| {DT_BOOL, "DT_BOOL"}, // bool type | |||||
| {DT_DOUBLE, "DT_DOUBLE"}, // double type | |||||
| {DT_DUAL, "DT_DUAL"}, // dual output type | |||||
| {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||||
| {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||||
| {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||||
| {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||||
| {DT_QINT8, "DT_QINT8"}, // qint8 type | |||||
| {DT_QINT16, "DT_QINT16"}, // qint16 type | |||||
| {DT_QINT32, "DT_QINT32"}, // qint32 type | |||||
| {DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||||
| {DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||||
| {DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||||
| {DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||||
| {DT_STRING, "DT_STRING"}, // string type | |||||
| }; | }; | ||||
| static const std::map<std::string, DataType> kStringTodataTypeMap = { | static const std::map<std::string, DataType> kStringTodataTypeMap = { | ||||
| {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||||
| {"DT_FLOAT", DT_FLOAT}, // float type | |||||
| { | |||||
| "DT_FLOAT16", | |||||
| DT_FLOAT16, | |||||
| }, // fp16 type | |||||
| {"DT_INT8", DT_INT8}, // int8 type | |||||
| {"DT_INT16", DT_INT16}, // int16 type | |||||
| {"DT_UINT16", DT_UINT16}, // uint16 type | |||||
| {"DT_UINT8", DT_UINT8}, // uint8 type | |||||
| {"DT_INT32", DT_INT32}, // uint32 type | |||||
| {"DT_INT64", DT_INT64}, // int64 type | |||||
| {"DT_UINT32", DT_UINT32}, // unsigned int32 | |||||
| {"DT_UINT64", DT_UINT64}, // unsigned int64 | |||||
| {"DT_BOOL", DT_BOOL}, // bool type | |||||
| {"DT_DOUBLE", DT_DOUBLE}, // double type | |||||
| {"DT_DUAL", DT_DUAL}, // dual output type | |||||
| {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||||
| {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||||
| {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||||
| {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||||
| {"DT_QINT8", DT_QINT8}, // qint8 type | |||||
| {"DT_QINT16", DT_QINT16}, // qint16 type | |||||
| {"DT_QINT32", DT_QINT32}, // qint32 type | |||||
| {"DT_QUINT8", DT_QUINT8}, // quint8 type | |||||
| {"DT_QUINT16", DT_QUINT16}, // quint16 type | |||||
| {"DT_RESOURCE", DT_RESOURCE}, // resource type | |||||
| {"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||||
| {"DT_STRING", DT_STRING}, // string type | |||||
| {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||||
| {"DT_FLOAT", DT_FLOAT}, // float type | |||||
| { | |||||
| "DT_FLOAT16", | |||||
| DT_FLOAT16, | |||||
| }, // fp16 type | |||||
| {"DT_INT8", DT_INT8}, // int8 type | |||||
| {"DT_INT16", DT_INT16}, // int16 type | |||||
| {"DT_UINT16", DT_UINT16}, // uint16 type | |||||
| {"DT_UINT8", DT_UINT8}, // uint8 type | |||||
| {"DT_INT32", DT_INT32}, // uint32 type | |||||
| {"DT_INT64", DT_INT64}, // int64 type | |||||
| {"DT_UINT32", DT_UINT32}, // unsigned int32 | |||||
| {"DT_UINT64", DT_UINT64}, // unsigned int64 | |||||
| {"DT_BOOL", DT_BOOL}, // bool type | |||||
| {"DT_DOUBLE", DT_DOUBLE}, // double type | |||||
| {"DT_DUAL", DT_DUAL}, // dual output type | |||||
| {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||||
| {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||||
| {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||||
| {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||||
| {"DT_QINT8", DT_QINT8}, // qint8 type | |||||
| {"DT_QINT16", DT_QINT16}, // qint16 type | |||||
| {"DT_QINT32", DT_QINT32}, // qint32 type | |||||
| {"DT_QUINT8", DT_QUINT8}, // quint8 type | |||||
| {"DT_QUINT16", DT_QUINT16}, // quint16 type | |||||
| {"DT_RESOURCE", DT_RESOURCE}, // resource type | |||||
| {"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||||
| {"DT_STRING", DT_STRING}, // string type | |||||
| }; | }; | ||||
| static const std::map<ge::DataType, uint32_t> kDataTypeToLength = { | static const std::map<ge::DataType, uint32_t> kDataTypeToLength = { | ||||
| {DT_BOOL, sizeof(bool)}, | |||||
| {DT_INT64, sizeof(int64_t)}, | |||||
| {DT_UINT64, sizeof(int64_t)}, | |||||
| {DT_FLOAT, sizeof(float)}, | |||||
| {DT_INT32, sizeof(int32_t)}, | |||||
| {DT_UINT32, sizeof(int32_t)}, | |||||
| {DT_INT8, sizeof(char)}, | |||||
| {DT_UINT8, sizeof(char)}, | |||||
| {DT_INT16, sizeof(int16_t)}, | |||||
| {DT_UINT16, sizeof(int16_t)}, | |||||
| {DT_FLOAT16, sizeof(int16_t)}, | |||||
| {DT_DOUBLE, sizeof(double)}, | |||||
| {DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||||
| {DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||||
| {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||||
| {DT_COMPLEX64, sizeof(int64_t)}, | |||||
| {DT_COMPLEX128, sizeof(int64_t) * 2}, | |||||
| {DT_QINT8, sizeof(int8_t)}, | |||||
| {DT_QINT16, sizeof(int16_t)}, | |||||
| {DT_QINT32, sizeof(int32_t)}, | |||||
| {DT_QUINT8, sizeof(uint8_t)}, | |||||
| {DT_QUINT16, sizeof(uint16_t)}, | |||||
| {DT_STRING_REF, sizeof(uint64_t) * 2}, | |||||
| {DT_STRING, sizeof(uint64_t)}, | |||||
| {DT_RESOURCE, sizeof(uint64_t)}, | |||||
| {DT_BOOL, sizeof(bool)}, | |||||
| {DT_INT64, sizeof(int64_t)}, | |||||
| {DT_UINT64, sizeof(int64_t)}, | |||||
| {DT_FLOAT, sizeof(float)}, | |||||
| {DT_INT32, sizeof(int32_t)}, | |||||
| {DT_UINT32, sizeof(int32_t)}, | |||||
| {DT_INT8, sizeof(char)}, | |||||
| {DT_UINT8, sizeof(char)}, | |||||
| {DT_INT16, sizeof(int16_t)}, | |||||
| {DT_UINT16, sizeof(int16_t)}, | |||||
| {DT_FLOAT16, sizeof(int16_t)}, | |||||
| {DT_DOUBLE, sizeof(double)}, | |||||
| {DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||||
| {DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||||
| {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||||
| {DT_COMPLEX64, sizeof(int64_t)}, | |||||
| {DT_COMPLEX128, sizeof(int64_t) * 2}, | |||||
| {DT_QINT8, sizeof(int8_t)}, | |||||
| {DT_QINT16, sizeof(int16_t)}, | |||||
| {DT_QINT32, sizeof(int32_t)}, | |||||
| {DT_QUINT8, sizeof(uint8_t)}, | |||||
| {DT_QUINT16, sizeof(uint16_t)}, | |||||
| {DT_STRING_REF, sizeof(uint64_t) * 2}, | |||||
| {DT_STRING, sizeof(uint64_t)}, | |||||
| {DT_RESOURCE, sizeof(uint64_t)}, | |||||
| }; | }; | ||||
| bool TypeUtils::IsDataTypeValid(DataType dt) { | bool TypeUtils::IsDataTypeValid(DataType dt) { | ||||
| @@ -13,7 +13,7 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| # ============================================================================ | # ============================================================================ | ||||
| # libge.so & libge_train.so | |||||
| # libge_compiler.so & libge_train.so | |||||
| # will later be integrated into libgraph_runner.so, works for both training and inference | # will later be integrated into libgraph_runner.so, works for both training and inference | ||||
| # compiling proto files generates some warnings, use no-unused-variable to suppress them | # compiling proto files generates some warnings, use no-unused-variable to suppress them | ||||
| set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | ||||
| @@ -49,7 +49,7 @@ include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||||
| ######### libge_train.so ############# | ######### libge_train.so ############# | ||||
| # need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
| file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "common/formats/format_transfers/*.cc" | "common/formats/format_transfers/*.cc" | ||||
| "common/formats/formats.cc" | "common/formats/formats.cc" | ||||
| "common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
| @@ -57,20 +57,24 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
| "common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
| "engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
| "ge_local_engine/engine/host_cpu_engine.cc" | |||||
| "generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
| "generator/generator_api.cc" | "generator/generator_api.cc" | ||||
| "graph/build/graph_build.cc" | |||||
| "graph/build/graph_builder.cc" | |||||
| "graph/build/label_allocator.cc" | |||||
| "graph/build/logical_stream_allocator.cc" | "graph/build/logical_stream_allocator.cc" | ||||
| "graph/build/model_builder.cc" | "graph/build/model_builder.cc" | ||||
| "graph/build/optimize_stream_graph.cc" | |||||
| "graph/build/run_context.cc" | "graph/build/run_context.cc" | ||||
| "graph/build/stream_allocator.cc" | "graph/build/stream_allocator.cc" | ||||
| "graph/build/stream_graph_optimizer.cc" | |||||
| "graph/build/task_generator.cc" | "graph/build/task_generator.cc" | ||||
| "graph/common/bcast.cc" | "graph/common/bcast.cc" | ||||
| "graph/common/omg_util.cc" | "graph/common/omg_util.cc" | ||||
| "graph/common/transop_util.cc" | "graph/common/transop_util.cc" | ||||
| "graph/execute/graph_execute.cc" | "graph/execute/graph_execute.cc" | ||||
| "graph/label/*.cc" | |||||
| "graph/load/graph_loader.cc" | "graph/load/graph_loader.cc" | ||||
| "graph/load/new_model_manager/cpu_queue_schedule.cc" | |||||
| "graph/load/new_model_manager/data_dumper.cc" | "graph/load/new_model_manager/data_dumper.cc" | ||||
| "graph/load/new_model_manager/data_inputer.cc" | "graph/load/new_model_manager/data_inputer.cc" | ||||
| "graph/load/new_model_manager/davinci_model.cc" | "graph/load/new_model_manager/davinci_model.cc" | ||||
| @@ -92,10 +96,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||||
| "graph/load/new_model_manager/task_info/task_info.cc" | "graph/load/new_model_manager/task_info/task_info.cc" | ||||
| "graph/load/new_model_manager/tbe_handle_store.cc" | "graph/load/new_model_manager/tbe_handle_store.cc" | ||||
| "graph/load/output/output.cc" | "graph/load/output/output.cc" | ||||
| "graph/manager/custom/custom_op.cc" | |||||
| "graph/manager/graph_context.cc" | "graph/manager/graph_context.cc" | ||||
| "graph/manager/graph_manager.cc" | "graph/manager/graph_manager.cc" | ||||
| "graph/manager/graph_manager_utils.cc" | "graph/manager/graph_manager_utils.cc" | ||||
| @@ -105,12 +111,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| "graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
| "graph/manager/util/hcom_util.cc" | "graph/manager/util/hcom_util.cc" | ||||
| "graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||||
| "graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
| "graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
| "graph/optimize/graph_functiondef.cc" | |||||
| "graph/optimize/graph_optimize.cc" | "graph/optimize/graph_optimize.cc" | ||||
| "graph/optimize/graph_optimizer.cc" | |||||
| "graph/optimize/optimizer/allreduce_fusion_pass.cc" | "graph/optimize/optimizer/allreduce_fusion_pass.cc" | ||||
| "graph/optimize/summary_optimize.cc" | "graph/optimize/summary_optimize.cc" | ||||
| "graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
| @@ -120,7 +123,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
| "graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
| "graph/passes/base_pass.cc" | "graph/passes/base_pass.cc" | ||||
| "graph/passes/cast_remove_pass.cc" | |||||
| "graph/passes/cast_translate_pass.cc" | "graph/passes/cast_translate_pass.cc" | ||||
| "graph/passes/common_subexpression_elimination_pass.cc" | |||||
| "graph/passes/compile_nodes_pass.cc" | "graph/passes/compile_nodes_pass.cc" | ||||
| "graph/passes/constant_folding_pass.cc" | "graph/passes/constant_folding_pass.cc" | ||||
| "graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
| @@ -159,12 +164,14 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/folding_kernel/shape_kernel.cc" | "graph/passes/folding_kernel/shape_kernel.cc" | ||||
| "graph/passes/folding_kernel/shape_n_kernel.cc" | "graph/passes/folding_kernel/shape_n_kernel.cc" | ||||
| "graph/passes/folding_kernel/size_kernel.cc" | "graph/passes/folding_kernel/size_kernel.cc" | ||||
| "graph/passes/folding_kernel/slice_d_kernel.cc" | |||||
| "graph/passes/folding_kernel/slice_kernel.cc" | "graph/passes/folding_kernel/slice_kernel.cc" | ||||
| "graph/passes/folding_kernel/squeeze_kernel.cc" | "graph/passes/folding_kernel/squeeze_kernel.cc" | ||||
| "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | ||||
| "graph/passes/folding_kernel/strided_slice_kernel.cc" | "graph/passes/folding_kernel/strided_slice_kernel.cc" | ||||
| "graph/passes/folding_kernel/sub_kernel.cc" | "graph/passes/folding_kernel/sub_kernel.cc" | ||||
| "graph/passes/folding_kernel/transdata_kernel.cc" | "graph/passes/folding_kernel/transdata_kernel.cc" | ||||
| "graph/passes/folding_kernel/unpack_kernel.cc" | |||||
| "graph/passes/folding_pass.cc" | "graph/passes/folding_pass.cc" | ||||
| "graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
| "graph/passes/guarantee_const_pass.cc" | "graph/passes/guarantee_const_pass.cc" | ||||
| @@ -179,7 +186,6 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/multi_batch_pass.cc" | "graph/passes/multi_batch_pass.cc" | ||||
| "graph/passes/net_output_pass.cc" | "graph/passes/net_output_pass.cc" | ||||
| "graph/passes/next_iteration_pass.cc" | "graph/passes/next_iteration_pass.cc" | ||||
| "graph/passes/no_reshape_op_remove_pass.cc" | |||||
| "graph/passes/no_use_reshape_remove_pass.cc" | "graph/passes/no_use_reshape_remove_pass.cc" | ||||
| "graph/passes/pass_manager.cc" | "graph/passes/pass_manager.cc" | ||||
| "graph/passes/pass_utils.cc" | "graph/passes/pass_utils.cc" | ||||
| @@ -188,6 +194,7 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
| "graph/passes/print_op_pass.cc" | "graph/passes/print_op_pass.cc" | ||||
| "graph/passes/prune_pass.cc" | "graph/passes/prune_pass.cc" | ||||
| "graph/passes/replace_with_empty_const_pass.cc" | |||||
| "graph/passes/reshape_remove_pass.cc" | "graph/passes/reshape_remove_pass.cc" | ||||
| "graph/passes/resource_pair_add_control_pass.cc" | "graph/passes/resource_pair_add_control_pass.cc" | ||||
| "graph/passes/resource_pair_remove_control_pass.cc" | "graph/passes/resource_pair_remove_control_pass.cc" | ||||
| @@ -206,14 +213,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
| "graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
| "graph/passes/unused_op_remove_pass.cc" | "graph/passes/unused_op_remove_pass.cc" | ||||
| "graph/passes/update_net_output_pass.cc" | |||||
| "graph/passes/var_is_initialized_op_pass.cc" | "graph/passes/var_is_initialized_op_pass.cc" | ||||
| "graph/passes/variable_format_pass.cc" | "graph/passes/variable_format_pass.cc" | ||||
| "graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
| "graph/passes/variable_prepare_op_pass.cc" | "graph/passes/variable_prepare_op_pass.cc" | ||||
| "graph/passes/variable_ref_delete_op_pass.cc" | "graph/passes/variable_ref_delete_op_pass.cc" | ||||
| "graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
| "graph/preprocess/insert_op/base_insert_op.cc" | |||||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
| "graph/preprocess/multi_batch_copy_graph.cc" | "graph/preprocess/multi_batch_copy_graph.cc" | ||||
| @@ -223,13 +228,8 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "opskernel_manager/ops_kernel_manager.cc" | "opskernel_manager/ops_kernel_manager.cc" | ||||
| "session/inner_session.cc" | "session/inner_session.cc" | ||||
| "session/session_manager.cc" | "session/session_manager.cc" | ||||
| "single_op/single_op.cc" | |||||
| "single_op/single_op_manager.cc" | |||||
| "single_op/single_op_model.cc" | |||||
| "single_op/stream_resource.cc" | |||||
| "single_op/task/build_task_utils.cc" | |||||
| "single_op/task/op_task.cc" | |||||
| "single_op/task/tbe_task_builder.cc" | |||||
| "single_op/*.cc" | |||||
| "single_op/task/*.cc" | |||||
| ) | ) | ||||
| @@ -261,9 +261,9 @@ target_link_libraries(ge_train | |||||
| rt | rt | ||||
| dl) | dl) | ||||
| ######### libge.so ############# | |||||
| ######### libge_compiler.so ############# | |||||
| # need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
| file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "common/formats/format_transfers/*.cc" | "common/formats/format_transfers/*.cc" | ||||
| "common/formats/formats.cc" | "common/formats/formats.cc" | ||||
| "common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
| @@ -271,20 +271,24 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
| "common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
| "engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
| "ge_local_engine/engine/host_cpu_engine.cc" | |||||
| "generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
| "generator/generator_api.cc" | "generator/generator_api.cc" | ||||
| "graph/build/graph_build.cc" | |||||
| "graph/build/graph_builder.cc" | |||||
| "graph/build/label_allocator.cc" | |||||
| "graph/build/logical_stream_allocator.cc" | "graph/build/logical_stream_allocator.cc" | ||||
| "graph/build/model_builder.cc" | "graph/build/model_builder.cc" | ||||
| "graph/build/optimize_stream_graph.cc" | |||||
| "graph/build/run_context.cc" | "graph/build/run_context.cc" | ||||
| "graph/build/stream_allocator.cc" | "graph/build/stream_allocator.cc" | ||||
| "graph/build/stream_graph_optimizer.cc" | |||||
| "graph/build/task_generator.cc" | "graph/build/task_generator.cc" | ||||
| "graph/common/bcast.cc" | "graph/common/bcast.cc" | ||||
| "graph/common/omg_util.cc" | "graph/common/omg_util.cc" | ||||
| "graph/common/transop_util.cc" | "graph/common/transop_util.cc" | ||||
| "graph/execute/graph_execute.cc" | "graph/execute/graph_execute.cc" | ||||
| "graph/label/*.cc" | |||||
| "graph/load/graph_loader.cc" | "graph/load/graph_loader.cc" | ||||
| "graph/load/new_model_manager/cpu_queue_schedule.cc" | |||||
| "graph/load/new_model_manager/data_dumper.cc" | "graph/load/new_model_manager/data_dumper.cc" | ||||
| "graph/load/new_model_manager/data_inputer.cc" | "graph/load/new_model_manager/data_inputer.cc" | ||||
| "graph/load/new_model_manager/davinci_model.cc" | "graph/load/new_model_manager/davinci_model.cc" | ||||
| @@ -305,10 +309,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||||
| "graph/load/new_model_manager/task_info/task_info.cc" | "graph/load/new_model_manager/task_info/task_info.cc" | ||||
| "graph/load/new_model_manager/tbe_handle_store.cc" | "graph/load/new_model_manager/tbe_handle_store.cc" | ||||
| "graph/load/output/output.cc" | "graph/load/output/output.cc" | ||||
| "graph/manager/custom/custom_op.cc" | |||||
| "graph/manager/graph_context.cc" | "graph/manager/graph_context.cc" | ||||
| "graph/manager/graph_manager.cc" | "graph/manager/graph_manager.cc" | ||||
| "graph/manager/graph_manager_utils.cc" | "graph/manager/graph_manager_utils.cc" | ||||
| @@ -317,13 +323,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| "graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
| "graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||||
| "graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
| "graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
| "graph/optimize/graph_functiondef.cc" | |||||
| "graph/optimize/graph_optimize.cc" | "graph/optimize/graph_optimize.cc" | ||||
| "graph/optimize/graph_optimizer.cc" | |||||
| "graph/optimize/optimizer/allreduce_fusion_inference_pass.cc" | |||||
| "graph/optimize/summary_optimize.cc" | "graph/optimize/summary_optimize.cc" | ||||
| "graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
| "graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
| @@ -332,7 +334,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
| "graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
| "graph/passes/base_pass.cc" | "graph/passes/base_pass.cc" | ||||
| "graph/passes/cast_remove_pass.cc" | |||||
| "graph/passes/cast_translate_pass.cc" | "graph/passes/cast_translate_pass.cc" | ||||
| "graph/passes/common_subexpression_elimination_pass.cc" | |||||
| "graph/passes/compile_nodes_pass.cc" | "graph/passes/compile_nodes_pass.cc" | ||||
| "graph/passes/constant_folding_pass.cc" | "graph/passes/constant_folding_pass.cc" | ||||
| "graph/passes/constant_fuse_same_pass.cc" | "graph/passes/constant_fuse_same_pass.cc" | ||||
| @@ -371,12 +375,14 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/folding_kernel/shape_kernel.cc" | "graph/passes/folding_kernel/shape_kernel.cc" | ||||
| "graph/passes/folding_kernel/shape_n_kernel.cc" | "graph/passes/folding_kernel/shape_n_kernel.cc" | ||||
| "graph/passes/folding_kernel/size_kernel.cc" | "graph/passes/folding_kernel/size_kernel.cc" | ||||
| "graph/passes/folding_kernel/slice_d_kernel.cc" | |||||
| "graph/passes/folding_kernel/slice_kernel.cc" | "graph/passes/folding_kernel/slice_kernel.cc" | ||||
| "graph/passes/folding_kernel/squeeze_kernel.cc" | "graph/passes/folding_kernel/squeeze_kernel.cc" | ||||
| "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | ||||
| "graph/passes/folding_kernel/strided_slice_kernel.cc" | "graph/passes/folding_kernel/strided_slice_kernel.cc" | ||||
| "graph/passes/folding_kernel/sub_kernel.cc" | "graph/passes/folding_kernel/sub_kernel.cc" | ||||
| "graph/passes/folding_kernel/transdata_kernel.cc" | "graph/passes/folding_kernel/transdata_kernel.cc" | ||||
| "graph/passes/folding_kernel/unpack_kernel.cc" | |||||
| "graph/passes/folding_pass.cc" | "graph/passes/folding_pass.cc" | ||||
| "graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
| "graph/passes/guarantee_const_pass.cc" | "graph/passes/guarantee_const_pass.cc" | ||||
| @@ -391,7 +397,6 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/multi_batch_pass.cc" | "graph/passes/multi_batch_pass.cc" | ||||
| "graph/passes/net_output_pass.cc" | "graph/passes/net_output_pass.cc" | ||||
| "graph/passes/next_iteration_pass.cc" | "graph/passes/next_iteration_pass.cc" | ||||
| "graph/passes/no_reshape_op_remove_pass.cc" | |||||
| "graph/passes/no_use_reshape_remove_pass.cc" | "graph/passes/no_use_reshape_remove_pass.cc" | ||||
| "graph/passes/pass_manager.cc" | "graph/passes/pass_manager.cc" | ||||
| "graph/passes/pass_utils.cc" | "graph/passes/pass_utils.cc" | ||||
| @@ -400,6 +405,7 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/prevent_gradient_pass.cc" | "graph/passes/prevent_gradient_pass.cc" | ||||
| "graph/passes/print_op_pass.cc" | "graph/passes/print_op_pass.cc" | ||||
| "graph/passes/prune_pass.cc" | "graph/passes/prune_pass.cc" | ||||
| "graph/passes/replace_with_empty_const_pass.cc" | |||||
| "graph/passes/reshape_remove_pass.cc" | "graph/passes/reshape_remove_pass.cc" | ||||
| "graph/passes/resource_pair_add_control_pass.cc" | "graph/passes/resource_pair_add_control_pass.cc" | ||||
| "graph/passes/resource_pair_remove_control_pass.cc" | "graph/passes/resource_pair_remove_control_pass.cc" | ||||
| @@ -418,14 +424,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
| "graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
| "graph/passes/unused_op_remove_pass.cc" | "graph/passes/unused_op_remove_pass.cc" | ||||
| "graph/passes/update_net_output_pass.cc" | |||||
| "graph/passes/var_is_initialized_op_pass.cc" | "graph/passes/var_is_initialized_op_pass.cc" | ||||
| "graph/passes/variable_format_pass.cc" | "graph/passes/variable_format_pass.cc" | ||||
| "graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
| "graph/passes/variable_prepare_op_pass.cc" | "graph/passes/variable_prepare_op_pass.cc" | ||||
| "graph/passes/variable_ref_delete_op_pass.cc" | "graph/passes/variable_ref_delete_op_pass.cc" | ||||
| "graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
| "graph/preprocess/insert_op/base_insert_op.cc" | |||||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
| "graph/preprocess/multi_batch_copy_graph.cc" | "graph/preprocess/multi_batch_copy_graph.cc" | ||||
| @@ -442,16 +446,19 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "single_op/task/build_task_utils.cc" | "single_op/task/build_task_utils.cc" | ||||
| "single_op/task/op_task.cc" | "single_op/task/op_task.cc" | ||||
| "single_op/task/tbe_task_builder.cc" | "single_op/task/tbe_task_builder.cc" | ||||
| ########################################## | |||||
| # "ir_build/ge_ir_build.cc" | |||||
| # "offline/atc_ir_common.cc" | |||||
| ) | ) | ||||
| add_library(ge SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||||
| target_compile_definitions(ge PRIVATE | |||||
| add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||||
| target_compile_definitions(ge_compiler PRIVATE | |||||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | PROTOBUF_INLINE_NOT_IN_HEADERS=0 | ||||
| DAVINCI_SUPPORT_PROFILING | DAVINCI_SUPPORT_PROFILING | ||||
| REUSE_MEMORY=1 | REUSE_MEMORY=1 | ||||
| FMK_HOST_INFER | FMK_HOST_INFER | ||||
| PLATFORM_CLOUD) | PLATFORM_CLOUD) | ||||
| target_link_libraries(ge | |||||
| target_link_libraries(ge_compiler | |||||
| graph | graph | ||||
| ge_common | ge_common | ||||
| "-Wl,--whole-archive" | "-Wl,--whole-archive" | ||||
| @@ -80,7 +80,7 @@ target_compile_definitions(ge_client_train PRIVATE | |||||
| PLATFORM_CLOUD) | PLATFORM_CLOUD) | ||||
| target_link_libraries(ge_client | target_link_libraries(ge_client | ||||
| graph | graph | ||||
| ge | |||||
| ge_compiler | |||||
| ge_common | ge_common | ||||
| ${PROTOBUF_LIBRARY} | ${PROTOBUF_LIBRARY} | ||||
| ${register} | ${register} | ||||
| @@ -61,14 +61,14 @@ Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | |||||
| const int kDecimal = 10; | const int kDecimal = 10; | ||||
| auto dump_op_env = std::getenv("DUMP_OP"); | auto dump_op_env = std::getenv("DUMP_OP"); | ||||
| int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | ||||
| auto disable_reuse_memory_iter = options.find("ge.exec.disableReuseMemory"); | |||||
| if (disable_reuse_memory_iter != options.end()) { | |||||
| if (disable_reuse_memory_iter->second == "0") { | |||||
| auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); | |||||
| if (disableReuseMemoryIter != options.end()) { | |||||
| if (disableReuseMemoryIter->second == "0") { | |||||
| GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | ||||
| if (dump_op_flag) { | if (dump_op_flag) { | ||||
| GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | ||||
| } | } | ||||
| } else if (disable_reuse_memory_iter->second == "1") { | |||||
| } else if (disableReuseMemoryIter->second == "1") { | |||||
| GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | ||||
| } else { | } else { | ||||
| GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | ||||
| @@ -128,22 +128,29 @@ Status GEInitialize(const std::map<string, string> &options) { | |||||
| OpsProtoManager *manager = OpsProtoManager::Instance(); | OpsProtoManager *manager = OpsProtoManager::Instance(); | ||||
| std::map<string, string> option_tmp; | std::map<string, string> option_tmp; | ||||
| option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | ||||
| GE_TIMESTAMP_START(GEInitialize); | |||||
| bool is_proto_init = manager->Initialize(option_tmp); | bool is_proto_init = manager->Initialize(option_tmp); | ||||
| GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize"); | |||||
| if (!is_proto_init) { | if (!is_proto_init) { | ||||
| GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, ops proto path is invalid."); | GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, ops proto path is invalid."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // check options is valid | // check options is valid | ||||
| GE_TIMESTAMP_START(CheckOptionsValid); | |||||
| if (CheckOptionsValid(options) != SUCCESS) { | if (CheckOptionsValid(options) != SUCCESS) { | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); | |||||
| GE_TIMESTAMP_START(InitPreparation); | |||||
| SaveDdkVersion(options); | SaveDdkVersion(options); | ||||
| GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation"); | |||||
| // call Initialize | // call Initialize | ||||
| GELOGT(TRACE_RUNNING, "Initializing environment"); | GELOGT(TRACE_RUNNING, "Initializing environment"); | ||||
| GE_TIMESTAMP_START(GELibInitialize); | |||||
| Status ret = ge::GELib::Initialize(options); | Status ret = ge::GELib::Initialize(options); | ||||
| GE_TIMESTAMP_END(GELibInitialize, "GEInitialize::GELibInitialize"); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, error code = %u", ret); | GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, error code = %u", ret); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -170,17 +177,20 @@ Status GEFinalize() { | |||||
| std::lock_guard<std::mutex> lock(kGeReleaseMutex); | std::lock_guard<std::mutex> lock(kGeReleaseMutex); | ||||
| // call Finalize | // call Finalize | ||||
| Status ret = SUCCESS; | |||||
| Status middle_ret; | |||||
| GELOGT(TRACE_RUNNING, "Finalizing environment"); | GELOGT(TRACE_RUNNING, "Finalizing environment"); | ||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GEFinalize Failed: GE not initialized"); | |||||
| return GE_CLI_GE_NOT_INITIALIZED; | |||||
| } | |||||
| Status ret = instance_ptr->Finalize(); | |||||
| GELOGI("GEFinalize finalize gelib ret=%u", ret); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "GEFinalize Failed"); | |||||
| return FAILED; | |||||
| std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance(); | |||||
| if (instancePtr == nullptr || !instancePtr->InitFlag()) { | |||||
| GELOGW("GEFinalize Failed: GE not initialized."); | |||||
| ret = GE_CLI_GE_NOT_INITIALIZED; | |||||
| } | |||||
| if (ret != GE_CLI_GE_NOT_INITIALIZED) { | |||||
| middle_ret = instancePtr->Finalize(); | |||||
| GELOGI("GEFinalize finalize gelib ret=%u", middle_ret); | |||||
| if (middle_ret != SUCCESS) { | |||||
| ret = middle_ret; | |||||
| } | |||||
| } | } | ||||
| if (kGeInitialized && ret == SUCCESS) { | if (kGeInitialized && ret == SUCCESS) { | ||||
| @@ -379,8 +389,6 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, s | |||||
| } | } | ||||
| Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { | Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { | ||||
| GELOGW( | |||||
| "The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | |||||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | ||||
| } | } | ||||