| @@ -18,7 +18,6 @@ | |||
| #define INC_COMMON_BLOCKING_QUEUE_H_ | |||
| #include <stdint.h> | |||
| #include <condition_variable> | |||
| #include <list> | |||
| #include <mutex> | |||
| @@ -87,7 +86,7 @@ class BlockingQueue { | |||
| is_stoped_ = false; | |||
| } | |||
| // if the queue stop , the function to release the unprocessed items will be call | |||
| // if the queue is stoped ,need call this function to release the unprocessed items | |||
| std::list<T> GetRemainItems() { | |||
| std::unique_lock<std::mutex> lock(mutex_); | |||
| @@ -19,10 +19,10 @@ | |||
| #include <stdint.h> | |||
| /// | |||
| /// @ingroup dnn | |||
| /// @brief struct define of dynamic aipp batch parameter. | |||
| /// | |||
| /** | |||
| * @ingroup dnn | |||
| * @brief struct define of dynamic aipp batch parameter. | |||
| */ | |||
| typedef struct tagAippDynamicBatchPara { | |||
| int8_t cropSwitch; // crop switch | |||
| int8_t scfSwitch; // resize switch | |||
| @@ -66,10 +66,10 @@ typedef struct tagAippDynamicBatchPara { | |||
| int8_t reserve1[16]; // 32B assign, for ub copy | |||
| } kAippDynamicBatchPara; | |||
| /// | |||
| /// @ingroup dnn | |||
| /// @brief struct definition of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||
| /// | |||
| /** | |||
| * @ingroup dnn | |||
| * @brief struct define of dynamic aipp parameter. lite:64+96*batchNum byte ; tiny:64+64*batchNum byte | |||
| */ | |||
| typedef struct tagAippDynamicPara { | |||
| uint8_t inputFormat; // input format:YUV420SP_U8/XRGB8888_U8/RGB888_U8 | |||
| int8_t cscSwitch; // csc switch | |||
| @@ -61,19 +61,19 @@ typedef enum tagHiAiNpuModuleId { | |||
| HIAI_DP = 23, | |||
| } HiAiNpuModuleId; | |||
| // bit 31-bit30 to be hiai local | |||
| /* bit 31-bit30 to be hiai local */ | |||
| #define HIAI_NPULOCAL_MASK 0xC0000000 | |||
| #define SHIFT_LOCAL_MASK 30 | |||
| #define HIAI_NPULOCAL_VAL_MASK 0x3 | |||
| // bit 29 -bit28 to be hiai aicpu code type | |||
| /* bit 29 -bit28 to be hiai aicpu code type */ | |||
| #define HIAI_CODE_TYPE_MASK 0x30000000 | |||
| #define SHIFT_CODE_MASK 28 | |||
| #define HIAI_CODE_TYPE_VAL_MASK 0x3 | |||
| // bit 27 -bit25 to be hiai error level | |||
| /* bit 27 -bit25 to be hiai error level */ | |||
| #define HIAI_ERROR_LEVEL_MASK 0x0E000000 | |||
| #define SHIFT_ERROR_LVL_MASK 25 | |||
| #define HIAI_ERROR_LEVEL_VAL_MASK 0x7 | |||
| // bit 24 -bit17 to be hiai mod | |||
| /* bit 24 -bit17 to be hiai mod */ | |||
| #define HIAI_MODE_ID_MASK 0x01FE0000 | |||
| #define SHIFT_MODE_MASK 17 | |||
| #define HIAI_MODE_ID_VAL_MASK 0xFF | |||
| @@ -19,13 +19,12 @@ | |||
| #include <runtime/rt.h> | |||
| #include <stdint.h> | |||
| #include <string> | |||
| #include <vector> | |||
| using std::string; | |||
| namespace ge { | |||
| // DAVINCI_TRAIN/DAVINCI_CLOUD is not needed when GETaskKernelHcclInfo needed | |||
| // when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD | |||
| struct GETaskKernelHcclInfo { | |||
| string hccl_type; | |||
| void *inputDataAddr; | |||
| @@ -21,7 +21,6 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "./ge_task_info.h" | |||
| #include "./ops_kernel_info_types.h" | |||
| #include "cce/aicpu_engine_struct.h" | |||
| @@ -29,7 +28,6 @@ | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "graph/node.h" | |||
| #include "proto/task.pb.h" | |||
| using std::map; | |||
| using std::string; | |||
| using std::to_string; | |||
| @@ -47,7 +45,7 @@ class OpsKernelInfoStore { | |||
| // initialize opsKernelInfoStore | |||
| virtual Status Initialize(const map<string, string> &options) = 0; | |||
| // finalize opsKernelInfoStore | |||
| // close opsKernelInfoStore | |||
| virtual Status Finalize() = 0; | |||
| virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | |||
| @@ -57,18 +55,20 @@ class OpsKernelInfoStore { | |||
| // get all opsKernelInfo | |||
| virtual void GetAllOpsKernelInfo(map<string, OpInfo> &infos) const = 0; | |||
| // check whether opsKernelInfoStore is supported based on the operator attribute | |||
| // whether the opsKernelInfoStore is supported based on the operator attribute | |||
| virtual bool CheckSupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason) const = 0; | |||
| virtual bool CheckAccuracySupported(const OpDescPtr &opDescPtr, std::string &un_supported_reason, | |||
| bool realQuery = false) const { | |||
| return CheckSupported(opDescPtr, un_supported_reason); | |||
| } | |||
| // opsFlag opsFlag[0] indicates constant folding is supported or not | |||
| virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; | |||
| // requirement of memory allocation | |||
| // memory allocation requirement | |||
| virtual Status CalcOpRunningParam(Node &node) = 0; | |||
| // generate task for op | |||
| // generate task for op。 | |||
| virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | |||
| // only call fe engine interface to compile single op | |||
| @@ -77,10 +77,10 @@ class OpsKernelInfoStore { | |||
| // load task for op | |||
| virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | |||
| // only to call aicpu interface for generating task struct | |||
| // only call aicpu interface to generate task struct | |||
| virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | |||
| // only to call aicpu interface for generating task struct | |||
| // only call aicpu interface to generate task struct | |||
| virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | |||
| }; | |||
| } // namespace ge | |||
| @@ -37,6 +37,7 @@ struct RunContext { | |||
| ge::Buffer weightsBuffer; | |||
| std::vector<rtStream_t> graphStreamList; // all streams of graph, order by ge stream id(0,1,...) | |||
| std::vector<rtEvent_t> graphEventList; // all events of graph, order by ge event id(0,1,...) | |||
| std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...) | |||
| }; | |||
| struct Task { | |||
| @@ -19,7 +19,6 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include "./graph_optimizer_types.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/opskernel/ops_kernel_info_types.h" | |||
| @@ -39,19 +38,19 @@ class GraphOptimizer { | |||
| // close graphOptimizer | |||
| virtual Status Finalize() = 0; | |||
| // optimize original graph for FE quant optimization | |||
| // optimize original graph for FE quant optimize | |||
| virtual Status OptimizeGraphPrepare(ComputeGraph &graph) { return SUCCESS; } | |||
| // optimize original graph used in the graph preparation stage | |||
| // optimize original graph, using in graph preparation stage | |||
| virtual Status OptimizeOriginalGraph(ComputeGraph &graph) = 0; | |||
| // optimize fused graph | |||
| virtual Status OptimizeFusedGraph(ComputeGraph &graph) = 0; | |||
| // optimize the whole graph which will be used after graph merged | |||
| // optimize whole graph, using after graph merged stage | |||
| virtual Status OptimizeWholeGraph(ComputeGraph &graph) = 0; | |||
| // get attributes of graph optimizer | |||
| // get attribute of graph optimizer | |||
| virtual Status GetAttributes(GraphOptimizerAttribute &attrs) const = 0; | |||
| // optimize streamed Graph | |||
| @@ -19,8 +19,6 @@ | |||
| #include <stdint.h> | |||
| #include <string> | |||
| using std::string; | |||
| namespace ge { | |||
| enum OPTIMIZER_SCOPE { | |||
| UNIT = 0, | |||
| @@ -28,7 +26,7 @@ enum OPTIMIZER_SCOPE { | |||
| }; | |||
| struct GraphOptimizerAttribute { | |||
| string engineName; | |||
| std::string engineName; | |||
| OPTIMIZER_SCOPE scope; | |||
| }; | |||
| } // namespace ge | |||
| @@ -20,6 +20,7 @@ | |||
| #include <cstdint> | |||
| #include <string> | |||
| #include <vector> | |||
| #include <set> | |||
| namespace ge { | |||
| // Option key: graph run mode | |||
| @@ -38,9 +39,11 @@ const char *const GE_AICPU_FLAG = "ge.aicpuFlag"; | |||
| const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||
| const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | |||
| const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | |||
| const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||
| // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | |||
| const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | |||
| const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | |||
| const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; | |||
| // Option key: memory init | |||
| const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | |||
| @@ -141,19 +144,43 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||
| // congigure outputDatatype to setting net output type | |||
| const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | |||
| // congigure opSelectImplmode to setting op select implmode | |||
| const std::string kOpSelectImplmode = "ge.opSelectImplmode"; | |||
| // configure whether to enable hcom parallel by session constructor options param, | |||
| // its value should be "0" or "1", default value is "0" | |||
| const std::string HCOM_PARALLEL = "ge.hcomParallel"; | |||
| // configure whether to use dynamic batch size | |||
| const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | |||
| // configure whether to use dynamic image size | |||
| const char *const kDynamicImageSize = "ge.dynamicImageSize"; | |||
| // Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, | |||
| // example: GA|RL, support configure multiple, split by | | |||
| const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | |||
| // Configure soc version , example: "Ascend310" | |||
| const std::string SOC_VERSION = "ge.socVersion"; | |||
| // Configure core type "VectorEngine", default value is "AIcoreEngine" | |||
| const std::string CORE_TYPE = "ge.engineType"; | |||
| // Configure soc version , example: "Ascend310" | |||
| const std::string SOC_VERSION = "ge.socVersion"; | |||
| // Configure AICORE NUM | |||
| const std::string AICORE_NUM = "ge.aicoreNum"; | |||
| // Configure L1FUSION | |||
| const std::string L1_FUSION = "ge.l1Fusion"; | |||
| // Configure Small Channel flag | |||
| const std::string ENABLE_SMALL_CHANNEL = "ge.enableSmallChannel"; | |||
| // Configure Compress Weight flag | |||
| const std::string ENABLE_COMPRESS_WEIGHT = "ge.enableCompressWeight"; | |||
| // Configure fusion switch file path | |||
| const std::string FUSION_SWITCH_FILE = "ge.fusionSwitchFile"; | |||
| // Save original model | |||
| const std::string SAVE_ORIGINAL_MODEL = "ge.saveOriginalModel"; | |||
| @@ -194,6 +221,28 @@ struct TensorInfo { | |||
| DataDesc data; // tensor data | |||
| ShapeDesc shapeInfo; // tensor shape | |||
| }; | |||
| // for ir build | |||
| namespace ir_option { | |||
| static const char *const INPUT_FORMAT = "input_format"; | |||
| static const char *const INPUT_SHAPE = "input_shape"; | |||
| static const char *const OP_NAME_MAP = "op_name_map"; | |||
| static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | |||
| static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | |||
| static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | |||
| static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); | |||
| static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; | |||
| static const char *const HEAD_STREAM = ge::HEAD_STREAM.c_str(); | |||
| static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | |||
| static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | |||
| static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | |||
| // for interface: aclgrphBuildModel | |||
| const std::set<std::string> ir_builder_suppported_options = { | |||
| INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||
| DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
| AUTO_TUNE_MODE}; | |||
| // for interface: aclgrphBuildInitialize | |||
| const std::set<std::string> global_options = {HEAD_STREAM, CORE_TYPE, SOC_VERSION}; | |||
| } // namespace ir_option | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_GE_GE_API_TYPES_H_ | |||
| @@ -0,0 +1,75 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_EXTERNAL_GE_IR_BUILD_H_ | |||
| #define INC_EXTERNAL_GE_IR_BUILD_H_ | |||
| #include <string> | |||
| #include <map> | |||
| #include <memory> | |||
| #include "graph/graph.h" | |||
| #include "graph/ge_error_codes.h" | |||
| namespace ge { | |||
| struct ModelBufferData { | |||
| std::shared_ptr<uint8_t> data = nullptr; | |||
| uint64_t length; | |||
| }; | |||
| /** | |||
| * @ingroup AscendCL | |||
| * @brief build model.Notice the model is stored in buffer | |||
| * | |||
| * @param global_options[IN] global init params for build | |||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||
| * @retval OtherValues Failure | |||
| */ | |||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options); | |||
| /** | |||
| * @ingroup AscendCL | |||
| * @brief build model.Notice the model is stored in buffer | |||
| * | |||
| */ | |||
| void aclgrphBuildFinalize(); | |||
| /** | |||
| * @ingroup AscendCL | |||
| * @brief build model.Notice the model is stored in buffer | |||
| * | |||
| * @param graph[IN] the graph ready to build | |||
| * @param options[IN] options used for build | |||
| * @param model[OUT] builded model | |||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||
| * @retval OtherValues Failure | |||
| */ | |||
| graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | |||
| ModelBufferData &model); | |||
| /** | |||
| * @ingroup AscendCL | |||
| * @brief save model buffer to file | |||
| * | |||
| * @param output_file[IN] the file path to be saved | |||
| * @param model[IN] model buffer data | |||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||
| * @retval OtherValues Failure | |||
| */ | |||
| graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &model); | |||
| }; // namespace ge | |||
| #endif | |||
| @@ -22,7 +22,7 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/graph/ge_error_codes.h" | |||
| #include "./ge_error_codes.h" | |||
| using std::make_shared; | |||
| using std::map; | |||
| @@ -22,7 +22,7 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "external/graph/operator.h" | |||
| #include "./operator.h" | |||
| namespace ge { | |||
| class GraphImpl; | |||
| @@ -21,8 +21,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/graph/tensor.h" | |||
| #include "external/graph/types.h" | |||
| #include "./tensor.h" | |||
| #include "./types.h" | |||
| namespace ge { | |||
| class InferenceContext; | |||
| @@ -69,7 +69,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InferenceContext { | |||
| static std::unique_ptr<InferenceContext> Create(); | |||
| private: | |||
| InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
| explicit InferenceContext(std::unique_ptr<InferenceContextImpl> &impl); | |||
| std::shared_ptr<InferenceContextImpl> inference_context_impl_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -23,9 +23,9 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/graph/ge_error_codes.h" | |||
| #include "external/graph/inference_context.h" | |||
| #include "external/graph/tensor.h" | |||
| #include "./ge_error_codes.h" | |||
| #include "./inference_context.h" | |||
| #include "./tensor.h" | |||
| #ifndef USER_GE_LOGI | |||
| #define USER_GE_LOGI(...) | |||
| @@ -22,8 +22,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/graph//operator.h" | |||
| #include "external/graph/ge_error_codes.h" | |||
| #include "./operator.h" | |||
| #include "./ge_error_codes.h" | |||
| namespace ge { | |||
| using OpCreator = std::function<Operator(const std::string &)>; | |||
| @@ -22,10 +22,10 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/graph/operator.h" | |||
| #include "external/graph/operator_factory.h" | |||
| #include "external/graph/tensor.h" | |||
| #include "external/graph/types.h" | |||
| #include "./operator.h" | |||
| #include "./operator_factory.h" | |||
| #include "./tensor.h" | |||
| #include "./types.h" | |||
| namespace ge { | |||
| using std::function; | |||
| @@ -60,7 +60,7 @@ class OpReg { | |||
| \ | |||
| private: \ | |||
| void __##x() { \ | |||
| OpReg() | |||
| OpReg() | |||
| #define ATTR(x, Type, ...) \ | |||
| N(); \ | |||
| @@ -86,7 +86,7 @@ class OpReg { | |||
| void __attr_##x() { \ | |||
| Operator::AttrRegister(#x, Op##Type(__VA_ARGS__)); \ | |||
| string attr_name(#x); \ | |||
| (void)OpReg() | |||
| (void)OpReg() | |||
| #define REQUIRED_ATTR(x, Type) \ | |||
| N(); \ | |||
| @@ -112,7 +112,7 @@ class OpReg { | |||
| void __required_attr_##x() { \ | |||
| Operator::RequiredAttrRegister(#x); \ | |||
| string attr_name(#x); \ | |||
| (void)OpReg() | |||
| (void)OpReg() | |||
| #define INPUT(x, t) \ | |||
| N(); \ | |||
| @@ -137,7 +137,7 @@ class OpReg { | |||
| private: \ | |||
| void __input_##x() { \ | |||
| Operator::InputRegister(#x); \ | |||
| (void)OpReg() | |||
| (void)OpReg() | |||
| #define OPTIONAL_INPUT(x, t) \ | |||
| N(); \ | |||
| @@ -162,7 +162,7 @@ class OpReg { | |||
| private: \ | |||
| void __optional_input_##x() { \ | |||
| Operator::OptionalInputRegister(#x); \ | |||
| (void)OpReg() | |||
| (void)OpReg() | |||
| #define OUTPUT(x, t) \ | |||
| N(); \ | |||
| @@ -179,7 +179,7 @@ class OpReg { | |||
| private: \ | |||
| void __out_##x() { \ | |||
| Operator::OutputRegister(#x); \ | |||
| (void)OpReg() | |||
| (void)OpReg() | |||
| #define DYNAMIC_INPUT(x, t) \ | |||
| N(); \ | |||
| @@ -206,7 +206,7 @@ class OpReg { | |||
| \ | |||
| private: \ | |||
| void __dy_input_##x() { \ | |||
| (void)OpReg() | |||
| (void)OpReg() | |||
| #define DYNAMIC_OUTPUT(x, t) \ | |||
| N(); \ | |||
| @@ -227,18 +227,18 @@ class OpReg { | |||
| \ | |||
| private: \ | |||
| void __dy_output_##x() { \ | |||
| (void)OpReg() | |||
| (void)OpReg() | |||
| #define PASTE(g_register, y) g_register##y | |||
| #define __OP_END_IMPL__(x, y) \ | |||
| N(); \ | |||
| } \ | |||
| static_assert( \ | |||
| std::is_same<x, _THIS_TYPE>::value, \ | |||
| "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||
| } \ | |||
| ; \ | |||
| static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||
| #define __OP_END_IMPL__(x, y) \ | |||
| N(); \ | |||
| } \ | |||
| static_assert( \ | |||
| std::is_same<x, _THIS_TYPE>::value, \ | |||
| "The class name entered into the OP_END_FACTORY_REG needs to be the same as the operator name you define."); \ | |||
| } \ | |||
| ; \ | |||
| static const OperatorCreatorRegister PASTE(g_register, y)(#x, [](const std::string &name) { return x(name); }); \ | |||
| } | |||
| #define OP_END_FACTORY_REG(x) __OP_END_IMPL__(x, __COUNTER__) | |||
| @@ -286,7 +286,7 @@ class OpReg { | |||
| // Common shape inferencer | |||
| #define ELMTWISE_INFER_SHAPEANDTYPE(in_name, out_name) \ | |||
| [](Operator op)->graphStatus { \ | |||
| [](Operator op) -> graphStatus { \ | |||
| auto x_shape = op.GetInputDesc(in_name).GetShape().GetDims(); \ | |||
| auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | |||
| TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | |||
| @@ -300,7 +300,7 @@ graphStatus BroadCastInfer(const function<vector<int64_t>()> &get_in1_shape, | |||
| const function<void(const vector<int64_t> &y_shape)> &set_out_shape); | |||
| #define BROADCAST_INFER(in1_name, in2_name, out_name) \ | |||
| [](Operator op)->graphStatus { \ | |||
| [](Operator op) -> graphStatus { \ | |||
| return BroadCastInfer([&]() { return op.GetInputDesc(in1_name).GetShape().GetDims(); }, \ | |||
| [&]() { return op.GetInputDesc(in2_name).GetShape().GetDims(); }, \ | |||
| [&](const vector<int64_t> &y_shape) { \ | |||
| @@ -22,8 +22,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/graph/ge_error_codes.h" | |||
| #include "external/graph/types.h" | |||
| #include "./ge_error_codes.h" | |||
| #include "./types.h" | |||
| namespace ge { | |||
| class ShapeImpl; | |||
| @@ -133,11 +133,13 @@ enum Format { | |||
| FORMAT_FRACTAL_ZZ, | |||
| FORMAT_FRACTAL_NZ, | |||
| FORMAT_NCDHW, | |||
| FORMAT_DHWCK, // 3D filter input tensor format | |||
| FORMAT_DHWCN, // 3D filter input tensor format | |||
| FORMAT_NDC1HWC0, | |||
| FORMAT_FRACTAL_Z_3D, | |||
| FORMAT_CN, | |||
| FORMAT_NC, | |||
| FORMAT_DHWNC, | |||
| FORMAT_FRACTAL_Z_3D_TRANSPOSE, // 3D filter(transpose) input tensor format | |||
| FORMAT_RESERVED, | |||
| FORMAT_ALL | |||
| }; | |||
| @@ -47,6 +47,12 @@ class Tensor; | |||
| class TBEPluginManager; | |||
| } // namespace ge | |||
| namespace google { | |||
| namespace protobuf { | |||
| class Message; | |||
| } | |||
| } // namespace google | |||
| namespace domi { | |||
| Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | |||
| Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | |||
| @@ -56,6 +62,8 @@ using google::protobuf::Message; | |||
| class OpRegistrationDataImpl; | |||
| using ParseParamFunc = std::function<domi::Status(const google::protobuf::Message *, ge::Operator &)>; | |||
| using FusionParseParamFunc = | |||
| std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | |||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
| public: | |||
| @@ -71,15 +79,20 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
| OpRegistrationData &ParseParamsFn(const ParseParamFunc &parseParamFn); | |||
| OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | |||
| OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | |||
| OpRegistrationData &DelInputWithCond(int inputIdx, const std::string &attrName, bool attrValue); | |||
| OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); | |||
| domi::ImplyType GetImplyType() const; | |||
| std::string GetOmOptype() const; | |||
| std::set<std::string> GetOriginOpTypeSet() const; | |||
| domi::FrameworkType GetFrameworkType() const; | |||
| ParseParamFunc GetParseParamFn() const; | |||
| FusionParseParamFunc GetFusionParseParamFn() const; | |||
| private: | |||
| std::shared_ptr<OpRegistrationDataImpl> impl_; | |||
| @@ -103,5 +116,27 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||
| namespace ge { | |||
| using OpRegistrationData = domi::OpRegistrationData; | |||
| using OpReceiver = domi::OpReceiver; | |||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||
| public: | |||
| HostCpuOp() = default; | |||
| virtual ~HostCpuOp() = default; | |||
| virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||
| std::map<std::string, Tensor> &outputs) = 0; | |||
| }; | |||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||
| public: | |||
| HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||
| }; | |||
| #define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||
| static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||
| ::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | |||
| @@ -22,7 +22,7 @@ | |||
| #define DECLARE_ERRORNO(sysid, modid, name, value) \ | |||
| const domi::Status name = \ | |||
| ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||
| ((0xFF & ((uint8_t)sysid)) << 24) | ((0xFF & ((uint8_t)modid)) << 16) | (0xFFFF & ((uint16_t)value)); | |||
| #define DECLARE_ERRORNO_COMMON(name, value) DECLARE_ERRORNO(SYSID_FWK, MODID_COMMON, name, value) | |||
| @@ -33,6 +33,7 @@ using Status = uint32_t; | |||
| DECLARE_ERRORNO(0, 0, SUCCESS, 0); | |||
| DECLARE_ERRORNO(0xFF, 0xFF, FAILED, 0xFFFFFFFF); | |||
| DECLARE_ERRORNO_COMMON(PARAM_INVALID, 1); // 50331649 | |||
| DECLARE_ERRORNO(SYSID_FWK, 1, SCOPE_NOT_CHANGED, 201); | |||
| } // namespace domi | |||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_ERROR_CODES_H_ | |||
| @@ -48,6 +48,10 @@ typedef enum tagDomiTensorFormat { | |||
| DOMI_TENSOR_BN_WEIGHT, | |||
| DOMI_TENSOR_CHWN, // Android NN Depth CONV | |||
| DOMI_TENSOR_FILTER_HWCK, // filter input tensor format | |||
| DOMI_TENSOR_NDHWC, | |||
| DOMI_TENSOR_NCDHW, | |||
| DOMI_TENSOR_DHWCN, // 3D filter input tensor format | |||
| DOMI_TENSOR_DHWNC, | |||
| DOMI_TENSOR_RESERVED | |||
| } domiTensorFormat_t; | |||
| } // namespace domi | |||
| @@ -18,11 +18,13 @@ | |||
| #define INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | |||
| #include <cstdint> | |||
| #include <unistd.h> | |||
| #include <sys/syscall.h> | |||
| #include "framework/common/ge_inner_error_codes.h" | |||
| #include "toolchain/slog.h" | |||
| #define GE_MODULE_NAME GE | |||
| #define GE_MODULE_NAME static_cast<int>(GE) | |||
| // trace status of log | |||
| enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | |||
| @@ -35,15 +37,20 @@ enum TraceStatus { TRACE_INIT = 0, TRACE_RUNNING, TRACE_WAITING, TRACE_STOP }; | |||
| #define GELOGO(...) GE_LOG_OPLOG(GE_MODULE_NAME, __VA_ARGS__) | |||
| #define GELOGT(VALUE, ...) GE_LOG_TRACE(GE_MODULE_NAME, VALUE, __VA_ARGS__) | |||
| inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||
| int32_t enable_event = 0; | |||
| int32_t dlog_level = dlog_getlevel(module_name, &enable_event); | |||
| if (dlog_level <= log_level) { | |||
| inline bool IsLogEnable(int module_name, int log_level) { | |||
| int32_t enable = CheckLogLevel(module_name, log_level); | |||
| // 1:enable, 0:disable | |||
| if (enable == 1) { | |||
| return true; | |||
| } | |||
| return false; | |||
| } | |||
| inline pid_t GetTid() { | |||
| thread_local static pid_t tid = syscall(__NR_gettid); | |||
| return tid; | |||
| } | |||
| #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||
| #define GE_TIMESTAMP_END(stage, stage_name) \ | |||
| @@ -68,29 +75,35 @@ inline bool IsLogEnable(int module_name, int log_level) noexcept { | |||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | |||
| call_num_of##stage) | |||
| #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||
| dlog_error(static_cast<int>(MOD_NAME), "%s: ErrorNo: %d(%s) " fmt, __FUNCTION__, ERROR_CODE, \ | |||
| #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||
| dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ | |||
| ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | |||
| #define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||
| if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_WARN)) \ | |||
| dlog_warn(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||
| if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_INFO)) \ | |||
| dlog_info(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||
| if (IsLogEnable(static_cast<int>(MOD_NAME), DLOG_DEBUG)) \ | |||
| dlog_debug(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(MOD_NAME), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_WARN(MOD_NAME, fmt, ...) \ | |||
| if (IsLogEnable(MOD_NAME, DLOG_WARN)) dlog_warn(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_INFO(MOD_NAME, fmt, ...) \ | |||
| if (IsLogEnable(MOD_NAME, DLOG_INFO)) dlog_info(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_DEBUG(MOD_NAME, fmt, ...) \ | |||
| if (IsLogEnable(MOD_NAME, DLOG_DEBUG)) dlog_debug(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_EVENT(MOD_NAME, fmt, ...) dlog_event(MOD_NAME, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_OPLOG(MOD_NAME, fmt, ...) \ | |||
| Dlog(static_cast<int>(MOD_NAME), DLOG_OPLOG, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||
| do { \ | |||
| TraceStatus stat = value; \ | |||
| const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||
| int idx = static_cast<int>(stat); \ | |||
| char *k = const_cast<char *>("status"); \ | |||
| char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||
| KeyValue kv = {k, v}; \ | |||
| DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%s:" fmt, __FUNCTION__, ##__VA_ARGS__); \ | |||
| Dlog(MOD_NAME, DLOG_OPLOG, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOG_TRACE(MOD_NAME, value, fmt, ...) \ | |||
| do { \ | |||
| TraceStatus stat = value; \ | |||
| const char *const TraceStatStr[] = {"INIT", "RUNNING", "WAITING", "STOP"}; \ | |||
| int idx = static_cast<int>(stat); \ | |||
| char *k = const_cast<char *>("status"); \ | |||
| char *v = const_cast<char *>(TraceStatStr[idx]); \ | |||
| KeyValue kv = {k, v}; \ | |||
| DlogWithKV(static_cast<int>(MOD_NAME), DLOG_TRACE, &kv, 1, "%lu %s:" fmt, GetTid(), __FUNCTION__, ##__VA_ARGS__); \ | |||
| } while (0) | |||
| // print memory when it is greater than 1KB. | |||
| #define GE_PRINT_DYNAMIC_MEMORY(FUNC, PURPOSE, SIZE) \ | |||
| do { \ | |||
| if ((SIZE) > 1024) { \ | |||
| GELOGI("MallocMemory, func=%s, size=%zu, purpose=%s", (#FUNC), static_cast<size_t>(SIZE), (PURPOSE)); \ | |||
| } \ | |||
| } while (0); | |||
| #endif // INC_FRAMEWORK_COMMON_DEBUG_GE_LOG_H_ | |||
| @@ -29,7 +29,18 @@ | |||
| using cce::CC_STATUS_SUCCESS; | |||
| using cce::ccStatus_t; | |||
| #define GE_LOGE(...) DAV_LOGE("GE", __VA_ARGS__) | |||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||
| #define DOMI_LOGE(...) DAV_LOGE("DOMI", __VA_ARGS__) | |||
| #else | |||
| #include <android/log.h> | |||
| #if defined(BUILD_VERSION_PERF) | |||
| #define DOMI_LOGE(fmt, ...) | |||
| #else | |||
| // The Android system has strict log control. Do not modify the log. | |||
| #define DOMI_LOGE(fmt, ...) \ | |||
| __android_log_print(ANDROID_LOG_ERROR, "NPU_FMK", "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||
| #endif | |||
| #endif | |||
| // ge marco | |||
| #define GE_LOGI_IF(condition, ...) \ | |||
| @@ -44,7 +55,7 @@ using cce::ccStatus_t; | |||
| #define GE_LOGE_IF(condition, ...) \ | |||
| if ((condition)) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| } | |||
| // If expr is not SUCCESS, print the log and return the same value | |||
| @@ -52,7 +63,7 @@ using cce::ccStatus_t; | |||
| do { \ | |||
| const ge::Status _status = (expr); \ | |||
| if (_status != ge::SUCCESS) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0); | |||
| @@ -62,7 +73,7 @@ using cce::ccStatus_t; | |||
| do { \ | |||
| const ge::Status _status = (expr); \ | |||
| if (_status != ge::SUCCESS) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| } \ | |||
| } while (0); | |||
| @@ -75,6 +86,15 @@ using cce::ccStatus_t; | |||
| } \ | |||
| } while (0); | |||
| // If expr is not GRAPH_SUCCESS, print the log and return FAILED | |||
| #define GE_CHK_GRAPH_STATUS_RET(expr, ...) \ | |||
| do { \ | |||
| if ((expr) != ge::GRAPH_SUCCESS) { \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| return FAILED; \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not SUCCESS, print the log and execute a custom statement | |||
| #define GE_CHK_STATUS_EXEC(expr, exec_expr, ...) \ | |||
| do { \ | |||
| @@ -91,25 +111,11 @@ using cce::ccStatus_t; | |||
| (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||
| (void)msg.append( \ | |||
| ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
| GE_LOGE("%s", msg.c_str()); \ | |||
| DOMI_LOGE("%s", msg.c_str()); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not true, print the Info log and return the specified status | |||
| #define GE_CHK_BOOL_RET_STATUS_LOGI(expr, _status, ...) \ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| std::string msg; \ | |||
| (void)msg.append(StringUtils::FormatString(__VA_ARGS__)); \ | |||
| (void)msg.append( \ | |||
| StringUtils::FormatString(" Check result false, status: 0x%X %s", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
| GELOGI("%s", msg.c_str()); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not true, print the log and return the specified status | |||
| #define GE_CHK_BOOL_RET_STATUS_NOLOG(expr, _status, ...) \ | |||
| do { \ | |||
| @@ -124,7 +130,7 @@ using cce::ccStatus_t; | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| @@ -163,7 +169,7 @@ using cce::ccStatus_t; | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (b) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| @@ -182,7 +188,7 @@ using cce::ccStatus_t; | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (b) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| return; \ | |||
| } \ | |||
| @@ -193,7 +199,7 @@ using cce::ccStatus_t; | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (b) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| return _status; \ | |||
| } \ | |||
| @@ -210,62 +216,42 @@ using cce::ccStatus_t; | |||
| // -----------------runtime related macro definitions------------------------------- | |||
| // If expr is not RT_ERROR_NONE, print the log | |||
| #define GE_CHK_RT(expr) \ | |||
| do { \ | |||
| rtError_t _rt_ret = (expr); \ | |||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||
| GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| } \ | |||
| #define GE_CHK_RT(expr) \ | |||
| do { \ | |||
| rtError_t _rt_ret = (expr); \ | |||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not RT_ERROR_NONE, print the log and execute the exec_expr expression | |||
| #define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||
| { \ | |||
| rtError_t _rt_ret = (expr); \ | |||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||
| GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| exec_expr; \ | |||
| } \ | |||
| #define GE_CHK_RT_EXEC(expr, exec_expr) \ | |||
| { \ | |||
| rtError_t _rt_ret = (expr); \ | |||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } | |||
| // If expr is not RT_ERROR_NONE, print the log and return | |||
| #define GE_CHK_RT_RET(expr) \ | |||
| do { \ | |||
| rtError_t _rt_ret = (expr); \ | |||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||
| GE_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| return ge::RT_FAILED; \ | |||
| } \ | |||
| #define GE_CHK_RT_RET(expr) \ | |||
| do { \ | |||
| rtError_t _rt_ret = (expr); \ | |||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| return ge::RT_FAILED; \ | |||
| } \ | |||
| } while (0); | |||
| // ------------------------cce related macro definitions---------------------------- | |||
| // If expr is not CC_STATUS_SUCCESS, print the log | |||
| #define GE_CHK_CCE(expr) \ | |||
| do { \ | |||
| ccStatus_t _cc_ret = (expr); \ | |||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
| GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not CC_STATUS_SUCCESS, print the log and execute the exec_expr expression | |||
| #define GE_CHK_CCE_EXEC(expr, exec_expr) \ | |||
| do { \ | |||
| ccStatus_t _cc_ret = (expr); \ | |||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
| GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not CC_STATUS_SUCCESS, print the log and return | |||
| #define GE_CHK_CCE_RET(expr) \ | |||
| do { \ | |||
| ccStatus_t _cc_ret = (expr); \ | |||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
| GE_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
| return ge::CCE_FAILED; \ | |||
| } \ | |||
| #define GE_CHK_CCE(expr) \ | |||
| do { \ | |||
| ccStatus_t _cc_ret = (expr); \ | |||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
| DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is true, execute exec_expr without printing logs | |||
| @@ -281,37 +267,8 @@ using cce::ccStatus_t; | |||
| try { \ | |||
| exec_expr0; \ | |||
| } catch (const std::bad_alloc &) { \ | |||
| GE_LOGE("Make shared failed"); \ | |||
| DOMI_LOGE("Make shared failed"); \ | |||
| exec_expr1; \ | |||
| } | |||
| #define GE_CHECK_INT32_MUL_OVERFLOW(a, b, ...) \ | |||
| do { \ | |||
| if ((a) > 0) { \ | |||
| if ((b) > 0) { \ | |||
| if ((a) > (INT32_MAX / (b))) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } else { \ | |||
| if ((b) < (INT32_MIN / (a))) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } \ | |||
| } else { \ | |||
| if ((b) > 0) { \ | |||
| if ((a) < (INT32_MAX / (b))) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } else { \ | |||
| if (((a) != 0) && ((b) < (INT32_MAX / (a)))) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } \ | |||
| } \ | |||
| } while (0); | |||
| #endif // INC_FRAMEWORK_COMMON_DEBUG_LOG_H_ | |||
| @@ -152,7 +152,6 @@ GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_RUN_GRAPH_INVALID, 11, "Get computeGraph by g | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_INSERT_DYN_OP_FAILED, 12, "Graph which insert dynamic op failed."); // 1343242252 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_PREPROCESS_FAILED, 13, "Graph preprocess failed."); // 1343242253 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_GRAPH_FUSION_FAILED, 14, "Graph fusion failed."); // 1343242254 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_TINY_CAL_CHECK_FAILED, 15, "Check tiny calibration failed."); // 1343242255 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_OPTIMIZE_CALIBRATION_FAILED, 16, "Calibration failed."); // 1343242256 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_NUM_ZERO, 17, "Graph partition success, but subGraph num is 0."); // 1343242257 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_SUBGRAPH_ENGINENAME_REPEATED, 18, "Graph subGraph engine name is repeated."); // 1343242258 | |||
| @@ -204,15 +203,16 @@ GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_GET_GRAPH_REBUILD_FAILED, 60, | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_NODE_SEARCHER_SET_GRAPH_FINISH_REBUILD_GRAPH_FAILED, 61, | |||
| "Failed set graph finish rebuild in node searcher."); // 1343242301 | |||
| GE_ERRORNO_GRAPH(GE_GRAPH_VARIABLE_OP_PASS_FAILED, 62, "Failed to run variable pass."); // 1343242302 | |||
| // Optimize errocode | |||
| GE_ERRORNO_GRAPH(TO_BE_DELETED, 200, "The node of the graph to be deleted."); | |||
| GE_ERRORNO_GRAPH(NOT_CHANGED, 201, "NThe node of the graph not changed."); | |||
| // Engine_manager module error code definition | |||
| GE_ERRORNO_ENGINE(GE_ENG_INIT_FAILED, 0, "Failed to initialize engine."); // 1343246336 | |||
| GE_ERRORNO_ENGINE(GE_ENG_FINALIZE_FAILED, 1, "Engine finalize failed."); // 1343246337 | |||
| GE_ERRORNO_ENGINE(GE_ENG_MEMTYPE_ERROR, 2, "Memory type HBM is necessary when engine is in device"); // 1343246338 | |||
| // Optimize errocode | |||
| GE_ERRORNO_GRAPH(TO_BE_DELETED, 63, "The node of the graph to be deleted."); // 1343242303 | |||
| GE_ERRORNO_GRAPH(NOT_CHANGED, 64, "The node of the graph no changed."); // 1343242304 | |||
| // Ops module error code definition | |||
| GE_ERRORNO_OPS(GE_OPS_KERNEL_STORE_INIT_FAILED, 0, "Failed to initialize OpsKernelInfoStore."); // 1343250432 | |||
| GE_ERRORNO_OPS(GE_OPS_GRAPH_OPTIMIZER_INIT_FAILED, 1, "Failed to initialize GraphOptimizer."); // 1343250433 | |||
| @@ -24,8 +24,7 @@ | |||
| #include "common/fmk_error_codes.h" | |||
| #include "ge/ge_api_error_codes.h" | |||
| using std::string; | |||
| #include "external/graph/types.h" | |||
| namespace ge { | |||
| enum RuntimeType { HOST = 0, DEVICE = 1 }; | |||
| @@ -56,7 +55,7 @@ struct DataBuffer { | |||
| /// | |||
| /// @ingroup domi_ome | |||
| /// @brief External inputdata | |||
| /// @brief External input data | |||
| /// | |||
| struct InputData { | |||
| uint32_t index; // Index of input data | |||
| @@ -65,13 +64,14 @@ struct InputData { | |||
| uint32_t model_id; // Model ID required for data processing | |||
| uint64_t request_id = 0; // Request ID | |||
| std::vector<DataBuffer> blobs; // Actual input data, currently only supports one input | |||
| bool is_dynamic_batch = false; // Whether is dynamic batch size scene, default:false | |||
| std::string batch_label; // Gear used for current inference in dynamic batch scene | |||
| }; | |||
| // The definition of output result structure | |||
| /// Output result structure definition | |||
| struct OutputData { | |||
| uint32_t index; // Index of input data | |||
| uint32_t model_id; // The model ID corresponding to the processing result | |||
| /// Output data cache, arranged in sequence of output operators. | |||
| /// If the operator has multiple outputs, | |||
| /// the data buffer order of the operator is the same as that defined in the | |||
| @@ -142,11 +142,31 @@ struct Options { | |||
| bool deployMode; | |||
| bool isAICPUMode; | |||
| bool enable_atomic; | |||
| string podName; | |||
| std::string podName; | |||
| int64_t rankId; | |||
| string rankTableFile; | |||
| std::string rankTableFile; | |||
| int32_t ge_hccl_flag = 0; | |||
| int32_t physical_device_id; | |||
| }; | |||
| // Profiling info of task | |||
| struct TaskDescInfo { | |||
| std::string op_name; | |||
| uint32_t block_dim; | |||
| uint32_t task_id; | |||
| uint32_t stream_id; | |||
| }; | |||
| // Profiling info of graph | |||
| struct ComputeGraphDescInfo { | |||
| std::string op_name; | |||
| std::string op_type; | |||
| std::vector<Format> input_format; | |||
| std::vector<std::vector<int64_t>> input_shape; | |||
| std::vector<DataType> input_data_type; | |||
| std::vector<Format> output_format; | |||
| std::vector<std::vector<int64_t>> output_shape; | |||
| std::vector<DataType> output_data_type; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_GE_TYPES_H_ | |||
| @@ -19,7 +19,6 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <memory> | |||
| #include "common/fmk_types.h" | |||
| #include "common/helper/om_file_helper.h" | |||
| @@ -33,36 +32,41 @@ class ModelHelper { | |||
| ModelHelper() = default; | |||
| ~ModelHelper(); | |||
| Status SaveToOmModel(const GeModelPtr &ge_model, const SaveParam &save_param, const std::string &output_file); | |||
| Status SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::string &output_file); | |||
| Status LoadModel(const ge::ModelData &model_data); | |||
| Status SaveToOmModel(const GeModelPtr& ge_model, const SaveParam& save_param, const std::string& output_file, | |||
| ge::ModelBufferData& model); | |||
| Status SaveOriginalGraphToOmModel(const ge::Graph& graph, const std::string& output_file); | |||
| Status LoadModel(const ge::ModelData& model_data); | |||
| Status GetModelBufferData(ge::ModelBufferData& model); | |||
| ModelFileHeader *GetFileHeader() { return file_header_; } | |||
| ModelFileHeader* GetFileHeader() { return file_header_; } | |||
| GeModelPtr GetGeModel(); | |||
| void SetSaveMode(bool val) { is_offline_ = val; } | |||
| bool GetSaveMode(void) const { return is_offline_; } | |||
| static Status TransModelToGeModel(const ModelPtr &model, GeModelPtr &ge_model); | |||
| static Status TransGeModelToModel(const GeModelPtr &geModelPtr, ModelPtr &modelPtr); | |||
| static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); | |||
| static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); | |||
| private: | |||
| bool is_assign_model_ = false; | |||
| ModelFileHeader *file_header_ = nullptr; | |||
| bool is_offline_ = true; | |||
| ModelFileHeader* file_header_ = nullptr; | |||
| // Encrypted model need delete temp model and unencrypted model need not delete model | |||
| uint8_t *model_addr_tmp_ = nullptr; | |||
| uint8_t* model_addr_tmp_ = nullptr; | |||
| uint32_t model_len_tmp_ = 0; | |||
| GeModelPtr model_; | |||
| ModelHelper(const ModelHelper &); | |||
| ModelHelper &operator=(const ModelHelper &); | |||
| Status GenerateGeModel(OmFileLoadHelper &om_load_helper); | |||
| Status LoadModelData(OmFileLoadHelper &om_load_helper); | |||
| void SetModelToGeModel(ge::Model &model); | |||
| Status LoadWeights(OmFileLoadHelper &om_load_helper); | |||
| Status LoadTask(OmFileLoadHelper &om_load_helper); | |||
| Status LoadTBEKernelStore(OmFileLoadHelper &om_load_helper); | |||
| ModelHelper(const ModelHelper&); | |||
| ModelHelper& operator=(const ModelHelper&); | |||
| Status GenerateGeModel(OmFileLoadHelper& om_load_helper); | |||
| Status LoadModelData(OmFileLoadHelper& om_load_helper); | |||
| void SetModelToGeModel(ge::Model& model); | |||
| Status LoadWeights(OmFileLoadHelper& om_load_helper); | |||
| Status LoadTask(OmFileLoadHelper& om_load_helper); | |||
| Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | |||
| Status ReleaseLocalModelData() noexcept; | |||
| Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_file_save_helper, ModelPartitionType type, | |||
| const uint8_t *data, size_t size); | |||
| Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | |||
| const uint8_t* data, size_t size); | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_HELPER_MODEL_HELPER_H_ | |||
| @@ -20,10 +20,12 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/ge/ge_ir_build.h" | |||
| #include "framework/common/fmk_types.h" | |||
| #include "framework/common/ge_types.h" | |||
| #include "framework/common/types.h" | |||
| #include "framework/common/ge_types.h" | |||
| using ProcParam = struct PROC_PARAM; | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -80,9 +82,10 @@ class OmFileSaveHelper { | |||
| const std::vector<ModelPartition> &GetModelPartitions() const; | |||
| Status SaveModel(const SaveParam &save_param, const char *target_file); | |||
| Status SaveModel(const SaveParam &save_param, const char *target_file, ge::ModelBufferData &model, | |||
| bool is_offline = true); | |||
| Status SaveModelToFile(const char *output_file); | |||
| Status SaveModelToFile(const char *output_file, ge::ModelBufferData &model, bool is_offline = true); | |||
| ModelFileHeader model_header_; | |||
| OmFileContext context_; | |||
| @@ -120,4 +120,4 @@ class L2CacheOptimize { | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | |||
| #endif // INC_FRAMEWORK_COMMON_L2_CACHE_OPTIMIZE_H_ | |||
| @@ -649,6 +649,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_M | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_EVENT_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_LABEL_NUM; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_MEMORY_SIZE; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
| @@ -801,6 +803,8 @@ extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_N | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
| // For constant folding | |||
| extern FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
| } // namespace domi | |||
| #endif // INC_FRAMEWORK_COMMON_OP_ATTR_DEFINE_H_ | |||
| @@ -17,11 +17,12 @@ | |||
| #ifndef INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | |||
| #define INC_FRAMEWORK_COMMON_OP_ATTR_VALUE_UTIL_H_ | |||
| #include <google/protobuf/map.h> | |||
| #include <unordered_map> | |||
| #include <string> | |||
| #include <google/protobuf/map.h> | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "common/types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "proto/om.pb.h" | |||
| using domi::AttrDef; | |||
| @@ -18,7 +18,6 @@ | |||
| #define INC_FRAMEWORK_COMMON_OP_GE_OP_UTILS_H_ | |||
| #include <cce/dnn.h> | |||
| #include <memory> | |||
| #include <vector> | |||
| @@ -56,6 +55,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_TR | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_DATA_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t SWITCH_PRED_INPUT; | |||
| // FunctionOp | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t IF_COND_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_START_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_LIMIT_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DELTA_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const uint32_t FOR_DATA_INPUT; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const int NORMAL_TENSOR_SIZE; | |||
| class OpUtils { | |||
| public: | |||
| /// | |||
| @@ -164,15 +172,23 @@ class OpUtils { | |||
| /// | |||
| static Status ConvertAippParams(const GeAttrValue::NamedAttrs &aipp_attr, domi::AippOpParams *aipp_params); | |||
| static Status TransferDim(const std::vector<int64_t> &dim, std::vector<int64_t> &dim_vector); | |||
| static void SliceData(std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, int64_t begin, | |||
| int64_t out_dim, int64_t stride); | |||
| template <typename T> | |||
| static void SliceData(const std::vector<char *> &input, int64_t chunk_size, std::vector<char *> &output, | |||
| int64_t begin, int64_t out_dim, int64_t stride); | |||
| template <typename T> | |||
| static Status SetDataByDataType(size_t out_size, const std::vector<char *> &chunk_input, | |||
| const std::vector<char *> &chunk_output, GeTensor *output); | |||
| template <typename T> | |||
| static Status SetOutputSliceDataByDataType(void *data, int64_t data_size, const std::vector<int64_t> &input_dims, | |||
| const std::vector<int64_t> &begin, const std::vector<int64_t> &output_dims, | |||
| ge::GeTensor *output, const std::vector<int64_t> &stride); | |||
| static Status SetOutputSliceData(void *data, int64_t data_size, int32_t data_type, std::vector<int64_t> &input_dims, | |||
| std::vector<int64_t> &begin, std::vector<int64_t> &output_dims, ge::GeTensor *output, | |||
| std::vector<int64_t> &stride); | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Convert the convolution weight data from [h, w, c, k] to [k, c, h, w] | |||
| /// @brief Convert the convolutional weight data from [h, w, c, k] to [k, c, h, w] | |||
| /// @param [in] input Weight data in HWCK format | |||
| /// @param [in] H value of H dimension | |||
| /// @param [in] W value of W dimension | |||
| @@ -183,7 +199,7 @@ class OpUtils { | |||
| static void TransDataHWCK2KCHW(const void *input, int64_t H, int64_t W, int64_t C, int64_t K, void **output); | |||
| /// | |||
| /// @ingroup domi_omg | |||
| /// @brief Converts the convolution weight data from [k, c, h, w] to [h, w, c, k]. | |||
| /// @brief Converts the convolutional weight data from [k, c, h, w] to [h, w, c, k]. | |||
| /// @param [in] input Weight data in HWCK format | |||
| /// @param [in] K value of K dimension | |||
| /// @param [in] C value of C dimension | |||
| @@ -222,7 +238,6 @@ using CceTensorDescriptorPtr = std::shared_ptr<CceTensorDescriptor>; | |||
| class CceTensorDescriptor { | |||
| public: | |||
| explicit CceTensorDescriptor(ccTensorDescriptor_t cc_tensor); | |||
| CceTensorDescriptor(const CceTensorDescriptor &) = delete; | |||
| CceTensorDescriptor &operator=(const CceTensorDescriptor &) = delete; | |||
| @@ -22,7 +22,7 @@ | |||
| #include <math.h> | |||
| #include <stdint.h> | |||
| namespace domi { | |||
| namespace ge { | |||
| // general | |||
| const float DEFAULT_ALPHA_VALUE = 1.0; | |||
| const float DEFAULT_BETA_VALUE = 0.0; | |||
| @@ -421,5 +421,5 @@ const uint32_t MULTI_SHAPE_INPUT_NUM = 2; | |||
| // Shufflechannel | |||
| const uint32_t SHUFFLECHANNEL_DEFAULT_GROUP = 1; | |||
| } // namespace domi | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_OP_OP_PARSER_UTIL_H_ | |||
| @@ -25,7 +25,7 @@ | |||
| /// MAKE_GUARD([&] { Release Resource 1 }) | |||
| /// Acquire Resource 2 | |||
| // MAKE_GUARD([&] { Release Resource 2 }) | |||
| #define GE_MAKE_GUARD(var, callback) ge::ScopeGuard make_guard_##var(callback) | |||
| #define GE_MAKE_GUARD(var, callback) ScopeGuard make_guard_##var(callback) | |||
| #define GE_DISMISS_GUARD(var) make_guard_##var.Dismiss() | |||
| namespace ge { | |||
| @@ -156,6 +156,7 @@ REGISTER_OPTYPE_DECLARE(GATHER, "Gather"); | |||
| REGISTER_OPTYPE_DECLARE(REALDIV, "RealDiv"); | |||
| REGISTER_OPTYPE_DECLARE(PACK, "Pack"); | |||
| REGISTER_OPTYPE_DECLARE(SLICE, "Slice"); | |||
| REGISTER_OPTYPE_DECLARE(SLICED, "SliceD"); | |||
| REGISTER_OPTYPE_DECLARE(FLOORDIV, "FloorDiv"); | |||
| REGISTER_OPTYPE_DECLARE(SQUEEZE, "Squeeze"); | |||
| REGISTER_OPTYPE_DECLARE(STRIDEDSLICE, "StridedSlice"); | |||
| @@ -188,6 +189,19 @@ REGISTER_OPTYPE_DECLARE(REFNEXTITERATION, "RefNextIteration"); | |||
| REGISTER_OPTYPE_DECLARE(EXIT, "Exit"); | |||
| REGISTER_OPTYPE_DECLARE(REFEXIT, "RefExit"); | |||
| REGISTER_OPTYPE_DECLARE(CONTROLTRIGGER, "ControlTrigger"); | |||
| REGISTER_OPTYPE_DECLARE(SYMBOLICGRADIENT, "SymbolicGradient"); | |||
| REGISTER_OPTYPE_DECLARE(REMOTECALL, "RemoteCall"); | |||
| REGISTER_OPTYPE_DECLARE(_IF, "_If"); | |||
| REGISTER_OPTYPE_DECLARE(STATELESSIF, "StatelessIf"); | |||
| REGISTER_OPTYPE_DECLARE(IF, "If"); | |||
| REGISTER_OPTYPE_DECLARE(CASE, "Case"); | |||
| REGISTER_OPTYPE_DECLARE(_WHILE, "_While"); | |||
| REGISTER_OPTYPE_DECLARE(WHILE, "While"); | |||
| REGISTER_OPTYPE_DECLARE(STATELESSWHILE, "StatelessWhile"); | |||
| REGISTER_OPTYPE_DECLARE(FOR, "For"); | |||
| REGISTER_OPTYPE_DECLARE(PARTITIONEDCALL, "PartitionedCall"); | |||
| REGISTER_OPTYPE_DECLARE(STATEFULPARTITIONEDCALL, "StatefulPartitionedCall"); | |||
| REGISTER_OPTYPE_DECLARE(FAKEPARAM, "FakeParam"); | |||
| REGISTER_OPTYPE_DECLARE(TRANSPOSE, "Transpose"); | |||
| REGISTER_OPTYPE_DECLARE(TRANSPOSED, "TransposeD"); | |||
| REGISTER_OPTYPE_DECLARE(CAST, "Cast"); | |||
| @@ -424,6 +438,12 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||
| REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | |||
| REGISTER_OPTYPE_DECLARE(SEND, "Send"); | |||
| REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||
| REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | |||
| REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | |||
| REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | |||
| REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | |||
| REGISTER_OPTYPE_DECLARE(ATOMICADDRCLEAN, "AtomicAddrClean"); | |||
| REGISTER_OPTYPE_DECLARE(ABS_GRAD, "AbsGrad"); | |||
| @@ -1032,14 +1052,11 @@ struct BasicInfo { | |||
| uint32_t workspace_size; // workspace | |||
| uint32_t total_size; // total memory size | |||
| }; | |||
| #pragma pack() // Cancels single-byte alignment | |||
| } // namespace ge | |||
| namespace domi { | |||
| /// @brief Data structure definition related to task sinking | |||
| /// Build model | |||
| enum BuildMode { | |||
| GEN_TASK_WITHOUT_L2FUSION = 3, // Carrying task data (L2 convergence function disabled) | |||
| GEN_TASK_WITHOUT_FUSION = 4, // Carrying task data (all convergence functions disabled) | |||
| @@ -30,6 +30,14 @@ | |||
| #include "framework/common/ge_inner_error_codes.h" | |||
| #include "mmpa/mmpa_api.h" | |||
| #define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
| do { \ | |||
| if (size <= 0) { \ | |||
| DOMI_LOGE(param[#size] is not a positive number); \ | |||
| return PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| #define CHECK_FALSE_EXEC(expr, exec_expr, ...) \ | |||
| { \ | |||
| bool b = (expr); \ | |||
| @@ -50,21 +58,6 @@ | |||
| if (var) GE_CHK_RT(rtStreamDestroy(var)); \ | |||
| }); | |||
| #define GE_MAKE_GUARD_RTEVENT(var) \ | |||
| GE_MAKE_GUARD(var, [&] { \ | |||
| if (var) GE_CHK_RT(rtEventDestroy(var)); \ | |||
| }); | |||
| #define GE_MAKE_GUARD_TENSOR(var) \ | |||
| GE_MAKE_GUARD(var, [&] { \ | |||
| if (var) GE_CHK_CCE(ccDestroyTensorDescriptor(&var)); \ | |||
| }); | |||
| #define GE_MAKE_GUARD_FILTER_DESC(var) \ | |||
| GE_MAKE_GUARD(var, [&] { \ | |||
| if (var) GE_CHK_CCE(ccDestroyFilterDescriptor(&var)); \ | |||
| }); | |||
| // For propagating errors when calling a function. | |||
| #define GE_RETURN_IF_ERROR(expr) \ | |||
| do { \ | |||
| @@ -76,7 +69,7 @@ | |||
| do { \ | |||
| const ::ge::Status _status = (expr); \ | |||
| if (_status) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0) | |||
| @@ -85,7 +78,7 @@ | |||
| #define GE_RETURN_WITH_LOG_IF_TRUE(condition, ...) \ | |||
| do { \ | |||
| if (condition) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } while (0) | |||
| @@ -95,7 +88,7 @@ | |||
| do { \ | |||
| bool _condition = (condition); \ | |||
| if (!_condition) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } while (0) | |||
| @@ -104,7 +97,7 @@ | |||
| #define GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(condition, ...) \ | |||
| do { \ | |||
| if (condition) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| @@ -114,111 +107,90 @@ | |||
| do { \ | |||
| bool _condition = (condition); \ | |||
| if (!_condition) { \ | |||
| GE_LOGE(__VA_ARGS__); \ | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||
| #define GE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| GE_LOGE(param[#val] must not be null.); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the parameter is null. If yes, return PARAM_INVALID and record the error | |||
| #define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| GE_LOGE(param[#val] must not be null.); \ | |||
| return; \ | |||
| } \ | |||
| // Check if the parameter is null. If yes, just return and record the error | |||
| #define GE_CHECK_NOTNULL_JUST_RETURN(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return; \ | |||
| } \ | |||
| } while (0) | |||
| // Check whether the parameter is null. If so, execute the exec_expr expression and record the error log | |||
| #define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| GE_LOGE(param[#val] must not be null.); \ | |||
| exec_expr; \ | |||
| } \ | |||
| #define GE_CHECK_NOTNULL_EXEC(val, exec_expr) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } while (0) | |||
| // Check whether the parameter is null. If yes, return directly and record the error log | |||
| #define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| GE_LOGE(param[#val] must not be null.); \ | |||
| return; \ | |||
| } \ | |||
| #define GE_RT_VOID_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the parameter is null. If yes, return false and record the error log | |||
| #define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| GE_LOGE(param[#val] must not be null.); \ | |||
| return false; \ | |||
| } \ | |||
| #define GE_RT_FALSE_CHECK_NOTNULL(val) \ | |||
| do { \ | |||
| if (val == nullptr) { \ | |||
| DOMI_LOGE(param[#val] must not be null.); \ | |||
| return false; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the parameter is out of bounds | |||
| #define GE_CHECK_SIZE(size) \ | |||
| do { \ | |||
| if (size == 0) { \ | |||
| GE_LOGE(param[#size] is out of range); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_SIZE(size) \ | |||
| do { \ | |||
| if (size == 0) { \ | |||
| DOMI_LOGE(param[#size] is out of range); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| // Macros that define the size variable | |||
| #define GE_DEFINE_BYTE_SIZE(_var_name, _expr, _sizeof) \ | |||
| uint32_t _var_name; \ | |||
| do { \ | |||
| uint32_t _expr_size = (_expr); \ | |||
| uint32_t _sizeof_size = (_sizeof); \ | |||
| if (_expr_size > (0xffffffff) / _sizeof_size) { \ | |||
| GE_LOGE(byte size : #_var_name is out of range); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| _var_name = _sizeof_size * _expr_size; \ | |||
| } while (0); | |||
| // Check if the container is empty | |||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
| do { \ | |||
| if (vector.empty()) { \ | |||
| GE_LOGE(param[#vector] is empty !); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } while (0) | |||
| #define GE_CHECK_POSITIVE_SIZE_RANGE(size) \ | |||
| do { \ | |||
| if (size <= 0) { \ | |||
| GE_LOGE(param[#size] is not a positive number); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_VECTOR_NOT_EMPTY(vector) \ | |||
| do { \ | |||
| if (vector.empty()) { \ | |||
| DOMI_LOGE(param[#vector] is empty !); \ | |||
| return ge::FAILED; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the value on the left is greater than or equal to the value on the right | |||
| #define GE_CHECK_GE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs < rhs) { \ | |||
| GE_LOGE(param[#lhs] is less than[#rhs]); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_GE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs < rhs) { \ | |||
| DOMI_LOGE(param[#lhs] is less than[#rhs]); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| // Check if the value on the left is less than or equal to the value on the right | |||
| #define GE_CHECK_LE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs > rhs) { \ | |||
| GE_LOGE(param[#lhs] is greater than[#rhs]); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| #define GE_CHECK_LE(lhs, rhs) \ | |||
| do { \ | |||
| if (lhs > rhs) { \ | |||
| DOMI_LOGE(param[#lhs] is greater than[#rhs]); \ | |||
| return ge::PARAM_INVALID; \ | |||
| } \ | |||
| } while (0) | |||
| #define GE_DELETE_NEW_SINGLE(var) \ | |||
| @@ -52,10 +52,10 @@ | |||
| #define DLOG_DECLARE(level) \ | |||
| void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) | |||
| namespace ge { | |||
| namespace domi { | |||
| DLOG_DECLARE(INFO); | |||
| DLOG_DECLARE(WARNING); | |||
| DLOG_DECLARE(ERROR); | |||
| } // namespace ge | |||
| } // namespace domi | |||
| #endif // INC_FRAMEWORK_DLOG_LOG_H_ | |||
| @@ -38,7 +38,7 @@ struct DNNEngineAttribute { | |||
| std::vector<std::string> mem_type; | |||
| uint32_t compute_cost; | |||
| enum RuntimeType runtime_type; // HOST, DEVICE | |||
| // set this attribute if the inputformat of engine must be specific, otherwise set FORMAT_RESERVED | |||
| // If engine input format must be specific, set this attribute, else set FORMAT_RESERVED | |||
| Format engine_input_format; | |||
| Format engine_output_format; | |||
| }; | |||
| @@ -26,6 +26,7 @@ | |||
| #include "common/types.h" | |||
| #include "graph/tensor.h" | |||
| #include "runtime/base.h" | |||
| #include "common/dynamic_aipp.h" | |||
| namespace ge { | |||
| class ModelListenerAdapter; | |||
| @@ -33,12 +34,15 @@ class ModelListenerAdapter; | |||
| class SingleOp; | |||
| struct RunModelData { | |||
| uint32_t index; // Data index | |||
| uint32_t model_id; // Model id | |||
| std::vector<DataBuffer> blobs; // All input/output data buffer | |||
| uint32_t timestamp; // Data creation time | |||
| uint32_t timeout; // Processing timeout | |||
| uint64_t request_id = 0; // Request ID | |||
| uint32_t index; // Data index | |||
| uint32_t modelId; | |||
| std::vector<DataBuffer> blobs; // All input/output data buffer | |||
| uint32_t timestamp; // Data creation time | |||
| uint32_t timeout; // Processing timeout | |||
| uint64_t request_id = 0; // Request ID | |||
| uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 | |||
| uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 | |||
| uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 | |||
| }; | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| @@ -46,12 +50,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| GeExecutor(); | |||
| ~GeExecutor() = default; | |||
| ge::Status Initialize(); | |||
| ge::Status Finalize(); | |||
| // Load model | |||
| ge::Status LoadModelOffline(uint32_t &model_id, const std::string &path, const std::string &key, int32_t priority, | |||
| std::shared_ptr<ge::ModelListener> listener); | |||
| ge::Status UnloadModel(uint32_t model_id); | |||
| ge::Status UnloadModel(uint32_t modelId); | |||
| ge::Status RunModel(const ge::RunModelData &input_data, ge::RunModelData &output_data); | |||
| @@ -59,6 +64,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | |||
| std::vector<ge::TensorDesc> &output_desc); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Set dynamic batch size | |||
| /// @param [in] model_id: model id allocate from manager | |||
| /// @param [in] dynamic_input_addr: dynamic input addr created by user | |||
| /// @param [in] length: length of dynamic input addr | |||
| /// @param [in] batch_size: batch size entered by user in dynamic multi-batch scenario | |||
| /// @return execute result | |||
| /// | |||
| ge::Status SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Set dynamic image info | |||
| /// @param [in] model_id: model id allocate from manager | |||
| /// @param [in] dynamic_input_addr: dynamic input addr created by user | |||
| /// @param [in] length: length of dynamic input addr | |||
| /// @param [in] image_height: image height entered by user in dynamic multi-resolution scenario | |||
| /// @param [in] image_width: image width entered by user in dynamic multi-resolution scenario | |||
| /// @return execute result | |||
| /// | |||
| ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, | |||
| uint64_t image_width); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Get dynamic batch_info | |||
| /// @param [in] model_id | |||
| /// @param [out] batch_info | |||
| /// @return execute result | |||
| /// | |||
| ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Set dynamic image info | |||
| /// @param [in] model_id: model id allocate from manager | |||
| /// @param [in] dynamic_input_addr: dynamic input addr created by user | |||
| /// @param [in] length: length of dynamic input addr | |||
| /// @param [in] aippBatchPara: kAippDynamicBatchPara vector by user in dynamic aipp | |||
| /// @param [in] aippParms: kAippDynamicPara by user in dynamic aipp | |||
| /// @return execute result | |||
| /// | |||
| ge::Status SetDynamicAippData(uint32_t model_id, void *dynamic_input_addr, uint64_t length, | |||
| const std::vector<kAippDynamicBatchPara> &aippBatchPara, | |||
| const kAippDynamicPara &aippParms); | |||
| ge::Status GetModelDescInfoForZeroCopy(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | |||
| std::vector<ge::TensorDesc> &output_desc); | |||
| @@ -147,7 +198,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| /// | |||
| ge::Status GetMemAndWeightSize(const void *model_data, size_t model_size, size_t &mem_size, size_t &weight_size); | |||
| static ge::Status LoadSingleOp(const std::string &model_name, const ge::ModelData &model_data, void *stream, | |||
| static ge::Status LoadSingleOp(const std::string &modelName, const ge::ModelData &modelData, void *stream, | |||
| SingleOp **single_op); | |||
| static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs, | |||
| @@ -156,8 +207,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| static ge::Status ReleaseSingleOpResource(void *stream); | |||
| private: | |||
| static bool is_init_; | |||
| std::vector<std::shared_ptr<ModelListenerAdapter>> listener_adapters_; | |||
| static bool isInit_; | |||
| }; | |||
| ge::Status ModelInfoParser(const ge::ModelData &model, ge::ModelInfo &model_info); | |||
| @@ -21,7 +21,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ge/ge_ir_build.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/graph.h" | |||
| @@ -45,6 +45,8 @@ class GeGenerator { | |||
| Status GenerateOfflineModel(const Graph &graph, const std::string &file_name_prefix, | |||
| const std::vector<GeTensor> &inputs = std::vector<GeTensor>()); | |||
| Status GenerateOnlineModel(const Graph &graph, const vector<GeTensor> &inputs, ge::ModelBufferData &model); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief: Build single OP in Model. | |||
| @@ -58,6 +60,8 @@ class GeGenerator { | |||
| const std::vector<GeTensor> &outputs, const std::string &model_file_name); | |||
| private: | |||
| Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | |||
| ge::ModelBufferData &model, bool is_offline = true); | |||
| class Impl; | |||
| std::shared_ptr<Impl> impl_; | |||
| @@ -24,7 +24,6 @@ extern "C" { | |||
| #endif | |||
| typedef uint32_t Status_t; | |||
| using Status_t = uint32_t; | |||
| typedef void *OpAttr_t; | |||
| typedef void *OpTensor_t; | |||
| @@ -23,7 +23,7 @@ | |||
| #include "graph/node.h" | |||
| namespace ge { | |||
| const int64_t kMemAlignSize = 512; | |||
| const int64_t MEM_ALIGN_SIZE = 512; | |||
| class MemoryAssigner { | |||
| public: | |||
| explicit MemoryAssigner(ge::ComputeGraphPtr compute_graph) : compute_graph_(std::move(compute_graph)) {} | |||
| @@ -39,4 +39,4 @@ class MemoryAssigner { | |||
| ge::ComputeGraphPtr compute_graph_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ | |||
| #endif // INC_FRAMEWORK_MEMORY_MEMORY_ASSIGNER_H_ | |||
| @@ -31,7 +31,6 @@ | |||
| using domi::DOMI_TENSOR_ND; | |||
| using domi::DOMI_TENSOR_RESERVED; | |||
| using domi::domiTensorFormat_t; | |||
| using domi::FMK_TYPE_RESERVED; | |||
| using domi::FrameworkType; | |||
| using std::map; | |||
| using std::string; | |||
| @@ -44,10 +43,10 @@ namespace ge { | |||
| * @brief run model | |||
| */ | |||
| enum RunMode { | |||
| kGeOmModel = 0, // generate offline model file | |||
| kModelToJson = 1, // convert to JSON file | |||
| kOnlyPreCheck = 3, // only for pre-check | |||
| kPbtxtToJson = 5 // pbtxt to json | |||
| GEN_OM_MODEL = 0, // generate offline model file | |||
| MODEL_TO_JSON = 1, // convert to JSON file | |||
| ONLY_PRE_CHECK = 3, // only for pre-check | |||
| PBTXT_TO_JSON = 5 // pbtxt to json | |||
| }; | |||
| /// | |||
| @@ -56,10 +55,10 @@ enum RunMode { | |||
| /// | |||
| enum HighPrecisionMode { | |||
| // the FP16 high-precision function is disabled in common mode | |||
| kHighPrecisonDefault = 0, | |||
| HIGH_PRECISION_DEFAULT = 0, | |||
| // high-precision mode, in which FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) is enable | |||
| kHighPrecisionFP16 = 1 | |||
| // high-precision mode, enabling FP16 high-precision mode (Convolution/FullConnect/AvgPooling are involved) | |||
| HIGH_PRECISION_FP16 = 1 | |||
| }; | |||
| /// | |||
| @@ -99,21 +98,23 @@ struct OmgContext { | |||
| // preferential format used by the entire network | |||
| domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | |||
| domi::FrameworkType type = domi::FMK_TYPE_RESERVED; | |||
| RunMode run_mode = kOnlyPreCheck; | |||
| RunMode run_mode = ONLY_PRE_CHECK; | |||
| bool train_flag = false; | |||
| // whether to use FP16 high precision | |||
| int32_t fp16_high_precision = kHighPrecisonDefault; | |||
| int32_t fp16_high_precision = HIGH_PRECISION_DEFAULT; | |||
| std::string output_type; | |||
| // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | |||
| // network require special processing based on the specific network. | |||
| // e.g:faster-rcnn, the FirstStageProcessor module is determined as the Faster-R-CNN network based on the scope | |||
| // fusion. Then, the conv+reshape operators in the FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The | |||
| // convolution kernel rearrangement reshape operator needs to be deleted for the convolution kernel. | |||
| // network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||
| // is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||
| // FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||
| // operator needs to be deleted for the convolution kernel. | |||
| std::string net_name; | |||
| // whether to enable dynamic batch | |||
| bool enable_l2dynamic = false; | |||
| // Whether to use dynamic batch size or dynamic image size | |||
| bool is_dynamic_input = false; | |||
| std::string dynamic_batch_size; | |||
| std::string dynamic_image_size; | |||
| }; | |||
| } // namespace ge | |||
| @@ -32,15 +32,7 @@ class PlatformVersionManager { | |||
| PlatformVersionManager() = delete; | |||
| ~PlatformVersionManager() = delete; | |||
| static Status GetPlatformVersion(std::string &ver) { | |||
| #if defined PLATFORM_PHOENIX | |||
| ver = "3.51.z"; | |||
| #elif defined PLATFORM_ORLANDO | |||
| ver = "3.31.z"; | |||
| #elif defined PLATFORM_MINI | |||
| ver = "1.11.z"; | |||
| #elif defined PLATFORM_CLOUD | |||
| ver = "1.61.z"; | |||
| #endif | |||
| std::vector<std::string> version_splits = StringUtils::Split(ver, '.'); | |||
| GE_IF_BOOL_EXEC(version_splits.size() < 3, GELOGW("Read platform version error!"); return FAILED;); | |||
| @@ -20,13 +20,17 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/range_vistor.h" | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| enum AnchorStatus { ANCHOR_SUSPEND = 0, ANCHOR_CONST = 1, ANCHOR_DATA = 2, ANCHOR_RESERVED = 3 }; | |||
| enum AnchorStatus { | |||
| ANCHOR_SUSPEND = 0, // dat null | |||
| ANCHOR_CONST = 1, | |||
| ANCHOR_DATA = 2, // Effective | |||
| ANCHOR_RESERVED = 3 | |||
| }; | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -81,17 +85,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||
| virtual ~Anchor() = default; | |||
| protected: | |||
| // Whether the two anchors are equal | |||
| // Whether the two anchor is equal | |||
| virtual bool Equal(AnchorPtr anchor) const = 0; | |||
| virtual bool IsTypeOf(TYPE type) const; | |||
| public: | |||
| // Get all peer anchors connected to current anchor | |||
| Vistor<AnchorPtr> GetPeerAnchors() const; | |||
| // Get the first peer anchor | |||
| // Get peer anchor size | |||
| size_t GetPeerAnchorsSize() const; | |||
| // Get first peer anchor | |||
| AnchorPtr GetFirstPeerAnchor() const; | |||
| // Get the node which is the owner of the anchor | |||
| // Get the anchor belong to which node | |||
| NodePtr GetOwnerNode() const; | |||
| // Remove all links with the anchor | |||
| @@ -100,22 +106,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Anchor : public std::enable | |||
| // Remove link with the given anchor | |||
| graphStatus Unlink(const AnchorPtr &peer); | |||
| // Replace the peeranchor with the new peeranchor | |||
| // Replace peer with new peers | |||
| graphStatus ReplacePeer(const AnchorPtr &oldPeer, const AnchorPtr &firstPeer, const AnchorPtr &secondPeer); | |||
| // Judge if the anchor is linked with the given anchor | |||
| bool IsLinkedWith(const AnchorPtr &peer); | |||
| // Get the anchor index of the node | |||
| // Get anchor index of the node | |||
| int GetIdx() const; | |||
| // Set the anchor index of the node | |||
| // set anchor index of the node | |||
| void SetIdx(int index); | |||
| protected: | |||
| // All peer anchors connected to current anchor | |||
| vector<std::weak_ptr<Anchor>> peer_anchors_; | |||
| // The owner nodes of the anchor | |||
| // The owner node of anchor | |||
| std::weak_ptr<Node> owner_node_; | |||
| // The index of current anchor | |||
| int idx_; | |||
| @@ -167,7 +173,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchor : public DataA | |||
| virtual ~InDataAnchor() = default; | |||
| // Get source out data anchor | |||
| // Get source out data anchor | |||
| OutDataAnchorPtr GetPeerOutAnchor() const; | |||
| // Build connection from OutDataAnchor to InDataAnchor | |||
| @@ -19,10 +19,10 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/ge_attr_value.h" | |||
| namespace ge { | |||
| class GeAttrValue; | |||
| class _GeSerializable { | |||
| public: | |||
| @@ -107,7 +107,6 @@ class _GeSerializable { | |||
| static graphStatus LoadItem(GeAttrValue::NamedAttrs &namedAttrs __attribute__((__unused__))) { return GRAPH_SUCCESS; } | |||
| }; | |||
| #define _GE_FI(a) #a, a | |||
| #define _GE_MAP_FIELDS1(a1) _GE_FI(a1) | |||
| #define _GE_MAP_FIELDS2(a1, a2) _GE_FI(a1), _GE_FI(a2) | |||
| @@ -130,23 +129,23 @@ class _GeSerializable { | |||
| #define _GE_MAP_FIELDS11(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11) \ | |||
| _GE_FI(a1) \ | |||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
| _GE_FI(a11) | |||
| _GE_FI(a11) | |||
| #define _GE_MAP_FIELDS12(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12) \ | |||
| _GE_FI(a1) \ | |||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
| _GE_FI(a11), _GE_FI(a12) | |||
| _GE_FI(a11), _GE_FI(a12) | |||
| #define _GE_MAP_FIELDS13(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13) \ | |||
| _GE_FI(a1) \ | |||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13) | |||
| #define _GE_MAP_FIELDS14(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14) \ | |||
| _GE_FI(a1) \ | |||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14) | |||
| #define _GE_MAP_FIELDS15(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15) \ | |||
| _GE_FI(a1) \ | |||
| , _GE_FI(a2), _GE_FI(a3), _GE_FI(a4), _GE_FI(a5), _GE_FI(a6), _GE_FI(a7), _GE_FI(a8), _GE_FI(a9), _GE_FI(a10), \ | |||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||
| _GE_FI(a11), _GE_FI(a12), _GE_FI(a13), _GE_FI(a14), _GE_FI(a15) | |||
| #define _GE_PRIVATE_ARGS_GLUE(x, y) x y | |||
| @@ -17,12 +17,11 @@ | |||
| #ifndef INC_GRAPH_BUFFER_H_ | |||
| #define INC_GRAPH_BUFFER_H_ | |||
| #include <graph/types.h> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "detail/attributes_holder.h" | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| #ifdef HOST_VISIBILITY | |||
| @@ -72,7 +71,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { | |||
| GeIrProtoHelper<proto::AttrDef> data_; | |||
| std::string *buffer_ = nullptr; | |||
| // Create buffer from protobuf obj | |||
| // Create from protobuf obj | |||
| Buffer(const ProtoMsgOwner &protoOnwer, proto::AttrDef *buffer); | |||
| Buffer(const ProtoMsgOwner &protoOnwer, std::string *buffer); | |||
| @@ -17,7 +17,6 @@ | |||
| #ifndef INC_GRAPH_COMPUTE_GRAPH_H_ | |||
| #define INC_GRAPH_COMPUTE_GRAPH_H_ | |||
| #include <deque> | |||
| #include <map> | |||
| #include <memory> | |||
| #include <string> | |||
| @@ -63,7 +62,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| using Vistor = RangeVistor<T, std::shared_ptr<ConstComputeGraph>>; | |||
| explicit ComputeGraph(const std::string &name); | |||
| virtual ~ComputeGraph(); | |||
| ~ComputeGraph() override; | |||
| std::string GetName() const; | |||
| void SetName(const std::string &name); | |||
| @@ -81,7 +80,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| Vistor<NodePtr> GetOutputNodes() const; | |||
| NodePtr FindNode(const std::string &name) const; | |||
| // Add node | |||
| // AddNode with NodePtr | |||
| NodePtr AddNode(NodePtr node); | |||
| NodePtr AddNode(OpDescPtr op); | |||
| NodePtr AddNodeFront(NodePtr node); | |||
| @@ -94,9 +93,40 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| graphStatus RemoveOutputNode(const NodePtr &node); | |||
| graphStatus RemoveConstInput(const NodePtr &node); | |||
| /// Add a subgraph to this graph. The subgraph must has a parent graph and parent node, | |||
| /// which means the member functions `SetParentGraph` and `SetParentNode` of the subgraph | |||
| /// must be called before add it to the root graph. and subgraph->GetParentNode()->GetOwnerGraph() | |||
| /// must equal to subgraph->GetOwnerGraph(). | |||
| /// The subgraphs can only be added to a *root graph*. A root graph is a graph without any parent graph. | |||
| /// The subgraph's name SHOULD(not must) be the same as the parameter `name` | |||
| graphStatus AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph); | |||
| graphStatus AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||
| void RemoveSubgraph(const std::string &name); | |||
| void RemoveSubgraph(const std::shared_ptr<ComputeGraph> &subgraph); | |||
| std::shared_ptr<ComputeGraph> GetSubgraph(const std::string &name) const; | |||
| std::vector<std::shared_ptr<ComputeGraph>> GetAllSubgraphs() const; | |||
| // obsolete | |||
| std::shared_ptr<ComputeGraph> AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph); | |||
| // obsolete | |||
| graphStatus RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph); | |||
| /// | |||
| /// @brief Update input-mapping | |||
| /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||
| /// | |||
| /// @brief Update output-mapping | |||
| /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||
| graphStatus TopologicalSorting(); | |||
| bool IsValid() const; | |||
| void Dump() const; | |||
| @@ -127,6 +157,11 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| } | |||
| } | |||
| shared_ptr<ComputeGraph> GetParentGraph(); | |||
| void SetParentGraph(const shared_ptr<ComputeGraph> &parent); | |||
| shared_ptr<Node> GetParentNode(); | |||
| void SetParentNode(const shared_ptr<Node> &parent); | |||
| const std::map<std::string, std::vector<int32_t>> &GetGraphOutNodes() const { return out_nodes_map_; } | |||
| void SetOrigGraph(ComputeGraphPtr orig_graph) { origGraph_ = orig_graph; } | |||
| @@ -138,8 +173,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| uint32_t GetInputSize() const { return input_size_; } | |||
| /// | |||
| /// Set iteration needed. | |||
| /// If set is true, it means this graph need run iteration some | |||
| /// Set is need train iteration. | |||
| /// If set true, it means this graph need to be run iteration some | |||
| /// times(according variant "npu_runconfig/iterations_per_loop"). | |||
| /// @param need_iteration is need iteration | |||
| /// | |||
| @@ -150,7 +185,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| const std::string GetOutput(); | |||
| /// | |||
| /// Get need_iteration. | |||
| /// Get is need train iteration. | |||
| /// @return is need iteration | |||
| /// | |||
| bool GetNeedIteration() const { return need_iteration_; } | |||
| @@ -201,6 +236,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| std::deque<NodePtr> &stack); | |||
| graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | |||
| std::map<string, NodePtr> &breadth_node_map); | |||
| graphStatus TopologicalSortingSubgraph(); | |||
| graphStatus SortNodes(std::vector<NodePtr> &stack, std::map<NodePtr, uint32_t> &mapInEdgeNum); | |||
| size_t GetInEdgeSize(const NodePtr &node); | |||
| size_t GetOutEdgeSize(const NodePtr &node); | |||
| @@ -210,31 +246,38 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | |||
| const std::vector<NodePtr> &l_node_ptr_vector) const; | |||
| ProtoAttrMapHelper attrs_; | |||
| friend class ModelSerializeImp; | |||
| friend class GraphDebugImp; | |||
| friend class OnnxUtils; | |||
| std::string name_; | |||
| uint32_t graph_id_ = 0; | |||
| ProtoAttrMapHelper attrs_; | |||
| std::vector<NodePtr> nodes_; | |||
| std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||
| std::vector<NodePtr> target_nodes_info_; | |||
| std::vector<NodePtr> input_nodes_; | |||
| std::vector<std::string> inputs_order_; | |||
| uint32_t input_size_ = 1; | |||
| std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||
| uint32_t output_size_ = 1; | |||
| std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||
| std::vector<std::shared_ptr<ComputeGraph>> sub_graph_; | |||
| std::string name_; | |||
| std::map<std::string, std::shared_ptr<ComputeGraph>> names_to_subgraph_; | |||
| std::weak_ptr<ComputeGraph> parent_graph_; | |||
| std::weak_ptr<Node> parent_node_; | |||
| // the members followed should not in the ComputeGraph class | |||
| bool is_valid_flag_; | |||
| bool is_summary_graph_ = false; | |||
| // Indicates whether it is need iteration | |||
| bool need_iteration_ = false; | |||
| std::map<std::vector<std::string>, std::vector<std::string>> params_share_map_; | |||
| std::map<std::string, std::vector<int32_t>> out_nodes_map_; | |||
| // TaskIdx -> op_name Map | |||
| std::map<uint32_t, std::string> op_name_map_; | |||
| std::vector<std::string> inputs_order_; | |||
| uint32_t output_size_ = 1; | |||
| uint32_t input_size_ = 1; | |||
| std::map<OperatorImplPtr, NodePtr> all_nodes_infos_; | |||
| std::vector<std::pair<NodePtr, int32_t>> output_nodes_info_; | |||
| std::vector<NodePtr> target_nodes_info_; | |||
| uint64_t session_id_ = 0; | |||
| uint32_t graph_id_ = 0; | |||
| ge::Format data_format_ = ge::FORMAT_ND; | |||
| }; | |||
| } // namespace ge | |||
| @@ -18,7 +18,6 @@ | |||
| #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | |||
| #include <string> | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| @@ -59,6 +58,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BIAS_TERM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HAS_BIAS_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PAD; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PADS; | |||
| @@ -75,8 +76,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_CEIL_MODE; | |||
| // GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string | |||
| // ATTR_NAME_WEIGHTS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STRIDE_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RELUMODE; | |||
| @@ -124,6 +124,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NAN_OPT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_FORMAT; | |||
| @@ -141,10 +148,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_PERMUTE_DELETED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IGNORE_PRED_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DIM_ALIGN; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_START; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WEIGHTS_DATA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_BATCH_NUM; | |||
| @@ -166,15 +178,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_LOC_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_CONF_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_PRIORBOX_CONCAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEED_INFER; | |||
| // _Arg | |||
| @@ -263,7 +275,29 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNOR | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_BIAS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_DATA_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION; | |||
| // Huberloss | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HUBER_LOSS_ATTR_DELTA; | |||
| // SSDRealDivTileMul | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA; | |||
| // SSDSumMulRealDivMean | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM; | |||
| /// ConcatFive2Four | |||
| /// ConcatFour2Five | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_CLASS_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TRANS_FOR_LOSS_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOX_TYPE_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_HIGH; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_FEATURE_MAP_WIDTH; | |||
| // Scale | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_ATTR_BIAS; | |||
| @@ -300,7 +334,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_AT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string CONST_ATTR_NAME_OUTPUT_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
| // Roipooling | |||
| @@ -313,6 +346,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIPOOLI | |||
| // DetectionOutput | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NUM_CLASSES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_OCR_NUM_CLASSES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_NMS_THRESHOLD; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_TOP_K; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DETECTIONOUTPUT_ATTR_CONFIDENCE_THRESHOLD; | |||
| @@ -371,6 +405,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SOFTMAX_ | |||
| // Permute | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_ORDER; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_ATTR_PERM; | |||
| // SSD Normalize | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL; | |||
| @@ -411,9 +446,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_AT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POWER_ATTR_NAME_SHIFT; | |||
| // Log | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SCALE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_SHIFT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_ATTR_NAME_BASE; | |||
| // Pack | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PACK_ATTR_NAME_NUM; | |||
| // Dynamic stitch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | |||
| // Unpack | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UNPACK_ATTR_NAME_NUM; | |||
| // Gathernd | |||
| @@ -422,8 +463,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERND | |||
| // Argmax | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_TOPK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_OUTMAX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_AXISTYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ARGMAX_ATTR_NAME_KEEPDIMS; | |||
| // Upsample | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_H; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE_W; | |||
| // Relu | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEGATIVE_SLOPE; | |||
| @@ -439,6 +488,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPLIT_AT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_MAGIC; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_BLOCKDIM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_METADATA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TVM_ATTR_NAME_WORKSPACE_TYPE; | |||
| // Squeeze | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SQUEEZE_ATTR_AXIS; | |||
| @@ -461,6 +511,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_SAMPLING_RATIO; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_H; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_POOLED_W; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ROIALIGN_ATTR_NAME_TF; | |||
| // Generate_rpn_proposal | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK; | |||
| @@ -493,6 +544,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REORG_AT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_DEAD_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MERGE_PRENODE_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string TO_BE_OUTPUT; | |||
| static const std::string NOT_NET_OUTPUT = "not_net_output"; | |||
| // ENTER | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ENTER_ATTR_FRAME_NAME; | |||
| @@ -518,6 +570,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_B | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_ALPHA; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESIZE_BILINEAR_ATTR_BETA; | |||
| // RetinaNet | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_FILTER_BACKGROUND_TRUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RETINANET_ANCHOR_FUSION; | |||
| // MatMul | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_X; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MATMUL_TRANSPOSE_W; | |||
| @@ -566,10 +621,30 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GRU_CELL | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_HT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_XT_HT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RNN_BATCH_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_CELL_CLIP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_PROJ_CLIP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_ACTIVATE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MAP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_OUT_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_STATE_OUT_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_TIME_MAJOR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSTM_IS_INPUT_PRE_PROCESS; | |||
| // Upsample | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string UPSAMPLE_ATTR_NAME_SCALE; | |||
| // PadV2 | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PADS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_T; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_PAD_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PADV2_ATTR_NAME_CONST_VALUE; | |||
| // MirrorPad | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PADS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE; | |||
| // Filler | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FILLER_VALUE; | |||
| @@ -590,36 +665,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PAD_LEFT | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_ALGO_ATTR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SCALE_TYPE_ATTR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_CONST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GROUP; | |||
| @@ -634,6 +679,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_EVENT_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_LABEL_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_MEMORY_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | |||
| @@ -642,12 +689,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_WEIGHT_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
| // Public attribute | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IMPLY_TYPE; | |||
| @@ -685,11 +726,178 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REFERENCE; | |||
| // Used for operators that do not generate task | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOTASK; | |||
| // Used for operators that output reuse input | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_REUSE_INPUT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ATOMIC_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_VAR_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_LABEL; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_VAR_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_INDEX_OP_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEQUANTIZE_OFFSET_PAD_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_SCALE_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_DATA_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_WEIGHT_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REQUANTIZE_OFFSET_PAD_OFFSET; | |||
| // L2_normalize | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_AXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string L2_NORMALIZE_ATTR_EPS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_WINDOW; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_CEIL_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_DATA_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_NAN_OP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string POOL_PARAMA_ATTR_PAD_MOD; | |||
| // HCOM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_ROOT_RANK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_RANK_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_REDUCTION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_GROUP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SR_TAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SRC_RANK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DEST_RANK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_SHAPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HCOM_ATTR_DATA_TYPE; | |||
| // Log time stamp | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_LOGID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LOG_TIME_STAMP_NOTIFY; | |||
| // SpaceToDepth/DepthToSpace | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BLOCK_SIZE; | |||
| // SparseSoftmaxCrossEntropyWithLogits | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SPARSE_SOFT_MAX_ATTR_TLABLES; | |||
| // MaxPoolGradWithArgmax | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string MAX_POOL_GRAD_OUTPUT_SHAPE; | |||
| // AvgPoolGrad | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string AVG_POOL_GRAD_OUTPUT_SHAPE; | |||
| // Varible | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_FRACTALZ_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_4D_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_5D_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DATA_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHAPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string HALF_VAR_NAME_END; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_CONTAINER; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SHARED_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_DTYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_ADDR_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IN_INDEX_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_OUT_INDEX_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_SAVE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_RESTORE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_VAR_IS_BROADCAST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | |||
| // Assign | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ASSIGN_VALIDATE_SHAPE; | |||
| // ShapeN | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_N; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_IN_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SHAPEN_ATTR_OUT_TYPE; | |||
| // Space2bacth batch2space | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_BLOCK; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string BATCH_SPACE_ATTR_PADDING; | |||
| // Depth_to_space space_to_depth | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE; | |||
| // FakeQuantWithMinMaxVars | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MAX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string FakeQuantWithMinMaxVars_ATTR_MIN; | |||
| // Mobilenet_ssd_conv_fusion | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_BOXES_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_SCORES_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM; | |||
| // Lsh project | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string LSH_PROJ_TYPE; | |||
| // Control flow | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ITERATORS_PER_LOOP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG; | |||
| // GatherV2 attr def | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TAXIS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TINDICES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string GATHERV2_ATTR_NAME_TPARAMS; | |||
| // Reshape attr def | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_INPUT_DESC; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC; | |||
| // Axis attr def | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AXIS_ORG_OP; | |||
| // The node link with SparseSoftmaxCrossEntropyWithLogits | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LINK_WITH_SPARE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NET_OUTPUT_DATATYPE; | |||
| // For constant folding | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_NEED_CONSTANT_FOLDING; | |||
| // Used for mark the active label list to find stream of activated node | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_LABEL_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE; | |||
| // Multi batch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; | |||
| @@ -697,7 +905,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| // Control flow | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TRUE_BRANCH_STREAM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | |||
| @@ -709,6 +916,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEXT_ITERATION; | |||
| // Function Op | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_NODE_INDEX; | |||
| // Used for mark the active node is for loop, type:bool | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_LOOP_ACTIVE; | |||
| @@ -742,6 +952,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INS | |||
| // For inserted op | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | |||
| // For compress weight | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMPRESS_WEIGHT; | |||
| // For data dump | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP; | |||
| @@ -752,6 +965,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | |||
| // used for l1 fusion and other fusion in future | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||
| // used for label switch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | |||
| // Varible | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | |||
| @@ -20,10 +20,8 @@ | |||
| #include <atomic> | |||
| #include <memory> | |||
| #include <vector> | |||
| #include "graph/attr_value_serializable.h" | |||
| #include "graph/buffer.h" | |||
| namespace ge { | |||
| #define DEF_TYPE_DEC(type, name) \ | |||
| inline void set_##name(const type &value) { name = value; } \ | |||
| @@ -49,10 +47,9 @@ namespace ge { | |||
| inline void add_##name(type value) { name.push_back(value); } \ | |||
| inline std::vector<type> *mutable_##name() { return &name; } | |||
| #define DEF_TYPE_BYTES_DEC(name) \ | |||
| inline void clear_##name() { name.ClearBuffer(); } \ | |||
| inline void set_##name(const void *value, size_t size) { \ | |||
| name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||
| #define DEF_TYPE_BYTES_DEC(name) \ | |||
| inline void clear_##name() { name.ClearBuffer(); } \ | |||
| inline void set_##name(const void *value, size_t size) { name = Buffer::CopyFrom((const uint8_t *)(value), size); } \ | |||
| inline Buffer *mutable_##name() { return &name; } | |||
| struct CompressInfo { | |||
| @@ -23,7 +23,6 @@ | |||
| #include <unordered_set> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "graph/detail/any_map.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| @@ -96,7 +95,7 @@ class GeIrProtoHelper { | |||
| } | |||
| } | |||
| // protoMsg_ is part of protoOwner_ and they have the same runtime | |||
| // protoMsg_ is part of protoOwner_, they have the same runtime | |||
| ProtoMsgOwner protoOwner_ = nullptr; | |||
| ProtoType *protoMsg_ = nullptr; | |||
| friend class GeIrProtoHelper<typename std::conditional< | |||
| @@ -21,9 +21,7 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/anchor.h" | |||
| #include "graph/model.h" | |||
| #include "detail/attributes_holder.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/graph.h" | |||
| @@ -48,15 +46,15 @@ struct NodeNameNodeReq { | |||
| class ModelSerializeImp { | |||
| public: | |||
| bool SerializeModel(const Model &model, proto::ModelDef *modeProto); | |||
| bool SerializeModel(const Model &model, proto::ModelDef *modeProto, bool is_dump = false); | |||
| bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto); | |||
| bool SerializeGraph(const ConstComputeGraphPtr &graph, proto::GraphDef *graphProto, bool is_dump = false); | |||
| bool SerializeEdge(const NodePtr &node, proto::OpDef *opDefProto); | |||
| bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto); | |||
| bool SerializeOpDesc(const ConstOpDescPtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||
| bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto); | |||
| bool SerializeNode(const NodePtr &node, proto::OpDef *opDefProto, bool is_dump = false); | |||
| bool SerializeTensor(const ConstGeTensorPtr &tensor, proto::TensorDef *tensorProto); | |||
| @@ -23,7 +23,6 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "graph/buffer.h" | |||
| #include "detail/attributes_holder.h" | |||
| #include "graph/ge_error_codes.h" | |||
| @@ -139,15 +138,14 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
| template <typename vector_type> | |||
| // To cols | |||
| using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, | |||
| int>::type; | |||
| using enable_if_vector_type_valid_t = typename std::enable_if<IsAttrTypeEnable<vector_type>::LIST_VALUE, int>::type; | |||
| template <typename one_type> | |||
| using enable_if_one_type_valid_t = typename std::enable_if<IsAttrTypeEnable<one_type>::VALUE, int>::type; | |||
| template <typename val_type> | |||
| using enable_if_type_valid_t = | |||
| typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||
| typename std::enable_if<IsAttrTypeEnable<val_type>::VALUE || IsAttrTypeEnable<val_type>::LIST_VALUE, int>::type; | |||
| template <typename seriliable_type> | |||
| using enable_if_seriliable_type_valid_t = typename seriliable_type::__ge_serializable; | |||
| @@ -18,7 +18,6 @@ | |||
| #define INC_GRAPH_GE_CONTEXT_H_ | |||
| #include <string> | |||
| #include "graph/ge_error_codes.h" | |||
| namespace ge { | |||
| @@ -42,4 +41,4 @@ class GEContext { | |||
| GEContext &GetContext(); | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_GE_CONTEXT_H_ | |||
| #endif // INC_GRAPH_GE_CONTEXT_H_ | |||
| @@ -20,7 +20,6 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/ge_error_codes.h" | |||
| using std::map; | |||
| @@ -42,5 +41,4 @@ class GEThreadLocalContext { | |||
| GEThreadLocalContext &GetThreadLocalContext(); | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_GE_LOCAL_CONTEXT_H_ | |||
| @@ -21,12 +21,10 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "detail/attributes_holder.h" | |||
| #include "graph/buffer.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
| public: | |||
| @@ -43,6 +41,18 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
| int64_t GetShapeSize() const; | |||
| std::string ToString() const; | |||
| /// | |||
| /// @brief Check is unknown shape | |||
| /// @return bool | |||
| /// | |||
| bool IsUnknownShape() const; | |||
| /// | |||
| /// @brief Check is a scalar | |||
| /// @return bool | |||
| /// | |||
| bool IsScalar() const; | |||
| GeShape(const GeShape &other); | |||
| GeShape(GeShape &&other); | |||
| GeShape &operator=(const GeShape &other); | |||
| @@ -51,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
| private: | |||
| GeIrProtoHelper<proto::ShapeDef> shape_def_; | |||
| friend class GeTensorDesc; | |||
| // Create geshape from proto obj | |||
| // Create from proto obj | |||
| GeShape(const ProtoMsgOwner &protoOnwer, proto::ShapeDef *protoMsg); | |||
| void RefTo(const GeShape &shape) { shape_def_ = shape.shape_def_; } | |||
| @@ -112,7 +122,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||
| void Init(); | |||
| // Create getensordesc from proto obj | |||
| // Create from proto obj | |||
| GeTensorDesc(const ProtoMsgOwner &protoOnwer, proto::TensorDescriptor *protoMsg); | |||
| friend class GeTensor; | |||
| friend class GeAttrValueImp; | |||
| @@ -159,10 +169,10 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { | |||
| friend class GeAttrValueImp; | |||
| friend class ModelSerializeImp; | |||
| friend class OnnxUtils; | |||
| // Create getensor from proto obj | |||
| // Create from proto obj | |||
| GeTensor(const ProtoMsgOwner &protoOnwer, proto::TensorDef *protoMsg); | |||
| GeIrProtoHelper<proto::TensorDef> tensor_def_; | |||
| // Reference from tensorDef_, cab not use it directly | |||
| // Reference from tensorDef_, do not direct use | |||
| mutable GeTensorDesc __desc_; | |||
| GeTensorDesc &DescReference() const; | |||
| }; | |||
| @@ -21,7 +21,6 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "detail/attributes_holder.h" | |||
| #include "graph/ge_attr_value.h" | |||
| #include "graph/graph.h" | |||
| @@ -62,7 +61,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Model : public AttrHolder { | |||
| using AttrHolder::HasAttr; | |||
| using AttrHolder::SetAttr; | |||
| graphStatus Save(Buffer &buffer) const; | |||
| graphStatus Save(Buffer &buffer, bool is_dump = false) const; | |||
| graphStatus SaveToFile(const string &file_name) const; | |||
| // Model will be rewrite | |||
| @@ -19,7 +19,6 @@ | |||
| #include <map> | |||
| #include <string> | |||
| #include "graph/buffer.h" | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/model.h" | |||
| @@ -27,7 +26,7 @@ | |||
| namespace ge { | |||
| class ModelSerialize { | |||
| public: | |||
| Buffer SerializeModel(const Model &model); | |||
| Buffer SerializeModel(const Model &model, bool is_dump = false); | |||
| Model UnserializeModel(const uint8_t *data, size_t len); | |||
| Model UnserializeModel(ge::proto::ModelDef &model_def); | |||
| @@ -113,25 +113,25 @@ class Node : public std::enable_shared_from_this<Node> { | |||
| bool IsAllInNodesSeen(std::unordered_set<Node *> &nodes_seen) const; | |||
| // All inData nodes | |||
| // All in Data nodes | |||
| Vistor<NodePtr> GetInDataNodes() const; | |||
| // All inControl nodes | |||
| // All in Control nodes | |||
| Vistor<NodePtr> GetInControlNodes() const; | |||
| // GetInAllNodes = InDataNodes + InControlNodes | |||
| Vistor<NodePtr> GetInAllNodes() const; | |||
| // All outData nodes | |||
| // All out Data nodes | |||
| Vistor<NodePtr> GetOutDataNodes() const; | |||
| uint32_t GetOutDataNodesSize() const; | |||
| // All outControl nodes | |||
| // All out Control nodes | |||
| Vistor<NodePtr> GetOutControlNodes() const; | |||
| // GetOutAllNodes = OutDataNodes + InControlNodes | |||
| Vistor<NodePtr> GetOutAllNodes() const; | |||
| // Get all indata nodes and its outanchor | |||
| // Get all in data nodes and its out-anchor | |||
| Vistor<std::pair<NodePtr, OutDataAnchorPtr>> GetInDataNodesAndAnchors() const; | |||
| // Get all outdata nodes and its inanchor | |||
| // Get all out data nodes and its in-anchor | |||
| Vistor<std::pair<NodePtr, InDataAnchorPtr>> GetOutDataNodesAndAnchors() const; | |||
| graphStatus InferShapeAndType() const; | |||
| @@ -176,7 +176,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||
| void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | |||
| NodePtr GetOrigNode(void) { return orig_node_; } | |||
| NodePtr GetOrigNode() { return orig_node_; } | |||
| private: | |||
| bool NodeMembersAreEqual(const Node &r_node) const; | |||
| @@ -23,7 +23,6 @@ | |||
| #include <string> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include "detail/attributes_holder.h" | |||
| #include "graph/range_vistor.h" | |||
| @@ -108,6 +107,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| size_t GetInputsSize() const; | |||
| size_t GetAllInputsSize() const; | |||
| graphStatus AddOutputDesc(const GeTensorDesc &output_desc); | |||
| graphStatus AddOutputDesc(const string &name, const GeTensorDesc &output_desc); | |||
| @@ -122,6 +123,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | |||
| uint32_t GetAllOutputsDescSize() const; | |||
| Vistor<GeTensorDesc> GetAllOutputsDesc() const; | |||
| Vistor<GeTensorDescPtr> GetAllOutputsDescPtr() const; | |||
| @@ -132,6 +135,10 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| ConstGeTensorDescPtr GetInputDescPtr(uint32_t index) const; | |||
| ConstGeTensorDescPtr GetInputDescPtrDfault(uint32_t index) const; | |||
| ConstGeTensorDescPtr GetInputDescPtr(const string &name) const; | |||
| graphStatus AddDynamicInputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
| graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
| @@ -140,7 +147,11 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| bool IsOptionalInput(uint32_t index) const; | |||
| std::map<string, uint32_t> GetAllInputName(); | |||
| std::map<string, uint32_t> GetAllInputName() const; | |||
| void SetAllInputName(const std::map<string, uint32_t> &input_name_idx); | |||
| std::vector<string> GetAllOptionalInputName() const; | |||
| std::map<string, uint32_t> GetAllOutputName(); | |||
| @@ -225,6 +236,14 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| std::string GetOpEngineName() const; | |||
| graphStatus AddSubgraphName(const std::string &name); | |||
| const std::map<std::string, uint32_t> &GetSubgraphNameIndexes() const; | |||
| std::string GetSubgraphInstanceName(uint32_t index) const; | |||
| const std::vector<std::string> &GetSubgraphInstanceNames() const; | |||
| void AddSubgraphInstanceName(std::string name); | |||
| void RemoveSubgraphInstanceName(const std::string &name); | |||
| protected: | |||
| ProtoAttrMapHelper MutableAttrMap() override; | |||
| ConstProtoAttrMapHelper GetAttrMap() const override; | |||
| @@ -236,9 +255,9 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| bool OpDescGenTensorDescsAreEqual(const OpDesc &r_op_desc) const; | |||
| GeIrProtoHelper<ge::proto::OpDef> op_def_; | |||
| std::vector<std::string> subgraph_instance_names_; | |||
| std::map<std::string, uint32_t> subgraph_names_to_index_; | |||
| vector<GeTensorDescPtr> inputs_desc_{}; | |||
| map<string, uint32_t> input_name_idx_{}; | |||
| std::unordered_set<string> optional_input_names_{}; | |||
| vector<GeTensorDescPtr> outputs_desc_{}; | |||
| map<string, uint32_t> output_name_idx_{}; | |||
| std::function<graphStatus(Operator &)> infer_func_ = nullptr; | |||
| @@ -21,7 +21,6 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/operator_factory.h" | |||
| namespace ge { | |||
| @@ -47,7 +46,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OperatorFactoryImpl { | |||
| static graphStatus RegisterVerifyFunc(const std::string &operator_type, VerifyFunc const verify_func); | |||
| private: | |||
| static shared_ptr<std::map<string, OpCreator>> operator_creators_; | |||
| static shared_ptr<std::map<string, InferShapeFunc>> operator_infershape_funcs_; | |||
| static shared_ptr<std::map<string, InferFormatFunc>> operator_inferformat_funcs_; | |||
| @@ -18,8 +18,8 @@ | |||
| #define INC_GRAPH_SHAPE_REFINER_H_ | |||
| #include <string> | |||
| #include "external/graph/inference_context.h" | |||
| #include "external/graph/ge_error_codes.h" | |||
| #include "graph/node.h" | |||
| @@ -27,8 +27,10 @@ namespace ge { | |||
| // ShapeRefiner performs shape inference for compute graphs | |||
| class ShapeRefiner { | |||
| public: | |||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph); | |||
| static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); | |||
| static graphStatus InferShapeAndType(const NodePtr &node); | |||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||
| private: | |||
| static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | |||
| @@ -14,8 +14,8 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
| #define INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
| #ifndef INC_GRAPH_USR_TYPES_H_ | |||
| #define INC_GRAPH_USR_TYPES_H_ | |||
| #include <atomic> | |||
| #include <memory> | |||
| @@ -130,4 +130,4 @@ struct UsrQuantizeFactorParams { | |||
| #undef USR_TYPE_BYTES_DEC | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_GRAPH_USR_TYPES_H_ | |||
| #endif // INC_GRAPH_USR_TYPES_H_ | |||
| @@ -99,8 +99,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
| static bool SetZeroCopyBytes(AttrHolderAdapter &&obj, const string &name, Buffer &&buffer); | |||
| static bool GetZeroCopyBytes(ConstAttrHolderAdapter &&obj, const string &name, Buffer &buffer); | |||
| // Value will be moved | |||
| static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, | |||
| vector<Buffer> &listBuffer); | |||
| static bool SetZeroCopyListBytes(AttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||
| static bool GetZeroCopyListBytes(ConstAttrHolderAdapter &&obj, const string &name, vector<Buffer> &listBuffer); | |||
| static bool SetListListInt(AttrHolderAdapter &&obj, const string &name, const vector<vector<int64_t>> &value); | |||
| @@ -116,6 +115,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrUtils { | |||
| static OpDescPtr CopyOpDesc(const ConstOpDescPtr &orgOpDesc); | |||
| static std::string GetAllAttrsStr(ConstAttrHolderAdapter &&obj); | |||
| class AttrHolderAdapter { | |||
| public: | |||
| AttrHolderAdapter(AttrHolder *obj) : obj_(obj) {} | |||
| @@ -137,6 +137,18 @@ class GraphUtils { | |||
| static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | |||
| const std::vector<OpDescPtr> &vec_op_desc); | |||
| /// | |||
| /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst | |||
| /// @param [in] src | |||
| /// @param [in] dsts | |||
| /// @param [in] insert_node | |||
| /// @param [in] input_index | |||
| /// @param [in] output_index | |||
| /// @return graphStatus | |||
| /// | |||
| static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||
| static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | |||
| static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node); | |||
| @@ -145,16 +157,12 @@ class GraphUtils { | |||
| static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node); | |||
| static bool CheckIsTrainGraph(const ge::ComputeGraphPtr &compute_graph); | |||
| static bool MatchDumpStr(const std::string &suffix); | |||
| static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); | |||
| static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | |||
| static bool CheckGlobalStepNode(const ge::NodePtr &node); | |||
| static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | |||
| static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | |||
| @@ -252,6 +260,315 @@ class GraphUtils { | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | |||
| static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||
| }; | |||
| class ComputeGraphBuilder { | |||
| public: | |||
| ComputeGraphBuilder() : owner_graph_(nullptr) {} | |||
| ComputeGraphBuilder(const ComputeGraphBuilder &) = delete; | |||
| ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete; | |||
| ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete; | |||
| ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete; | |||
| ~ComputeGraphBuilder() = default; | |||
| /// | |||
| /// @brief Add node to graph | |||
| /// @param [in] op_desc | |||
| /// @return ComputeGraphBuilder | |||
| /// | |||
| virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc); | |||
| /// | |||
| /// @brief Add data-link among nodes in graph | |||
| /// @param [in] src_name | |||
| /// @param [in] out_anchor_ind | |||
| /// @param [in] dst_name | |||
| /// @param [in] in_anchor_ind | |||
| /// @return ComputeGraphBuilder | |||
| /// | |||
| virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, | |||
| const std::string &dst_name, uint32_t in_anchor_ind); | |||
| /// | |||
| /// @brief Add ctrl-link among nodes in graph | |||
| /// @param [in] src_name | |||
| /// @param [in] dst_name | |||
| /// @return ComputeGraphBuilder | |||
| /// | |||
| virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name); | |||
| /// | |||
| /// @brief Build graph | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return ComputeGraphPtr | |||
| /// | |||
| virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0; | |||
| /// @brief Get node with name | |||
| /// @param [in] name | |||
| /// @return NodePtr | |||
| /// | |||
| NodePtr GetNode(const std::string &name); | |||
| protected: | |||
| /// | |||
| /// @brief Build nodes | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildNodes(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Build data-links | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildDataLinks(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Build ctrl-links | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg); | |||
| ComputeGraphPtr owner_graph_; | |||
| // node_name -> node | |||
| std::map<std::string, NodePtr> node_names_; | |||
| std::vector<OpDescPtr> nodes_; | |||
| // <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind> | |||
| std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_; | |||
| // src_node_name -> dst_node_name | |||
| std::vector<std::pair<std::string, std::string>> ctrl_links_; | |||
| }; | |||
| class CompleteGraphBuilder : public ComputeGraphBuilder { | |||
| public: | |||
| explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} | |||
| CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; | |||
| CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; | |||
| CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; | |||
| CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete; | |||
| ~CompleteGraphBuilder() = default; | |||
| /// | |||
| /// @brief Add node to graph | |||
| /// @param [in] op_desc | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||
| /// | |||
| /// @brief Add data-link among nodes in graph | |||
| /// @param [in] src_name | |||
| /// @param [in] out_anchor_ind | |||
| /// @param [in] dst_name | |||
| /// @param [in] in_anchor_ind | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||
| uint32_t in_anchor_ind) override; | |||
| /// | |||
| /// @brief Add ctrl-link among nodes in graph | |||
| /// @param [in] src_name | |||
| /// @param [in] dst_name | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||
| /// | |||
| /// @brief Set index_th input anchor for graph | |||
| /// @param [in] index | |||
| /// @param [in] node_names | |||
| /// @param [in] anchor_inds | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names, | |||
| const std::vector<uint32_t> &anchor_inds); | |||
| /// | |||
| /// @brief Set index_th input of graph as useless | |||
| /// @param [in] index | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &SetUselessInput(uint32_t index); | |||
| /// | |||
| /// @brief Add output anchor for graph | |||
| /// @param [in] owner_node_name | |||
| /// @param [in] anchor_ind | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); | |||
| /// | |||
| /// @brief Set parent-node of graph | |||
| /// @param [in] parent_node | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node); | |||
| /// | |||
| /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node | |||
| /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping); | |||
| /// | |||
| /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind | |||
| /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node | |||
| /// @return CompleteGraphBuilder | |||
| /// | |||
| CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping); | |||
| /// | |||
| /// @brief Build graph | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return ComputeGraphPtr | |||
| /// | |||
| ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||
| private: | |||
| /// | |||
| /// @brief Build inputs | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildInputs(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Add data node | |||
| /// @param [in] index | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| NodePtr AddDateNode(uint32_t index, graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Build outputs | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildOutputs(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Add NetOutput node | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return NodePtr | |||
| /// | |||
| NodePtr AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||
| /// | |||
| /// @brief Add input/output tensor for NetOutput node | |||
| /// @param [in] out_nodes_info | |||
| /// @param [out] net_output_desc | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus BuildInOutForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
| OpDescPtr &net_output_desc); | |||
| /// | |||
| /// @brief Add edge for NetOutput node | |||
| /// @param [in] out_nodes_info | |||
| /// @param [out] net_output_node | |||
| /// @return graphStatus | |||
| /// | |||
| graphStatus AddEdgeForNetOutput(const std::vector<std::pair<NodePtr, int32_t>> &out_nodes_info, | |||
| const NodePtr &net_output_node); | |||
| std::string name_; | |||
| NodePtr parent_node_; | |||
| std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | |||
| std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | |||
| // index_of_graph_input -> in_anchor_index_of_parent_node | |||
| std::map<uint32_t, uint32_t> input_mapping_; | |||
| // index_of_graph_output -> out_anchor_index_of_parent_node | |||
| std::map<uint32_t, uint32_t> output_mapping_; | |||
| }; | |||
| class PartialGraphBuilder : public ComputeGraphBuilder { | |||
| public: | |||
| PartialGraphBuilder() = default; | |||
| PartialGraphBuilder(const PartialGraphBuilder &) = delete; | |||
| PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete; | |||
| PartialGraphBuilder(const PartialGraphBuilder &&) = delete; | |||
| PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete; | |||
| ~PartialGraphBuilder() = default; | |||
| /// | |||
| /// @brief Add node to graph | |||
| /// @param [in] op_desc | |||
| /// @return PartialGraphBuilder | |||
| /// | |||
| PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override; | |||
| /// | |||
| /// @brief Add data-link among nodes in graph | |||
| /// @param [in] src_name | |||
| /// @param [in] out_anchor_ind | |||
| /// @param [in] dst_name | |||
| /// @param [in] in_anchor_ind | |||
| /// @return PartialGraphBuilder | |||
| /// | |||
| PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name, | |||
| uint32_t in_anchor_ind) override; | |||
| /// | |||
| /// @brief Add ctrl-link among nodes in graph | |||
| /// @param [in] src_name | |||
| /// @param [in] dst_name | |||
| /// @return PartialGraphBuilder | |||
| /// | |||
| PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override; | |||
| /// | |||
| /// @brief Set owner graph | |||
| /// @param [in] graph | |||
| /// @return PartialGraphBuilder | |||
| /// | |||
| PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph); | |||
| /// | |||
| /// @brief Add exist node | |||
| /// @param [in] node | |||
| /// @return PartialGraphBuilder | |||
| /// | |||
| PartialGraphBuilder &AddExistNode(const NodePtr &node); | |||
| /// | |||
| /// @brief Build multi nodes with links | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return ComputeGraphPtr | |||
| /// | |||
| ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override; | |||
| private: | |||
| /// | |||
| /// @brief Build exist nodes | |||
| /// @param [out] error_code | |||
| /// @param [out] error_msg | |||
| /// @return void | |||
| /// | |||
| void BuildExistNodes(graphStatus &error_code, std::string &error_msg); | |||
| std::vector<NodePtr> exist_nodes_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -56,6 +56,11 @@ class NodeUtils { | |||
| static graphStatus UpdateOutputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
| static graphStatus UpdateInputShape(const Node &node, uint32_t index, const GeShape &shape); | |||
| static std::string GetNodeType(const Node &node); | |||
| static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | |||
| static graphStatus AddSubgraph(Node &node, const ComputeGraphPtr &subgraph); | |||
| private: | |||
| static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | |||
| static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | |||
| @@ -20,7 +20,6 @@ | |||
| #include <memory> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/def_types.h" | |||
| #include "graph/node.h" | |||
| #include "graph/op_desc.h" | |||
| @@ -29,7 +28,6 @@ | |||
| namespace ge { | |||
| class OpDesc; | |||
| using OpDescPtr = std::shared_ptr<OpDesc>; | |||
| class OpDescUtils { | |||
| @@ -39,55 +37,108 @@ class OpDescUtils { | |||
| OpDescUtils() = default; | |||
| ~OpDescUtils() = default; | |||
| static bool HasQuantizeFactorParams(const OpDescPtr &op_desc); | |||
| static bool HasQuantizeFactorParams(const OpDesc &op_desc); | |||
| static graphStatus GetQuantizeFactorParams(const OpDescPtr &op_desc, QuantizeFactorParams &quant); | |||
| static graphStatus GetQuantizeFactorParams(const OpDesc &op_desc, QuantizeFactorParams &quant); | |||
| static graphStatus SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant); | |||
| static graphStatus SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant); | |||
| static vector<ge::NodePtr> GetConstInputNode(const ge::Node &node); | |||
| static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr> &input_nodes); | |||
| static vector<ConstGeTensorPtr> GetWeights(const ge::Node &node); | |||
| static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr &node); | |||
| static vector<GeTensorPtr> MutableWeights(const ge::Node &node); | |||
| static bool HasQuantizeFactorParams(const OpDescPtr& op_desc); | |||
| static bool HasQuantizeFactorParams(const OpDesc& op_desc); | |||
| static graphStatus GetQuantizeFactorParams(const OpDescPtr& op_desc, QuantizeFactorParams& quant); | |||
| static graphStatus GetQuantizeFactorParams(const OpDesc& op_desc, QuantizeFactorParams& quant); | |||
| static graphStatus SetQuantizeFactorParams(const OpDescPtr& op_desc, const QuantizeFactorParams& quant); | |||
| static graphStatus SetQuantizeFactorParams(OpDesc& op_desc, const QuantizeFactorParams& quant); | |||
| static vector<ge::NodePtr> GetConstInputNode(const ge::Node& node); | |||
| static vector<ConstGeTensorPtr> GetInputData(const vector<ge::NodePtr>& input_nodes); | |||
| static vector<ConstGeTensorPtr> GetWeights(const ge::Node& node); | |||
| static vector<ConstGeTensorPtr> GetWeights(const ge::ConstNodePtr& node); | |||
| static vector<GeTensorPtr> MutableWeights(const ge::Node& node); | |||
| static vector<GeTensorPtr> MutableWeights(const ge::NodePtr node); | |||
| static graphStatus SetWeights(ge::Node &node, const vector<ge::GeTensorPtr> &weights); | |||
| static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr> &weights); | |||
| static graphStatus SetWeights(ge::Node& node, const vector<ge::GeTensorPtr>& weights); | |||
| static graphStatus SetWeights(ge::NodePtr node, const vector<ge::GeTensorPtr>& weights); | |||
| static graphStatus ClearWeights(ge::NodePtr node); | |||
| static bool ClearInputDesc(ge::OpDescPtr op_desc, uint32_t index); | |||
| static bool ClearInputDesc(const ge::NodePtr &node); | |||
| static bool ClearOutputDesc(const ge::OpDescPtr &op_desc, uint32_t index); | |||
| static bool ClearOutputDesc(const ge::NodePtr &node); | |||
| static vector<ge::NodePtr> GetConstInputs(const ge::Node &node); | |||
| static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr &node); | |||
| static size_t GetNonConstInputsSize(const ge::Node &node); | |||
| static bool ClearInputDesc(const ge::NodePtr& node); | |||
| static bool ClearOutputDesc(const ge::OpDescPtr& op_desc, uint32_t index); | |||
| static bool ClearOutputDesc(const ge::NodePtr& node); | |||
| static vector<ge::NodePtr> GetConstInputs(const ge::Node& node); | |||
| static vector<ge::NodePtr> GetConstInputs(const ge::ConstNodePtr& node); | |||
| static size_t GetNonConstInputsSize(const ge::Node& node); | |||
| static size_t GetNonConstInputsSize(ge::ConstNodePtr node); | |||
| // Index: Indicate the index of all non const inputs | |||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node &node, size_t index_non_const = 0); | |||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr &node, size_t index_non_const = 0); | |||
| static bool GetNonConstInputIndex(const ge::Node &node, size_t index_non_const, size_t &index); | |||
| static bool GetNonConstInputIndex(const ge::ConstNodePtr &node, size_t index_non_const, size_t &index); | |||
| // Index: Indicate the index of all inputs | |||
| static bool IsNonConstInput(const ge::Node &node, size_t index = 0); | |||
| static bool IsNonConstInput(const ge::ConstNodePtr &node, size_t index = 0); | |||
| static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr &node); | |||
| static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr &tensor_ptr); | |||
| // Index: Indicates the index of all non const inputs | |||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::Node& node, size_t index_non_const = 0); | |||
| static GeTensorDesc GetNonConstInputTensorDesc(const ge::ConstNodePtr& node, size_t index_non_const = 0); | |||
| static bool GetNonConstInputIndex(const ge::Node& node, size_t index_non_const, size_t& index); | |||
| static bool GetNonConstInputIndex(const ge::ConstNodePtr& node, size_t index_non_const, size_t& index); | |||
| // Index: Indicates the index of all inputs | |||
| static bool IsNonConstInput(const ge::Node& node, size_t index = 0); | |||
| static bool IsNonConstInput(const ge::ConstNodePtr& node, size_t index = 0); | |||
| static vector<ge::GeTensorDesc> GetNonConstTensorDesc(const ge::ConstNodePtr& node); | |||
| static graphStatus AddConstOpToAnchor(InDataAnchorPtr in_anchor, const GeTensorPtr& tensor_ptr); | |||
| static Operator CreateOperatorFromOpDesc(OpDescPtr op_desc); | |||
| static Operator CreateOperatorFromNode(ge::ConstNodePtr node_ptr); | |||
| static OpDescPtr GetOpDescFromOperator(const Operator &oprt); | |||
| static OpDescPtr GetOpDescFromOperator(const Operator& oprt); | |||
| static OpDescPtr CreateConstOp(const GeTensorPtr &tensor_ptr); | |||
| static OpDescPtr CreateConstOp(const GeTensorPtr& tensor_ptr); | |||
| private: | |||
| static GeTensorPtr MutableWeights(ge::OpDesc &op_desc); | |||
| static GeTensorPtr MutableWeights(ge::OpDesc& op_desc); | |||
| static GeTensorPtr MutableWeights(ge::OpDescPtr op_desc); | |||
| static graphStatus SetWeights(ge::OpDesc &op_desc, const GeTensorPtr weight); | |||
| static graphStatus SetWeights(ge::OpDesc& op_desc, const GeTensorPtr weight); | |||
| static graphStatus SetWeights(ge::OpDescPtr op_desc, const GeTensorPtr weight); | |||
| }; | |||
| class OpDescBuilder { | |||
| public: | |||
| OpDescBuilder(std::string name, std::string type) : name_(std::move(name)), type_(std::move(type)) {} | |||
| OpDescBuilder(const OpDescBuilder&) = delete; | |||
| OpDescBuilder& operator=(const OpDescBuilder&) = delete; | |||
| OpDescBuilder(const OpDescBuilder&&) = delete; | |||
| OpDescBuilder& operator=(const OpDescBuilder&&) = delete; | |||
| ~OpDescBuilder() = default; | |||
| /// | |||
| /// @brief Add input | |||
| /// @param [in] name | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddInput(const std::string& name); | |||
| /// | |||
| /// @brief Add dynamic input | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddDynamicInput(const std::string& name, uint32_t num); | |||
| /// | |||
| /// @brief Add output | |||
| /// @param [in] name | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddOutput(const std::string& name); | |||
| /// | |||
| /// @brief Add dynamic output | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @return OpDescBuilder | |||
| /// | |||
| OpDescBuilder& AddDynamicOutput(const std::string& name, uint32_t num); | |||
| /// | |||
| /// @brief Build op_desc | |||
| /// @return OpDescPtr | |||
| /// | |||
| OpDescPtr Build(); | |||
| private: | |||
| std::string name_; | |||
| std::string type_; | |||
| std::vector<std::string> inputs_; | |||
| std::vector<std::string> outputs_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_UTILS_OP_DESC_UTILS_H_ | |||
| @@ -18,15 +18,14 @@ | |||
| #define INC_GRAPH_UTILS_TENSOR_UTILS_H_ | |||
| #include <vector> | |||
| #include "graph/def_types.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/ge_tensor.h" | |||
| namespace ge { | |||
| class TensorUtils { | |||
| public: | |||
| static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, uint32_t &size); | |||
| static void SetSize(GeTensorDesc &tensorDesc, uint32_t size); | |||
| static ge::graphStatus GetSize(const GeTensorDesc &tensorDesc, int64_t &size); | |||
| static void SetSize(GeTensorDesc &tensorDesc, int64_t size); | |||
| static uint32_t GetWeightSize(const ConstGeTensorPtr &tensorPtr); | |||
| static uint32_t GetWeightSize(const GeTensor &tensor); | |||
| static uint32_t GetWeightSize(const GeTensorDesc &tensorDesc); | |||
| @@ -62,16 +61,16 @@ class TensorUtils { | |||
| static void SetRC(GeTensorDesc &tensorDesc, uint32_t rc); | |||
| /// | |||
| /// calculate mem size of the tensor. | |||
| /// calculate tensor mem size. | |||
| /// @param shape tensor shape | |||
| /// @param format tensor format | |||
| /// @param data_type tensor data type | |||
| /// @param mem_size -1 means unknown shape,others means mem size | |||
| /// @return GRAPH_SUCCESS:success, others:failed | |||
| /// @param mem_size -1 means unknown shape,other means mem size | |||
| /// @return GRAPH_SUCCESS:success, other:failed | |||
| /// | |||
| static ge::graphStatus CalcTensorMemSize(const GeShape &shape, Format format, DataType data_type, int64_t &mem_size); | |||
| static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||
| static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp); | |||
| static ge::graphStatus GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||
| static ge::graphStatus GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp); | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_UTILS_TENSOR_UTILS_H_ | |||
| @@ -58,6 +58,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||
| include_directories(${GE_SOURCE_DIR}/inc/graph) | |||
| include_directories(${GE_SOURCE_DIR}/inc/common) | |||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | |||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/ops) | |||
| include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | |||
| include_directories(${CMAKE_BINARY_DIR}) | |||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
| @@ -26,6 +26,8 @@ Anchor::Anchor(const NodePtr &owner_node, int idx) : owner_node_(owner_node), id | |||
| bool Anchor::IsTypeOf(TYPE type) const { return strcmp(Anchor::TypeOf<Anchor>(), type) == 0; } | |||
| size_t Anchor::GetPeerAnchorsSize() const { return peer_anchors_.size(); } | |||
| Anchor::Vistor<AnchorPtr> Anchor::GetPeerAnchors() const { | |||
| vector<AnchorPtr> ret; | |||
| for (const auto &anchor : peer_anchors_) { | |||
| @@ -32,8 +32,7 @@ Buffer::Buffer(const Buffer &other) { | |||
| buffer_ = other.buffer_; | |||
| } | |||
| // default | |||
| Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { | |||
| Buffer::Buffer(std::size_t buffer_size, std::uint8_t default_val) : Buffer() { // default | |||
| auto proto_msg = data_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| try { | |||
| @@ -15,9 +15,7 @@ | |||
| */ | |||
| #include "graph/compute_graph.h" | |||
| #include <deque> | |||
| #include "./format_refiner.h" | |||
| #include "./ge_context.h" | |||
| #include "debug/ge_attr_define.h" | |||
| @@ -41,7 +39,7 @@ const size_t OUTPUT_PARAM_SIZE = 2; | |||
| } // namespace | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | |||
| : nodes_(), input_nodes_(), sub_graph_(), name_(name), is_valid_flag_(false), need_iteration_(false) { | |||
| : name_(name), nodes_(), input_nodes_(), sub_graph_(), is_valid_flag_(false), need_iteration_(false) { | |||
| attrs_.InitDefault(); | |||
| } | |||
| ComputeGraph::~ComputeGraph() {} | |||
| @@ -154,7 +152,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::VectorInputNod | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ComputeGraph::GraphMembersAreEqual( | |||
| const ComputeGraph &r_graph) const { | |||
| return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.sub_graph_.size()") && | |||
| return (IsEqual(this->sub_graph_.size(), r_graph.sub_graph_.size(), "graph.subgraphs_.size()") && | |||
| IsEqual(this->nodes_.size(), r_graph.nodes_.size(), "graph.nodes_.size()") && | |||
| VectorInputNodePtrIsEqual(this->input_nodes_, r_graph.input_nodes_) && | |||
| IsEqual(this->name_, r_graph.name_, "graph.name_") && | |||
| @@ -398,6 +396,165 @@ graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &su | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<ComputeGraph> &subgraph) { | |||
| if (subgraph == nullptr) { | |||
| GE_LOGE("Try to add a null subgraph, name %s", name.c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto parent_graph = subgraph->GetParentGraph(); | |||
| if (parent_graph == nullptr) { | |||
| GE_LOGE("Try to add subgraph without parent graph, name %s", name.c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto parent_node = subgraph->GetParentNode(); | |||
| if (parent_node == nullptr) { | |||
| GE_LOGE("Try to add a subgraph without parent node, name %s", name.c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| if (parent_node->GetOwnerComputeGraph() != parent_graph) { | |||
| GE_LOGE( | |||
| "Try to add a subgraph which parent node's parent graph is not equal to " | |||
| "the subgraph's parent graph, subgraph name %s, parent node name %s", | |||
| subgraph->GetName().c_str(), parent_graph->GetName().c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| if (!this->parent_graph_.expired()) { | |||
| GE_LOGE("The subgraphs can only be added to the root graph"); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| if (name != subgraph->GetName()) { | |||
| GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); | |||
| } | |||
| sub_graph_.push_back(subgraph); | |||
| names_to_subgraph_[name] = subgraph; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| ComputeGraph::AddSubgraph(const std::shared_ptr<ComputeGraph> &subgraph) { | |||
| if (subgraph == nullptr) { | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| return AddSubgraph(subgraph->GetName(), subgraph); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph(const std::string &name) { | |||
| auto iter = names_to_subgraph_.find(name); | |||
| if (iter == names_to_subgraph_.end()) { | |||
| return; | |||
| } | |||
| for (auto vec_iter = sub_graph_.begin(); vec_iter != sub_graph_.end(); ++vec_iter) { | |||
| if (*vec_iter == iter->second) { | |||
| sub_graph_.erase(vec_iter); | |||
| break; | |||
| } | |||
| } | |||
| names_to_subgraph_.erase(iter); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::RemoveSubgraph( | |||
| const std::shared_ptr<ComputeGraph> &subgraph) { | |||
| if (subgraph != nullptr) { | |||
| RemoveSubgraph(subgraph->GetName()); | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::shared_ptr<ComputeGraph> ComputeGraph::GetSubgraph( | |||
| const std::string &name) const { | |||
| auto iter = names_to_subgraph_.find(name); | |||
| return iter == names_to_subgraph_.end() ? nullptr : iter->second; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::vector<std::shared_ptr<ComputeGraph>> | |||
| ComputeGraph::GetAllSubgraphs() const { | |||
| return sub_graph_; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<ComputeGraph> ComputeGraph::GetParentGraph() { | |||
| return parent_graph_.lock(); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentGraph( | |||
| const shared_ptr<ComputeGraph> &parent) { | |||
| parent_graph_ = parent; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY shared_ptr<Node> ComputeGraph::GetParentNode() { | |||
| return parent_node_.lock(); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode(const shared_ptr<Node> &parent) { | |||
| parent_node_ = parent; | |||
| } | |||
| /// | |||
| /// @brief Update input-mapping | |||
| /// @param [in] input_mapping : index_of_cur_graph_node_input -> index_of_new_graph_node_input | |||
| /// @return graphStatus | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | |||
| for (auto &input : input_nodes_) { | |||
| uint32_t cur_index = 0; | |||
| if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||
| continue; | |||
| } | |||
| auto iter = input_mapping.find(cur_index); | |||
| if (iter == input_mapping.end()) { | |||
| continue; | |||
| } | |||
| if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
| GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// @brief Update output-mapping | |||
| /// @param [in] output_mapping : index_of_cur_graph_node_output -> index_of_new_graph_node_output | |||
| /// @return graphStatus | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| ComputeGraph::UpdateOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping) { | |||
| NodePtr net_output = FindNode(kNodeNameNetOutput); | |||
| if (net_output == nullptr) { | |||
| GE_LOGE("UpdateOutputMapping failed: node %s not exist in graph.", kNodeNameNetOutput); | |||
| return GRAPH_FAILED; | |||
| } | |||
| OpDescPtr op_desc = net_output->GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| GE_LOGE("UpdateOutputMapping failed: op_desc is NULL."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| size_t num = op_desc->GetInputsSize(); | |||
| for (size_t i = 0; i < num; i++) { | |||
| GeTensorDesc tensor = op_desc->GetInputDesc(i); | |||
| uint32_t cur_index = 0; | |||
| if (!ge::AttrUtils::GetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||
| continue; | |||
| } | |||
| auto iter = output_mapping.find(cur_index); | |||
| if (iter == output_mapping.end()) { | |||
| continue; | |||
| } | |||
| if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||
| GE_LOGE("UpdateOutputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (op_desc->UpdateInputDesc(i, tensor) != GRAPH_SUCCESS) { | |||
| GE_LOGE("UpdateOutputMapping failed: update %u input_tensor failed.", i); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertEventNodes() { | |||
| std::vector<NodePtr> node_vec = nodes_; | |||
| for (const auto &node : GetAllNodes()) { | |||
| @@ -551,6 +708,23 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | |||
| auto ret = TopologicalSortingSubgraph(); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Sub graph partition Failed"); | |||
| return ret; | |||
| } | |||
| // partition sub graph | |||
| for (const auto &sub_graph : GetAllSubgraphs()) { | |||
| ret = sub_graph->TopologicalSortingSubgraph(); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Sub graph topological sort Failed"); | |||
| return ret; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSortingSubgraph() { | |||
| std::vector<NodePtr> node_vec; | |||
| std::map<NodePtr, uint32_t> map_in_edge_num; | |||
| bool use_BFS = false; | |||
| @@ -598,6 +772,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||
| node->GetOpDesc()->SetId(i); // [node->GetOpDesc(): should not be null] | |||
| nodes_.push_back(node); | |||
| } | |||
| is_valid_flag_ = true; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -614,7 +789,7 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
| verify_isolated = true; | |||
| } | |||
| } | |||
| for (const auto &node : GetAllNodes()) { | |||
| for (const auto &node : GetDirectNode()) { | |||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||
| map_in_edge_num[node] = static_cast<uint32_t>(GetInEdgeSize(node)); | |||
| if (map_in_edge_num[node] == 0) { | |||
| @@ -640,16 +815,16 @@ graphStatus ComputeGraph::SortNodes(std::vector<NodePtr> &stack, std::map<NodePt | |||
| /// 2. Compare two indices, if not match, swap the positions of two inputs | |||
| /// *: Remind: stack is reverse-order | |||
| for (size_t i = 0; i < stack.size(); ++i) { | |||
| // [stack: should not be null] | |||
| // If not found in 'inputs_order_', skip it | |||
| auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||
| GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||
| auto inx_i = it_i - inputs_order_.begin(); | |||
| for (size_t j = i + 1; j < stack.size(); ++j) { | |||
| // If not found in 'inputs_order_', skip it | |||
| auto it_i = std::find(inputs_order_.begin(), inputs_order_.end(), stack[i]->GetName()); | |||
| GE_IF_BOOL_EXEC(it_i == inputs_order_.end(), continue); | |||
| auto it_j = std::find(inputs_order_.begin(), inputs_order_.end(), stack[j]->GetName()); | |||
| GE_IF_BOOL_EXEC(it_j == inputs_order_.end(), continue); | |||
| // Compare index, swap them if it should be | |||
| auto inx_i = it_i - inputs_order_.begin(); | |||
| auto inx_j = it_j - inputs_order_.begin(); | |||
| GE_IF_BOOL_EXEC(inx_i < inx_j, std::swap(stack[i], stack[j])); | |||
| } | |||
| @@ -663,7 +838,7 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||
| return in_edge_size; | |||
| } | |||
| for (const auto &anchor : node->GetAllInDataAnchors()) { | |||
| in_edge_size = in_edge_size + anchor->GetPeerAnchors().size(); | |||
| in_edge_size = in_edge_size + anchor->GetPeerAnchorsSize(); | |||
| // Break flow control data loop. | |||
| OutDataAnchorPtr out_anchor = anchor->GetPeerOutAnchor(); | |||
| if ((out_anchor != nullptr) && (out_anchor->GetOwnerNode() != nullptr)) { | |||
| @@ -680,10 +855,11 @@ size_t ComputeGraph::GetInEdgeSize(const NodePtr &node) { | |||
| } | |||
| } | |||
| if (node->GetInControlAnchor() != nullptr) { | |||
| in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchors().size(); | |||
| in_edge_size = in_edge_size + node->GetInControlAnchor()->GetPeerAnchorsSize(); | |||
| } | |||
| return in_edge_size; | |||
| } | |||
| size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||
| size_t out_edge_size = 0; | |||
| if (node == nullptr) { | |||
| @@ -699,7 +875,7 @@ size_t ComputeGraph::GetOutEdgeSize(const NodePtr &node) { | |||
| } | |||
| } | |||
| if (node->GetOutControlAnchor() != nullptr) { | |||
| if (out_edge_size > (UINT32_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||
| if (out_edge_size > (UINT64_MAX - node->GetOutControlAnchor()->GetPeerAnchors().size())) { | |||
| return 0; | |||
| } | |||
| out_edge_size = out_edge_size + node->GetOutControlAnchor()->GetPeerAnchors().size(); | |||
| @@ -724,17 +900,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
| } | |||
| } | |||
| GE_IF_BOOL_EXEC(node->GetOutControlAnchor() == nullptr, GELOGE(GRAPH_FAILED, "Out control anchor is null"); | |||
| return ); | |||
| for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) { | |||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
| } | |||
| for (const auto &peer_in_anchor : node->GetOutControlAnchor()->GetPeerInDataAnchors()) { | |||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
| auto out_control_anchor = node->GetOutControlAnchor(); | |||
| if (out_control_anchor != nullptr) { | |||
| for (const auto &peer_in_anchor : out_control_anchor->GetPeerInControlAnchors()) { | |||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
| } | |||
| for (const auto &peer_in_anchor : out_control_anchor->GetPeerInDataAnchors()) { | |||
| GE_IF_BOOL_EXEC(peer_in_anchor != nullptr && peer_in_anchor->GetOwnerNode() != nullptr, | |||
| GELOGI("node name = %s, out control node name = %s.", node->GetName().c_str(), | |||
| peer_in_anchor->GetOwnerNode()->GetName().c_str())); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| @@ -18,21 +18,9 @@ | |||
| #define COMMON_GRAPH_DEBUG_GE_LOG_H_ | |||
| #include "graph/ge_error_codes.h" | |||
| #include "toolchain/slog.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #define GE_MOD_ID GE | |||
| #ifdef _MSC_VER | |||
| #define FUNC_NAME __FUNCTION__ | |||
| #else | |||
| #define FUNC_NAME __PRETTY_FUNCTION__ | |||
| #endif | |||
| #define D_GE_LOGE(fmt, ...) \ | |||
| dlog_error(static_cast<int>(GE_MOD_ID), "%s:" fmt, __FUNCTION__, ##__VA_ARGS__) | |||
| #define GE_LOGE(...) D_GE_LOGE(__VA_ARGS__) | |||
| #define GE_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||
| #define GE_LOGI_IF(condition, ...) \ | |||
| if ((condition)) { \ | |||
| @@ -44,15 +32,15 @@ | |||
| GELOGW(__VA_ARGS__); \ | |||
| } | |||
| #define GE_LOGE_IF(condition, ...) \ | |||
| if ((condition)) { \ | |||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
| #define GE_LOGE_IF(condition, ...) \ | |||
| if ((condition)) { \ | |||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
| } | |||
| #define GE_CHK_STATUS_RET_NOLOG(expr) \ | |||
| do { \ | |||
| const ge::graphStatus _status = (expr); \ | |||
| if (_status != ge::GRAPH_SUCCESS) { \ | |||
| if (ge::SUCCESS != _status) { \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0) | |||
| @@ -61,7 +49,7 @@ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0) | |||
| @@ -85,7 +73,7 @@ | |||
| do { \ | |||
| const ge::graphStatus _status = (expr); \ | |||
| if (_status) { \ | |||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0) | |||
| @@ -95,7 +83,7 @@ | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (b) { \ | |||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } | |||
| @@ -119,63 +107,41 @@ | |||
| } while (0) | |||
| // If expr is not true, the log is printed and a custom statement is executed | |||
| #define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } | |||
| // If expr is not true, the log is printed and a custom statement is executed | |||
| #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| #define GE_CHK_BOOL_EXEC(expr, exec_expr, ...) \ | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } | |||
| // If expr is not true, the log is printed and a custom statement is executed | |||
| #define GE_CHK_BOOL_EXEC_DEBUG(expr, exec_expr, ...) \ | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGD(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||
| { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| } | |||
| // If expr is not GRAPH_SUCCESS, print the log and return the same value | |||
| #define GE_CHK_STATUS_RET(expr, ...) \ | |||
| do { \ | |||
| const ge::graphStatus _status = (expr); \ | |||
| if (_status != ge::GRAPH_SUCCESS) { \ | |||
| GELOGE(ge::GRAPH_FAILED, __VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| #define GE_CHK_STATUS_RET(expr, ...) \ | |||
| do { \ | |||
| const ge::graphStatus _status = (expr); \ | |||
| if (ge::SUCCESS != _status) { \ | |||
| GELOGE(ge::FAILED, __VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0) | |||
| #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||
| try { \ | |||
| exec_expr0; \ | |||
| } catch (...) { \ | |||
| GELOGE(ge::GRAPH_FAILED, "Make shared failed"); \ | |||
| exec_expr1; \ | |||
| #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||
| try { \ | |||
| exec_expr0; \ | |||
| } catch (...) { \ | |||
| GELOGE(ge::FAILED, "Make shared failed"); \ | |||
| exec_expr1; \ | |||
| } | |||
| /// CCE related macro definition | |||
| /// If expr is not CC_STATUS_GRAPH_SUCCESS, print the log and return | |||
| #define GE_CHK_CCE_RET(expr) \ | |||
| do { \ | |||
| ccgraphStatus_t _cc_ret = (expr); \ | |||
| if (_cc_ret != CC_STATUS_GRAPH_SUCCESS) { \ | |||
| GELOGE(ge::GRAPH_FAILED, "Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
| return ge::GRAPH_FAILED; \ | |||
| } \ | |||
| } while (0) | |||
| #endif // COMMON_GRAPH_DEBUG_GE_LOG_H_ | |||
| @@ -25,7 +25,6 @@ | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/debug/ge_log.h" | |||
| #include "graph/ge_error_codes.h" | |||
| @@ -15,12 +15,10 @@ | |||
| */ | |||
| #include "graph/debug/graph_debug.h" | |||
| #include <algorithm> | |||
| #include <unordered_set> | |||
| #include <vector> | |||
| #include "debug/ge_util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #define TAB " " | |||
| @@ -16,13 +16,11 @@ | |||
| #ifndef COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | |||
| #define COMMON_GRAPH_DEBUG_GRAPH_DEBUG_H_ | |||
| #include <cstdint> | |||
| #include <fstream> | |||
| #include <iostream> | |||
| #include <sstream> | |||
| #include <string> | |||
| #include "external/graph/graph.h" | |||
| #include "./ge_error_codes.h" | |||
| #include "graph/compute_graph.h" | |||
| @@ -15,9 +15,7 @@ | |||
| */ | |||
| #include "detail/attributes_holder.h" | |||
| #include <map> | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| @@ -14,14 +14,12 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph/format_refiner.h" | |||
| #include "format_refiner.h" | |||
| #include <deque> | |||
| #include <iostream> | |||
| #include <set> | |||
| #include <unordered_map> | |||
| #include <unordered_set> | |||
| #include "./compute_graph.h" | |||
| #include "./ge_error_codes.h" | |||
| #include "./graph/ge_tensor.h" | |||
| @@ -57,6 +55,7 @@ graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | |||
| std::vector<ge::NodePtr> &data_nodes, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||
| @@ -82,10 +81,10 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| // consider special node save process | |||
| // get all input desc format | |||
| bool node_is_all_nd = false; | |||
| for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetInputsSize()); i++) { | |||
| auto input_desc = op_desc->GetInputDesc(i); | |||
| auto input_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||
| for (uint32_t i = 0; i < input_size; i++) { | |||
| // Operator pre-set format but not origin format | |||
| auto input_format = input_desc.GetFormat(); | |||
| auto input_format = op_desc->MutableInputDesc(i)->GetFormat(); | |||
| // Pre-save data node and default infer fail | |||
| if (node_ptr->GetType() == DATA) { | |||
| data_nodes.push_back(node_ptr); | |||
| @@ -95,9 +94,9 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| } | |||
| } | |||
| // Get all output desc format | |||
| for (uint32_t i = 0; i < static_cast<uint32_t>(op_desc->GetOutputsSize()); i++) { | |||
| GeTensorDesc output_desc = op_desc->GetOutputDesc(i); | |||
| auto output_format = output_desc.GetFormat(); | |||
| auto output_size = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||
| for (uint32_t i = 0; i < output_size; i++) { | |||
| auto output_format = op_desc->MutableOutputDesc(i)->GetFormat(); | |||
| if (output_format != FORMAT_ND && output_format != FORMAT_RESERVED) { | |||
| node_is_all_nd = true; | |||
| } | |||
| @@ -145,7 +144,8 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||
| GELOGD("Node is [%s] [B]", (node->GetName()).c_str()); | |||
| auto in_data_anchor_idx = in_anchor->GetIdx(); | |||
| auto to_be_set_format = (node->GetOpDesc()->GetInputDesc(in_data_anchor_idx)).GetOriginFormat(); | |||
| auto to_be_set_format = | |||
| node->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx))->GetOriginFormat(); | |||
| if (to_be_set_format == FORMAT_ND) { | |||
| GELOGD("Node [%s] [B], format is ND", (node->GetName()).c_str()); | |||
| continue; | |||
| @@ -162,7 +162,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
| } | |||
| // Check format whether have been set | |||
| int idx = peer_out_data_anchor->GetIdx(); | |||
| auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(idx); | |||
| auto ge_tensor_desc = peer_out_data_node->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(idx)); | |||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
| if (dim_num == 0) { | |||
| @@ -182,7 +182,7 @@ graphStatus FormatRefiner::BackInferProcess(std::deque<ge::NodePtr> &nodes, ge:: | |||
| ge_tensor_desc.SetOriginFormat(to_be_set_format); | |||
| ge_tensor_desc.SetFormat(to_be_set_format); | |||
| (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(idx, ge_tensor_desc); | |||
| (void)peer_out_data_node->GetOpDesc()->UpdateOutputDesc(static_cast<uint32_t>(idx), ge_tensor_desc); | |||
| // Call operator infer format api (forward) to get out format | |||
| GELOGD("call infer format func[Back]!Node is [%s] ", (peer_out_data_node->GetName()).c_str()); | |||
| @@ -205,7 +205,8 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
| GELOGD("Node is [%s] [F]", (node->GetName()).c_str()); | |||
| GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue); | |||
| auto out_data_anchor_idx = out_data_anchor->GetIdx(); | |||
| auto to_be_set_format = (node->GetOpDesc()->GetOutputDesc(out_data_anchor_idx)).GetOriginFormat(); | |||
| auto to_be_set_format = | |||
| node->GetOpDesc()->MutableOutputDesc(static_cast<uint32_t>(out_data_anchor_idx))->GetOriginFormat(); | |||
| if (to_be_set_format == FORMAT_ND) { | |||
| GELOGD("Node [%s] format is ND.[F]", (node->GetName()).c_str()); | |||
| continue; | |||
| @@ -222,7 +223,7 @@ graphStatus FormatRefiner::ForwardInferProcess(std::deque<ge::NodePtr> &nodes, g | |||
| } | |||
| // Check format whether have been set | |||
| int idx = peer_in_data_anchor->GetIdx(); | |||
| auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(idx); | |||
| auto ge_tensor_desc = peer_in_data_node->GetOpDesc()->GetInputDesc(static_cast<uint32_t>(idx)); | |||
| if (ge_tensor_desc.GetOriginFormat() == FORMAT_ND) { | |||
| auto dim_num = ge_tensor_desc.GetShape().GetDimNum(); | |||
| if (dim_num == 0) { | |||
| @@ -285,9 +286,9 @@ void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = | |||
| graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||
| bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | |||
| bool need_process = ((!is_first_infer) && (is_internal_format == false) && (data_format != FORMAT_ND)); | |||
| bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); | |||
| if (!need_process) { | |||
| GELOGI("no necessary to do DataNodeFormatProcess.IsFirstInfer: %d, data_format:%s", is_first_infer, | |||
| GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, | |||
| TypeUtils::FormatToSerialString(data_format).c_str()); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -378,9 +379,9 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
| /// Notice: ignore 5D formats | |||
| auto data_format = graph->GetDataFormat(); | |||
| status = DataNodeFormatProcess(data_nodes, data_format, node_status); | |||
| // Set infer flag to false | |||
| SetInferOrigineFormatFlag(false); | |||
| return status; | |||
| } | |||
| } // namespace ge | |||
| @@ -42,6 +42,8 @@ const std::string ATTR_NAME_BIAS = "bias"; | |||
| const std::string ATTR_NAME_BIAS_TERM = "bias_term"; | |||
| const std::string ATTR_NAME_HAS_BIAS_VALUE = "has_bias_value"; | |||
| const std::string ATTR_NAME_PAD = "pad"; | |||
| const std::string ATTR_NAME_PADS = "pad"; | |||
| @@ -83,6 +85,7 @@ const std::string ATTR_NAME_LRN_BETA = "lrn_beta"; | |||
| const std::string ATTR_NAME_AXIS = "axis"; | |||
| const std::string ATTR_NAME_BROADCAST = "broadcast"; | |||
| const std::string ATTR_NAME_OUTPUT = "output"; | |||
| const std::string ATTR_NAME_OUTPUT_NUM = "output_num"; | |||
| const std::string ATTR_NAME_TIDX = "t_idx"; | |||
| @@ -103,6 +106,13 @@ const std::string ATTR_NAME_TSHAPE = "Tshape"; | |||
| const std::string ATTR_NAME_NAN_OPT = "nan_opt"; | |||
| const std::string ATTR_NAME_AIPP = "aipp"; | |||
| const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||
| const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
| const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | |||
| const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | |||
| const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||
| const std::string ATTR_NAME_INPUT_FORMAT = "input_format"; | |||
| const std::string ATTR_NAME_OUTPUT_FORMAT = "output_format"; | |||
| @@ -111,6 +121,7 @@ const std::string ATTR_NAME_FRAMEWORK_NODE_DEF = "node_def"; | |||
| const std::string ATTR_NAME_FRAMEWORK_OP_DEF = "op_def"; | |||
| const std::string ATTR_NAME_FRAMEWORK_FWK_TYPE = "framework_type"; | |||
| const std::string ATTR_NAME_FRAMEWORK_FUNC_DEF = "func_def"; | |||
| const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
| const std::string ATTR_NAME_INPUT_TENSOR_DESC = "input_tensor_desc"; | |||
| const std::string ATTR_NAME_OUTPUT_TENSOR_DESC = "output_tensor_desc"; | |||
| @@ -122,15 +133,11 @@ const std::string ATTR_NAME_WEIGHTS = "value"; | |||
| const std::string ATTR_NAME_WEIGHTS_DATA = "weights_data"; | |||
| const std::string ATTR_NAME_BROACAST_REAL_DIM_CNT = "broacast_real_dim_cnt"; | |||
| const std::string ATTR_NAME_DIM_ALIGN = "dim_align"; | |||
| const std::string ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE = "original_type"; | |||
| const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
| const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||
| const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||
| const std::string ATTR_MODEL_BATCH_NUM = "batch_num"; | |||
| const std::string ATTR_NAME_STREAM_LABEL = "_stream_label"; | |||
| const std::string ATTR_NAME_STREAM_CYCLE_EVENT_FLAG = "need_stream_cycle_event"; | |||
| const std::string ATTR_NAME_RTSWITCH_RECV_EVENT_ID = "rtswitch_event_id"; | |||
| const std::string ATTR_NAME_AUTOMIC_ADD_START = "automic_add_addr_start"; | |||
| const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||
| // To be deleted | |||
| const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | |||
| @@ -144,15 +151,13 @@ const std::string SSD_MBOX_OCR_FUSION = "permute_flatten_ocr_fusion"; | |||
| const std::string SSD_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
| const std::string SSD_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
| const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
| // Refinedet | |||
| const std::string REFINEDET_MBOX_LOC_FUSION = "permute_flatten_fusion"; | |||
| const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
| const std::string REFINEDET_MBOX_CONF_FUSION = "permute_flatten_reshape_flatten_fusion"; | |||
| const std::string REFINEDET_MBOX_FUSION_BOX_TYPE_NUM = "ssd_mbox_fusion_box_type_num"; | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
| const std::string REFINEDET_RESHAPE_SLICE_CONCAT_FUSION = "reshape_slice_concat_fusion"; | |||
| const std::string SSD_PRIORBOX_CONCAT = "ssd_mbox_conf_priorbox_concat_flag"; | |||
| // _Arg | |||
| const std::string ATTR_NAME_INDEX = "index"; | |||
| @@ -242,6 +247,30 @@ const std::string BATCHNORM_ATTR_ESTIMATED_MEAN = "estimated_mean"; | |||
| const std::string BATCHNORM_ATTR_ESTIMATED_VARIANCE = "estimated_variance"; | |||
| const std::string BATCHNORM_ATTR_SCALE = "scale"; | |||
| const std::string BATCHNORM_ATTR_BIAS = "bias"; | |||
| const std::string BATCHNORM_ATTR_DATA_FORMAT = "data_format"; | |||
| const std::string BATCHNORM_ATTR_IS_TRAINING = "is_training"; | |||
| const std::string BATCHNORM_ATTR_IS_TRAINING_FUSION = "is_training_fusion"; | |||
| // huberloss | |||
| const std::string HUBER_LOSS_ATTR_DELTA = "delta"; | |||
| // SSDRealDivTileMul | |||
| const std::string SSD_REAL_DIV_TILE_MUL_ATTR_TILE_PARA = "tilepara"; | |||
| // SSDSumMulRealDivMean | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_REDUCTION_INDICES = "reduction_indices"; | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_AXIS = "axis"; | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_MEAN_PARA = "mean_para"; | |||
| const std::string SSD_SUM_MUL_REALDIV_MEAN_ATTR_HAS_SUM = "has_sum"; | |||
| // ConcatFive2Four | |||
| // ConcatFour2Five | |||
| const std::string SSD_BOX_TYPE_NUM = "box_type_num"; | |||
| const std::string SSD_CLASS_NUM = "class_num"; | |||
| const std::string TRANS_FOR_LOSS_MODE = "trans_for_loss_mode"; | |||
| const std::string SSD_FEATURE_MAP_SIZE = "feature_map_size"; | |||
| const std::string SSD_FEATURE_MAP_HIGH = "feature_map_high"; | |||
| const std::string SSD_FEATURE_MAP_WIDTH = "feature_map_width"; | |||
| // Scale | |||
| const std::string SCALE_ATTR_SCALE = "scale"; | |||
| @@ -346,6 +375,7 @@ const std::string SOFTMAX_ATTR_AXIS = "axis"; | |||
| // Permute | |||
| const std::string PERMUTE_ATTR_ORDER = "order"; | |||
| const std::string PERMUTE_ATTR_PERM = "perm"; | |||
| // SSD Normalize | |||
| const std::string SSDNORMALIZE_ATTR_ACCROSS_SPATIAL = "across_spatial"; | |||
| @@ -373,6 +403,10 @@ const std::string SSD_PRIOR_BOX_ATTR_ASPECT_RATIO_NUM = "aspect_ratio_num"; | |||
| const std::string SSD_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
| const std::string SSD_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
| // RefinedetDetectionOutput | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE_NUM = "variance_num"; | |||
| const std::string REFINEDET_PRIOR_BOX_ATTR_VARIANCE = "variance"; | |||
| // PRelu | |||
| const std::string PRELU_ATTR_CHANNEL_SHARED = "channel_shared"; | |||
| @@ -386,11 +420,16 @@ const std::string POWER_ATTR_NAME_POWER = "power"; | |||
| const std::string POWER_ATTR_NAME_SCALE = "scale"; | |||
| const std::string POWER_ATTR_NAME_SHIFT = "shift"; | |||
| // log | |||
| const std::string LOG_ATTR_NAME_SCALE = "scale"; | |||
| const std::string LOG_ATTR_NAME_SHIFT = "shift"; | |||
| const std::string LOG_ATTR_NAME_BASE = "base"; | |||
| // Pack | |||
| const std::string PACK_ATTR_NAME_NUM = "N"; | |||
| // Unpack | |||
| const std::string UNPACK_ATTR_NAME_NUM = "num"; | |||
| const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||
| // Gathernd | |||
| const std::string GATHERND_ATTR_NAME_TINDICES = "Tindices"; | |||
| const std::string GATHERND_ATTR_NAME_TPARAMS = "Tparams"; | |||
| @@ -400,6 +439,13 @@ const std::string ARGMAX_ATTR_NAME_TOPK = "topk"; | |||
| const std::string ARGMAX_ATTR_NAME_REDUCESIZE = "reduce_size"; | |||
| const std::string ARGMAX_ATTR_NAME_REDUCESTRIDE = "reduce_stride"; | |||
| const std::string ARGMAX_ATTR_NAME_OUTMAX = "outmaxval"; | |||
| const std::string ARGMAX_ATTR_NAME_AXIS = "axis"; | |||
| const std::string ARGMAX_ATTR_NAME_AXISTYPE = "axis_type"; | |||
| const std::string ARGMAX_ATTR_NAME_KEEPDIMS = "keep_dims"; | |||
| // upsample | |||
| const std::string UPSAMPLE_ATTR_NAME_SCALE_H = "scale_h"; | |||
| const std::string UPSAMPLE_ATTR_NAME_SCALE_W = "scale_w"; | |||
| // Relu | |||
| const std::string ATTR_NAME_NEGATIVE_SLOPE = "negative_slope"; | |||
| @@ -416,6 +462,7 @@ const std::string SPLIT_ATTR_NAME_NUM_SPLIT = "num_split"; | |||
| const std::string TVM_ATTR_NAME_MAGIC = "tvm_magic"; | |||
| const std::string TVM_ATTR_NAME_BLOCKDIM = "tvm_blockdim"; | |||
| const std::string TVM_ATTR_NAME_METADATA = "tvm_metadata"; | |||
| const std::string TVM_ATTR_NAME_WORKSPACE_TYPE = "tvm_workspace_type"; | |||
| // Squeeze | |||
| const std::string SQUEEZE_ATTR_AXIS = "axis"; | |||
| @@ -438,6 +485,7 @@ const std::string ROIALIGN_ATTR_SPATIAL_SCALE = "spatial_scale"; | |||
| const std::string ROIALIGN_ATTR_SAMPLING_RATIO = "sampling_ratio"; | |||
| const std::string ROIALIGN_ATTR_NAME_POOLED_H = "pooled_h"; | |||
| const std::string ROIALIGN_ATTR_NAME_POOLED_W = "pooled_w"; | |||
| const std::string ROIALIGN_ATTR_NAME_TF = "roialign_tf"; | |||
| // Generate_rpn_proposal | |||
| const std::string GENERATE_RPN_PROPOSAL_ATTR_PRE_NMS_TOPK = "pre_nms_topk"; | |||
| @@ -536,19 +584,42 @@ const std::string CONV_GRAD_FILTER_OUTPUT_SHAPE = "conv_grad_filter_output_shape | |||
| const std::string CONV_GRAD_INPUT_OUTPUT_SHAPE = "conv_grad_input_output_shape"; | |||
| // Rnn | |||
| const std::string RNN_MODE_ = "rnn_"; | |||
| const std::string CNN_RNN = "cnn_rnn"; | |||
| const std::string RNN_TENSORFLOW = "rnn_tensorflow"; | |||
| const std::string RNN_MODE_STATIC = "rnn_static"; | |||
| const std::string MUTI_RNN = "multi_rnn"; | |||
| const std::string CNN_RNN = "cnn_rnn"; | |||
| const std::string RNN_MODE_ = "rnn_"; | |||
| const std::string CELL_MODE = "mode"; | |||
| const std::string LSTM_CELL = "lstm_cell"; | |||
| const std::string GRU_CELL = "gru_cell"; | |||
| const std::string RNN_HT = "ht"; | |||
| const std::string RNN_XT_HT = "xt_ht"; | |||
| const std::string RNN_BATCH_SIZE = "batch_size"; | |||
| const std::string LSTM_CELL_CLIP = "lstm_cell_clip"; | |||
| const std::string LSTM_PROJ_CLIP = "lstm_proj_clip"; | |||
| const std::string LSTM_ACTIVATE = "lstm_activate"; | |||
| const std::string LSTM_OUT_MAP = "lstm_out_map"; | |||
| const std::string LSTM_OUT_MODE = "lstm_out_mode"; | |||
| const std::string LSTM_STATE_OUT_MODE = "lstm_state_out_mode"; | |||
| const std::string LSTM_TIME_MAJOR = "lstm_time_major"; | |||
| const std::string LSTM_IS_INPUT_PRE_PROCESS = "lstm_is_input_pre_process"; | |||
| // Upsample | |||
| const std::string UPSAMPLE_ATTR_NAME_SCALE = "scale"; | |||
| // PadV2 | |||
| const std::string PADV2_ATTR_NAME_MODE = "mode"; | |||
| const std::string PADV2_ATTR_NAME_PADS = "paddings"; | |||
| const std::string PADV2_ATTR_NAME_T = "T"; | |||
| const std::string PADV2_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
| const std::string PADV2_ATTR_NAME_CONST_VALUE = "const_value"; | |||
| // MirrorPad | |||
| const std::string MIRRORPAD_ATTR_NAME_MODE = "mode"; | |||
| const std::string MIRRORPAD_ATTR_NAME_PADS = "paddings"; | |||
| const std::string MIRRORPAD_ATTR_NAME_PAD_FORMAT = "pad_format"; | |||
| const std::string MIRRORPAD_ATTR_NAME_CONST_VALUE = "const_value"; | |||
| // Filler | |||
| const std::string FILLER_TYPE = "filler_type"; | |||
| const std::string FILLER_VALUE = "filler_value"; | |||
| @@ -559,9 +630,6 @@ const std::string SHUFFLE_CHANNEL_GROUP = "group"; | |||
| // TopKV2 | |||
| const std::string TOPKV2_ATTR_K = "k"; | |||
| const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
| const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
| // Calibaration | |||
| const std::string STRIDE_H_INDEX = "STRIDE_H_INDEX"; | |||
| const std::string STRIDE_W_INDEX = "STRIDE_W_INDEX"; | |||
| @@ -616,6 +684,8 @@ const std::string ATTR_MODEL_STREAM_NUM = "stream_num"; | |||
| const std::string ATTR_MODEL_EVENT_NUM = "event_num"; | |||
| const std::string ATTR_MODEL_LABEL_NUM = "label_num"; | |||
| const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | |||
| const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | |||
| @@ -630,6 +700,8 @@ const std::string ATTR_MODEL_VAR_SIZE = "variable_size"; | |||
| const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | |||
| const std::string ATTR_MODEL_CORE_TYPE = "core_type"; | |||
| // Public attribute | |||
| const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | |||
| @@ -661,17 +733,145 @@ const std::string TARGET_TYPE_TINY = "TINY"; | |||
| const std::string TARGET_TYPE_LITE = "LITE"; | |||
| // l2_normalize | |||
| const std::string L2_NORMALIZE_ATTR_AXIS = "axis"; | |||
| const std::string L2_NORMALIZE_ATTR_EPS = "eps"; | |||
| const std::string POOL_PARAMA_ATTR_WINDOW = "window"; | |||
| const std::string POOL_PARAMA_ATTR_CEIL_MODE = "ceil_mode"; | |||
| const std::string POOL_PARAMA_ATTR_DATA_MODE = "data_mode"; | |||
| const std::string POOL_PARAMA_ATTR_GLOBAL_POOLING = "global_pooling"; | |||
| const std::string POOL_PARAMA_ATTR_NAN_OP = "nan_opt"; | |||
| const std::string POOL_PARAMA_ATTR_PAD_MOD = "pad_mode"; | |||
| // HCOM | |||
| const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||
| const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||
| const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||
| const std::string HCOM_ATTR_GROUP = "group"; | |||
| const std::string HCOM_ATTR_SR_TAG = "sr_tag"; | |||
| const std::string HCOM_ATTR_SRC_RANK = "src_rank"; | |||
| const std::string HCOM_ATTR_DEST_RANK = "dest_rank"; | |||
| const std::string HCOM_ATTR_FUSION = "fusion"; | |||
| const std::string HCOM_ATTR_SHAPE = "shape"; | |||
| const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||
| // SpaceToDepth/DepthToSpace | |||
| const std::string ATTR_NAME_BLOCK_SIZE = "block_size"; | |||
| // SparseSoftmaxCrossEntropyWithLogits | |||
| const std::string SPARSE_SOFT_MAX_ATTR_TLABLES = "Tlabels"; | |||
| // MaxPoolGradWithArgmax | |||
| const std::string MAX_POOL_GRAD_OUTPUT_SHAPE = "max_pool_grad_output_shape"; | |||
| // AvgPoolGrad | |||
| const std::string AVG_POOL_GRAD_OUTPUT_SHAPE = "avg_pool_grad_output_shape"; | |||
| // Pad | |||
| const std::string ATTR_PAD_FORMAT = "attr_pad_format"; | |||
| // Varible | |||
| const std::string VAR_ATTR_FORMAT = "_var_format"; | |||
| const std::string VAR_ATTR_NAME = "var_name"; | |||
| const std::string VAR_ATTR_FRACTALZ_FORMAT = "FZ"; | |||
| const std::string VAR_ATTR_4D_FORMAT = "4D"; | |||
| const std::string VAR_ATTR_5D_FORMAT = "5D"; | |||
| const std::string VAR_ATTR_DATA_TYPE = "data_format"; | |||
| const std::string VAR_ATTR_VAR_IN_NAME = "var_in_name"; | |||
| const std::string VAR_ATTR_VAR_IN_INDEX = "var_in_index"; | |||
| const std::string VAR_ATTR_VAR_OUT_INDEX = "var_out_index"; | |||
| const std::string VAR_ATTR_SHAPE = "shape"; | |||
| const std::string HALF_VAR_NAME_END = "_fp16"; | |||
| const std::string VAR_ATTR_INITED = "var_is_inited"; | |||
| const std::string VAR_ATTR_CONTAINER = "container"; | |||
| const std::string VAR_ATTR_SHARED_NAME = "shared_name"; | |||
| const std::string VAR_ATTR_DTYPE = "dtype"; | |||
| const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||
| const std::string VAR_ATTR_VAR_IS_SAVE = "_var_is_save"; | |||
| const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||
| const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||
| const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||
| const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||
| // Assign | |||
| const std::string ASSIGN_VALIDATE_SHAPE = "validate_shape"; | |||
| // space2bacth batch2space | |||
| const std::string BATCH_SPACE_ATTR_BLOCK = "block"; | |||
| const std::string BATCH_SPACE_ATTR_PADDING = "padding"; | |||
| // depth_to_space space_to_depth | |||
| const std::string DEPTH_SPACE_ATTR_BLOCK_SIZE = "block_size"; | |||
| // FakeQuantWithMinMaxVars | |||
| const std::string FakeQuantWithMinMaxVars_ATTR_MAX = "max"; | |||
| const std::string FakeQuantWithMinMaxVars_ATTR_MIN = "min"; | |||
| // mobilenet_ssd_conv_fusion | |||
| const std::string SSD_BOXPREDICTOR_BOXES_FUSION = "ssd_boxpredictor_boxes_fusion"; | |||
| const std::string SSD_BOXPREDICTOR_SCORES_FUSION = "ssd_boxpredictor_scores_fusion"; | |||
| const std::string SSD_BOXPREDICTOR_FUSION_BOX_TYPE_NUM = "ssd_boxpredictor_fusion_box_type_num"; | |||
| // lsh project | |||
| const std::string LSH_PROJ_TYPE = "lsh_project_type"; | |||
| // log time stamp | |||
| const std::string LOG_TIME_STAMP_LOGID = "logid"; | |||
| const std::string LOG_TIME_STAMP_NOTIFY = "notify"; | |||
| // ShapeN | |||
| const std::string SHAPEN_ATTR_N = "N"; | |||
| const std::string SHAPEN_ATTR_IN_TYPE = "in_type"; | |||
| const std::string SHAPEN_ATTR_OUT_TYPE = "dtype"; | |||
| // GatherV2 attr def | |||
| const std::string GATHERV2_ATTR_NAME_TAXIS = "Taxis"; | |||
| const std::string GATHERV2_ATTR_NAME_TINDICES = "Tindices"; | |||
| const std::string GATHERV2_ATTR_NAME_TPARAMS = "Tparams"; | |||
| // Reshape attr def | |||
| const std::string RESHAPE_ATTR_NAME_INPUT_DESC = "input_desc_reshape"; | |||
| const std::string RESHAPE_ATTR_NAME_OUTPUT_DESC = "output_desc_reshape"; | |||
| // axis attr def | |||
| const std::string ATTR_NAME_AXIS_ORG_OP = "axis_org_op"; | |||
| const std::string ATTR_NAME_LINK_WITH_SPARE = "link_with_sparse"; | |||
| const std::string ATTR_NAME_NET_OUTPUT_FORMAT = "net_output_format"; | |||
| const std::string ATTR_NAME_NET_OUTPUT_DATATYPE = "net_output_datatype"; | |||
| // For constant folding | |||
| const std::string ATTR_NO_NEED_CONSTANT_FOLDING = "no_need_constant_folding"; | |||
| const std::string ATTR_NAME_CONTINUOUS_INPUT = "continuous_input"; | |||
| const std::string ATTR_NAME_CONTINUOUS_OUTPUT = "continuous_output"; | |||
| const std::string ATTR_NAME_REFERENCE = "reference"; | |||
| const std::string ATTR_NAME_NOTASK = "_no_task"; | |||
| const std::string ATTR_NAME_OUTPUT_REUSE_INPUT = "_output_reuse_input"; | |||
| const std::string ATTR_NAME_REUSE_INPUT_ON_DIM_INDEX = "_reuse_input_on_dim_index"; | |||
| const std::string ATTR_NAME_NOPADDING_CONTINUOUS_INPUT = "_no_padding_continuous_input"; | |||
| const std::string ATTR_NAME_NOPADDING_CONTINUOUS_OUTPUT = "_no_padding_continuous_output"; | |||
| const std::string ATTR_NAME_ATOMIC_INDEX = "atomic_index"; | |||
| // Used for mark the active label list stream of activated node | |||
| const std::string ATTR_NAME_ACTIVE_LABEL_LIST = "_active_label_list"; | |||
| // Used for l2cache, true: the memory of all inputs is used for the last time. | |||
| const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_lifecycle"; | |||
| // Multi batch | |||
| const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; | |||
| const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; | |||
| @@ -682,6 +882,8 @@ const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||
| const std::string ATTR_NAME_TRUE_BRANCH_STREAM = "true_branch_stream"; | |||
| const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | |||
| const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | |||
| const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||
| const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||
| const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | |||
| const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | |||
| @@ -691,6 +893,9 @@ const std::string ATTR_NAME_CYCLIC_DEPENDENCE_FLAG = "_cyclic_dependence_flag"; | |||
| const std::string ATTR_NAME_NEXT_ITERATION = "_next_iteration_node"; | |||
| // Function Op | |||
| const std::string ATTR_NAME_PARENT_NODE_INDEX = "_parent_node_index"; | |||
| // Used for mark the active node is for loop, type:bool | |||
| const std::string ATTR_NAME_IS_LOOP_ACTIVE = "is_loop_active"; | |||
| @@ -702,6 +907,20 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; | |||
| const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||
| // l1 fusion and other fusion in future | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||
| const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | |||
| const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | |||
| const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; | |||
| const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; | |||
| const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; | |||
| const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; | |||
| const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; | |||
| const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||
| const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||
| const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||
| // Atomic addr clean attrs | |||
| const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | |||
| const std::string ATOMIC_ATTR_OUTPUT_INDEX = "atomic_output_index"; | |||
| @@ -722,6 +941,9 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | |||
| // For inserted op | |||
| const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | |||
| // For compress weight | |||
| const std::string ATTR_NAME_COMPRESS_WEIGHT = "_is_compress_weight"; | |||
| // For data dump | |||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES = "_datadump_original_op_names"; | |||
| const std::string ATTR_NAME_DATA_DUMP_IS_MULTIOP = "_datadump_is_multiop"; | |||
| @@ -732,24 +954,17 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX = "_datadump_origin_ou | |||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | |||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | |||
| // Variable | |||
| const std::string REF_VAR_SRC_VAR_NAME = "ref_var_src_var_name"; | |||
| const std::string VAR_ATTR_SRC_VAR_NAME = "_src_var_name"; | |||
| const std::string REF_VAR_PRE_PEER_OUT_INDEX = "ref_var_pre_peer_out_index"; | |||
| const std::string VAR_ATTR_VAR_IS_BROADCAST = "_var_is_broadcast"; | |||
| const std::string VAR_ATTR_VAR_IS_RESTORE = "_var_is_restore"; | |||
| // HCOM | |||
| const std::string HCOM_ATTR_ROOT_RANK = "root_rank"; | |||
| const std::string HCOM_ATTR_RANK_SIZE = "rank_size"; | |||
| const std::string HCOM_ATTR_SHAPE = "shape"; | |||
| const std::string HCOM_ATTR_DATA_TYPE = "dtype"; | |||
| // functional ops attr | |||
| const std::string ATTR_NAME_TCOND = "Tcond"; | |||
| const std::string ATTR_NAME_TIN = "Tin"; | |||
| const std::string ATTR_NAME_TOUT = "Tout"; | |||
| const std::string ATTR_NAME_THEN_BRANCH = "then_branch"; | |||
| const std::string ATTR_NAME_ELSE_BRANCH = "else_branch"; | |||
| const std::string HCOM_ATTR_REDUCE_TYPE = "reduction"; | |||
| // used for label switch | |||
| const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | |||
| const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | |||
| const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | |||
| const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | |||
| // Dynamic stitch | |||
| const std::string DYNAMIC_STITCH_ATTR_NAME_NUM = "DynamicStitchN_"; | |||
| } // namespace ge | |||
| @@ -22,7 +22,7 @@ | |||
| #include "graph/model_serialize.h" | |||
| #include "proto/ge_ir.pb.h" | |||
| #include "detail/model_serialize_imp.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "debug/ge_attr_define.h" | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_util.h" | |||
| @@ -53,7 +53,7 @@ string GeAttrValue::NamedAttrs::GetName() const { | |||
| GeAttrValue GeAttrValue::NamedAttrs::GetItem(const string &key) const { | |||
| GeAttrValue value; | |||
| (void)GetAttr(key, value); | |||
| GetAttr(key, value); | |||
| return value; | |||
| } | |||
| @@ -1081,6 +1081,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||
| if (!GetListInt(std::move(obj), name, int64_list)) { | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < int64_list.size(); ++i) { | |||
| if (int64_list[i] > INT32_MAX) { | |||
| GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to int32_t", i, int64_list[i]); | |||
| @@ -1098,6 +1099,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListInt(ConstA | |||
| if (!GetListInt(std::move(obj), name, int64_list)) { | |||
| return false; | |||
| } | |||
| for (size_t i = 0; i < int64_list.size(); ++i) { | |||
| if (int64_list[i] > UINT32_MAX) { | |||
| GELOGE(GRAPH_FAILED, "index %zu %ld int64_t value cannot cast to uint32_t", i, int64_list[i]); | |||
| @@ -1215,6 +1217,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||
| GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); | |||
| op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
| if (op_desc->HasAttr("_input_name_idx_key")) { | |||
| if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed."); | |||
| } | |||
| } | |||
| if (op_desc->HasAttr("_input_name_idx_value")) { | |||
| if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed."); | |||
| } | |||
| } | |||
| if (op_desc->HasAttr("_opt_input")) { | |||
| if (op_desc->DelAttr("_opt_input") != SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed."); | |||
| } | |||
| } | |||
| return op_desc; | |||
| } | |||
| @@ -1237,11 +1256,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||
| op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
| op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); | |||
| op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), | |||
| org_op_desc->optional_input_names_.end()); | |||
| op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | |||
| op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | |||
| op_desc->infer_func_ = org_op_desc->infer_func_; | |||
| @@ -1250,4 +1264,25 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||
| return op_desc; | |||
| } | |||
| std::string AttrUtils::GetAllAttrsStr(AttrUtils::ConstAttrHolderAdapter &&obj) { | |||
| auto holder = obj.get(); | |||
| if (holder == nullptr) { | |||
| return ""; | |||
| } | |||
| auto attrs_map = holder->GetAttrMap(); | |||
| if (attrs_map.GetProtoMsg() == nullptr) { | |||
| return ""; | |||
| } | |||
| std::map<std::string, std::string> ordered_attrs; | |||
| for (auto &attr : *(attrs_map.GetProtoMsg())) { | |||
| ordered_attrs[attr.first] = attr.second.SerializeAsString(); | |||
| } | |||
| std::stringstream ss; | |||
| for (auto &attr : ordered_attrs) { | |||
| ss << attr.first << ":" << attr.second << ";"; | |||
| } | |||
| return ss.str(); | |||
| } | |||
| } // namespace ge | |||
| @@ -163,6 +163,34 @@ int64_t GeShape::GetShapeSize() const { | |||
| return res; | |||
| } | |||
| /// | |||
| /// @brief Check is unknown shape | |||
| /// @return bool | |||
| /// /// | |||
| bool GeShape::IsUnknownShape() const { | |||
| auto proto_msg = shape_def_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| for (auto i : proto_msg->dim()) { | |||
| if (i < 0) { | |||
| return true; | |||
| } | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| /// | |||
| /// @brief Check is a scalar | |||
| /// @return bool | |||
| /// | |||
| bool GeShape::IsScalar() const { | |||
| auto proto_msg = shape_def_.GetProtoMsg(); | |||
| if (proto_msg != nullptr) { | |||
| return proto_msg->dim().empty(); | |||
| } | |||
| return false; | |||
| } | |||
| const string TENSOR_UTILS_SIZE = "size"; | |||
| const string TENSOR_UTILS_WEIGHT_SIZE = "weight_size"; | |||
| const string TENSOR_UTILS_REUSE_INPUT = "reuse_input"; | |||
| @@ -639,14 +667,14 @@ GeTensor &GeTensor::operator=(const GeTensor &other) { | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::GetSize(const GeTensorDesc &tensor_desc, | |||
| uint32_t &size) { | |||
| int64_t &size) { | |||
| auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | |||
| GE_CHECK_NOTNULL(tensor_descriptor_msg); | |||
| size = static_cast<uint32_t>(tensor_descriptor_msg->size()); | |||
| size = static_cast<int64_t>(tensor_descriptor_msg->size()); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, uint32_t size) { | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void TensorUtils::SetSize(GeTensorDesc &tensor_desc, int64_t size) { | |||
| auto tensor_descriptor_msg = tensor_desc.tensor_descriptor_.GetProtoMsg(); | |||
| if (tensor_descriptor_msg != nullptr) { | |||
| tensor_descriptor_msg->set_size(size); | |||
| @@ -49,6 +49,7 @@ void Model::Init() { | |||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | |||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | |||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | |||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); | |||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0); | |||
| (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI); | |||
| version_ = 0; | |||
| @@ -77,9 +78,9 @@ void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; } | |||
| Graph Model::GetGraph() const { return graph_; } | |||
| graphStatus Model::Save(Buffer &buffer) const { | |||
| graphStatus Model::Save(Buffer &buffer, bool is_dump) const { | |||
| ModelSerialize serialize; | |||
| buffer = serialize.SerializeModel(*this); | |||
| buffer = serialize.SerializeModel(*this, is_dump); | |||
| return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED; | |||
| } | |||
| @@ -113,7 +114,7 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||
| } | |||
| int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS); | |||
| if (fd < 0) { | |||
| GELOGE(GRAPH_FAILED, "open file failed, file path [%s] ", real_path); | |||
| GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno)); | |||
| return GRAPH_FAILED; | |||
| } | |||
| bool ret = ge_proto.SerializeToFileDescriptor(fd); | |||
| @@ -129,6 +130,10 @@ graphStatus Model::SaveToFile(const string &file_name) const { | |||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!ret) { | |||
| GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -152,7 +157,7 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||
| } | |||
| int fd = open(real_path, O_RDONLY); | |||
| if (fd < 0) { | |||
| GELOGE(GRAPH_FAILED, "open file failed"); | |||
| GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno)); | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -170,6 +175,10 @@ graphStatus Model::LoadFromFile(const string &file_name) { | |||
| GELOGE(GRAPH_FAILED, "close file descriptor fail."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!ret) { | |||
| GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| return Load(model_def); | |||
| } | |||
| @@ -15,10 +15,8 @@ | |||
| */ | |||
| #include "graph/model_serialize.h" | |||
| #include <google/protobuf/text_format.h> | |||
| #include <iostream> | |||
| #include "debug/ge_attr_define.h" | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_util.h" | |||
| @@ -26,6 +24,7 @@ | |||
| #include "graph/detail/model_serialize_imp.h" | |||
| #include "proto/ge_ir.pb.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "debug/ge_op_types.h" | |||
| using std::string; | |||
| @@ -84,20 +83,29 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ | |||
| return true; | |||
| } | |||
| bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { | |||
| bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||
| if (op_desc == nullptr || op_def_proto == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input Para Invalid"); | |||
| return false; | |||
| } | |||
| if (op_desc->op_def_.GetProtoMsg() != nullptr) { | |||
| *op_def_proto = *op_desc->op_def_.GetProtoMsg(); | |||
| // Delete unnecessary attr | |||
| if (is_dump) { | |||
| auto attr = op_def_proto->mutable_attr(); | |||
| attr->erase(ATTR_NAME_FRAMEWORK_NODE_DEF); | |||
| attr->erase(ATTR_NAME_FRAMEWORK_OP_DEF); | |||
| attr->erase(ATTR_NAME_FRAMEWORK_FUNC_DEF); | |||
| GE_IF_BOOL_EXEC((op_def_proto->type() == CONSTANT || op_def_proto->type() == CONSTANTOP), | |||
| attr->erase(ATTR_NAME_WEIGHTS)); | |||
| } | |||
| op_def_proto->clear_input_desc(); | |||
| op_def_proto->clear_output_desc(); | |||
| // Input descs | |||
| if (op_desc->GetInputsSize() > 0) { | |||
| auto size = static_cast<uint32_t>(op_desc->GetInputsSize()); | |||
| if (op_desc->GetAllInputsSize() > 0) { | |||
| auto size = static_cast<uint32_t>(op_desc->GetAllInputsSize()); | |||
| for (uint32_t i = 0; i < size; i++) { | |||
| auto tensor_desc = op_desc->GetInputDescPtr(i); | |||
| auto tensor_desc = op_desc->GetInputDescPtrDfault(i); | |||
| if (tensor_desc != nullptr && tensor_desc->tensor_descriptor_.GetProtoMsg() != nullptr) { | |||
| *op_def_proto->add_input_desc() = *(tensor_desc->tensor_descriptor_.GetProtoMsg()); | |||
| } | |||
| @@ -117,12 +125,12 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||
| return true; | |||
| } | |||
| bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto) { | |||
| bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { | |||
| if (node == nullptr || op_def_proto == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input Para Node Invalid"); | |||
| return false; | |||
| } | |||
| if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto)) { | |||
| if (!SerializeOpDesc(node->GetOpDesc(), op_def_proto, is_dump)) { | |||
| GELOGE(GRAPH_FAILED, "Serialize OpDesc failed"); | |||
| return false; | |||
| } | |||
| @@ -134,7 +142,8 @@ bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_ | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::SerializeGraph(const ConstComputeGraphPtr &graph, | |||
| proto::GraphDef *graph_proto) { | |||
| proto::GraphDef *graph_proto, | |||
| bool is_dump) { | |||
| if (graph == nullptr || graph_proto == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input para Invalid"); | |||
| return false; | |||
| @@ -156,7 +165,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||
| *graph_proto->mutable_attr() = *graph->attrs_.GetProtoMsg(); | |||
| } | |||
| for (const auto &node : graph->GetDirectNode()) { | |||
| if (!SerializeNode(node, graph_proto->add_op())) { | |||
| if (!SerializeNode(node, graph_proto->add_op(), is_dump)) { | |||
| if (node->GetOpDesc() != nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Serialize Node %s failed", node->GetName().c_str()); | |||
| } | |||
| @@ -166,7 +175,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||
| return true; | |||
| } | |||
| bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto) { | |||
| bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *model_proto, bool is_dump) { | |||
| if (model_proto == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "model_proto para Invalid"); | |||
| return false; | |||
| @@ -183,7 +192,7 @@ bool ModelSerializeImp::SerializeModel(const Model &model, proto::ModelDef *mode | |||
| GELOGE(GRAPH_FAILED, "GetComputeGraph return nullptr"); | |||
| return false; | |||
| } | |||
| if (!SerializeGraph(compute_graph, model_proto->add_graph())) { | |||
| if (!SerializeGraph(compute_graph, model_proto->add_graph(), is_dump)) { | |||
| GELOGE(GRAPH_FAILED, "SerializeGraph fail"); | |||
| return false; | |||
| } | |||
| @@ -390,10 +399,10 @@ bool ReadProtoFromBinaryFile(const uint8_t *data, size_t len, google::protobuf:: | |||
| return true; | |||
| } | |||
| Buffer ModelSerialize::SerializeModel(const Model &model) { | |||
| Buffer ModelSerialize::SerializeModel(const Model &model, bool is_dump) { | |||
| proto::ModelDef model_def; | |||
| ModelSerializeImp imp; | |||
| if (!imp.SerializeModel(model, &model_def)) { | |||
| if (!imp.SerializeModel(model, &model_def, is_dump)) { | |||
| return Buffer(); | |||
| } | |||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||
| @@ -401,7 +401,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get | |||
| vec.push_back(in_anchor); | |||
| } | |||
| } | |||
| // Push back in_control_anchor_ | |||
| // Push back in_control_anchor_ | |||
| if ((in_control_anchor_->GetPeerOutControlAnchors().size() > 0) || | |||
| (in_control_anchor_->GetPeerOutDataAnchors().size() > 0)) { | |||
| auto in_anchor = Anchor::DynamicAnchorCast<Anchor>(in_control_anchor_); | |||
| @@ -512,7 +512,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||
| auto peer_out_anchors = in_control_anchor_->GetPeerOutDataAnchors(); | |||
| for (const auto &out_anchor : peer_out_anchors) { | |||
| GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, " in_control_anchor_ peer out data anchors is nullptr"); | |||
| GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "in_control_anchor_ peer out data anchors is nullptr"); | |||
| auto node = out_anchor->GetOwnerNode(); | |||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||
| vec.push_back(node); | |||
| @@ -521,7 +521,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<NodePtr> Node::GetIn | |||
| auto peer_out_control_anchors = in_control_anchor_->GetPeerOutControlAnchors(); | |||
| for (const auto &out_control_anchor : peer_out_control_anchors) { | |||
| GE_CHK_BOOL_EXEC(out_control_anchor != nullptr, continue, | |||
| " in_control_anchor_ peer out control anchors is nullptr"); | |||
| "in_control_anchor_ peer out control anchors is nullptr"); | |||
| auto node = out_control_anchor->GetOwnerNode(); | |||
| GE_CHK_BOOL_EXEC(node != nullptr, continue, "GetOwnerNode is nullptr"); | |||
| vec.push_back(node); | |||
| @@ -785,6 +785,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus Node::UpdateOpDesc(co | |||
| GE_CHK_BOOL_EXEC(op_->GetInputsSize() == op_desc->GetInputsSize(), return GRAPH_PARAM_INVALID, | |||
| "Inputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetInputsSize(), | |||
| op_desc->GetInputsSize()); | |||
| GE_CHK_BOOL_EXEC(op_->GetOutputsSize() == op_desc->GetOutputsSize(), return GRAPH_PARAM_INVALID, | |||
| "Outputs count expected to be same, orginial OpDesc %zu, Param OpDesc %zu", op_->GetOutputsSize(), | |||
| op_desc->GetOutputsSize()); | |||
| @@ -61,6 +61,12 @@ const std::string ATTR_NAME_WORKSPACE_BYTES = "workspace_bytes"; | |||
| const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | |||
| const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | |||
| const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | |||
| const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value"; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { | |||
| op_def_.InitDefault(); | |||
| if (op_def_.GetProtoMsg() != nullptr) { | |||
| @@ -202,7 +208,8 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d | |||
| } | |||
| graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | |||
| if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| if (input_name_idx.find(name) != input_name_idx.end()) { | |||
| GELOGI("input %s is exist, update it", name.c_str()); | |||
| graphStatus ret = UpdateInputDesc(name, input_desc); | |||
| return ret; | |||
| @@ -214,15 +221,17 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||
| return GRAPH_FAILED; | |||
| } | |||
| inputs_desc_.push_back(in_desc); | |||
| (void)input_name_idx_.insert(make_pair(name, index)); | |||
| (void)input_name_idx.insert(make_pair(name, index)); | |||
| SetAllInputName(input_name_idx); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } | |||
| graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| for (unsigned int i = 0; i < num; i++) { | |||
| string input_name = name + std::to_string(i); | |||
| GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, | |||
| GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||
| "Add input tensor_desc is existed. name[%s]", input_name.c_str()); | |||
| std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||
| @@ -234,12 +243,13 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||
| (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | |||
| // Update index in input_name_idx | |||
| for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { | |||
| for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||
| it->second += 1; | |||
| } | |||
| (void)input_name_idx_.insert(make_pair(input_name, 0)); | |||
| (void)input_name_idx.insert(make_pair(input_name, 0)); | |||
| } | |||
| SetAllInputName(input_name_idx); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -270,10 +280,19 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int | |||
| graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | |||
| if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; | |||
| (void)optional_input_names_.insert(name); | |||
| vector<string> optional_input_names; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| optional_input_names.push_back(name); | |||
| (void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::vector<string> OpDesc::GetAllOptionalInputName() const { | |||
| vector<string> optional_input_names; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| return optional_input_names; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||
| GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); | |||
| @@ -288,11 +307,12 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { | |||
| return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && | |||
| IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||
| IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && | |||
| IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||
| IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||
| return ( | |||
| IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") && | |||
| IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||
| IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") && | |||
| IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||
| IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { | |||
| @@ -366,8 +386,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD | |||
| } | |||
| graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { | |||
| auto it = input_name_idx_.find(name); | |||
| if (it == input_name_idx_.end()) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| if (it == input_name_idx.end()) { | |||
| GELOGW("Cann't find the input desc. name[%s]", name.c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -387,8 +408,9 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc & | |||
| } | |||
| bool OpDesc::InputIsSet(const string &name) const { | |||
| auto it = input_name_idx_.find(name); | |||
| if (it != input_name_idx_.end()) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| if (it != input_name_idx.end()) { | |||
| GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); | |||
| auto tensor_desc = inputs_desc_[it->second]; | |||
| GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); | |||
| @@ -406,18 +428,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc | |||
| } | |||
| GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | |||
| auto it = input_name_idx_.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc()); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); | |||
| return *(inputs_desc_[it->second].get()); | |||
| } | |||
| GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| vector<string> names; | |||
| if (input_name_idx_.empty()) { | |||
| if (input_name_idx.empty()) { | |||
| return OpDesc::Vistor<string>(shared_from_this(), names); | |||
| } | |||
| for (std::pair<string, uint32_t> input : input_name_idx_) { | |||
| for (std::pair<string, uint32_t> input : input_name_idx) { | |||
| names.push_back(input.first); | |||
| } | |||
| return OpDesc::Vistor<string>(shared_from_this(), names); | |||
| @@ -483,6 +507,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetInputsSize() co | |||
| return size; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t OpDesc::GetAllInputsSize() const { return inputs_desc_.size(); } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddOutputDesc(const ge::GeTensorDesc &output_desc) { | |||
| int index = static_cast<int>(outputs_desc_.size()); | |||
| return AddOutputDesc("__output" + std::to_string(index), output_desc); | |||
| @@ -548,6 +574,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu | |||
| return outputs_desc_[index]; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { | |||
| return static_cast<uint32_t>(outputs_desc_.size()); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllOutputsDesc() const { | |||
| vector<GeTensorDesc> temp{}; | |||
| for (const auto &it : outputs_desc_) { | |||
| @@ -580,6 +610,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetI | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr | |||
| OpDesc::GetInputDescPtrDfault(uint32_t index) const { | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG((index) < (uint32_t)(inputs_desc_.size()), nullptr); | |||
| return inputs_desc_[(int32_t)index]; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr<const GeTensorDesc>()); | |||
| return inputs_desc_[it->second]; | |||
| } | |||
| graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | |||
| if (is_push_back) { | |||
| for (unsigned int i = 0; i < num; i++) { | |||
| @@ -603,12 +646,45 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int | |||
| } | |||
| bool OpDesc::IsOptionalInput(const string &name) const { | |||
| return optional_input_names_.find(name) != optional_input_names_.end(); | |||
| vector<string> optional_input_names; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| for (auto &item : optional_input_names) { | |||
| if (item == name) { | |||
| return true; | |||
| } | |||
| } | |||
| return false; | |||
| } | |||
| bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | |||
| std::map<string, uint32_t> OpDesc::GetAllInputName() { return input_name_idx_; } | |||
| std::map<string, uint32_t> OpDesc::GetAllInputName() const { | |||
| std::map<string, uint32_t> input_name_idx; | |||
| std::vector<string> key; | |||
| std::vector<uint32_t> value; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||
| (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||
| if (key.size() != value.size()) { | |||
| GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); | |||
| } else { | |||
| for (uint32_t i = 0; i < key.size(); ++i) { | |||
| input_name_idx.insert(std::pair<string, uint32_t>(key.at(i), value.at(i))); | |||
| } | |||
| } | |||
| return input_name_idx; | |||
| } | |||
| void OpDesc::SetAllInputName(const std::map<string, uint32_t> &input_name_idx) { | |||
| std::vector<string> key; | |||
| std::vector<uint32_t> value; | |||
| for (auto &item : input_name_idx) { | |||
| key.emplace_back(item.first); | |||
| value.emplace_back(item.second); | |||
| } | |||
| (void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||
| (void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||
| } | |||
| std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; } | |||
| @@ -619,6 +695,7 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||
| auto factory_map_size = input_name_idx.size(); | |||
| // It indicates that some inputs have no optionalname. | |||
| // The redundant optionalname of factory needs to be deleted and then assigned | |||
| auto all_input_name_idx = GetAllInputName(); | |||
| if (input_map_size < factory_map_size) { | |||
| GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, | |||
| factory_map_size); | |||
| @@ -631,22 +708,23 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||
| } | |||
| if (input_name_idx.size() == input_map_size) { | |||
| GELOGI("UpdateInputName"); | |||
| input_name_idx_ = input_name_idx; | |||
| all_input_name_idx = input_name_idx; | |||
| } else { | |||
| ret = false; | |||
| GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); | |||
| } | |||
| } else if (input_map_size == factory_map_size) { | |||
| input_name_idx_ = input_name_idx; | |||
| all_input_name_idx = input_name_idx; | |||
| } else { | |||
| ret = false; | |||
| GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); | |||
| } | |||
| SetAllInputName(all_input_name_idx); | |||
| return ret; | |||
| } | |||
| bool OpDesc::UpdateOutputName(std::map<string, uint32_t> output_name_idx) { | |||
| size_t output_map_size = GetAllOutputsDesc().size(); | |||
| size_t output_map_size = GetAllOutputsDescSize(); | |||
| size_t factory_map_size = output_name_idx.size(); | |||
| if (output_map_size < factory_map_size) { | |||
| GELOGI("UpdateOutputName org outputname map size: %zu, factory outputname map size: %zu", output_map_size, | |||
| @@ -754,17 +832,17 @@ graphStatus OpDesc::OpVerify() { | |||
| } | |||
| graphStatus OpDesc::CommonVerify() const { | |||
| for (string iname : GetAllInputNames()) { | |||
| for (const string &iname : GetAllInputNames()) { | |||
| // Checking shape of all inputs | |||
| vector<int64_t> ishape = GetInputDesc(iname).GetShape().GetDims(); | |||
| vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | |||
| for (int64_t dim : ishape) { | |||
| GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | |||
| iname.c_str()); | |||
| } | |||
| } | |||
| // Check all attributes defined | |||
| const auto all_attributes = GetAllAttrs(); | |||
| for (const auto name : GetAllAttrNames()) { | |||
| const auto &all_attributes = GetAllAttrs(); | |||
| for (const auto &name : GetAllAttrNames()) { | |||
| GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | |||
| "operator attribute %s is empty.", name.c_str()); | |||
| } | |||
| @@ -773,19 +851,21 @@ graphStatus OpDesc::CommonVerify() const { | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { | |||
| auto it = input_name_idx_.begin(); | |||
| for (; it != input_name_idx_.end(); ++it) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.begin(); | |||
| for (; it != input_name_idx.end(); ++it) { | |||
| if (it->second == index) { | |||
| break; | |||
| } | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), ""); | |||
| return it->first; | |||
| } | |||
| int OpDesc::GetInputIndexByName(const string &name) const { | |||
| auto it_find = input_name_idx_.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it_find = input_name_idx.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1); | |||
| return static_cast<int>(it_find->second); | |||
| } | |||
| @@ -1065,10 +1145,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputCo | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, | |||
| const int &index) { | |||
| if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| if (input_name_idx.find(name) != input_name_idx.end()) { | |||
| GELOGI("Restore input name index is existed. name[%s]", name.c_str()); | |||
| } | |||
| (void)input_name_idx_.insert(make_pair(name, index)); | |||
| (void)input_name_idx.insert(make_pair(name, index)); | |||
| SetAllInputName(input_name_idx); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -1104,4 +1186,45 @@ graphStatus OpDesc::CallInferFormatFunc(Operator &op) { | |||
| } | |||
| return (graphStatus)infer_format_func_(op); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetSubgraphInstanceName(uint32_t index) const { | |||
| if (static_cast<size_t>(index) >= subgraph_instance_names_.size()) { | |||
| return ""; | |||
| } | |||
| return subgraph_instance_names_.at(index); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::vector<std::string> &OpDesc::GetSubgraphInstanceNames() | |||
| const { | |||
| return subgraph_instance_names_; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::AddSubgraphInstanceName(std::string name) { | |||
| subgraph_instance_names_.emplace_back(std::move(name)); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::RemoveSubgraphInstanceName(const std::string &name) { | |||
| for (auto iter = subgraph_instance_names_.begin(); iter != subgraph_instance_names_.end(); ++iter) { | |||
| if (*iter == name) { | |||
| subgraph_instance_names_.erase(iter); | |||
| return; | |||
| } | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::AddSubgraphName(const std::string &name) { | |||
| auto iter = subgraph_names_to_index_.find(name); | |||
| if (iter != subgraph_names_to_index_.end()) { | |||
| GELOGW("The subgraph name %s exists, index %u", name.c_str(), iter->second); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto size = subgraph_names_to_index_.size(); | |||
| subgraph_names_to_index_[name] = size; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY const std::map<std::string, uint32_t> &OpDesc::GetSubgraphNameIndexes() | |||
| const { | |||
| return subgraph_names_to_index_; | |||
| } | |||
| } // namespace ge | |||
| @@ -20,8 +20,7 @@ | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_util.h" | |||
| using std::function; | |||
| using std::vector; | |||
| using namespace std; | |||
| namespace ge { | |||
| @@ -15,13 +15,12 @@ | |||
| */ | |||
| #include "external/graph/operator.h" | |||
| #include <stdint.h> | |||
| #include <algorithm> | |||
| #include <mutex> | |||
| #include <queue> | |||
| #include <set> | |||
| #include "array_ops.h" | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "debug/ge_util.h" | |||
| @@ -33,7 +32,6 @@ | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/node.h" | |||
| #include "graph/op_desc.h" | |||
| #include "graph/operator_factory.h" | |||
| #include "graph/usr_types.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/op_desc_utils.h" | |||
| @@ -48,10 +46,6 @@ using std::string; | |||
| using std::to_string; | |||
| using std::vector; | |||
| namespace { | |||
| const char *const kValue = "value"; | |||
| } // namespace | |||
| namespace ge { | |||
| class OpIO { | |||
| public: | |||
| @@ -148,6 +142,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
| for (int i = static_cast<int>(is_input_const.size()); i <= dst_index; ++i) { | |||
| is_input_const.push_back(false); | |||
| } | |||
| is_input_const[dst_index] = is_const; | |||
| op_desc_->SetIsInputConst(is_input_const); | |||
| @@ -179,8 +174,8 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
| GE_CHK_BOOL_EXEC(dst_index >= 0, return, "Find input index by name failed. name[%s], op name:%s", dst_name.c_str(), | |||
| op_desc_->GetName().c_str()); | |||
| auto out_op_impl = out_handler->GetOwner(); | |||
| GE_CHK_BOOL_EXEC(out_op_impl && out_op_impl->GetOpDescImpl(), return, "out_handler invalid. name[%s]", | |||
| dst_name.c_str()); | |||
| GE_CHK_BOOL_EXEC(out_op_impl != nullptr && out_op_impl->GetOpDescImpl() != nullptr, return, | |||
| "out_handler invalid. name[%s]", dst_name.c_str()); | |||
| bool is_const = false; | |||
| if (out_op_impl->GetOpDescImpl()->GetType() == CONSTANT) { | |||
| is_const = true; | |||
| @@ -193,7 +188,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
| op_desc_->SetIsInputConst(is_input_const); | |||
| OpIO in_handler(dst_name, dst_index, shared_from_this()); | |||
| GE_CHK_BOOL_EXEC(!!out_op_impl, return, "Get out_handler's impl failed."); | |||
| GE_CHK_BOOL_EXEC(out_op_impl != nullptr, return, "Get out_handler's impl failed."); | |||
| out_op_impl->UpdateLinkMapImpl(src_name, in_handler); | |||
| auto src_output_desc = out_op_impl->GetOutputDesc(src_name); | |||
| @@ -210,7 +205,7 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||
| void AddControlInputImp(const ge::Operator &src_oprt) { | |||
| if (src_oprt.operator_impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Src operator impl is nullptr"); | |||
| GELOGE(FAILED, "Src operator impl is nullptr"); | |||
| return; | |||
| } | |||
| for (auto &input : control_input_link_) { | |||
| @@ -520,9 +515,9 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||
| if (peer_node_ptr->GetOpDesc() != nullptr) { | |||
| const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | |||
| if (op_descType == CONSTANTOP) { | |||
| return const_op.GetAttr(kValue, data); | |||
| return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||
| } else if (op_descType == CONSTANT) { | |||
| return const_op.GetAttr(kValue, data); | |||
| return const_op.GetAttr(op::Const::name_attr_value(), data); | |||
| } | |||
| } | |||
| } else { | |||
| @@ -542,9 +537,9 @@ graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) | |||
| Operator const_op(out_handle.GetOwner()); | |||
| const auto &op_desc_impl_type = out_handle.GetOwner()->GetOpDescImpl()->GetType(); | |||
| if (op_desc_impl_type == CONSTANTOP) { | |||
| return const_op.GetAttr(kValue, data); | |||
| return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||
| } else if (op_desc_impl_type == CONSTANT) { | |||
| return const_op.GetAttr(kValue, data); | |||
| return const_op.GetAttr(op::Const::name_attr_value(), data); | |||
| } | |||
| } | |||
| return GRAPH_FAILED; | |||
| @@ -709,6 +704,7 @@ void Operator::InputRegister(const string &name) { | |||
| void Operator::OptionalInputRegister(const string &name) { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
| // [No need to verify return value] | |||
| (void)operator_impl_->GetOpDescImpl()->AddOptionalInputDesc(name, | |||
| GeTensorDesc(GeShape(), FORMAT_RESERVED, DT_UNDEFINED)); | |||
| } | |||
| @@ -716,24 +712,28 @@ void Operator::OptionalInputRegister(const string &name) { | |||
| void Operator::InferFuncRegister(const std::function<graphStatus(Operator &)> &func) { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
| // [No need to verify return value] | |||
| (void)operator_impl_->GetOpDescImpl()->AddInferFunc(func); | |||
| } | |||
| void Operator::InferFormatFuncRegister(const std::function<graphStatus(Operator &)> &func) { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
| // [No need to verify return value] | |||
| (void)operator_impl_->GetOpDescImpl()->AddInferFormatFunc(func); | |||
| } | |||
| void Operator::VerifierFuncRegister(const std::function<graphStatus(Operator &)> &func) { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
| // [No need to verify return value] | |||
| (void)operator_impl_->GetOpDescImpl()->AddVerifierFunc(func); | |||
| } | |||
| void Operator::OutputRegister(const string &name) { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
| // [No need to verify return value] | |||
| (void)operator_impl_->GetOpDescImpl()->AddOutputDesc(name, GeTensorDesc()); | |||
| } | |||
| @@ -757,7 +757,8 @@ int Operator::GetDynamicInputNum(const string &name) const { | |||
| void Operator::DynamicOutputRegister(const string &name, const unsigned int num, bool is_push_back) { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return, "GetOpDescImpl is nullptr."); | |||
| (void)AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num); | |||
| GE_CHK_BOOL_EXEC(AttrUtils::SetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return, | |||
| "Set %s int failed", name.c_str()); | |||
| (void)operator_impl_->GetOpDescImpl()->AddDynamicOutputDesc(name, num, is_push_back); | |||
| } | |||
| @@ -765,7 +766,8 @@ int Operator::GetDynamicOutputNum(const string &name) const { | |||
| GE_CHK_BOOL_EXEC(operator_impl_ != nullptr, return 0, "operator impl is nullptr."); | |||
| GE_CHK_BOOL_EXEC(operator_impl_->GetOpDescImpl() != nullptr, return 0, "GetOpDescImpl is nullptr."); | |||
| int num = 0; | |||
| (void)AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_INPUT_TD_NUM(name), num); | |||
| GE_CHK_BOOL_EXEC(AttrUtils::GetInt(operator_impl_->GetOpDescImpl(), DYNAMIC_OUTPUT_TD_NUM(name), num), return num, | |||
| "Get %s int failed", name.c_str()); | |||
| return num; | |||
| } | |||
| @@ -1141,7 +1143,9 @@ class GraphBuilderImpl { | |||
| GELOGW("Input operator should be Data, Variable operator or operator that has output but no input."); | |||
| } | |||
| } | |||
| GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, | |||
| "User Input do not include operator such as \ | |||
| Data, Variable operator or operator that has output but no input."); | |||
| auto ret = WalkAllOperators(vec_inputs); | |||
| GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | |||
| @@ -1163,7 +1167,8 @@ class GraphBuilderImpl { | |||
| que.pop(); | |||
| for (const auto &op_impl : vec_tem) { | |||
| GE_CHK_BOOL_EXEC(op_impl != nullptr, return GRAPH_FAILED, "Operator Impl is null.") | |||
| GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue) | |||
| GE_CHK_BOOL_EXEC_INFO(all_nodes_info_.find(op_impl) == all_nodes_info_.end(), continue, | |||
| "This node %s has created.", op_impl->GetName().c_str()) | |||
| auto node_ptr = graph_->AddNode(op_impl->op_desc_); | |||
| GE_CHK_BOOL_EXEC(node_ptr != nullptr, return GRAPH_FAILED, "Add node failed."); | |||
| all_nodes_info_.insert(std::make_pair(op_impl, node_ptr)); | |||
| @@ -1202,10 +1207,13 @@ class GraphBuilderImpl { | |||
| for (const auto &node_info : all_nodes_info_) { | |||
| auto src_op_impl_ptr = node_info.first; | |||
| auto src_node_ptr = node_info.second; | |||
| GE_IF_BOOL_EXEC(src_op_impl_ptr == nullptr || src_node_ptr == nullptr, continue); | |||
| auto out_links = src_op_impl_ptr->output_links_; | |||
| GE_CHK_BOOL_EXEC(src_op_impl_ptr->op_desc_ != nullptr, return GRAPH_FAILED, | |||
| "Src operator impl's op_desc is null."); | |||
| auto &op_desc = src_op_impl_ptr->op_desc_; | |||
| GE_IF_BOOL_EXEC(op_desc == nullptr, continue); | |||
| for (const auto &out : out_links) { | |||
| auto src_idx = op_desc->GetOutputIndexByName(out.first); | |||
| GE_CHK_BOOL_EXEC(src_idx >= 0, return GRAPH_FAILED, "Find output index by name failed"); | |||
| @@ -1216,7 +1224,9 @@ class GraphBuilderImpl { | |||
| for (const auto &dst_opio : out.second) { | |||
| auto dst_node_info = all_nodes_info_.find(dst_opio.GetOwner()); | |||
| GE_CHK_BOOL_EXEC(dst_node_info != all_nodes_info_.end(), return GRAPH_FAILED, "Find Dst node failed."); | |||
| GE_IF_BOOL_EXEC(dst_node_info->second == nullptr, continue); | |||
| auto dst_anchor = dst_node_info->second->GetInDataAnchor(dst_opio.GetIndex()); | |||
| GE_CHK_BOOL_EXEC(dst_anchor != nullptr, return GRAPH_FAILED, "GetInDataAnchor failed."); | |||
| @@ -1260,8 +1270,7 @@ inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { | |||
| ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | |||
| auto graph_builder_impl = GraphBuilderImpl(name); | |||
| ComputeGraphPtr compute_graph = graph_builder_impl.BuildGraph(inputs); | |||
| GE_IF_BOOL_EXEC(compute_graph == nullptr, return compute_graph); | |||
| GE_CHK_BOOL_EXEC(compute_graph != nullptr, return compute_graph, "Computer graph is nullptr"); | |||
| compute_graph->SetAllNodesInfo(graph_builder_impl.GetAllNodesInfo()); | |||
| if (HasSameNameNode(compute_graph)) { | |||
| GELOGW("Compute do not allow has same name nodes."); | |||
| @@ -15,13 +15,11 @@ | |||
| */ | |||
| #include "graph/opsproto_manager.h" | |||
| #include <algorithm> | |||
| #include <cstdlib> | |||
| #include <algorithm> | |||
| #include <functional> | |||
| #include <iostream> | |||
| #include <sstream> | |||
| #include "debug/ge_util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/debug/ge_log.h" | |||
| @@ -155,7 +153,7 @@ void OpsProtoManager::LoadOpsProtoPluginSo(std::string &path) { | |||
| // Load .so file | |||
| for (auto elem : file_list) { | |||
| void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL | RTLD_NODELETE); | |||
| void *handle = dlopen(elem.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
| if (handle == nullptr) { | |||
| GELOGW("OpsProtoManager dlopen failed, plugin name:%s. Message(%s).", elem.c_str(), dlerror()); | |||
| continue; | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "./ge_context.h" | |||
| #include "./ge_global_options.h" | |||
| #include "./ge_local_context.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| @@ -87,4 +86,5 @@ uint32_t GEContext::DeviceId() { return device_id_; } | |||
| uint64_t GEContext::TraceId() { return trace_id_; } | |||
| void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | |||
| } // namespace ge | |||
| @@ -22,6 +22,7 @@ | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "external/graph/operator.h" | |||
| @@ -34,6 +35,122 @@ | |||
| #include "utils/type_utils.h" | |||
| namespace ge { | |||
| namespace { | |||
| constexpr const char *kRefIndex = "parent_node_index"; | |||
| graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| auto sub_graph = root_graph->GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (const auto &node_sub : sub_graph->GetDirectNode()) { | |||
| if (node_sub->GetType() != DATA) { | |||
| continue; | |||
| } | |||
| int ref_i; | |||
| auto data_opdesc = node_sub->GetOpDesc(); | |||
| if (data_opdesc == nullptr) { | |||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no OpDesc", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!AttrUtils::GetInt(node_sub->GetOpDesc(), kRefIndex, ref_i)) { | |||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto input_desc = op_desc->MutableInputDesc(ref_i); | |||
| if (input_desc == nullptr) { | |||
| GE_LOGE( | |||
| "The ref index(%d) on the data %s on the sub graph %s " | |||
| "parent node %s are incompatible, inputs num %u", | |||
| ref_i, node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str(), node->GetAllOutDataAnchorsSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto ret = data_opdesc->UpdateInputDesc(0, *input_desc); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GE_LOGE("Failed to update input desc of data %s on the sub graph %s parent node %s", | |||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| ret = data_opdesc->UpdateOutputDesc(0, *input_desc); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GE_LOGE("Failed to update output desc of data %s on the sub graph %s parent node %s", | |||
| node_sub->GetName().c_str(), name.c_str(), node->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| auto sub_graph = root_graph->GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| GE_LOGE("Can node find the subgrpah %s for node %s", name.c_str(), node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| NodePtr netoutput = nullptr; | |||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
| auto sub_node = sub_nodes.at(i - 1); | |||
| if (sub_node->GetType() == NETOUTPUT) { | |||
| netoutput = sub_node; | |||
| break; | |||
| } | |||
| } | |||
| if (netoutput == nullptr) { | |||
| GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto netoutput_opdesc = netoutput->GetOpDesc(); | |||
| if (netoutput_opdesc == nullptr) { | |||
| GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, no OpDesc on it", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (auto &edge_anchor : netoutput->GetAllInDataAnchors()) { | |||
| auto edge_desc = netoutput_opdesc->MutableInputDesc(edge_anchor->GetIdx()); | |||
| if (edge_desc == nullptr) { | |||
| GE_LOGE("Invalid NetOutput node on sub graph %s, parent node %s, can not find input tensor %d", name.c_str(), | |||
| node->GetName().c_str(), edge_anchor->GetIdx()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| int ref_i; | |||
| if (!AttrUtils::GetInt(edge_desc, kRefIndex, ref_i)) { | |||
| // if there is no ref index on the TensorDesc, it means the output data will be ignored outer. | |||
| continue; | |||
| } | |||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||
| if (output_desc == nullptr) { | |||
| GE_LOGE( | |||
| "The ref index(%d) on the input %d of netoutput %s on the sub graph %s " | |||
| "parent node %s are incompatible, outputs num %u", | |||
| ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
| node->GetAllOutDataAnchorsSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } // namespace | |||
| void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "node is null"); | |||
| @@ -42,7 +159,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
| ge::OpDescPtr op_desc = node->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | |||
| std::string str; | |||
| if (!op_desc->GetAllInputsDescPtr().empty()) { | |||
| if (op_desc->GetInputsSize() != 0) { | |||
| std::string input_desc_str = "input shape: "; | |||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||
| input_desc_str += "["; | |||
| @@ -56,7 +173,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
| str += input_desc_str; | |||
| } | |||
| if (!op_desc->GetAllOutputsDescPtr().empty()) { | |||
| if (op_desc->GetAllOutputsDescSize() != 0) { | |||
| std::string output_desc_str = "output shape: "; | |||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||
| if (output_desc == nullptr) { | |||
| @@ -76,13 +193,24 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
| } | |||
| graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op) { | |||
| return InferShapeAndType(node, op, true); | |||
| } | |||
| graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator &op, bool before_subgraph) { | |||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | |||
| auto op_desc = node->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | |||
| const auto &op_type = op_desc->GetType(); | |||
| graphStatus ret; | |||
| if (before_subgraph) { | |||
| ret = UpdateSubGraphDataNodes(node); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| return ret; | |||
| } | |||
| } | |||
| // Get infer func and execute | |||
| graphStatus ret = op_desc->CallInferFunc(op); | |||
| ret = op_desc->CallInferFunc(op); | |||
| if (ret == GRAPH_PARAM_INVALID) { | |||
| // Op ir no infer func, try to get infer func from operator factory | |||
| auto node_op = ge::OperatorFactory::CreateOperator("node_op", op_desc->GetType()); | |||
| @@ -113,7 +241,14 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & | |||
| ret = op_desc->CallInferFunc(op); | |||
| GELOGI("op CallInferFunc second. ret: %u", ret); | |||
| } | |||
| return ret; | |||
| if (ret != GRAPH_SUCCESS) { | |||
| return ret; | |||
| } | |||
| if (!before_subgraph) { | |||
| return UpdateParentNodeOutTensor(node); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, InferenceContextPtr> &context_map, | |||
| @@ -179,8 +314,11 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf | |||
| namespace { | |||
| std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | |||
| return InferShapeAndType(node, true); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node, | |||
| bool before_subgraph) { | |||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | |||
| if (node->Verify() != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | |||
| @@ -199,7 +337,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||
| op.SetInferenceContext(inference_context); | |||
| graphStatus status = InferShapeAndType(node, op); | |||
| graphStatus status = InferShapeAndType(node, op, before_subgraph); | |||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | |||
| (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); | |||
| } else { | |||
| @@ -353,6 +353,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) | |||
| } | |||
| } | |||
| } | |||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | |||
| } | |||
| @@ -516,13 +517,14 @@ graphStatus Tensor::IsValid() { | |||
| GELOGW("mul overflow: %lu, %u", shape_size, type_length); | |||
| } else { | |||
| if (shape_size * type_length != data_size) { | |||
| // [Just log] Constructor | |||
| GELOGW("tensor length not equal: shape_byte_size=%lu, data_size=%zu, dt_type=%s.", shape_size * type_length, | |||
| data_size, TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -539,7 +541,7 @@ GeTensorDesc TensorAdapter::TensorDesc2GeTensorDesc(const TensorDesc &tensor_des | |||
| tensor_desc.GetDataType()); | |||
| ge_tensor_desc.SetOriginShape(GeShape(tensor_desc.GetOriginShape().GetDims())); | |||
| ge_tensor_desc.SetOriginFormat(tensor_desc.GetOriginFormat()); | |||
| auto size = static_cast<uint32_t>(tensor_desc.GetSize()); | |||
| auto size = tensor_desc.GetSize(); | |||
| TensorUtils::SetSize(ge_tensor_desc, size); | |||
| auto real_dim_cnt = static_cast<uint32_t>(tensor_desc.GetRealDimCnt()); | |||
| @@ -552,7 +554,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||
| ge_tensor_desc.GetDataType()); | |||
| tensor_desc.SetOriginShape(Shape(ge_tensor_desc.GetOriginShape().GetDims())); | |||
| tensor_desc.SetOriginFormat(ge_tensor_desc.GetOriginFormat()); | |||
| uint32_t size = 0; | |||
| int64_t size = 0; | |||
| (void)TensorUtils::GetSize(ge_tensor_desc, size); | |||
| tensor_desc.SetSize(size); | |||
| @@ -15,18 +15,21 @@ | |||
| */ | |||
| #include "graph/utils/ge_ir_utils.h" | |||
| #include <utility> | |||
| #include "framework/common/debug/ge_log.h" | |||
| namespace { | |||
| const char *const kControlAnchorIndex = ":-1"; | |||
| const char *const kNodeTypeForSubgraph = "subgraph"; | |||
| const char *const kPrefixForInputDesc = "input_desc_attr_"; | |||
| const char *const kPrefixForOutputDesc = "output_desc_attr_"; | |||
| const char *const kDumpGEGraph = "DUMP_GE_GRAPH"; | |||
| const int8_t kMaxRecursionDepth = 10; | |||
| const char *const kDumpGeGraph = std::getenv(kDumpGEGraph); | |||
| const int64_t kDumpLevel = (kDumpGeGraph != nullptr) ? std::strtol(kDumpGeGraph, nullptr, 10) : ge::OnnxUtils::NO_DUMP; | |||
| const int64_t kInputPrefixLength = 5; | |||
| const int64_t kOutputPrefixLength = 6; | |||
| using AttrDefPair = ::google::protobuf::MapPair<std::string, ge::proto::AttrDef>; | |||
| } // namespace | |||
| namespace ge { | |||
| @@ -198,7 +201,7 @@ void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_A | |||
| void OnnxUtils::AddAttrProto(onnx::NodeProto *node_proto, onnx::AttributeProto_AttributeType type, const string &name, | |||
| ::google::protobuf::RepeatedField<bool> data) { | |||
| if (node_proto == nullptr) { | |||
| GELOGE(FAILED, "Node_proto %s is nullptr.", name.c_str()); | |||
| GELOGE(FAILED, "Node proto %s is nullptr.", name.c_str()); | |||
| return; | |||
| } | |||
| if (!data.empty()) { | |||
| @@ -320,7 +323,16 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||
| auto cmps_tab_offset = tensor_descriptor->cmps_tab_offset(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | |||
| "input_desc_cmps_tab_offset:" + std::to_string(i), &cmps_tab_offset); | |||
| const auto &tensor_desc_map = tensor_descriptor->attr(); | |||
| std::string suffix = ":" + std::to_string(i); | |||
| AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForInputDesc, suffix); | |||
| } else { | |||
| GELOGW("Tensor descriptor is nullptr"); | |||
| continue; | |||
| } | |||
| } else { | |||
| GELOGW("Input desc is nullptr"); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| @@ -360,16 +372,25 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||
| auto real_dim_cnt = tensor_descriptor->real_dim_cnt(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, | |||
| "output_desc_real_dim_cnt:" + std::to_string(i), &real_dim_cnt); | |||
| const auto &tensor_desc_map = tensor_descriptor->attr(); | |||
| std::string suffix = ":" + std::to_string(i); | |||
| AddAttrProtoForAttrsFromAttrMap(tensor_desc_map, node_proto, kPrefixForOutputDesc, suffix); | |||
| } else { | |||
| GELOGW("Tensor descriptor is nullptr"); | |||
| continue; | |||
| } | |||
| } else { | |||
| GELOGW("Output desc is nullptr"); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| } | |||
| void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto) { | |||
| GE_CHK_BOOL_EXEC(op_def != nullptr, return, "Opdef is nullptr"); | |||
| const auto &op_def_attr_map = op_def->attr(); | |||
| for (const auto &item : op_def_attr_map) { | |||
| void OnnxUtils::AddAttrProtoForAttrsFromAttrMap( | |||
| const ::google::protobuf::Map<std::string, ::ge::proto::AttrDef> &attr_map, onnx::NodeProto *node_proto, | |||
| const std::string &prefix, const std::string &suffix) { | |||
| for (const auto &item : attr_map) { | |||
| auto attr_name = item.first; | |||
| auto attr_def = item.second; | |||
| auto attr_type = attr_def.value_case(); | |||
| @@ -377,36 +398,40 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||
| const auto &tensor_def = attr_def.t(); | |||
| const auto &tensor_desc = tensor_def.desc(); | |||
| auto data_type = ge::proto::DataType_Name(tensor_desc.dtype()); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_dtype:", &data_type); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_dtype" + suffix, | |||
| &data_type); | |||
| auto dims = tensor_desc.shape().dim(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name + "_desc_shape:", dims); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + "_desc_shape" + suffix, | |||
| dims); | |||
| auto layout = tensor_desc.layout(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_desc_layout:", &layout); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_desc_layout" + suffix, | |||
| &layout); | |||
| auto device_type = tensor_desc.device_type(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | |||
| attr_name + "_desc_device_type:", &device_type); | |||
| prefix + attr_name + "_desc_device_type" + suffix, &device_type); | |||
| if (kDumpLevel == DUMP_ALL) { | |||
| auto data = tensor_def.data(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name + "_data", &data); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + "_data" + suffix, | |||
| &data); | |||
| } | |||
| } | |||
| if (attr_type == ge::proto::AttrDef::kS) { | |||
| if (kDumpLevel == DUMP_ALL) { | |||
| auto str_value = attr_def.s(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, attr_name, &str_value); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, prefix + attr_name + suffix, &str_value); | |||
| } | |||
| } | |||
| if (attr_type == ge::proto::AttrDef::kI) { | |||
| auto int_value = attr_def.i(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||
| } | |||
| if (attr_type == ge::proto::AttrDef::kF) { | |||
| auto float_value = attr_def.f(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, attr_name, &float_value); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOAT, prefix + attr_name + suffix, &float_value); | |||
| } | |||
| if (attr_type == ge::proto::AttrDef::kB) { | |||
| auto int_value = static_cast<int64_t>(attr_def.b()); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, attr_name, &int_value); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INT, prefix + attr_name + suffix, &int_value); | |||
| } | |||
| if (attr_type == ge::proto::AttrDef::kList) { | |||
| const auto &list_value = attr_def.list(); | |||
| @@ -415,21 +440,21 @@ void OnnxUtils::AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, on | |||
| ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_STRING) { | |||
| if (kDumpLevel == DUMP_ALL) { | |||
| const auto &strings = list_value.s(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, attr_name, strings); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRINGS, prefix + attr_name + suffix, strings); | |||
| } | |||
| } | |||
| if (list_value_type == | |||
| ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_FLOAT) { | |||
| const auto &floats = list_value.f(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, attr_name, floats); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_FLOATS, prefix + attr_name + suffix, floats); | |||
| } | |||
| if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_INT) { | |||
| const auto &ints = list_value.i(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, ints); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, ints); | |||
| } | |||
| if (list_value_type == ge::proto::AttrDef_ListValue_ListValueType::AttrDef_ListValue_ListValueType_VT_LIST_BOOL) { | |||
| const auto &bools = list_value.b(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, attr_name, bools); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, prefix + attr_name + suffix, bools); | |||
| } | |||
| } | |||
| } | |||
| @@ -481,8 +506,15 @@ void OnnxUtils::AddAttrProtoFromNodeMembers(const NodePtr &node, onnx::NodeProto | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "workspace_bytes", workspace_bytes); | |||
| const auto &is_input_const = op_def->is_input_const(); | |||
| AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "is_input_const", is_input_const); | |||
| AddAttrProtoForAttrsFromOpDef(op_def, node_proto); | |||
| const auto &op_def_attr_map = op_def->attr(); | |||
| AddAttrProtoForAttrsFromAttrMap(op_def_attr_map, node_proto); | |||
| } else { | |||
| GELOGE(FAILED, "Opdef is nullptr"); | |||
| return; | |||
| } | |||
| } else { | |||
| GELOGE(FAILED, "Opdesc is nullptr"); | |||
| return; | |||
| } | |||
| } | |||
| @@ -526,15 +558,13 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||
| node_proto->clear_input(); | |||
| // 1. Add input by in data edge | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| if (in_data_anchor != nullptr) { | |||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||
| node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||
| std::to_string(peer_out_anchor->GetIdx())); | |||
| } else { | |||
| // Add "" input | |||
| node_proto->add_input(""); | |||
| } | |||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if ((peer_out_anchor != nullptr) && (peer_out_anchor->GetOwnerNode() != nullptr)) { | |||
| node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + ":" + | |||
| std::to_string(peer_out_anchor->GetIdx())); | |||
| } else { | |||
| // Add "" input | |||
| node_proto->add_input(""); | |||
| } | |||
| } | |||
| @@ -547,6 +577,9 @@ bool OnnxUtils::EncodeNodeLink(const NodePtr &node, onnx::NodeProto *node_proto) | |||
| node_proto->add_input(peer_out_anchor->GetOwnerNode()->GetName() + kControlAnchorIndex); | |||
| } | |||
| } | |||
| } else { | |||
| GELOGE(FAILED, "Incontrol anchor is nullptr"); | |||
| return false; | |||
| } | |||
| // 3. Add output for Netron visual support | |||
| @@ -584,7 +617,7 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||
| } | |||
| const auto &op_desc = node->GetOpDesc(); | |||
| if (op_desc != nullptr) { | |||
| auto size_out = op_desc->GetOutputsSize(); | |||
| uint32_t size_out = static_cast<uint32_t>(op_desc->GetOutputsSize()); | |||
| if (size_out > 0) { | |||
| for (uint32_t i = 0; i < size_out; i++) { | |||
| const ConstGeTensorDescPtr &ge_tensor = op_desc->GetOutputDescPtr(i); | |||
| @@ -598,7 +631,13 @@ void OnnxUtils::EncodeTypeProtoTensorType(const NodePtr &node, onnx::TypeProto_T | |||
| auto dim = shape->add_dim(); | |||
| dim->set_dim_value(d); | |||
| } | |||
| } else { | |||
| GELOGW("Shape is nullptr"); | |||
| continue; | |||
| } | |||
| } else { | |||
| GELOGW("Ge tensor is nullptr"); | |||
| continue; | |||
| } | |||
| } | |||
| } | |||
| @@ -666,7 +705,7 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||
| } | |||
| // For subgraphs: a subgraph is represented by a node | |||
| for (const auto &sub_compute_graph : compute_graph->sub_graph_) { | |||
| for (const auto &sub_compute_graph : compute_graph->GetAllSubgraphs()) { | |||
| if (sub_compute_graph != nullptr) { | |||
| auto node_proto = graph_proto->add_node(); | |||
| if (node_proto == nullptr) { | |||
| @@ -679,6 +718,10 @@ bool OnnxUtils::ConvertGeModelToModelProto(const ge::Model &model, onnx::ModelPr | |||
| attr->set_name("graph"); | |||
| attr->set_type(onnx::AttributeProto_AttributeType_GRAPH); | |||
| auto sub_graph_proto = attr->mutable_g(); | |||
| if (sub_graph_proto == nullptr) { | |||
| GELOGW("Sub graph proto is nullptr"); | |||
| continue; | |||
| } | |||
| if (!EncodeGraph(sub_compute_graph, sub_graph_proto)) { | |||
| GELOGW("Encode sub graph: %s fail", sub_compute_graph->GetName().c_str()); | |||
| continue; | |||
| @@ -831,56 +874,116 @@ void OnnxUtils::DecodeAttribute(const onnx::AttributeProto &attr_proto, int64_t | |||
| value = attr_proto.i(); | |||
| } | |||
| void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||
| const std::string &attr_name_for_input_output_desc, int32_t index, | |||
| OpDescPtr &op_desc) { | |||
| if (op_desc == nullptr || op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "op_desc or op_desc->MutableInputDesc(index) is nullptr"); | |||
| void OnnxUtils::DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||
| const std::string &attr_name_for_input_desc, int32_t index, | |||
| OpDescPtr &op_desc) { | |||
| if (op_desc->MutableInputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableInputDesc(static_cast<uint32_t>(index)) is nullptr", | |||
| op_desc->GetName().c_str(), attr_name_for_input_desc.c_str()); | |||
| return; | |||
| } | |||
| if (attr_name_for_input_output_desc == "input_desc_dtype") { | |||
| if (attr_name_for_input_desc == "input_desc_dtype") { | |||
| auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | |||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | |||
| } else if (attr_name_for_input_output_desc == "input_desc_shape") { | |||
| } else if (attr_name_for_input_desc == "input_desc_shape") { | |||
| std::vector<std::int64_t> ints; | |||
| DecodeAttribute(attr_proto, ints); | |||
| GeShape ge_shape(ints); | |||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | |||
| } else if (attr_name_for_input_output_desc == "input_desc_layout") { | |||
| } else if (attr_name_for_input_desc == "input_desc_layout") { | |||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | |||
| } else if (attr_name_for_input_output_desc == "input_desc_origin_shape") { | |||
| } else if (attr_name_for_input_desc == "input_desc_origin_shape") { | |||
| std::vector<std::int64_t> ints; | |||
| DecodeAttribute(attr_proto, ints); | |||
| GeShape ge_shape(ints); | |||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | |||
| } else if (attr_name_for_input_output_desc == "input_desc_origin_layout") { | |||
| } else if (attr_name_for_input_desc == "input_desc_origin_layout") { | |||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
| op_desc->MutableInputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | |||
| } else if (attr_name_for_input_output_desc == "output_desc_dtype") { | |||
| } else if (attr_name_for_input_desc == "input_desc_size") { | |||
| int64_t input_size = 0; | |||
| auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
| DecodeAttribute(attr_proto, input_size); | |||
| tensor_descriptor->set_size(input_size); | |||
| } else if (attr_name_for_input_desc == "input_desc_data_offset") { | |||
| auto tensor_descriptor = op_desc->MutableInputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
| int64_t offset = 0; | |||
| DecodeAttribute(attr_proto, offset); | |||
| tensor_descriptor->set_data_offset(offset); | |||
| } else { | |||
| return; | |||
| } | |||
| } | |||
| void OnnxUtils::DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||
| const std::string &attr_name_for_output_desc, int32_t index, | |||
| OpDescPtr &op_desc) { | |||
| if (op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "[op name %s,attr name %s]op_desc->MutableOutputDesc(static_cast<uint32_t>(index)) is nullptr", | |||
| op_desc->GetName().c_str(), attr_name_for_output_desc.c_str()); | |||
| return; | |||
| } | |||
| if (attr_name_for_output_desc == "output_desc_dtype") { | |||
| auto data_type = TypeUtils::SerialStringToDataType(attr_proto.s()); | |||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetDataType(data_type); | |||
| } else if (attr_name_for_input_output_desc == "output_desc_shape") { | |||
| } else if (attr_name_for_output_desc == "output_desc_shape") { | |||
| std::vector<std::int64_t> ints; | |||
| DecodeAttribute(attr_proto, ints); | |||
| GeShape ge_shape(ints); | |||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetShape(ge_shape); | |||
| } else if (attr_name_for_input_output_desc == "output_desc_layout") { | |||
| } else if (attr_name_for_output_desc == "output_desc_layout") { | |||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetFormat(data_format); | |||
| } else if (attr_name_for_input_output_desc == "output_desc_origin_shape") { | |||
| } else if (attr_name_for_output_desc == "output_desc_origin_shape") { | |||
| std::vector<std::int64_t> ints; | |||
| DecodeAttribute(attr_proto, ints); | |||
| GeShape ge_shape(ints); | |||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginShape(ge_shape); | |||
| } else if (attr_name_for_input_output_desc == "output_desc_origin_layout") { | |||
| } else if (attr_name_for_output_desc == "output_desc_origin_layout") { | |||
| auto data_format = TypeUtils::SerialStringToFormat(attr_proto.s()); | |||
| op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->SetOriginFormat(data_format); | |||
| } else if (attr_name_for_output_desc == "output_desc_size") { | |||
| int64_t output_size = 0; | |||
| auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
| DecodeAttribute(attr_proto, output_size); | |||
| tensor_descriptor->set_size(output_size); | |||
| } else if (attr_name_for_output_desc == "output_desc_data_offset") { | |||
| auto tensor_descriptor = op_desc->MutableOutputDesc(static_cast<uint32_t>(index))->tensor_descriptor_.GetProtoMsg(); | |||
| int64_t offset = 0; | |||
| DecodeAttribute(attr_proto, offset); | |||
| tensor_descriptor->set_data_offset(offset); | |||
| } else { | |||
| return; | |||
| } | |||
| } | |||
| void OnnxUtils::DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||
| const std::string &attr_name_for_input_output_desc, int32_t index, | |||
| OpDescPtr &op_desc) { | |||
| if (op_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "op_desc is nullptr"); | |||
| return; | |||
| } | |||
| if (attr_name_for_input_output_desc.substr(0, kInputPrefixLength) == "input") { | |||
| DecodeNodeAttributeForOpInDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||
| } else if (attr_name_for_input_output_desc.substr(0, kOutputPrefixLength) == "output") { | |||
| DecodeNodeAttributeForOpOutDesc(attr_proto, attr_name_for_input_output_desc, index, op_desc); | |||
| } else { | |||
| return; | |||
| } | |||
| } | |||
| void OnnxUtils::DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def) { | |||
| auto attr_map = op_def.mutable_attr(); | |||
| const auto &attr_name = attr_proto.name(); | |||
| ge::proto::AttrDef op_attr; | |||
| int64_t value = 0; | |||
| DecodeAttribute(attr_proto, value); | |||
| op_attr.set_i(value); | |||
| attr_map->insert(AttrDefPair(attr_name, op_attr)); | |||
| } | |||
| void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc) { | |||
| if (op_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "DecodeNodeAttributeForOpDesc: op_desc is nullptr"); | |||
| @@ -910,6 +1013,16 @@ void OnnxUtils::DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_pr | |||
| std::vector<std::int64_t> ints; | |||
| DecodeAttribute(attr_proto, ints); | |||
| op_desc->SetDstIndex(ints); | |||
| } else if (attr_name == "fusion_scope") { | |||
| DecodeNodeAttributeForOpDef(attr_proto, *op_desc->op_def_.GetProtoMsg()); | |||
| } else if (attr_name == "input_i") { | |||
| std::vector<std::int64_t> ints; | |||
| DecodeAttribute(attr_proto, ints); | |||
| op_desc->SetInputOffset(ints); | |||
| } else if (attr_name == "output_i") { | |||
| std::vector<std::int64_t> ints; | |||
| DecodeAttribute(attr_proto, ints); | |||
| op_desc->SetOutputOffset(ints); | |||
| } else { | |||
| return; | |||
| } | |||
| @@ -939,20 +1052,14 @@ bool OnnxUtils::DecodeNodeDesc(const onnx::NodeProto *node_proto, OpDescPtr &op_ | |||
| auto size_in = attr.i(); | |||
| for (int64_t i = 0; i < size_in; i++) { | |||
| GeTensorDesc ge_tensor_desc; | |||
| if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||
| GELOGW("Add inputdesc failed"); | |||
| continue; | |||
| } | |||
| GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add inputdesc failed."); | |||
| } | |||
| } | |||
| if (attr.name() == "output_desc_nums") { | |||
| auto size_out = attr.i(); | |||
| for (int64_t i = 0; i < size_out; i++) { | |||
| GeTensorDesc ge_tensor_desc; | |||
| if (op_desc->AddInputDesc(ge_tensor_desc) != GRAPH_SUCCESS) { | |||
| GELOGW("add inputdesc failed"); | |||
| continue; | |||
| } | |||
| GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(ge_tensor_desc) == GRAPH_SUCCESS, continue, "Add outputdesc failed."); | |||
| } | |||
| } | |||
| } | |||
| @@ -970,10 +1077,7 @@ bool OnnxUtils::DecodeGraph(int recursion_depth, const onnx::GraphProto &graph_p | |||
| } | |||
| graph = ComGraphMakeShared<ge::ComputeGraph>(graph_proto.name()); | |||
| if (graph == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "ComputeGraph make shared failed"); | |||
| return false; | |||
| } | |||
| GE_CHK_BOOL_EXEC(graph != nullptr, return false, "ComputeGraph make shared failed"); | |||
| /// 1. Decode all nodes first, node should include input | |||
| /// and output nodes and nodes which represent sub graphs | |||
| std::map<std::string, NodePtr> node_map; | |||
| @@ -131,6 +131,10 @@ class OnnxUtils { | |||
| static void AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const OpDescPtr &op_desc); | |||
| static void AddAttrProtoForAttrsFromAttrMap(const ::google::protobuf::Map<std::string, ge::proto::AttrDef> &attr_map, | |||
| onnx::NodeProto *node_proto, const std::string &prefix = "", | |||
| const std::string &suffix = ""); | |||
| static void AddAttrProtoForAttrsFromOpDef(const ge::proto::OpDef *op_def, onnx::NodeProto *node_proto); | |||
| static onnx::TensorProto_DataType EncodeDataType(ge::DataType data_type); | |||
| @@ -172,10 +176,20 @@ class OnnxUtils { | |||
| static void DecodeAttribute(const onnx::AttributeProto &attr_proto, std::string &value); | |||
| static void DecodeNodeAttributeForOpOutDesc(const onnx::AttributeProto &attr_proto, | |||
| const std::string &attr_name_for_output_desc, int32_t index, | |||
| OpDescPtr &op_desc); | |||
| static void DecodeNodeAttributeForOpInDesc(const onnx::AttributeProto &attr_proto, | |||
| const std::string &attr_name_for_input_desc, int32_t index, | |||
| OpDescPtr &op_desc); | |||
| static void DecodeNodeAttributeForOpInAndOutDesc(const onnx::AttributeProto &attr_proto, | |||
| const std::string &attr_name_for_input_output_desc, int32_t index, | |||
| OpDescPtr &op_desc); | |||
| static void DecodeNodeAttributeForOpDef(const onnx::AttributeProto &attr_proto, ge::proto::OpDef &op_def); | |||
| static void DecodeNodeAttributeForOpDesc(const onnx::AttributeProto &attr_proto, OpDescPtr &op_desc); | |||
| static bool DecodeNodeLinkImp(const NodeLinkInfo &item, NodePtr &node_ptr); | |||
| @@ -15,10 +15,12 @@ | |||
| */ | |||
| #include "utils/node_utils.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "debug/ge_util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/anchor.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "utils/tensor_utils.h" | |||
| #include "utils/type_utils.h" | |||
| @@ -109,6 +111,7 @@ graphStatus NodeUtils::GetDataOutAnchorAndControlInAnchor(const NodePtr &node_pt | |||
| graphStatus NodeUtils::ClearInDataAnchor(const NodePtr &node_ptr, const InDataAnchorPtr &in_data_anchor) { | |||
| GE_CHK_BOOL_EXEC(node_ptr != nullptr && in_data_anchor != nullptr, return GRAPH_FAILED, | |||
| "node or in_data_anchor is nullptr"); | |||
| bool find_flag = false; | |||
| uint32_t index = 0; | |||
| vector<InDataAnchorPtr>::iterator it = node_ptr->in_data_anchors_.end(); | |||
| @@ -358,4 +361,45 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||
| input_desc->SetShape(shape); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::string NodeUtils::GetNodeType(const Node &node) { | |||
| if (node.GetType() != FRAMEWORKOP) { | |||
| return node.GetType(); | |||
| } | |||
| std::string type; | |||
| (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||
| return type; | |||
| } | |||
| ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||
| auto op_desc = node.GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| return nullptr; | |||
| } | |||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||
| if (root_graph == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return root_graph->GetSubgraph(op_desc->GetSubgraphInstanceName(index)); | |||
| } | |||
| graphStatus NodeUtils::AddSubgraph(Node &node, const ComputeGraphPtr &subgraph) { | |||
| if (subgraph == nullptr) { | |||
| GE_LOGE("Failed to add subgraph to node %s, null subgraph", node.GetName().c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto op_desc = node.GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||
| if (root_graph == nullptr) { | |||
| GE_LOGE("Failed to add subgraph to node %s, null root graph", node.GetName().c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| op_desc->AddSubgraphInstanceName(subgraph->GetName()); | |||
| subgraph->SetParentNode(node.shared_from_this()); | |||
| subgraph->SetParentGraph(node.GetOwnerComputeGraph()); | |||
| root_graph->AddSubgraph(subgraph); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -15,9 +15,7 @@ | |||
| */ | |||
| #include "utils/op_desc_utils.h" | |||
| #include <algorithm> | |||
| #include "debug/ge_attr_define.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "debug/ge_util.h" | |||
| @@ -209,6 +207,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ConstGeTensorPtr> OpDescUtils::GetInputData( | |||
| const vector<ge::NodePtr> &input_nodes) { | |||
| vector<ConstGeTensorPtr> ret; | |||
| for (const auto &input_node : input_nodes) { | |||
| auto temp_weight = MutableWeights(input_node->GetOpDesc()); | |||
| if (temp_weight == nullptr) { | |||
| @@ -379,7 +378,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||
| if (NodeUtils::IsAnchorStatusSet(*node)) { | |||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||
| if (ge::AnchorUtils::GetStatus(in_anchor) == ANCHOR_DATA) { | |||
| (void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
| ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
| } | |||
| } | |||
| } else { | |||
| @@ -389,7 +388,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::GeTensorDesc> OpDescUt | |||
| continue; | |||
| } | |||
| if (out_anchor->GetOwnerNode()->GetOpDesc()->GetType() != CONSTANT) { | |||
| (void)ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
| ret.push_back(node->GetOpDesc()->GetInputDesc(in_anchor->GetIdx())); | |||
| } | |||
| } | |||
| } | |||
| @@ -572,4 +571,80 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::ClearWei | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// @brief Add input | |||
| /// @param [in] name | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddInput(const std::string &name) { | |||
| inputs_.emplace_back(name); | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Add dynamic input | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicInput(const std::string &name, | |||
| uint32_t num) { | |||
| for (uint32_t i = 0; i < num; i++) { | |||
| inputs_.emplace_back(name + std::to_string(i)); | |||
| } | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Add output | |||
| /// @param [in] name | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddOutput(const std::string &name) { | |||
| outputs_.emplace_back(name); | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Add dynamic output | |||
| /// @param [in] name | |||
| /// @param [in] num | |||
| /// @return OpDescBuilder | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescBuilder &OpDescBuilder::AddDynamicOutput(const std::string &name, | |||
| uint32_t num) { | |||
| for (uint32_t i = 0; i < num; i++) { | |||
| outputs_.emplace_back(name + std::to_string(i)); | |||
| } | |||
| return *this; | |||
| } | |||
| /// | |||
| /// @brief Build op_desc | |||
| /// @return OpDescPtr | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr OpDescBuilder::Build() { | |||
| OpDescPtr op_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name_, type_)); | |||
| if (op_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "OpDesc is nullptr"); | |||
| return nullptr; | |||
| } | |||
| for (auto &input : inputs_) { | |||
| if (op_desc->AddInputDesc(input, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Add input_desc failed."); | |||
| return nullptr; | |||
| } | |||
| } | |||
| for (auto &output : outputs_) { | |||
| if (op_desc->AddOutputDesc(output, GeTensorDesc()) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Add output_desc failed."); | |||
| return nullptr; | |||
| } | |||
| } | |||
| return op_desc; | |||
| } | |||
| } // namespace ge | |||
| @@ -15,7 +15,6 @@ | |||
| */ | |||
| #include "graph/utils/tensor_utils.h" | |||
| #include <cmath> | |||
| #include "debug/ge_log.h" | |||
| @@ -276,6 +275,14 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
| break; | |||
| case FORMAT_FRACTAL_NZ: | |||
| case FORMAT_FRACTAL_ZZ: | |||
| case FORMAT_NDHWC: | |||
| case FORMAT_NCDHW: | |||
| case FORMAT_DHWCN: | |||
| case FORMAT_DHWNC: | |||
| case FORMAT_FRACTAL_Z_3D: | |||
| case FORMAT_FRACTAL_Z_3D_TRANSPOSE: | |||
| case FORMAT_NDC1HWC0: | |||
| case FORMAT_FRACTAL_Z_C04: | |||
| graph_status = CalcElementCntByDims(dims, element_cnt); | |||
| break; | |||
| default: | |||
| @@ -351,21 +358,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTens | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||
| TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||
| graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp); | |||
| if (graph_status != GRAPH_SUCCESS) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| // 64-byte alignment, if size is 0, align to 32 bytes | |||
| if (size_temp > (UINT32_MAX - kNum2 * kDataMemAlignSize)) { | |||
| GELOGW("The updated mem size %u is bigger than UINT32_MAX", size_temp); | |||
| if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) { | |||
| GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp); | |||
| } else { | |||
| size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize; | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_temp) { | |||
| TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) { | |||
| GeShape output_shape = desc_temp.GetShape(); | |||
| Format format = desc_temp.GetFormat(); | |||
| DataType data_type = desc_temp.GetDataType(); | |||
| @@ -376,13 +383,13 @@ TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, uint32_t &size_ | |||
| return GRAPH_FAILED; | |||
| } | |||
| if ((output_mem_size > UINT32_MAX) || (output_mem_size < 0)) { | |||
| GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %u]", | |||
| output_mem_size, UINT32_MAX); | |||
| if (output_mem_size < 0) { | |||
| GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]", | |||
| output_mem_size, INT64_MAX); | |||
| return GRAPH_FAILED; | |||
| } | |||
| size_temp = static_cast<uint32_t>(output_mem_size); | |||
| size_temp = output_mem_size; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -19,43 +19,45 @@ | |||
| namespace ge { | |||
| static const std::map<Format, std::string> kFormatToStringMap = { | |||
| {FORMAT_NCHW, "NCHW"}, | |||
| {FORMAT_NHWC, "NHWC"}, | |||
| {FORMAT_ND, "ND"}, | |||
| {FORMAT_NC1HWC0, "NC1HWC0"}, | |||
| {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||
| {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||
| {FORMAT_NHWC1C0, "NHWC1C0"}, | |||
| {FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||
| {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||
| {FORMAT_C1HWNC0, "C1HWNC0"}, | |||
| {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||
| {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||
| {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||
| {FORMAT_CHWN, "CHWN"}, | |||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||
| {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||
| {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||
| {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||
| {FORMAT_HWCN, "HWCN"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||
| {FORMAT_MD, "MD"}, | |||
| {FORMAT_NDHWC, "NDHWC"}, | |||
| {FORMAT_NCDHW, "NCDHW"}, | |||
| {FORMAT_DHWCK, "DHWCK"}, | |||
| {FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||
| {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||
| {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||
| {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||
| {FORMAT_CN, "CN"}, | |||
| {FORMAT_NC, "NC"}, | |||
| {FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||
| {FORMAT_ALL, "ALL"}}; | |||
| {FORMAT_NCHW, "NCHW"}, | |||
| {FORMAT_NHWC, "NHWC"}, | |||
| {FORMAT_ND, "ND"}, | |||
| {FORMAT_NC1HWC0, "NC1HWC0"}, | |||
| {FORMAT_FRACTAL_Z, "FRACTAL_Z"}, | |||
| {FORMAT_NC1C0HWPAD, "NC1C0HWPAD"}, | |||
| {FORMAT_NHWC1C0, "NHWC1C0"}, | |||
| {FORMAT_FSR_NCHW, "FSR_NCHW"}, | |||
| {FORMAT_FRACTAL_DECONV, "FRACTAL_DECONV"}, | |||
| {FORMAT_C1HWNC0, "C1HWNC0"}, | |||
| {FORMAT_FRACTAL_DECONV_TRANSPOSE, "FRACTAL_DECONV_TRANSPOSE"}, | |||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS, "FRACTAL_DECONV_SP_STRIDE_TRANS"}, | |||
| {FORMAT_NC1HWC0_C04, "NC1HWC0_C04"}, | |||
| {FORMAT_FRACTAL_Z_C04, "FRACTAL_Z_C04"}, | |||
| {FORMAT_CHWN, "CHWN"}, | |||
| {FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS, "DECONV_SP_STRIDE8_TRANS"}, | |||
| {FORMAT_NC1KHKWHWC0, "NC1KHKWHWC0"}, | |||
| {FORMAT_BN_WEIGHT, "BN_WEIGHT"}, | |||
| {FORMAT_FILTER_HWCK, "FILTER_HWCK"}, | |||
| {FORMAT_HWCN, "HWCN"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_LOOKUPS, "LOOKUP_LOOKUPS"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_KEYS, "LOOKUP_KEYS"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_VALUE, "LOOKUP_VALUE"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_OUTPUT, "LOOKUP_OUTPUT"}, | |||
| {FORMAT_HASHTABLE_LOOKUP_HITS, "LOOKUP_HITS"}, | |||
| {FORMAT_MD, "MD"}, | |||
| {FORMAT_NDHWC, "NDHWC"}, | |||
| {FORMAT_NCDHW, "NCDHW"}, | |||
| {FORMAT_DHWCN, "DHWCN"}, | |||
| {FORMAT_DHWNC, "DHWNC"}, | |||
| {FORMAT_NDC1HWC0, "NDC1HWC0"}, | |||
| {FORMAT_FRACTAL_Z_3D, "FRACTAL_Z_3D"}, | |||
| {FORMAT_FRACTAL_Z_3D_TRANSPOSE, "FRACTAL_Z_3D_TRANSPOSE"}, | |||
| {FORMAT_C1HWNCoC0, "C1HWNCoC0"}, | |||
| {FORMAT_FRACTAL_NZ, "FRACTAL_NZ"}, | |||
| {FORMAT_CN, "CN"}, | |||
| {FORMAT_NC, "NC"}, | |||
| {FORMAT_RESERVED, "FORMAT_RESERVED"}, | |||
| {FORMAT_ALL, "ALL"}}; | |||
| static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||
| "FRACTAL_Z", | |||
| @@ -73,137 +75,140 @@ static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | |||
| "FRACTAL_ZZ", | |||
| "FRACTAL_NZ", | |||
| "NDC1HWC0", | |||
| "FORMAT_FRACTAL_Z_3D"}; | |||
| "FORMAT_FRACTAL_Z_3D", | |||
| "FORMAT_FRACTAL_Z_3D_TRANSPOSE"}; | |||
| static const std::map<std::string, Format> kDataFormatMap = { | |||
| {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"ND", FORMAT_ND}}; | |||
| {"NCHW", FORMAT_NCHW}, {"NHWC", FORMAT_NHWC}, {"NDHWC", FORMAT_NDHWC}, {"NCDHW", FORMAT_NCDHW}, {"ND", FORMAT_ND}}; | |||
| static const std::map<std::string, Format> kStringToFormatMap = { | |||
| {"NCHW", FORMAT_NCHW}, | |||
| {"NHWC", FORMAT_NHWC}, | |||
| {"ND", FORMAT_ND}, | |||
| {"NC1HWC0", FORMAT_NC1HWC0}, | |||
| {"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||
| {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||
| {"NHWC1C0", FORMAT_NHWC1C0}, | |||
| {"FSR_NCHW", FORMAT_FSR_NCHW}, | |||
| {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||
| {"C1HWNC0", FORMAT_C1HWNC0}, | |||
| {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||
| {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||
| {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||
| {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||
| {"CHWN", FORMAT_CHWN}, | |||
| {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||
| {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||
| {"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||
| {"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||
| {"HWCN", FORMAT_HWCN}, | |||
| {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||
| {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||
| {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||
| {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||
| {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||
| {"MD", FORMAT_MD}, | |||
| {"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||
| {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||
| {"NDHWC", FORMAT_NDHWC}, | |||
| {"NCDHW", FORMAT_NCDHW}, | |||
| {"DHWCK", FORMAT_DHWCK}, | |||
| {"NDC1HWC0", FORMAT_NDC1HWC0}, | |||
| {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||
| {"CN", FORMAT_CN}, | |||
| {"NC", FORMAT_NC}, | |||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
| {"ALL", FORMAT_ALL}}; | |||
| {"NCHW", FORMAT_NCHW}, | |||
| {"NHWC", FORMAT_NHWC}, | |||
| {"ND", FORMAT_ND}, | |||
| {"NC1HWC0", FORMAT_NC1HWC0}, | |||
| {"FRACTAL_Z", FORMAT_FRACTAL_Z}, | |||
| {"NC1C0HWPAD", FORMAT_NC1C0HWPAD}, | |||
| {"NHWC1C0", FORMAT_NHWC1C0}, | |||
| {"FSR_NCHW", FORMAT_FSR_NCHW}, | |||
| {"FRACTAL_DECONV", FORMAT_FRACTAL_DECONV}, | |||
| {"C1HWNC0", FORMAT_C1HWNC0}, | |||
| {"FRACTAL_DECONV_TRANSPOSE", FORMAT_FRACTAL_DECONV_TRANSPOSE}, | |||
| {"FRACTAL_DECONV_SP_STRIDE_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE_TRANS}, | |||
| {"NC1HWC0_C04", FORMAT_NC1HWC0_C04}, | |||
| {"FRACTAL_Z_C04", FORMAT_FRACTAL_Z_C04}, | |||
| {"CHWN", FORMAT_CHWN}, | |||
| {"DECONV_SP_STRIDE8_TRANS", FORMAT_FRACTAL_DECONV_SP_STRIDE8_TRANS}, | |||
| {"NC1KHKWHWC0", FORMAT_NC1KHKWHWC0}, | |||
| {"BN_WEIGHT", FORMAT_BN_WEIGHT}, | |||
| {"FILTER_HWCK", FORMAT_FILTER_HWCK}, | |||
| {"HWCN", FORMAT_HWCN}, | |||
| {"LOOKUP_LOOKUPS", FORMAT_HASHTABLE_LOOKUP_LOOKUPS}, | |||
| {"LOOKUP_KEYS", FORMAT_HASHTABLE_LOOKUP_KEYS}, | |||
| {"LOOKUP_VALUE", FORMAT_HASHTABLE_LOOKUP_VALUE}, | |||
| {"LOOKUP_OUTPUT", FORMAT_HASHTABLE_LOOKUP_OUTPUT}, | |||
| {"LOOKUP_HITS", FORMAT_HASHTABLE_LOOKUP_HITS}, | |||
| {"MD", FORMAT_MD}, | |||
| {"C1HWNCoC0", FORMAT_C1HWNCoC0}, | |||
| {"FRACTAL_NZ", FORMAT_FRACTAL_NZ}, | |||
| {"NDHWC", FORMAT_NDHWC}, | |||
| {"NCDHW", FORMAT_NCDHW}, | |||
| {"DHWCN", FORMAT_DHWCN}, | |||
| {"DHWNC", FORMAT_DHWNC}, | |||
| {"NDC1HWC0", FORMAT_NDC1HWC0}, | |||
| {"FRACTAL_Z_3D", FORMAT_FRACTAL_Z_3D}, | |||
| {"FRACTAL_Z_3D_TRANSPOSE", FORMAT_FRACTAL_Z_3D_TRANSPOSE}, | |||
| {"CN", FORMAT_CN}, | |||
| {"NC", FORMAT_NC}, | |||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
| {"ALL", FORMAT_ALL}}; | |||
| static const std::map<DataType, std::string> kDataTypeToStringMap = { | |||
| {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||
| {DT_FLOAT, "DT_FLOAT"}, // float type | |||
| {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||
| {DT_INT8, "DT_INT8"}, // int8 type | |||
| {DT_INT16, "DT_INT16"}, // int16 type | |||
| {DT_UINT16, "DT_UINT16"}, // uint16 type | |||
| {DT_UINT8, "DT_UINT8"}, // uint8 type | |||
| {DT_INT32, "DT_INT32"}, // uint32 type | |||
| {DT_INT64, "DT_INT64"}, // int64 type | |||
| {DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||
| {DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||
| {DT_BOOL, "DT_BOOL"}, // bool type | |||
| {DT_DOUBLE, "DT_DOUBLE"}, // double type | |||
| {DT_DUAL, "DT_DUAL"}, // dual output type | |||
| {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||
| {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||
| {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||
| {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||
| {DT_QINT8, "DT_QINT8"}, // qint8 type | |||
| {DT_QINT16, "DT_QINT16"}, // qint16 type | |||
| {DT_QINT32, "DT_QINT32"}, // qint32 type | |||
| {DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||
| {DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||
| {DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||
| {DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||
| {DT_STRING, "DT_STRING"}, // string type | |||
| {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||
| {DT_FLOAT, "DT_FLOAT"}, // float type | |||
| {DT_FLOAT16, "DT_FLOAT16"}, // fp16 type | |||
| {DT_INT8, "DT_INT8"}, // int8 type | |||
| {DT_INT16, "DT_INT16"}, // int16 type | |||
| {DT_UINT16, "DT_UINT16"}, // uint16 type | |||
| {DT_UINT8, "DT_UINT8"}, // uint8 type | |||
| {DT_INT32, "DT_INT32"}, // uint32 type | |||
| {DT_INT64, "DT_INT64"}, // int64 type | |||
| {DT_UINT32, "DT_UINT32"}, // unsigned int32 | |||
| {DT_UINT64, "DT_UINT64"}, // unsigned int64 | |||
| {DT_BOOL, "DT_BOOL"}, // bool type | |||
| {DT_DOUBLE, "DT_DOUBLE"}, // double type | |||
| {DT_DUAL, "DT_DUAL"}, // dual output type | |||
| {DT_DUAL_SUB_INT8, "DT_DUAL_SUB_INT8"}, // dual output int8 type | |||
| {DT_DUAL_SUB_UINT8, "DT_DUAL_SUB_UINT8"}, // dual output uint8 type | |||
| {DT_COMPLEX64, "DT_COMPLEX64"}, // complex64 type | |||
| {DT_COMPLEX128, "DT_COMPLEX128"}, // complex128 type | |||
| {DT_QINT8, "DT_QINT8"}, // qint8 type | |||
| {DT_QINT16, "DT_QINT16"}, // qint16 type | |||
| {DT_QINT32, "DT_QINT32"}, // qint32 type | |||
| {DT_QUINT8, "DT_QUINT8"}, // quint8 type | |||
| {DT_QUINT16, "DT_QUINT16"}, // quint16 type | |||
| {DT_RESOURCE, "DT_RESOURCE"}, // resource type | |||
| {DT_STRING_REF, "DT_STRING_REF"}, // string ref type | |||
| {DT_STRING, "DT_STRING"}, // string type | |||
| }; | |||
| static const std::map<std::string, DataType> kStringTodataTypeMap = { | |||
| {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||
| {"DT_FLOAT", DT_FLOAT}, // float type | |||
| { | |||
| "DT_FLOAT16", | |||
| DT_FLOAT16, | |||
| }, // fp16 type | |||
| {"DT_INT8", DT_INT8}, // int8 type | |||
| {"DT_INT16", DT_INT16}, // int16 type | |||
| {"DT_UINT16", DT_UINT16}, // uint16 type | |||
| {"DT_UINT8", DT_UINT8}, // uint8 type | |||
| {"DT_INT32", DT_INT32}, // uint32 type | |||
| {"DT_INT64", DT_INT64}, // int64 type | |||
| {"DT_UINT32", DT_UINT32}, // unsigned int32 | |||
| {"DT_UINT64", DT_UINT64}, // unsigned int64 | |||
| {"DT_BOOL", DT_BOOL}, // bool type | |||
| {"DT_DOUBLE", DT_DOUBLE}, // double type | |||
| {"DT_DUAL", DT_DUAL}, // dual output type | |||
| {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||
| {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||
| {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||
| {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||
| {"DT_QINT8", DT_QINT8}, // qint8 type | |||
| {"DT_QINT16", DT_QINT16}, // qint16 type | |||
| {"DT_QINT32", DT_QINT32}, // qint32 type | |||
| {"DT_QUINT8", DT_QUINT8}, // quint8 type | |||
| {"DT_QUINT16", DT_QUINT16}, // quint16 type | |||
| {"DT_RESOURCE", DT_RESOURCE}, // resource type | |||
| {"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||
| {"DT_STRING", DT_STRING}, // string type | |||
| {"DT_UNDEFINED", DT_UNDEFINED}, // Used to indicate a DataType field has not been set. | |||
| {"DT_FLOAT", DT_FLOAT}, // float type | |||
| { | |||
| "DT_FLOAT16", | |||
| DT_FLOAT16, | |||
| }, // fp16 type | |||
| {"DT_INT8", DT_INT8}, // int8 type | |||
| {"DT_INT16", DT_INT16}, // int16 type | |||
| {"DT_UINT16", DT_UINT16}, // uint16 type | |||
| {"DT_UINT8", DT_UINT8}, // uint8 type | |||
| {"DT_INT32", DT_INT32}, // uint32 type | |||
| {"DT_INT64", DT_INT64}, // int64 type | |||
| {"DT_UINT32", DT_UINT32}, // unsigned int32 | |||
| {"DT_UINT64", DT_UINT64}, // unsigned int64 | |||
| {"DT_BOOL", DT_BOOL}, // bool type | |||
| {"DT_DOUBLE", DT_DOUBLE}, // double type | |||
| {"DT_DUAL", DT_DUAL}, // dual output type | |||
| {"DT_DUAL_SUB_INT8", DT_DUAL_SUB_INT8}, // dual output int8 type | |||
| {"DT_DUAL_SUB_UINT8", DT_DUAL_SUB_UINT8}, // dual output uint8 type | |||
| {"DT_COMPLEX64", DT_COMPLEX64}, // complex64 type | |||
| {"DT_COMPLEX128", DT_COMPLEX128}, // complex128 type | |||
| {"DT_QINT8", DT_QINT8}, // qint8 type | |||
| {"DT_QINT16", DT_QINT16}, // qint16 type | |||
| {"DT_QINT32", DT_QINT32}, // qint32 type | |||
| {"DT_QUINT8", DT_QUINT8}, // quint8 type | |||
| {"DT_QUINT16", DT_QUINT16}, // quint16 type | |||
| {"DT_RESOURCE", DT_RESOURCE}, // resource type | |||
| {"DT_STRING_REF", DT_STRING_REF}, // string ref type | |||
| {"DT_STRING", DT_STRING}, // string type | |||
| }; | |||
| static const std::map<ge::DataType, uint32_t> kDataTypeToLength = { | |||
| {DT_BOOL, sizeof(bool)}, | |||
| {DT_INT64, sizeof(int64_t)}, | |||
| {DT_UINT64, sizeof(int64_t)}, | |||
| {DT_FLOAT, sizeof(float)}, | |||
| {DT_INT32, sizeof(int32_t)}, | |||
| {DT_UINT32, sizeof(int32_t)}, | |||
| {DT_INT8, sizeof(char)}, | |||
| {DT_UINT8, sizeof(char)}, | |||
| {DT_INT16, sizeof(int16_t)}, | |||
| {DT_UINT16, sizeof(int16_t)}, | |||
| {DT_FLOAT16, sizeof(int16_t)}, | |||
| {DT_DOUBLE, sizeof(double)}, | |||
| {DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||
| {DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||
| {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||
| {DT_COMPLEX64, sizeof(int64_t)}, | |||
| {DT_COMPLEX128, sizeof(int64_t) * 2}, | |||
| {DT_QINT8, sizeof(int8_t)}, | |||
| {DT_QINT16, sizeof(int16_t)}, | |||
| {DT_QINT32, sizeof(int32_t)}, | |||
| {DT_QUINT8, sizeof(uint8_t)}, | |||
| {DT_QUINT16, sizeof(uint16_t)}, | |||
| {DT_STRING_REF, sizeof(uint64_t) * 2}, | |||
| {DT_STRING, sizeof(uint64_t)}, | |||
| {DT_RESOURCE, sizeof(uint64_t)}, | |||
| {DT_BOOL, sizeof(bool)}, | |||
| {DT_INT64, sizeof(int64_t)}, | |||
| {DT_UINT64, sizeof(int64_t)}, | |||
| {DT_FLOAT, sizeof(float)}, | |||
| {DT_INT32, sizeof(int32_t)}, | |||
| {DT_UINT32, sizeof(int32_t)}, | |||
| {DT_INT8, sizeof(char)}, | |||
| {DT_UINT8, sizeof(char)}, | |||
| {DT_INT16, sizeof(int16_t)}, | |||
| {DT_UINT16, sizeof(int16_t)}, | |||
| {DT_FLOAT16, sizeof(int16_t)}, | |||
| {DT_DOUBLE, sizeof(double)}, | |||
| {DT_DUAL, sizeof(float) + sizeof(int8_t)}, | |||
| {DT_DUAL_SUB_INT8, sizeof(int8_t)}, | |||
| {DT_DUAL_SUB_UINT8, sizeof(uint8_t)}, | |||
| {DT_COMPLEX64, sizeof(int64_t)}, | |||
| {DT_COMPLEX128, sizeof(int64_t) * 2}, | |||
| {DT_QINT8, sizeof(int8_t)}, | |||
| {DT_QINT16, sizeof(int16_t)}, | |||
| {DT_QINT32, sizeof(int32_t)}, | |||
| {DT_QUINT8, sizeof(uint8_t)}, | |||
| {DT_QUINT16, sizeof(uint16_t)}, | |||
| {DT_STRING_REF, sizeof(uint64_t) * 2}, | |||
| {DT_STRING, sizeof(uint64_t)}, | |||
| {DT_RESOURCE, sizeof(uint64_t)}, | |||
| }; | |||
| bool TypeUtils::IsDataTypeValid(DataType dt) { | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| # ============================================================================ | |||
| # libge.so & libge_train.so | |||
| # libge_compiler.so & libge_train.so | |||
| # will later be integrated into libgraph_runner.so, works for both training and inference | |||
| # compiling proto files generates some warnings, use no-unused-variable to suppress them | |||
| set(CMAKE_CXX_FLAGS "-Wno-unused-variable ${CMAKE_CXX_FLAGS}") | |||
| @@ -49,7 +49,7 @@ include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||
| ######### libge_train.so ############# | |||
| # need to remove dependencies on pb files later | |||
| file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/formats/format_transfers/*.cc" | |||
| "common/formats/formats.cc" | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| @@ -57,20 +57,24 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/ge/plugin_manager.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| "engine_manager/dnnengine_manager.cc" | |||
| "ge_local_engine/engine/host_cpu_engine.cc" | |||
| "generator/ge_generator.cc" | |||
| "generator/generator_api.cc" | |||
| "graph/build/graph_build.cc" | |||
| "graph/build/graph_builder.cc" | |||
| "graph/build/label_allocator.cc" | |||
| "graph/build/logical_stream_allocator.cc" | |||
| "graph/build/model_builder.cc" | |||
| "graph/build/optimize_stream_graph.cc" | |||
| "graph/build/run_context.cc" | |||
| "graph/build/stream_allocator.cc" | |||
| "graph/build/stream_graph_optimizer.cc" | |||
| "graph/build/task_generator.cc" | |||
| "graph/common/bcast.cc" | |||
| "graph/common/omg_util.cc" | |||
| "graph/common/transop_util.cc" | |||
| "graph/execute/graph_execute.cc" | |||
| "graph/label/*.cc" | |||
| "graph/load/graph_loader.cc" | |||
| "graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
| "graph/load/new_model_manager/data_dumper.cc" | |||
| "graph/load/new_model_manager/data_inputer.cc" | |||
| "graph/load/new_model_manager/davinci_model.cc" | |||
| @@ -92,10 +96,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/new_model_manager/tbe_handle_store.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/custom/custom_op.cc" | |||
| "graph/manager/graph_context.cc" | |||
| "graph/manager/graph_manager.cc" | |||
| "graph/manager/graph_manager_utils.cc" | |||
| @@ -105,12 +111,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/manager/trans_var_data_utils.cc" | |||
| "graph/manager/util/debug.cc" | |||
| "graph/manager/util/hcom_util.cc" | |||
| "graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||
| "graph/manager/util/rt_context_util.cc" | |||
| "graph/manager/util/variable_accelerate_ctrl.cc" | |||
| "graph/optimize/graph_functiondef.cc" | |||
| "graph/optimize/graph_optimize.cc" | |||
| "graph/optimize/graph_optimizer.cc" | |||
| "graph/optimize/optimizer/allreduce_fusion_pass.cc" | |||
| "graph/optimize/summary_optimize.cc" | |||
| "graph/partition/engine_place.cc" | |||
| @@ -120,7 +123,9 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/assert_pass.cc" | |||
| "graph/passes/atomic_addr_clean_pass.cc" | |||
| "graph/passes/base_pass.cc" | |||
| "graph/passes/cast_remove_pass.cc" | |||
| "graph/passes/cast_translate_pass.cc" | |||
| "graph/passes/common_subexpression_elimination_pass.cc" | |||
| "graph/passes/compile_nodes_pass.cc" | |||
| "graph/passes/constant_folding_pass.cc" | |||
| "graph/passes/constant_fuse_same_pass.cc" | |||
| @@ -159,12 +164,14 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/folding_kernel/shape_kernel.cc" | |||
| "graph/passes/folding_kernel/shape_n_kernel.cc" | |||
| "graph/passes/folding_kernel/size_kernel.cc" | |||
| "graph/passes/folding_kernel/slice_d_kernel.cc" | |||
| "graph/passes/folding_kernel/slice_kernel.cc" | |||
| "graph/passes/folding_kernel/squeeze_kernel.cc" | |||
| "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | |||
| "graph/passes/folding_kernel/strided_slice_kernel.cc" | |||
| "graph/passes/folding_kernel/sub_kernel.cc" | |||
| "graph/passes/folding_kernel/transdata_kernel.cc" | |||
| "graph/passes/folding_kernel/unpack_kernel.cc" | |||
| "graph/passes/folding_pass.cc" | |||
| "graph/passes/get_original_format_pass.cc" | |||
| "graph/passes/guarantee_const_pass.cc" | |||
| @@ -179,7 +186,6 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/multi_batch_pass.cc" | |||
| "graph/passes/net_output_pass.cc" | |||
| "graph/passes/next_iteration_pass.cc" | |||
| "graph/passes/no_reshape_op_remove_pass.cc" | |||
| "graph/passes/no_use_reshape_remove_pass.cc" | |||
| "graph/passes/pass_manager.cc" | |||
| "graph/passes/pass_utils.cc" | |||
| @@ -188,6 +194,7 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/prevent_gradient_pass.cc" | |||
| "graph/passes/print_op_pass.cc" | |||
| "graph/passes/prune_pass.cc" | |||
| "graph/passes/replace_with_empty_const_pass.cc" | |||
| "graph/passes/reshape_remove_pass.cc" | |||
| "graph/passes/resource_pair_add_control_pass.cc" | |||
| "graph/passes/resource_pair_remove_control_pass.cc" | |||
| @@ -206,14 +213,12 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/transpose_transdata_pass.cc" | |||
| "graph/passes/unused_const_pass.cc" | |||
| "graph/passes/unused_op_remove_pass.cc" | |||
| "graph/passes/update_net_output_pass.cc" | |||
| "graph/passes/var_is_initialized_op_pass.cc" | |||
| "graph/passes/variable_format_pass.cc" | |||
| "graph/passes/variable_op_pass.cc" | |||
| "graph/passes/variable_prepare_op_pass.cc" | |||
| "graph/passes/variable_ref_delete_op_pass.cc" | |||
| "graph/preprocess/graph_preprocess.cc" | |||
| "graph/preprocess/insert_op/base_insert_op.cc" | |||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | |||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
| "graph/preprocess/multi_batch_copy_graph.cc" | |||
| @@ -223,13 +228,8 @@ file(GLOB_RECURSE TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "opskernel_manager/ops_kernel_manager.cc" | |||
| "session/inner_session.cc" | |||
| "session/session_manager.cc" | |||
| "single_op/single_op.cc" | |||
| "single_op/single_op_manager.cc" | |||
| "single_op/single_op_model.cc" | |||
| "single_op/stream_resource.cc" | |||
| "single_op/task/build_task_utils.cc" | |||
| "single_op/task/op_task.cc" | |||
| "single_op/task/tbe_task_builder.cc" | |||
| "single_op/*.cc" | |||
| "single_op/task/*.cc" | |||
| ) | |||
| @@ -261,9 +261,9 @@ target_link_libraries(ge_train | |||
| rt | |||
| dl) | |||
| ######### libge.so ############# | |||
| ######### libge_compiler.so ############# | |||
| # need to remove dependencies on pb files later | |||
| file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/formats/format_transfers/*.cc" | |||
| "common/formats/formats.cc" | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| @@ -271,20 +271,24 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/ge/plugin_manager.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| "engine_manager/dnnengine_manager.cc" | |||
| "ge_local_engine/engine/host_cpu_engine.cc" | |||
| "generator/ge_generator.cc" | |||
| "generator/generator_api.cc" | |||
| "graph/build/graph_build.cc" | |||
| "graph/build/graph_builder.cc" | |||
| "graph/build/label_allocator.cc" | |||
| "graph/build/logical_stream_allocator.cc" | |||
| "graph/build/model_builder.cc" | |||
| "graph/build/optimize_stream_graph.cc" | |||
| "graph/build/run_context.cc" | |||
| "graph/build/stream_allocator.cc" | |||
| "graph/build/stream_graph_optimizer.cc" | |||
| "graph/build/task_generator.cc" | |||
| "graph/common/bcast.cc" | |||
| "graph/common/omg_util.cc" | |||
| "graph/common/transop_util.cc" | |||
| "graph/execute/graph_execute.cc" | |||
| "graph/label/*.cc" | |||
| "graph/load/graph_loader.cc" | |||
| "graph/load/new_model_manager/cpu_queue_schedule.cc" | |||
| "graph/load/new_model_manager/data_dumper.cc" | |||
| "graph/load/new_model_manager/data_inputer.cc" | |||
| "graph/load/new_model_manager/davinci_model.cc" | |||
| @@ -305,10 +309,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_switch_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/stream_switchn_task_info.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/new_model_manager/tbe_handle_store.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/custom/custom_op.cc" | |||
| "graph/manager/graph_context.cc" | |||
| "graph/manager/graph_manager.cc" | |||
| "graph/manager/graph_manager_utils.cc" | |||
| @@ -317,13 +323,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/trans_var_data_utils.cc" | |||
| "graph/manager/util/debug.cc" | |||
| "graph/manager/util/node_searcher/need_rebuild_node_searcher.cc" | |||
| "graph/manager/util/rt_context_util.cc" | |||
| "graph/manager/util/variable_accelerate_ctrl.cc" | |||
| "graph/optimize/graph_functiondef.cc" | |||
| "graph/optimize/graph_optimize.cc" | |||
| "graph/optimize/graph_optimizer.cc" | |||
| "graph/optimize/optimizer/allreduce_fusion_inference_pass.cc" | |||
| "graph/optimize/summary_optimize.cc" | |||
| "graph/partition/engine_place.cc" | |||
| "graph/partition/graph_partition.cc" | |||
| @@ -332,7 +334,9 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/assert_pass.cc" | |||
| "graph/passes/atomic_addr_clean_pass.cc" | |||
| "graph/passes/base_pass.cc" | |||
| "graph/passes/cast_remove_pass.cc" | |||
| "graph/passes/cast_translate_pass.cc" | |||
| "graph/passes/common_subexpression_elimination_pass.cc" | |||
| "graph/passes/compile_nodes_pass.cc" | |||
| "graph/passes/constant_folding_pass.cc" | |||
| "graph/passes/constant_fuse_same_pass.cc" | |||
| @@ -371,12 +375,14 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/folding_kernel/shape_kernel.cc" | |||
| "graph/passes/folding_kernel/shape_n_kernel.cc" | |||
| "graph/passes/folding_kernel/size_kernel.cc" | |||
| "graph/passes/folding_kernel/slice_d_kernel.cc" | |||
| "graph/passes/folding_kernel/slice_kernel.cc" | |||
| "graph/passes/folding_kernel/squeeze_kernel.cc" | |||
| "graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | |||
| "graph/passes/folding_kernel/strided_slice_kernel.cc" | |||
| "graph/passes/folding_kernel/sub_kernel.cc" | |||
| "graph/passes/folding_kernel/transdata_kernel.cc" | |||
| "graph/passes/folding_kernel/unpack_kernel.cc" | |||
| "graph/passes/folding_pass.cc" | |||
| "graph/passes/get_original_format_pass.cc" | |||
| "graph/passes/guarantee_const_pass.cc" | |||
| @@ -391,7 +397,6 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/multi_batch_pass.cc" | |||
| "graph/passes/net_output_pass.cc" | |||
| "graph/passes/next_iteration_pass.cc" | |||
| "graph/passes/no_reshape_op_remove_pass.cc" | |||
| "graph/passes/no_use_reshape_remove_pass.cc" | |||
| "graph/passes/pass_manager.cc" | |||
| "graph/passes/pass_utils.cc" | |||
| @@ -400,6 +405,7 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/prevent_gradient_pass.cc" | |||
| "graph/passes/print_op_pass.cc" | |||
| "graph/passes/prune_pass.cc" | |||
| "graph/passes/replace_with_empty_const_pass.cc" | |||
| "graph/passes/reshape_remove_pass.cc" | |||
| "graph/passes/resource_pair_add_control_pass.cc" | |||
| "graph/passes/resource_pair_remove_control_pass.cc" | |||
| @@ -418,14 +424,12 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/passes/transpose_transdata_pass.cc" | |||
| "graph/passes/unused_const_pass.cc" | |||
| "graph/passes/unused_op_remove_pass.cc" | |||
| "graph/passes/update_net_output_pass.cc" | |||
| "graph/passes/var_is_initialized_op_pass.cc" | |||
| "graph/passes/variable_format_pass.cc" | |||
| "graph/passes/variable_op_pass.cc" | |||
| "graph/passes/variable_prepare_op_pass.cc" | |||
| "graph/passes/variable_ref_delete_op_pass.cc" | |||
| "graph/preprocess/graph_preprocess.cc" | |||
| "graph/preprocess/insert_op/base_insert_op.cc" | |||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | |||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | |||
| "graph/preprocess/multi_batch_copy_graph.cc" | |||
| @@ -442,16 +446,19 @@ file(GLOB_RECURSE INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "single_op/task/build_task_utils.cc" | |||
| "single_op/task/op_task.cc" | |||
| "single_op/task/tbe_task_builder.cc" | |||
| ########################################## | |||
| # "ir_build/ge_ir_build.cc" | |||
| # "offline/atc_ir_common.cc" | |||
| ) | |||
| add_library(ge SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
| target_compile_definitions(ge PRIVATE | |||
| add_library(ge_compiler SHARED ${INFER_SRC_LIST} ${PROTO_SRCS} ${PROTO_HEADER_HDRS}) | |||
| target_compile_definitions(ge_compiler PRIVATE | |||
| PROTOBUF_INLINE_NOT_IN_HEADERS=0 | |||
| DAVINCI_SUPPORT_PROFILING | |||
| REUSE_MEMORY=1 | |||
| FMK_HOST_INFER | |||
| PLATFORM_CLOUD) | |||
| target_link_libraries(ge | |||
| target_link_libraries(ge_compiler | |||
| graph | |||
| ge_common | |||
| "-Wl,--whole-archive" | |||
| @@ -80,7 +80,7 @@ target_compile_definitions(ge_client_train PRIVATE | |||
| PLATFORM_CLOUD) | |||
| target_link_libraries(ge_client | |||
| graph | |||
| ge | |||
| ge_compiler | |||
| ge_common | |||
| ${PROTOBUF_LIBRARY} | |||
| ${register} | |||
| @@ -61,14 +61,14 @@ Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | |||
| const int kDecimal = 10; | |||
| auto dump_op_env = std::getenv("DUMP_OP"); | |||
| int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | |||
| auto disable_reuse_memory_iter = options.find("ge.exec.disableReuseMemory"); | |||
| if (disable_reuse_memory_iter != options.end()) { | |||
| if (disable_reuse_memory_iter->second == "0") { | |||
| auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); | |||
| if (disableReuseMemoryIter != options.end()) { | |||
| if (disableReuseMemoryIter->second == "0") { | |||
| GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | |||
| if (dump_op_flag) { | |||
| GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | |||
| } | |||
| } else if (disable_reuse_memory_iter->second == "1") { | |||
| } else if (disableReuseMemoryIter->second == "1") { | |||
| GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | |||
| } else { | |||
| GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | |||
| @@ -128,22 +128,29 @@ Status GEInitialize(const std::map<string, string> &options) { | |||
| OpsProtoManager *manager = OpsProtoManager::Instance(); | |||
| std::map<string, string> option_tmp; | |||
| option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||
| GE_TIMESTAMP_START(GEInitialize); | |||
| bool is_proto_init = manager->Initialize(option_tmp); | |||
| GE_TIMESTAMP_END(GEInitialize, "GEInitialize::ManagerInitialize"); | |||
| if (!is_proto_init) { | |||
| GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, ops proto path is invalid."); | |||
| return FAILED; | |||
| } | |||
| // check options is valid | |||
| GE_TIMESTAMP_START(CheckOptionsValid); | |||
| if (CheckOptionsValid(options) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); | |||
| GE_TIMESTAMP_START(InitPreparation); | |||
| SaveDdkVersion(options); | |||
| GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation"); | |||
| // call Initialize | |||
| GELOGT(TRACE_RUNNING, "Initializing environment"); | |||
| GE_TIMESTAMP_START(GELibInitialize); | |||
| Status ret = ge::GELib::Initialize(options); | |||
| GE_TIMESTAMP_END(GELibInitialize, "GEInitialize::GELibInitialize"); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(GE_CLI_INIT_FAILED, "geInitialize failed, error code = %u", ret); | |||
| return FAILED; | |||
| @@ -170,17 +177,20 @@ Status GEFinalize() { | |||
| std::lock_guard<std::mutex> lock(kGeReleaseMutex); | |||
| // call Finalize | |||
| Status ret = SUCCESS; | |||
| Status middle_ret; | |||
| GELOGT(TRACE_RUNNING, "Finalizing environment"); | |||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GEFinalize Failed: GE not initialized"); | |||
| return GE_CLI_GE_NOT_INITIALIZED; | |||
| } | |||
| Status ret = instance_ptr->Finalize(); | |||
| GELOGI("GEFinalize finalize gelib ret=%u", ret); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "GEFinalize Failed"); | |||
| return FAILED; | |||
| std::shared_ptr<GELib> instancePtr = ge::GELib::GetInstance(); | |||
| if (instancePtr == nullptr || !instancePtr->InitFlag()) { | |||
| GELOGW("GEFinalize Failed: GE not initialized."); | |||
| ret = GE_CLI_GE_NOT_INITIALIZED; | |||
| } | |||
| if (ret != GE_CLI_GE_NOT_INITIALIZED) { | |||
| middle_ret = instancePtr->Finalize(); | |||
| GELOGI("GEFinalize finalize gelib ret=%u", middle_ret); | |||
| if (middle_ret != SUCCESS) { | |||
| ret = middle_ret; | |||
| } | |||
| } | |||
| if (kGeInitialized && ret == SUCCESS) { | |||
| @@ -379,8 +389,6 @@ Status Session::RunGraph(uint32_t graph_id, const std::vector<Tensor> &inputs, s | |||
| } | |||
| Status Session::RegisterCallBackFunc(const std::string &key, const pCallBackFunc &callback) { | |||
| GELOGW( | |||
| "The callback function will not be checked. Please ensure that the implementation of the function is trusted."); | |||
| return ge::GELib::GetInstance()->SessionManagerObj().RegisterCallBackFunc(sessionId_, key, callback); | |||
| } | |||