| @@ -174,9 +174,11 @@ echo "---------------- GraphEngine output generated ----------------" | |||
| # generate output package in tar form, including ut/st libraries/executables | |||
| cd ${BASEPATH} | |||
| mkdir -p output/plugin/nnengine/ge_config/ | |||
| mkdir -p output/plugin/opskernel/ | |||
| find output/ -name graphengine_lib.tar -exec rm {} \; | |||
| cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/ | |||
| find output/ -maxdepth 1 -name libengine.so -exec mv -f {} output/plugin/nnengine/ \; | |||
| find output/ -maxdepth 1 -name libge_local_engine.so -exec mv -f {} output/plugin/opskernel/ \; | |||
| tar -cf graphengine_lib.tar output/* | |||
| mv -f graphengine_lib.tar output | |||
| echo "---------------- GraphEngine package archive generated ----------------" | |||
| @@ -52,5 +52,23 @@ struct GETaskInfo { | |||
| std::vector<GETaskKernelHcclInfo> kernelHcclInfo; | |||
| }; | |||
| struct HcomOpertion { | |||
| std::string hcclType; | |||
| void *inputPtr; | |||
| void *outputPtr; | |||
| uint64_t count; | |||
| int32_t dataType; | |||
| int32_t opType; | |||
| int32_t root; | |||
| }; | |||
| struct HcomRemoteAccessAddrInfo { | |||
| uint32_t remotetRankID; | |||
| uint64_t remoteAddr; // host embedding table address | |||
| uint64_t localAddr; // device HBM address | |||
| uint64_t length; // memory Length in Bytes | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ | |||
| @@ -43,10 +43,10 @@ class OpsKernelInfoStore { | |||
| virtual ~OpsKernelInfoStore() {} | |||
| // initialize opsKernelInfoStore | |||
| virtual Status Initialize(const map<string, string> &options) = 0; | |||
| virtual Status Initialize(const map<string, string> &options) = 0; /*lint -e148*/ | |||
| // close opsKernelInfoStore | |||
| virtual Status Finalize() = 0; | |||
| virtual Status Finalize() = 0; /*lint -e148*/ | |||
| virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | |||
| @@ -66,10 +66,11 @@ class OpsKernelInfoStore { | |||
| virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; | |||
| // memory allocation requirement | |||
| virtual Status CalcOpRunningParam(Node &node) = 0; | |||
| virtual Status CalcOpRunningParam(Node &node) = 0; /*lint -e148*/ | |||
| // generate task for op。 | |||
| virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | |||
| virtual Status GenerateTask(const Node &node, RunContext &context, | |||
| std::vector<domi::TaskDef> &tasks) = 0; /*lint -e148*/ | |||
| // only call fe engine interface to compile single op | |||
| virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | |||
| @@ -26,6 +26,7 @@ | |||
| using std::string; | |||
| namespace ge { | |||
| /*lint -e148*/ | |||
| struct RunContext { | |||
| rtModel_t model; | |||
| rtStream_t stream; | |||
| @@ -40,6 +41,8 @@ struct RunContext { | |||
| std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...) | |||
| }; | |||
| /*lint +e148*/ | |||
| struct Task { | |||
| uint32_t id; | |||
| uint16_t type; | |||
| @@ -48,7 +51,8 @@ struct Task { | |||
| }; | |||
| struct OpInfo { | |||
| string engine; // which engin | |||
| string engine; // which engin | |||
| /*lint -e148*/ | |||
| string opKernelLib; // which opsKernelStore | |||
| int computeCost; // compute cost | |||
| bool flagPartial; // whether to support is related to shape | |||
| @@ -27,6 +27,7 @@ | |||
| using std::map; | |||
| using std::string; | |||
| /*lint -e148*/ | |||
| namespace ge { | |||
| class GraphOptimizer { | |||
| public: | |||
| @@ -60,4 +61,5 @@ class GraphOptimizer { | |||
| virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; } | |||
| }; | |||
| } // namespace ge | |||
| /*lint +e148*/ | |||
| #endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ | |||
| @@ -28,6 +28,7 @@ struct CompressConfig { | |||
| size_t channel; // channels of L2 or DDR. For load balance | |||
| size_t fractalSize; // size of compressing block | |||
| bool isTight; // whether compose compressed data tightly | |||
| size_t init_offset; | |||
| }; | |||
| CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef COMPRESS_WEIGHT_H | |||
| #define COMPRESS_WEIGHT_H | |||
| #include "compress.h" | |||
| const int SHAPE_SIZE_WEIGHT = 4; | |||
| struct CompressOpConfig { | |||
| int64_t wShape[SHAPE_SIZE_WEIGHT]; | |||
| size_t compressTilingK; | |||
| size_t compressTilingN; | |||
| struct CompressConfig compressConfig; | |||
| }; | |||
| extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer, | |||
| CompressOpConfig *const param); | |||
| #endif // COMPRESS_WEIGHT_H | |||
| @@ -31,27 +31,37 @@ class ErrorManager { | |||
| /// | |||
| /// @brief init | |||
| /// @param [in] path current so path | |||
| /// @param [in] path: current so path | |||
| /// @return int 0(success) -1(fail) | |||
| /// | |||
| int Init(std::string path); | |||
| /// | |||
| /// @brief Report error message | |||
| /// @param [in] errCode error code | |||
| /// @param [in] mapArgs parameter map | |||
| /// @param [in] error_code: error code | |||
| /// @param [in] args_map: parameter map | |||
| /// @return int 0(success) -1(fail) | |||
| /// | |||
| int ReportErrMessage(std::string error_code, const std::map<std::string, std::string> &args_map); | |||
| /// | |||
| /// @brief output error message | |||
| /// @param [in] handle print handle | |||
| /// @param [in] handle: print handle | |||
| /// @return int 0(success) -1(fail) | |||
| /// | |||
| int OutputErrMessage(int handle); | |||
| /// | |||
| /// @brief output message | |||
| /// @param [in] handle: print handle | |||
| /// @return int 0(success) -1(fail) | |||
| /// | |||
| int OutputMessage(int handle); | |||
| /// | |||
| /// @brief Report error message | |||
| /// @param [in] vector parameter key, vector parameter value | |||
| /// @param [in] key: vector parameter key | |||
| /// @param [in] value: vector parameter value | |||
| /// | |||
| void ATCReportErrMessage(std::string error_code, const std::vector<std::string> &key = {}, | |||
| const std::vector<std::string> &value = {}); | |||
| @@ -60,7 +70,7 @@ class ErrorManager { | |||
| struct ErrorInfo { | |||
| std::string error_id; | |||
| std::string error_message; | |||
| std::vector<std::string> arglist; | |||
| std::vector<std::string> arg_list; | |||
| }; | |||
| ErrorManager() {} | |||
| @@ -77,7 +87,8 @@ class ErrorManager { | |||
| bool is_init_ = false; | |||
| std::map<std::string, ErrorInfo> error_map_; | |||
| std::vector<std::string> error_message_evc_; | |||
| std::vector<std::string> error_messages_; | |||
| std::vector<std::string> warning_messages_; | |||
| }; | |||
| #endif // ERROR_MANAGER_H_ | |||
| @@ -27,7 +27,6 @@ using std::string; | |||
| using std::vector; | |||
| namespace fe { | |||
| class PlatformInfoManager { | |||
| public: | |||
| PlatformInfoManager(const PlatformInfoManager &) = delete; | |||
| @@ -39,6 +38,8 @@ class PlatformInfoManager { | |||
| uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||
| uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||
| void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); | |||
| private: | |||
| @@ -81,6 +82,8 @@ class PlatformInfoManager { | |||
| void ParseVectorCoreMemoryRates(map<string, string> &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||
| void ParseCPUCache(map<string, string> &CPUCacheMap, PlatformInfo &platformInfoTemp); | |||
| void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vectorCoreintrinsicDtypeMap, | |||
| PlatformInfo &platformInfoTemp); | |||
| @@ -94,6 +97,5 @@ class PlatformInfoManager { | |||
| map<string, PlatformInfo> platformInfoMap_; | |||
| OptionalInfo optiCompilationInfo_; | |||
| }; | |||
| } // namespace fe | |||
| #endif | |||
| @@ -73,6 +73,8 @@ typedef struct tagAiCoreSpec { | |||
| typedef struct tagAiCoreMemoryRates { | |||
| double ddrRate; | |||
| double ddrReadRate; | |||
| double ddrWriteRate; | |||
| double l2Rate; | |||
| double l2ReadRate; | |||
| double l2WriteRate; | |||
| @@ -86,6 +88,7 @@ typedef struct tagAiCoreMemoryRates { | |||
| } AiCoreMemoryRates; | |||
| typedef struct tagVectorCoreSpec { | |||
| double vecFreq; | |||
| uint64_t vecCalcSize; | |||
| uint64_t smaskBuffer; | |||
| uint64_t ubSize; | |||
| @@ -94,10 +97,15 @@ typedef struct tagVectorCoreSpec { | |||
| uint64_t ubbankNum; | |||
| uint64_t ubburstInOneBlock; | |||
| uint64_t ubbankGroupNum; | |||
| uint64_t vectorRegSize; | |||
| uint64_t predicateRegSize; | |||
| uint64_t addressRegSize; | |||
| } VectorCoreSpec; | |||
| typedef struct tagVectorCoreMemoryRates { | |||
| double ddrRate; | |||
| double ddrReadRate; | |||
| double ddrWriteRate; | |||
| double l2Rate; | |||
| double l2ReadRate; | |||
| double l2WriteRate; | |||
| @@ -105,6 +113,11 @@ typedef struct tagVectorCoreMemoryRates { | |||
| double ubToDdrRate; | |||
| } VectorCoreMemoryRates; | |||
| typedef struct tagCPUCache { | |||
| uint32_t AICPUSyncBySW; | |||
| uint32_t TSCPUSyncBySW; | |||
| } CPUCache; | |||
| typedef struct tagPlatformInfo { | |||
| StrInfo strInfo; | |||
| SoCInfo socInfo; | |||
| @@ -113,6 +126,7 @@ typedef struct tagPlatformInfo { | |||
| map<string, vector<string>> aiCoreIntrinsicDtypeMap; | |||
| VectorCoreSpec vectorCoreSpec; | |||
| VectorCoreMemoryRates vectorCoreMemoryRates; | |||
| CPUCache cpucache; | |||
| map<string, vector<string>> vectorCoreIntrinsicDtypeMap; | |||
| } PlatformInfo; | |||
| @@ -70,7 +70,7 @@ using Status = uint32_t; | |||
| // General error code | |||
| GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success"); | |||
| GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); | |||
| GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); /*lint !e401*/ | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ | |||
| @@ -44,8 +44,11 @@ const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | |||
| const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | |||
| const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||
| const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; | |||
| const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; | |||
| const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; | |||
| const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||
| const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||
| const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; | |||
| // profiling flag | |||
| const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; | |||
| const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | |||
| @@ -170,6 +173,9 @@ const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | |||
| // configure whether to use dynamic image size | |||
| const char *const kDynamicImageSize = "ge.dynamicImageSize"; | |||
| // Configure whether to use dynamic dims | |||
| const char *const kDynamicDims = "ge.dynamicDims"; | |||
| // Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, | |||
| // example: GA|RL, support configure multiple, split by | | |||
| const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; | |||
| @@ -219,6 +225,10 @@ const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | |||
| // Configure input fp16 nodes | |||
| const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||
| // Configure debug level, its value should be 0(default), 1 or 2. | |||
| // 0: close debug; 1: open TBE compiler; 2: open ccec compiler | |||
| const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | |||
| // Graph run mode | |||
| enum GraphRunMode { PREDICTION = 0, TRAIN }; | |||
| @@ -261,6 +271,7 @@ static const char *const INPUT_SHAPE = "input_shape"; | |||
| static const char *const OP_NAME_MAP = "op_name_map"; | |||
| static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; | |||
| static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; | |||
| static const char *const DYNAMIC_DIMS = kDynamicDims; | |||
| static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); | |||
| static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); | |||
| static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; | |||
| @@ -283,10 +294,11 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c | |||
| // for interface: aclgrphBuildModel | |||
| const std::set<std::string> ir_builder_suppported_options = { | |||
| INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, | |||
| DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
| AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, INPUT_FP16_NODES, | |||
| LOG_LEVEL}; | |||
| INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, | |||
| DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, DYNAMIC_DIMS, | |||
| INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||
| AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, | |||
| INPUT_FP16_NODES, LOG_LEVEL}; | |||
| // for interface: aclgrphBuildInitialize | |||
| const std::set<std::string> global_options = {CORE_TYPE, | |||
| SOC_VERSION, | |||
| @@ -34,6 +34,7 @@ using std::vector; | |||
| namespace ge { | |||
| class AttrValueImpl; | |||
| /*lint -e148*/ | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { | |||
| public: | |||
| using INT = int64_t; | |||
| @@ -69,5 +70,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { | |||
| VALUE_SET_GET_DEC(AttrValue::FLOAT) | |||
| #undef VALUE_SET_GET_DEC | |||
| }; | |||
| /*lint +e148*/ | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ | |||
| @@ -61,6 +61,7 @@ using std::function; | |||
| using std::shared_ptr; | |||
| using std::string; | |||
| /*lint -e148*/ | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| public: | |||
| friend class OperatorImpl; | |||
| @@ -88,7 +89,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| explicit Operator(const string &type); | |||
| Operator(const string &name, const string &type); | |||
| Operator(const string &name, const string &type); // lint !e148 | |||
| virtual ~Operator() = default; | |||
| @@ -101,7 +102,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| // Only has one output index = 0 | |||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt); | |||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); | |||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 | |||
| Operator &AddControlInput(const Operator &src_oprt); | |||
| @@ -123,22 +124,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| TensorDesc GetOutputDesc(uint32_t index) const; | |||
| graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); | |||
| graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 | |||
| TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; | |||
| graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); | |||
| graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 | |||
| TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; | |||
| graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); | |||
| graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 | |||
| graphStatus InferShapeAndType(); | |||
| graphStatus InferShapeAndType(); // lint !e148 | |||
| void SetInferenceContext(const InferenceContextPtr &inference_context); | |||
| InferenceContextPtr GetInferenceContext() const; | |||
| graphStatus VerifyAllAttr(bool disable_common_verifier = false); | |||
| graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 | |||
| size_t GetInputsSize() const; | |||
| @@ -251,19 +252,20 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| void RequiredAttrRegister(const string &name); | |||
| graphStatus VerifyAll(); | |||
| graphStatus VerifyAll(); // lint !e148 | |||
| // Only has one output index = 0 | |||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); | |||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | |||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, | |||
| const string &name); // lint !e148 | |||
| void SubgraphRegister(const string &ir_name, bool dynamic); | |||
| void SubgraphCountRegister(const string &ir_name, uint32_t count); | |||
| void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); | |||
| private: | |||
| Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | |||
| Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 | |||
| OutHandler GetOutput(const string &name) const; | |||
| @@ -273,6 +275,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||
| graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; | |||
| }; | |||
| /*lint +e148*/ | |||
| } // namespace ge | |||
| #endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ | |||
| @@ -343,6 +343,7 @@ class OpReg { | |||
| auto x_type = op.GetInputDesc(in_name).GetDataType(); \ | |||
| TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ | |||
| op_output_desc.SetShape(ge::Shape(x_shape)); \ | |||
| op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ | |||
| op_output_desc.SetDataType(x_type); \ | |||
| return op.UpdateOutputDesc(out_name, op_output_desc); \ | |||
| } | |||
| @@ -126,5 +126,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { | |||
| friend class TensorAdapter; | |||
| }; | |||
| } // namespace ge | |||
| /*lint +e148*/ | |||
| #endif // INC_EXTERNAL_GRAPH_TENSOR_H_ | |||
| @@ -145,7 +145,8 @@ enum Format { | |||
| FORMAT_FRACTAL_ZN_LSTM, | |||
| FORMAT_FRACTAL_Z_G, | |||
| FORMAT_RESERVED, | |||
| FORMAT_ALL | |||
| FORMAT_ALL, | |||
| FORMAT_NULL | |||
| }; | |||
| // for unknown shape op type | |||
| @@ -40,6 +40,7 @@ using std::to_string; | |||
| using std::unique_ptr; | |||
| using std::vector; | |||
| /*lint -e148*/ | |||
| namespace ge { | |||
| class Operator; | |||
| class TensorDesc; | |||
| @@ -98,6 +99,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
| OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); | |||
| OpRegistrationData &InputReorderVector(const vector<int> &input_order); | |||
| domi::ImplyType GetImplyType() const; | |||
| std::string GetOmOptype() const; | |||
| std::set<std::string> GetOriginOpTypeSet() const; | |||
| @@ -130,4 +133,5 @@ namespace ge { | |||
| using OpRegistrationData = domi::OpRegistrationData; | |||
| using OpReceiver = domi::OpReceiver; | |||
| } // namespace ge | |||
| /*lint +e148*/ | |||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | |||
| @@ -51,30 +51,6 @@ inline pid_t GetTid() { | |||
| return tid; | |||
| } | |||
| #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||
| #define GE_TIMESTAMP_END(stage, stage_name) \ | |||
| do { \ | |||
| uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ | |||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||
| (endUsec_##stage - startUsec_##stage)); \ | |||
| } while (0); | |||
| #define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||
| uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ | |||
| uint64_t call_num_of##stage = 0; \ | |||
| uint64_t time_of##stage = 0 | |||
| #define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) | |||
| #define GE_TIMESTAMP_ADD(stage) \ | |||
| time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ | |||
| call_num_of##stage++ | |||
| #define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | |||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | |||
| call_num_of##stage) | |||
| #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||
| dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ | |||
| ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | |||
| @@ -19,15 +19,12 @@ | |||
| #include <string> | |||
| #include "cce/cce_def.hpp" | |||
| #include "runtime/rt.h" | |||
| #include "common/string_util.h" | |||
| #include "common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "ge/ge_api_error_codes.h" | |||
| using cce::CC_STATUS_SUCCESS; | |||
| using cce::ccStatus_t; | |||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||
| #define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||
| #else | |||
| @@ -102,17 +99,13 @@ using cce::ccStatus_t; | |||
| } while (0); | |||
| // If expr is not true, print the log and return the specified status | |||
| #define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| std::string msg; \ | |||
| (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||
| (void)msg.append( \ | |||
| ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
| DOMI_LOGE("%s", msg.c_str()); \ | |||
| return _status; \ | |||
| } \ | |||
| #define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGE(_status, __VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not true, print the log and return the specified status | |||
| @@ -132,7 +125,7 @@ using cce::ccStatus_t; | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, print the log and execute a custom statement | |||
| #define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \ | |||
| @@ -142,7 +135,7 @@ using cce::ccStatus_t; | |||
| GELOGW(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, print the log and execute a custom statement | |||
| #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||
| { \ | |||
| @@ -151,7 +144,7 @@ using cce::ccStatus_t; | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, print the log and execute a custom statement | |||
| #define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \ | |||
| @@ -161,7 +154,7 @@ using cce::ccStatus_t; | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is true, print logs and execute custom statements | |||
| #define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ | |||
| @@ -171,7 +164,7 @@ using cce::ccStatus_t; | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is true, print the Information log and execute a custom statement | |||
| #define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \ | |||
| { \ | |||
| @@ -180,7 +173,7 @@ using cce::ccStatus_t; | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not SUCCESS, print the log and execute the expression + return | |||
| #define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \ | |||
| @@ -191,7 +184,7 @@ using cce::ccStatus_t; | |||
| exec_expr; \ | |||
| return; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not SUCCESS, print the log and execute the expression + return _status | |||
| #define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ | |||
| @@ -202,7 +195,7 @@ using cce::ccStatus_t; | |||
| exec_expr; \ | |||
| return _status; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, execute a custom statement | |||
| #define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ | |||
| @@ -211,7 +204,7 @@ using cce::ccStatus_t; | |||
| if (!b) { \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // -----------------runtime related macro definitions------------------------------- | |||
| // If expr is not RT_ERROR_NONE, print the log | |||
| @@ -231,7 +224,7 @@ using cce::ccStatus_t; | |||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not RT_ERROR_NONE, print the log and return | |||
| #define GE_CHK_RT_RET(expr) \ | |||
| @@ -239,27 +232,17 @@ using cce::ccStatus_t; | |||
| rtError_t _rt_ret = (expr); \ | |||
| if (_rt_ret != RT_ERROR_NONE) { \ | |||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| return ge::RT_FAILED; \ | |||
| return RT_ERROR_TO_GE_STATUS(_rt_ret); \ | |||
| } \ | |||
| } while (0); | |||
| // ------------------------cce related macro definitions---------------------------- | |||
| // If expr is not CC_STATUS_SUCCESS, print the log | |||
| #define GE_CHK_CCE(expr) \ | |||
| do { \ | |||
| ccStatus_t _cc_ret = (expr); \ | |||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
| DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is true, execute exec_expr without printing logs | |||
| #define GE_IF_BOOL_EXEC(expr, exec_expr) \ | |||
| { \ | |||
| if (expr) { \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If make_shared is abnormal, print the log and execute the statement | |||
| #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| /*lint -e* */ | |||
| #ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | |||
| #define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | |||
| @@ -280,8 +281,24 @@ GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_REDUCE_SCATTER_FAILED, 47, "call hccl hcom r | |||
| // Executor module error code definition | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_NOT_INIT, 1, "GE Executor is not yet initialized."); | |||
| GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 2, "GE AIPP is not exist."); | |||
| GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 3, "GE Dynamic AIPP is not support to query temporarily."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PATH_INVALID, 2, "Model file path is invalid."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_KEY_PATH_INVALID, 3, "Key file path of model is invalid."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_ID_INVALID, 4, "Model id is invalid."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_DATA_SIZE_INVALID, 5, "Data size of model is invalid."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PARTITION_NUM_INVALID, 6, "Partition number of model is invalid."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_QUEUE_ID_INVALID, 7, "Queue id of model is invalid."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, 8, "Model does not support encryption."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_READ_MODEL_FILE_FAILED, 9, "Failed to read model file."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_REPEATED, 10, "The model is loaded repeatedly."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_PARTITION_FAILED, 11, "Failed to load model partition."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED, 12, "Failed to load weight partition."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_TASK_PARTITION_FAILED, 13, "Failed to load task partition."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_KERNEL_PARTITION_FAILED, 14, "Failed to load kernel partition."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, 15, "Failed to allocate feature map memory."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, 16, "Failed to allocate weight memory."); | |||
| GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory."); | |||
| GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist."); | |||
| GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily."); | |||
| // Generator module error code definition | |||
| GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed."); | |||
| @@ -289,6 +306,8 @@ GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, 2, "Graph mana | |||
| GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, 3, "Graph manager build graph failed."); | |||
| GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, 4, "Graph manager finalize failed."); | |||
| GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_SAVE_MODEL_FAILED, 5, "Graph manager save model failed."); | |||
| #define RT_ERROR_TO_GE_STATUS(RT_ERROR) static_cast<Status>(RT_ERROR) | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | |||
| @@ -54,9 +54,9 @@ const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; | |||
| struct DataBuffer { | |||
| public: | |||
| void *data; // Data address | |||
| uint32_t length; // Data length | |||
| uint64_t length; // Data length | |||
| bool isDataSupportMemShare = false; | |||
| DataBuffer(void *dataIn, uint32_t len, bool isSupportMemShare) | |||
| DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare) | |||
| : data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {} | |||
| DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {} | |||
| @@ -106,7 +106,7 @@ struct ShapeDescription { | |||
| // Definition of input and output description information | |||
| struct InputOutputDescInfo { | |||
| std::string name; | |||
| uint32_t size; | |||
| uint64_t size; | |||
| uint32_t data_type; | |||
| ShapeDescription shape_info; | |||
| }; | |||
| @@ -231,6 +231,7 @@ struct Options { | |||
| // Profiling info of task | |||
| struct TaskDescInfo { | |||
| std::string model_name; | |||
| std::string op_name; | |||
| uint32_t block_dim; | |||
| uint32_t task_id; | |||
| @@ -239,6 +240,7 @@ struct TaskDescInfo { | |||
| // Profiling info of graph | |||
| struct ComputeGraphDescInfo { | |||
| std::string model_name; | |||
| std::string op_name; | |||
| std::string op_type; | |||
| std::vector<Format> input_format; | |||
| @@ -44,8 +44,6 @@ class ModelHelper { | |||
| void SetSaveMode(bool val) { is_offline_ = val; } | |||
| bool GetSaveMode(void) const { return is_offline_; } | |||
| static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); | |||
| static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); | |||
| Status GetBaseNameFromFileName(const std::string& file_name, std::string& base_name); | |||
| Status GetModelNameFromMergedGraphName(const std::string& graph_name, std::string& model_name); | |||
| @@ -36,8 +36,8 @@ class StringUtils { | |||
| #endif | |||
| return s; | |||
| } | |||
| static std::string &Rtrim(std::string &s) { | |||
| // lint -esym(551,*) | |||
| static std::string &Rtrim(std::string &s) { /*lint !e618*/ | |||
| #if __cplusplus >= 201103L | |||
| (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); | |||
| #else | |||
| @@ -45,7 +45,7 @@ class StringUtils { | |||
| #endif | |||
| return s; | |||
| } | |||
| // lint -esym(551,*) | |||
| /// | |||
| /// @ingroup domi_common | |||
| /// @brief delete spaces at the beginning and end of a string | |||
| @@ -48,6 +48,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL; | |||
| // Supported public properties name | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | |||
| @@ -335,6 +338,8 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); | |||
| REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | |||
| REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | |||
| REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") | |||
| REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); | |||
| REGISTER_OPTYPE_DECLARE(BITCAST, "Bitcast"); | |||
| // ANN dedicated operator | |||
| REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | |||
| @@ -428,6 +433,8 @@ REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce"); | |||
| REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter"); | |||
| REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); | |||
| REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | |||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | |||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | |||
| REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | |||
| REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); | |||
| @@ -554,6 +561,16 @@ enum ModelCheckType { | |||
| UNCHECK // no verification | |||
| }; | |||
| /// | |||
| /// @brief dynamic input type | |||
| /// | |||
| enum DynamicInputType { | |||
| FIXED = 0, // default mode | |||
| DYNAMIC_BATCH = 1, | |||
| DYNAMIC_IMAGE = 2, | |||
| DYNAMIC_DIMS = 3 | |||
| }; | |||
| /// | |||
| /// @brief magic number of the model file | |||
| /// | |||
| @@ -631,6 +648,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_N | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG; | |||
| // convolution node type | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION; | |||
| // adds a convolutional node name for the hard AIPP | |||
| @@ -21,28 +21,31 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "common/dynamic_aipp.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/ge_types.h" | |||
| #include "common/types.h" | |||
| #include "graph/tensor.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "runtime/base.h" | |||
| #include "common/dynamic_aipp.h" | |||
| namespace ge { | |||
| class ModelListenerAdapter; | |||
| class SingleOp; | |||
| class DynamicSingleOp; | |||
| struct RunModelData { | |||
| uint32_t index; // Data index | |||
| uint32_t modelId; | |||
| std::vector<DataBuffer> blobs; // All input/output data buffer | |||
| uint32_t timestamp; // Data creation time | |||
| uint32_t timeout; // Processing timeout | |||
| uint64_t request_id = 0; // Request ID | |||
| uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 | |||
| uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 | |||
| uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 | |||
| std::vector<DataBuffer> blobs; // All input/output data buffer | |||
| uint32_t timestamp; // Data creation time | |||
| uint32_t timeout; // Processing timeout | |||
| uint64_t request_id = 0; // Request ID | |||
| uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 | |||
| uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 | |||
| uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 | |||
| std::vector<uint64_t> dynamic_dims; // Dynamic dims scene, set dynamic dims, not supported by default:empty | |||
| }; | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| @@ -87,16 +90,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| /// | |||
| ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, | |||
| uint64_t image_width); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Set dynamic dims info | |||
| /// @param [in] model_id: model id allocate from manager | |||
| /// @param [in] dynamic_input_addr: dynamic input addr created by user | |||
| /// @param [in] length: length of dynamic input addr | |||
| /// @param [in] dynamic_dim_num: number of dynamic dimension | |||
| /// @param [in] dynamic_dims: array of dynamic dimensions | |||
| /// @return execute result | |||
| /// | |||
| ge::Status SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, uint64_t length, | |||
| const std::vector<uint64_t> &dynamic_dims); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Get current dynamic dims info by combined dims | |||
| /// @param [in] model_id: model id allocate from manager | |||
| /// @param [in] combined_dims: array of combined dimensions | |||
| /// @param [out] cur_dynamic_dims: current dynamic dims | |||
| /// @return execute result | |||
| /// | |||
| ge::Status GetCurDynamicDims(uint32_t model_id, const std::vector<uint64_t> &combined_dims, | |||
| std::vector<uint64_t> &cur_dynamic_dims); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Get dynamic batch_info | |||
| /// @param [in] model_id | |||
| /// @param [out] batch_info | |||
| /// @param [out] dynamic_type | |||
| /// @return execute result | |||
| /// | |||
| ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info); | |||
| ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info, | |||
| int32_t &dynamic_type); | |||
| ge::Status GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info); | |||
| /// | |||
| /// @ingroup ge | |||
| /// @brief Get combined dynamic dims info | |||
| /// @param [in] model_id | |||
| /// @param [out] batch_info | |||
| /// @return execute result | |||
| /// | |||
| ge::Status GetCombinedDynamicDims(uint32_t model_id, std::vector<std::vector<int64_t>> &batch_info); | |||
| ge::Status GetCurShape(const uint32_t model_id, std::vector<int64_t> &batch_info, int32_t &dynamic_type); | |||
| /// | |||
| /// @ingroup ge | |||
| @@ -209,6 +248,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| static ge::Status ExecuteAsync(SingleOp *executor, const std::vector<DataBuffer> &inputs, | |||
| std::vector<DataBuffer> &outputs); | |||
| static ge::Status LoadDynamicSingleOp(const std::string &model_name, const ge::ModelData &modelData, void *stream, | |||
| DynamicSingleOp **single_op); | |||
| static ge::Status ExecuteAsync(DynamicSingleOp *executor, const std::vector<GeTensorDesc> &input_desc, | |||
| const std::vector<DataBuffer> &inputs, std::vector<GeTensorDesc> &output_desc, | |||
| std::vector<DataBuffer> &outputs); | |||
| static ge::Status ReleaseSingleOpResource(void *stream); | |||
| ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count); | |||
| @@ -28,7 +28,7 @@ | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class RuntimeModel; | |||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||
| class ModelRunner { | |||
| public: | |||
| static ModelRunner &Instance(); | |||
| @@ -36,8 +36,18 @@ class ModelRunner { | |||
| bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||
| std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); | |||
| bool DistributeTask(uint32_t model_id); | |||
| bool LoadModelComplete(uint32_t model_id); | |||
| const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | |||
| const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const; | |||
| void *GetModelHandle(uint32_t model_id) const; | |||
| bool UnloadModel(uint32_t model_id); | |||
| bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | |||
| @@ -21,6 +21,7 @@ | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "cce/taskdown_api.h" | |||
| @@ -52,21 +53,27 @@ class TaskInfo { | |||
| virtual ~TaskInfo() {} | |||
| uint32_t stream_id() const { return stream_id_; } | |||
| TaskInfoType type() const { return type_; } | |||
| std::string op_name() const { return op_name_; } | |||
| bool dump_flag() const { return dump_flag_; } | |||
| protected: | |||
| TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} | |||
| TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) | |||
| : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} | |||
| private: | |||
| std::string op_name_; | |||
| uint32_t stream_id_; | |||
| TaskInfoType type_; | |||
| bool dump_flag_; | |||
| }; | |||
| class CceTaskInfo : public TaskInfo { | |||
| public: | |||
| CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, | |||
| const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, | |||
| const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||
| : TaskInfo(stream_id, TaskInfoType::CCE), | |||
| CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, | |||
| uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size, | |||
| const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table, | |||
| const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), | |||
| ctx_(ctx), | |||
| stub_func_(stub_func), | |||
| block_dim_(block_dim), | |||
| @@ -102,11 +109,11 @@ class CceTaskInfo : public TaskInfo { | |||
| class TbeTaskInfo : public TaskInfo { | |||
| public: | |||
| TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args, | |||
| uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size, | |||
| const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs) | |||
| : TaskInfo(stream_id, TaskInfoType::TBE), | |||
| TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, | |||
| const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, | |||
| uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), | |||
| stub_func_(stub_func), | |||
| block_dim_(block_dim), | |||
| args_(args), | |||
| @@ -153,9 +160,10 @@ class TbeTaskInfo : public TaskInfo { | |||
| class AicpuTaskInfo : public TaskInfo { | |||
| public: | |||
| AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, | |||
| const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs) | |||
| : TaskInfo(stream_id, TaskInfoType::AICPU), | |||
| AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | |||
| const std::string &node_def, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||
| so_name_(so_name), | |||
| kernel_name_(kernel_name), | |||
| node_def_(node_def), | |||
| @@ -177,37 +185,45 @@ class AicpuTaskInfo : public TaskInfo { | |||
| std::vector<void *> output_data_addrs_; | |||
| }; | |||
| class LabelTaskInfo : public TaskInfo { | |||
| class LabelSetTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| protected: | |||
| LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||
| : TaskInfo(stream_id, type), label_id_(label_id) {} | |||
| virtual ~LabelTaskInfo() override {} | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSetTaskInfo : public LabelTaskInfo { | |||
| class LabelGotoTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSwitchTaskInfo : public LabelTaskInfo { | |||
| class LabelSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||
| LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, | |||
| const std::vector<uint32_t> &label_list, void *cond) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), | |||
| label_size_(label_size), | |||
| label_list_(label_list), | |||
| cond_(cond) {} | |||
| ~LabelSwitchTaskInfo() override {} | |||
| }; | |||
| uint32_t label_size() { return label_size_; }; | |||
| const std::vector<uint32_t> &label_list() { return label_list_; }; | |||
| void *cond() { return cond_; }; | |||
| class LabelGotoTaskInfo : public LabelTaskInfo { | |||
| public: | |||
| LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| private: | |||
| uint32_t label_size_; | |||
| std::vector<uint32_t> label_list_; | |||
| void *cond_; | |||
| }; | |||
| class EventTaskInfo : public TaskInfo { | |||
| @@ -215,8 +231,8 @@ class EventTaskInfo : public TaskInfo { | |||
| uint32_t event_id() const { return event_id_; } | |||
| protected: | |||
| EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||
| : TaskInfo(stream_id, type), event_id_(event_id) {} | |||
| EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||
| : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} | |||
| virtual ~EventTaskInfo() override {} | |||
| uint32_t event_id_; | |||
| @@ -224,39 +240,41 @@ class EventTaskInfo : public TaskInfo { | |||
| class EventRecordTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||
| EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||
| ~EventRecordTaskInfo() override {} | |||
| }; | |||
| class EventWaitTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||
| EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||
| ~EventWaitTaskInfo() override {} | |||
| }; | |||
| class FusionStartTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} | |||
| explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} | |||
| ~FusionStartTaskInfo() override {} | |||
| }; | |||
| class FusionEndTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} | |||
| explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} | |||
| ~FusionEndTaskInfo() override {} | |||
| }; | |||
| class HcclTaskInfo : public TaskInfo { | |||
| public: | |||
| HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, | |||
| void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||
| HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, | |||
| void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||
| const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | |||
| int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model, | |||
| std::function<bool(void *)> hcom_unbind_model, | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task) | |||
| : TaskInfo(stream_id, TaskInfoType::HCCL), | |||
| int64_t op_type, int64_t data_type, const std::string &group, | |||
| std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model, | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), | |||
| hccl_type_(hccl_type), | |||
| input_data_addr_(input_data_addr), | |||
| output_data_addr_(output_data_addr), | |||
| @@ -269,6 +287,7 @@ class HcclTaskInfo : public TaskInfo { | |||
| root_id_(root_id), | |||
| op_type_(op_type), | |||
| data_type_(data_type), | |||
| group_(group), | |||
| hcom_bind_model_(hcom_bind_model), | |||
| hcom_unbind_model_(hcom_unbind_model), | |||
| hcom_distribute_task_(hcom_distribute_task) {} | |||
| @@ -286,6 +305,7 @@ class HcclTaskInfo : public TaskInfo { | |||
| int64_t root_id() const { return root_id_; } | |||
| int64_t op_type() const { return op_type_; } | |||
| int64_t data_type() const { return data_type_; } | |||
| const std::string &group() const { return group_; } | |||
| std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | |||
| std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | |||
| @@ -305,6 +325,7 @@ class HcclTaskInfo : public TaskInfo { | |||
| int64_t root_id_; | |||
| int64_t op_type_; | |||
| int64_t data_type_; | |||
| std::string group_; | |||
| std::function<bool(void *, void *)> hcom_bind_model_; | |||
| std::function<bool(void *)> hcom_unbind_model_; | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_; | |||
| @@ -312,8 +333,11 @@ class HcclTaskInfo : public TaskInfo { | |||
| class ProfilerTraceTaskInfo : public TaskInfo { | |||
| public: | |||
| ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||
| : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} | |||
| ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), | |||
| log_id_(log_id), | |||
| notify_(notify), | |||
| flat_(flat) {} | |||
| ~ProfilerTraceTaskInfo() override {} | |||
| uint64_t log_id() const { return log_id_; } | |||
| @@ -328,8 +352,9 @@ class ProfilerTraceTaskInfo : public TaskInfo { | |||
| class MemcpyAsyncTaskInfo : public TaskInfo { | |||
| public: | |||
| MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) | |||
| : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), | |||
| MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, | |||
| uint64_t count, uint32_t kind, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), | |||
| dst_(dst), | |||
| dst_max_(dst_max), | |||
| src_(src), | |||
| @@ -353,9 +378,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo { | |||
| class StreamSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, | |||
| int64_t data_type) | |||
| : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), | |||
| StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, | |||
| void *value_addr, int64_t cond, int64_t data_type) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), | |||
| true_stream_id_(true_stream_id), | |||
| input_addr_(input_addr), | |||
| value_addr_(value_addr), | |||
| @@ -379,8 +404,8 @@ class StreamSwitchTaskInfo : public TaskInfo { | |||
| class StreamActiveTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) | |||
| : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} | |||
| StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} | |||
| ~StreamActiveTaskInfo() override {} | |||
| uint32_t active_stream_id() const { return active_stream_id_; } | |||
| @@ -27,6 +27,7 @@ | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/op_desc.h" | |||
| #include "graph/detail/attributes_holder.h" | |||
| namespace ge { | |||
| class GeGenerator { | |||
| @@ -0,0 +1,56 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_FRAMEWORK_MEMORY_MEMORY_API_H_ | |||
| #define INC_FRAMEWORK_MEMORY_MEMORY_API_H_ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "ge/ge_api_error_codes.h" | |||
| #include "runtime/mem.h" | |||
| namespace ge { | |||
| enum MemStorageType { | |||
| HBM = 0, | |||
| RDMA_HBM, | |||
| }; | |||
| struct HostVarInfo { | |||
| uint64_t base_addr; | |||
| uint64_t var_size; | |||
| }; | |||
| /// | |||
| /// \param size [in] rdma pool memory size to be allocated. | |||
| /// \param mem_type [in] memory type for rdma pool. | |||
| /// \return Status result of function | |||
| Status InitRdmaPool(size_t size, rtMemType_t mem_type = RT_MEMORY_HBM); | |||
| /// | |||
| /// \param var_info [in] host variable addr infos. | |||
| /// \param mem_type [in] memory type for rdma pool. | |||
| /// \return Status result of function | |||
| Status RdmaRemoteRegister(const std::vector<HostVarInfo> &var_info, rtMemType_t mem_type = RT_MEMORY_HBM); | |||
| /// | |||
| /// \param var_name [in] var_name name of host variable. | |||
| /// \param base_addr [out] base_addr vase addr of host variable. | |||
| /// \param var_size [out] var_size memory_size of host variable. | |||
| /// \return Status result of function | |||
| Status GetVarBaseAddrAndSize(const std::string &var_name, uint64_t &base_addr, uint64_t &var_size); | |||
| } // namespace ge | |||
| #endif // INC_FRAMEWORK_MEMORY_MEMORY_API_H_ | |||
| @@ -96,17 +96,12 @@ Status CheckCustomAiCpuOpLib(); | |||
| Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); | |||
| Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); | |||
| Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name); | |||
| void UpdateOmgCtxWithParserCtx(); | |||
| void UpdateParserCtxWithOmgCtx(); | |||
| } // namespace ge | |||
| namespace domi { | |||
| @@ -120,6 +120,7 @@ struct OmgContext { | |||
| bool is_dynamic_input = false; | |||
| std::string dynamic_batch_size; | |||
| std::string dynamic_image_size; | |||
| std::string dynamic_dims; | |||
| }; | |||
| } // namespace ge | |||
| @@ -57,11 +57,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { | |||
| // For compatibility | |||
| inline const std::uint8_t *data() const { return GetData(); } | |||
| inline std::uint8_t *data() { return GetData(); } | |||
| inline std::uint8_t *data() { return GetData(); } // lint !e659 | |||
| inline std::size_t size() const { return GetSize(); } | |||
| inline void clear() { return ClearBuffer(); } | |||
| uint8_t operator[](size_t index) const { | |||
| if (buffer_ != nullptr && index < buffer_->size()) { | |||
| uint8_t operator[](size_t index) const { // lint !e1022 !e1042 | |||
| if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 | |||
| return (uint8_t)(*buffer_)[index]; | |||
| } | |||
| return 0xff; | |||
| @@ -74,6 +74,9 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| size_t GetAllNodesSize() const; | |||
| Vistor<NodePtr> GetAllNodes() const; | |||
| // is_unknown_shape: false, same with GetAllNodes func | |||
| // is_unknown_shape: true, same with GetDirectNodes func | |||
| Vistor<NodePtr> GetNodes(bool is_unknown_shape) const; | |||
| size_t GetDirectNodesSize() const; | |||
| Vistor<NodePtr> GetDirectNode() const; | |||
| Vistor<NodePtr> GetInputNodes() const; | |||
| @@ -81,14 +84,18 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| NodePtr FindNode(const std::string &name) const; | |||
| NodePtr FindFirstNodeMatchType(const std::string &name) const; | |||
| /*lint -e504*/ | |||
| // AddNode with NodePtr | |||
| NodePtr AddNode(NodePtr node); | |||
| NodePtr AddNode(OpDescPtr op); | |||
| NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize. | |||
| NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize | |||
| NodePtr AddNodeFront(NodePtr node); | |||
| NodePtr AddNodeFront(const OpDescPtr &op); | |||
| NodePtr AddInputNode(NodePtr node); | |||
| NodePtr AddOutputNode(NodePtr node); | |||
| // insert node with specific pre_node | |||
| NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); | |||
| NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); | |||
| graphStatus RemoveNode(const NodePtr &node); | |||
| graphStatus RemoveInputNode(const NodePtr &node); | |||
| @@ -133,6 +140,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| bool IsValid() const; | |||
| void Dump() const; | |||
| void Swap(ComputeGraph &graph); | |||
| graphStatus IsolateNode(const NodePtr &node); | |||
| graphStatus Verify(); | |||
| graphStatus InferShape(); | |||
| @@ -141,6 +150,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| graphStatus InsertEventNodes(); | |||
| bool operator==(const ComputeGraph &r_compute_graph) const; | |||
| /*lint +e504*/ | |||
| const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const { | |||
| return params_share_map_; | |||
| } | |||
| @@ -174,6 +184,10 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| void SetInputSize(uint32_t size) { input_size_ = size; } | |||
| uint32_t GetInputSize() const { return input_size_; } | |||
| // false: known shape true: unknow shape | |||
| bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } | |||
| void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } | |||
| /// | |||
| /// Set is need train iteration. | |||
| /// If set true, it means this graph need to be run iteration some | |||
| @@ -249,6 +263,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| bool VectorInputNodePtrIsEqual(const std::vector<NodePtr> &r_node_ptr_vector, | |||
| const std::vector<NodePtr> &l_node_ptr_vector) const; | |||
| void SetNodesOwner(); | |||
| friend class ModelSerializeImp; | |||
| friend class GraphDebugImp; | |||
| friend class OnnxUtils; | |||
| @@ -282,7 +298,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| std::map<uint32_t, std::string> op_name_map_; | |||
| uint64_t session_id_ = 0; | |||
| ge::Format data_format_ = ge::FORMAT_ND; | |||
| // unknown graph indicator, default is false, mean known shape | |||
| bool is_unknown_shape_graph_ = false; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_COMPUTE_GRAPH_H_ | |||
| @@ -14,6 +14,7 @@ | |||
| * limitations under the License. | |||
| */ | |||
| /*lint -e618*/ | |||
| #ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | |||
| #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | |||
| @@ -185,6 +186,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT; | |||
| // to be deleted | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; | |||
| @@ -778,6 +782,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
| @@ -930,12 +938,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; | |||
| // Control flow | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; | |||
| @@ -979,6 +989,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEE | |||
| // For mutil-batch | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE; | |||
| // For inserted op | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; | |||
| @@ -996,7 +1007,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | |||
| // used for l1 fusion and other fusion in future | |||
| // used for lX fusion | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; | |||
| @@ -1010,9 +1021,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; | |||
| // for unregistered op | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; | |||
| // op overflow dump | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; | |||
| // functional ops attr | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; | |||
| @@ -1058,6 +1081,31 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOR | |||
| // for gradient group | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; | |||
| // dynamic shape attrs | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; | |||
| // atc user def dtype&format | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; | |||
| // for fusion op plugin | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; | |||
| // graph partition for aicpu | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME; | |||
| // input and output memory type | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; | |||
| // input_output_offset | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | |||
| /*lint +e618*/ | |||
| @@ -38,7 +38,7 @@ class TypeID { | |||
| bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } | |||
| private: | |||
| explicit TypeID(string type) : type_(std::move(type)) {} | |||
| explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 | |||
| string type_; | |||
| }; | |||
| @@ -53,6 +53,8 @@ class AnyMap { | |||
| bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); } | |||
| void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); } | |||
| private: | |||
| class Placeholder { | |||
| public: | |||
| @@ -50,7 +50,7 @@ class OpDef; | |||
| class GraphDef; | |||
| } // namespace proto | |||
| using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; | |||
| using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 | |||
| using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; | |||
| template <class ProtoType> | |||
| @@ -95,6 +95,14 @@ class GeIrProtoHelper { | |||
| } | |||
| } | |||
| void Swap(GeIrProtoHelper<ProtoType> &other) { | |||
| protoOwner_.swap(other.protoOwner_); | |||
| ProtoType *temp = protoMsg_; | |||
| protoMsg_ = other.protoMsg_; | |||
| other.protoMsg_ = temp; | |||
| } | |||
| // protoMsg_ is part of protoOwner_, they have the same runtime | |||
| ProtoMsgOwner protoOwner_ = nullptr; | |||
| ProtoType *protoMsg_ = nullptr; | |||
| @@ -120,6 +128,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { | |||
| void CopyAttrsFrom(const AttrHolder &holder); | |||
| void Swap(AttrHolder &holder) { | |||
| requiredAttrs_.swap(holder.requiredAttrs_); | |||
| extAttrs_.Swap(holder.extAttrs_); | |||
| } | |||
| template <class T> | |||
| bool SetExtAttr(const string &name, const T &value) { | |||
| return extAttrs_.Set(name, value); | |||
| @@ -134,7 +147,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { | |||
| protected: | |||
| graphStatus AddRequiredAttr(const std::string &name); | |||
| const std::unordered_set<string> GetAllAttrNames() const; | |||
| const std::map<string, GeAttrValue> GetAllAttrs() const; | |||
| const std::map<string, GeAttrValue> GetAllAttrs() const; // lint !e1073 | |||
| virtual ProtoAttrMapHelper MutableAttrMap() = 0; | |||
| virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; | |||
| @@ -149,5 +162,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { | |||
| AnyMap extAttrs_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ | |||
| @@ -67,6 +67,9 @@ class ModelSerializeImp { | |||
| bool HandleNodeNameRef(); | |||
| bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto); | |||
| void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out, | |||
| std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out, std::vector<string> &opt); | |||
| void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto); | |||
| bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto); | |||
| @@ -310,7 +310,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
| VALUE_SET_GET_DEC(GeAttrValue::GRAPH) | |||
| VALUE_SET_GET_DEC(BYTES) | |||
| VALUE_SET_GET_DEC(NamedAttrs) | |||
| VALUE_SET_GET_DEC(ge::DataType) | |||
| VALUE_SET_GET_DEC(ge::DataType) // lint !e665 | |||
| VALUE_SET_GET_DEC(vector<GeAttrValue::STR>) | |||
| VALUE_SET_GET_DEC(vector<GeAttrValue::INT>) | |||
| VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>) | |||
| @@ -320,8 +320,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||
| VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>) | |||
| VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>) | |||
| VALUE_SET_GET_DEC(vector<NamedAttrs>) | |||
| VALUE_SET_GET_DEC(vector<vector<int64_t>>) | |||
| VALUE_SET_GET_DEC(vector<ge::DataType>) | |||
| VALUE_SET_GET_DEC(vector<vector<int64_t>>) // lint !e665 | |||
| VALUE_SET_GET_DEC(vector<ge::DataType>) // lint !e665 | |||
| #undef VALUE_SET_GET_DEC | |||
| GeIrProtoHelper<proto::AttrDef> value_; | |||
| @@ -28,6 +28,7 @@ class GEContext { | |||
| uint32_t DeviceId(); | |||
| uint64_t TraceId(); | |||
| void Init(); | |||
| void SetSessionId(uint64_t session_id); | |||
| void SetCtxDeviceId(uint32_t device_id); | |||
| private: | |||
| @@ -25,6 +25,7 @@ | |||
| #include "graph/buffer.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
| public: | |||
| @@ -108,8 +109,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||
| DataType GetDataType() const; | |||
| void SetDataType(DataType dt); | |||
| void SetOriginDataType(DataType originDataType); | |||
| DataType GetOriginDataType() const; | |||
| void SetOriginDataType(DataType originDataType); | |||
| std::vector<uint32_t> GetRefPortIndex() const; | |||
| void SetRefPortByIndex(const std::vector<uint32_t> &index); | |||
| GeTensorDesc Clone() const; | |||
| GeTensorDesc &operator=(const GeTensorDesc &desc); | |||
| @@ -186,5 +190,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { | |||
| GeTensorDesc &DescReference() const; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_GE_TENSOR_H_ | |||
| @@ -49,5 +49,4 @@ class ModelSerialize { | |||
| friend class GraphDebugImp; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_MODEL_SERIALIZE_H_ | |||
| @@ -190,7 +190,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||
| vector<OutDataAnchorPtr> out_data_anchors_; | |||
| InControlAnchorPtr in_control_anchor_; | |||
| OutControlAnchorPtr out_control_anchor_; | |||
| map<string, GeAttrValue> attrs_; | |||
| map<string, GeAttrValue> attrs_; // lint !e1073 | |||
| bool has_init_{false}; | |||
| bool anchor_status_updated_{false}; | |||
| std::vector<uint32_t> send_event_id_list_; | |||
| @@ -105,6 +105,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| GeTensorDescPtr MutableInputDesc(uint32_t index) const; | |||
| GeTensorDescPtr MutableInputDesc(const string &name) const; | |||
| Vistor<GeTensorDesc> GetAllInputsDesc() const; | |||
| Vistor<GeTensorDescPtr> GetAllInputsDescPtr() const; | |||
| @@ -127,6 +129,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | |||
| GeTensorDescPtr MutableOutputDesc(const string &name) const; | |||
| uint32_t GetAllOutputsDescSize() const; | |||
| Vistor<GeTensorDesc> GetAllOutputsDesc() const; | |||
| @@ -149,16 +153,15 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); | |||
| void RemoveInputDesc(uint32_t index); | |||
| void RemoveOutputDesc(uint32_t index); | |||
| bool IsOptionalInput(const string &name) const; | |||
| bool IsOptionalInput(uint32_t index) const; | |||
| std::map<string, uint32_t> GetAllInputName() const; | |||
| void SetAllInputName(const std::map<string, uint32_t> &input_name_idx); | |||
| std::vector<string> GetAllOptionalInputName() const; | |||
| std::map<string, uint32_t> GetAllOutputName(); | |||
| bool UpdateInputName(std::map<string, uint32_t> inputNameIdx); | |||
| @@ -296,6 +299,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| std::map<std::string, SubgraphType> subgraph_ir_names_to_type_; | |||
| vector<GeTensorDescPtr> inputs_desc_{}; | |||
| map<string, uint32_t> input_name_idx_{}; | |||
| std::unordered_set<string> optional_input_names_{}; | |||
| vector<GeTensorDescPtr> outputs_desc_{}; | |||
| map<string, uint32_t> output_name_idx_{}; | |||
| std::function<graphStatus(Operator &)> infer_func_ = nullptr; | |||
| @@ -31,6 +31,7 @@ class ShapeRefiner { | |||
| static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); | |||
| static graphStatus InferShapeAndType(const NodePtr &node); | |||
| static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); | |||
| static void ClearContextMap(); | |||
| private: | |||
| static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); | |||
| @@ -23,6 +23,8 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include <list> | |||
| #include <unordered_map> | |||
| #include "graph/anchor.h" | |||
| #include "graph/node.h" | |||
| #include "graph/compute_graph.h" | |||
| @@ -130,7 +132,7 @@ struct NodeIndexIO { | |||
| IOType io_type_ = kOut; | |||
| std::string value_; | |||
| std::string ToString() const { return value_; } | |||
| const std::string &ToString() const { return value_; } | |||
| }; | |||
| class GraphUtils { | |||
| @@ -188,8 +190,8 @@ class GraphUtils { | |||
| /// @param [in] output_index | |||
| /// @return graphStatus | |||
| /// | |||
| static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||
| static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||
| static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | |||
| @@ -303,8 +305,33 @@ class GraphUtils { | |||
| /// | |||
| static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | |||
| /// | |||
| /// Copy all in-data edges from `src_node` to `dst_node` | |||
| /// @param src_node | |||
| /// @param dst_node | |||
| /// @return | |||
| /// | |||
| static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); | |||
| static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||
| /// | |||
| /// Make a copy of ComputeGraph. | |||
| /// @param graph: original graph. | |||
| /// @param prefix: node name prefix of new graph. | |||
| /// @return ComputeGraphPtr | |||
| /// | |||
| static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix, | |||
| std::vector<NodePtr> &input_nodes, std::vector<NodePtr> &output_nodes); | |||
| /// | |||
| /// Copy tensor attribute to new node. | |||
| /// @param [in] dst_desc: cloned node. | |||
| /// @param [in] src_node: original node. | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); | |||
| static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||
| /// | |||
| @@ -392,6 +419,16 @@ class GraphUtils { | |||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol); | |||
| /// | |||
| /// Relink all edges for cloned ComputeGraph. | |||
| /// @param [in] node: original node. | |||
| /// @param [in] prefix: node name prefix of new node. | |||
| /// @param [in] all_nodes: all nodes in new graph. | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, | |||
| const std::unordered_map<string, NodePtr> &all_nodes); | |||
| /// | |||
| /// Union ref-mapping | |||
| /// @param [in] exist_node_info1 | |||
| @@ -728,5 +765,4 @@ class PartialGraphBuilder : public ComputeGraphBuilder { | |||
| std::vector<NodePtr> exist_nodes_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ | |||
| @@ -63,6 +63,9 @@ class NodeUtils { | |||
| static void UnlinkAll(const Node &node); | |||
| static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); | |||
| static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t index); | |||
| static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t index); | |||
| static bool IsInNodesEmpty(const Node &node); | |||
| static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); | |||
| static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); | |||
| @@ -99,6 +102,13 @@ class NodeUtils { | |||
| /// | |||
| static NodePtr GetParentInput(const NodePtr &node); | |||
| /// | |||
| /// @brief Check is varying_input for while node | |||
| /// @param [in] node: Data node for subgraph | |||
| /// @return bool | |||
| /// | |||
| static bool IsWhileVaryingInput(const ge::NodePtr &node); | |||
| /// | |||
| /// @brief Get subgraph input is constant. | |||
| /// @param [in] node | |||
| @@ -114,6 +124,24 @@ class NodeUtils { | |||
| /// | |||
| static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| static vector<NodePtr> GetSubgraphDataNodesByIndex(const Node &node, int index); | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| static vector<NodePtr> GetSubgraphOutputNodes(const Node &node); | |||
| static NodePtr GetInDataNodeByIndex(const Node &node, int index); | |||
| static vector<NodePtr> GetOutDataNodesByIndex(const Node &node, int index); | |||
| private: | |||
| static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | |||
| static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | |||
| @@ -20,6 +20,7 @@ | |||
| #include <memory> | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/tensor.h" | |||
| namespace ge { | |||
| using GeTensorPtr = std::shared_ptr<GeTensor>; | |||
| using ConstGeTensorPtr = std::shared_ptr<const GeTensor>; | |||
| @@ -21,6 +21,7 @@ | |||
| #include "graph/def_types.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/ge_tensor.h" | |||
| namespace ge { | |||
| class TensorUtils { | |||
| public: | |||
| @@ -71,5 +71,6 @@ target_link_libraries(graph PRIVATE | |||
| ${PROTOBUF_LIBRARY} | |||
| ${c_sec} | |||
| ${slog} | |||
| ${error_manager} | |||
| rt | |||
| dl) | |||
| @@ -62,18 +62,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() co | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { | |||
| size_t s = nodes_.size(); | |||
| for (const auto &sub_graph : sub_graph_) { | |||
| s += sub_graph->GetAllNodesSize(); | |||
| } | |||
| return s; | |||
| return GetAllNodes().size(); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | |||
| if (sub_graph_.empty()) { | |||
| return Vistor<NodePtr>(shared_from_this(), nodes_); | |||
| } | |||
| std::vector<std::shared_ptr<ComputeGraph>> subgraphs; | |||
| return AllGraphNodes(subgraphs); | |||
| } | |||
| @@ -106,6 +98,15 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::share | |||
| return Vistor<NodePtr>(shared_from_this(), all_nodes); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes( | |||
| bool is_unknown_shape) const { | |||
| if (is_unknown_shape) { | |||
| return GetDirectNode(); | |||
| } else { | |||
| return GetAllNodes(); | |||
| } | |||
| } | |||
| size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const { | |||
| @@ -268,7 +269,7 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) { | |||
| NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { | |||
| if (op == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); | |||
| return nullptr; | |||
| } | |||
| op->SetId(nodes_.size()); | |||
| @@ -278,9 +279,38 @@ NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { | |||
| return AddNodeFront(node_ptr); | |||
| } | |||
| NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) { | |||
| if (node == nullptr || node->GetOpDesc() == nullptr || pre_node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); | |||
| return nullptr; | |||
| } | |||
| node->GetOpDesc()->SetId(nodes_.size()); | |||
| auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node); | |||
| if (node_iter != nodes_.end()) { | |||
| nodes_.insert(node_iter + 1, node); | |||
| } else { | |||
| GELOGE(GRAPH_FAILED, "Cannot find pre_node in nodes_."); | |||
| return nullptr; | |||
| } | |||
| return node; | |||
| } | |||
| NodePtr ComputeGraph::AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node) { | |||
| if (op == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); | |||
| return nullptr; | |||
| } | |||
| op->SetId(nodes_.size()); | |||
| NodePtr node_ptr = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this())); | |||
| GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); | |||
| GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init failed."); return nullptr); | |||
| return AddNodeAfter(node_ptr, pre_node); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) { | |||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | |||
| return nullptr; | |||
| } | |||
| node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); | |||
| @@ -290,7 +320,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(Nod | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) { | |||
| if (op == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); | |||
| return nullptr; | |||
| } | |||
| op->SetId(GetDirectNodesSize()); | |||
| @@ -302,7 +332,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpD | |||
| NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. | |||
| if (op == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); | |||
| return nullptr; | |||
| } | |||
| op->SetId(id); | |||
| @@ -315,7 +345,7 @@ NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. | |||
| NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | |||
| return nullptr; | |||
| } | |||
| input_nodes_.push_back(node); | |||
| @@ -327,7 +357,7 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||
| NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | |||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); | |||
| return nullptr; | |||
| } | |||
| @@ -363,7 +393,7 @@ graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { | |||
| if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) { | |||
| GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED, | |||
| "Remove edge from const op failed."); | |||
| if (out_anchor->GetOwnerNode()->GetOutDataNodes().size() == 0) { | |||
| if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) { | |||
| GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str()); | |||
| auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode()); | |||
| if (iter != nodes_.end()) { | |||
| @@ -377,7 +407,7 @@ graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -406,7 +436,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveN | |||
| // Used in sub_graph scenes | |||
| graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -421,7 +451,7 @@ graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { | |||
| // Used in sub_graph scenes | |||
| graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The node ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -442,7 +472,7 @@ graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { | |||
| std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeGraph> sub_graph) { | |||
| if (sub_graph == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); | |||
| return nullptr; | |||
| } | |||
| sub_graph_.push_back(sub_graph); | |||
| @@ -452,7 +482,7 @@ std::shared_ptr<ComputeGraph> ComputeGraph::AddSubGraph(std::shared_ptr<ComputeG | |||
| graphStatus ComputeGraph::RemoveSubGraph(const std::shared_ptr<ComputeGraph> &sub_graph) { | |||
| if (sub_graph == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); | |||
| GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -491,12 +521,15 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<Compute | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| if (!this->parent_graph_.expired()) { | |||
| GE_LOGE("The subgraphs can only be added to the root graph"); | |||
| return GRAPH_PARAM_INVALID; | |||
| GELOGW("The subgraphs should only be added to the root graph"); | |||
| } | |||
| if (name != subgraph->GetName()) { | |||
| GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); | |||
| } | |||
| if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { | |||
| GE_LOGE("The subgraph %s existed", name.c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| sub_graph_.push_back(subgraph); | |||
| names_to_subgraph_[name] = subgraph; | |||
| return GRAPH_SUCCESS; | |||
| @@ -640,7 +673,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||
| GELOGW("node or OpDescPtr is nullptr."); | |||
| continue; | |||
| } | |||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED); | |||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED); | |||
| if (node->GetOpDesc()->GetType() == RECV) { | |||
| auto iter = find(node_vec.begin(), node_vec.end(), node); | |||
| if (iter == node_vec.end()) { | |||
| @@ -786,7 +819,8 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::TopologicalSorting() { | |||
| auto ret = TopologicalSortingGraph(); | |||
| if (ret != SUCCESS) { | |||
| GELOGE(ret, "Sub graph partition Failed"); | |||
| GraphUtils::DumpGEGraphToOnnx(*this, "black_box"); | |||
| GELOGE(ret, "Graph [%s] topological sort failed, saved to file black_box", name_.c_str()); | |||
| return ret; | |||
| } | |||
| @@ -1001,6 +1035,54 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Dump() const { | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::Swap(ComputeGraph &graph) { | |||
| this->AttrHolder::Swap(graph); | |||
| origGraph_.swap(graph.origGraph_); | |||
| name_.swap(graph.name_); | |||
| std::swap(graph_id_, graph.graph_id_); | |||
| attrs_.Swap(graph.attrs_); | |||
| nodes_.swap(graph.nodes_); | |||
| all_nodes_infos_.swap(graph.all_nodes_infos_); | |||
| target_nodes_info_.swap(graph.target_nodes_info_); | |||
| input_nodes_.swap(graph.input_nodes_); | |||
| inputs_order_.swap(graph.inputs_order_); | |||
| std::swap(input_size_, graph.input_size_); | |||
| out_nodes_map_.swap(graph.out_nodes_map_); | |||
| std::swap(output_size_, graph.output_size_); | |||
| output_nodes_info_.swap(graph.output_nodes_info_); | |||
| sub_graph_.swap(graph.sub_graph_); | |||
| names_to_subgraph_.swap(graph.names_to_subgraph_); | |||
| parent_graph_.swap(graph.parent_graph_); | |||
| parent_node_.swap(graph.parent_node_); | |||
| // the members followed should not in the ComputeGraph class | |||
| std::swap(is_valid_flag_, graph.is_valid_flag_); | |||
| std::swap(is_summary_graph_, graph.is_summary_graph_); | |||
| std::swap(need_iteration_, graph.need_iteration_); | |||
| params_share_map_.swap(graph.params_share_map_); | |||
| op_name_map_.swap(graph.op_name_map_); | |||
| std::swap(session_id_, graph.session_id_); | |||
| std::swap(data_format_, graph.data_format_); | |||
| std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_); | |||
| // Update Node owner. | |||
| SetNodesOwner(); | |||
| graph.SetNodesOwner(); | |||
| } | |||
| void ComputeGraph::SetNodesOwner() { | |||
| for (const auto &node : nodes_) { | |||
| if (node == nullptr) { | |||
| continue; | |||
| } | |||
| node->SetOwnerComputeGraph(shared_from_this()); | |||
| } | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { | |||
| GE_CHECK_NOTNULL(node); | |||
| auto next_nodes = node->GetOutAllNodes(); | |||
| @@ -1104,9 +1186,11 @@ graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { | |||
| } | |||
| graphStatus ComputeGraph::Verify() { | |||
| bool is_unknown_graph = GetGraphUnknownFlag(); | |||
| for (const auto &node_ptr : GetAllNodes()) { | |||
| GE_CHECK_NOTNULL(node_ptr); | |||
| GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); | |||
| GE_IF_BOOL_EXEC(is_unknown_graph, continue); | |||
| GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED, | |||
| "Verifying %s failed.", node_ptr->GetName().c_str()); | |||
| } | |||
| @@ -34,12 +34,16 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | |||
| GE_REGISTER_OPTYPE(SWITCH, "Switch"); | |||
| GE_REGISTER_OPTYPE(MERGE, "Merge"); | |||
| GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | |||
| GE_REGISTER_OPTYPE(ENTER, "Enter"); | |||
| GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); | |||
| GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | |||
| GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | |||
| GE_REGISTER_OPTYPE(CONSTANT, "Const"); | |||
| GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | |||
| GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | |||
| GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | |||
| GE_REGISTER_OPTYPE(INITDATA, "InitData"); | |||
| GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); | |||
| GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | |||
| GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | |||
| @@ -41,11 +41,9 @@ using namespace ge; | |||
| using namespace std; | |||
| namespace ge { | |||
| namespace { | |||
| static const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
| static bool net_format_is_nd = true; | |||
| static Format g_user_set_format = FORMAT_ND; | |||
| static bool is_first_infer = true; | |||
| static RefRelations reflection_builder; | |||
| const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
| const string kIsGraphInferred = "_is_graph_inferred"; | |||
| RefRelations reflection_builder; | |||
| } // namespace | |||
| graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | |||
| @@ -72,9 +70,49 @@ graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &re | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||
| graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { | |||
| // 5 meas dim num | |||
| if (node_ptr->GetType() != "BiasAdd") { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; | |||
| for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { | |||
| auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); | |||
| GE_CHECK_NOTNULL(in_desc); | |||
| if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||
| continue; | |||
| } | |||
| auto format = in_desc->GetOriginFormat(); | |||
| auto key = TypeUtils::FormatToSerialString(format); | |||
| auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||
| in_desc->SetOriginFormat(fixed_format); | |||
| in_desc->SetFormat(fixed_format); | |||
| GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, | |||
| node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||
| TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||
| } | |||
| for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { | |||
| auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); | |||
| GE_CHECK_NOTNULL(out_desc); | |||
| if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||
| continue; | |||
| } | |||
| auto format = out_desc->GetOriginFormat(); | |||
| auto key = TypeUtils::FormatToSerialString(format); | |||
| auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||
| out_desc->SetOriginFormat(fixed_format); | |||
| out_desc->SetFormat(fixed_format); | |||
| GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, | |||
| node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||
| TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { | |||
| if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { | |||
| ConstGeTensorPtr tensor_value; | |||
| if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { | |||
| GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); | |||
| @@ -95,7 +133,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| } | |||
| anchor_points.clear(); | |||
| // Get all anchor point nodes and switch nodes | |||
| for (const auto &node_ptr : graph->GetAllNodes()) { | |||
| for (auto &node_ptr : graph->GetAllNodes()) { | |||
| if (node_ptr == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -103,7 +141,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| if (op_desc == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| graphStatus status = RefreshConstantOutProcess(op_desc); | |||
| graphStatus status = RefreshConstantOutProcess(graph, op_desc); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); | |||
| return GRAPH_FAILED; | |||
| @@ -135,6 +173,16 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| if (!node_is_all_nd) { | |||
| continue; | |||
| } | |||
| // special process for biasAdd op | |||
| // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg | |||
| // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism | |||
| // so here do special process | |||
| status = BiasAddFormatFixProcess(node_ptr); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); | |||
| anchor_points.push_back(node_ptr); | |||
| } | |||
| @@ -344,14 +392,11 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor | |||
| } | |||
| } | |||
| void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; } | |||
| graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
| graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||
| ge::Format data_format, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||
| bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | |||
| bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); | |||
| if (!need_process) { | |||
| GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, | |||
| if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { | |||
| GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), | |||
| TypeUtils::FormatToSerialString(data_format).c_str()); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -410,8 +455,6 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
| std::vector<ge::NodePtr> anchor_points; | |||
| std::vector<ge::NodePtr> data_nodes; | |||
| // global net format | |||
| net_format_is_nd = true; | |||
| g_user_set_format = FORMAT_ND; | |||
| if (graph == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "input graph is null"); | |||
| @@ -448,10 +491,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
| /// format for these data nodes. | |||
| /// Notice: ignore 5D formats | |||
| auto data_format = graph->GetDataFormat(); | |||
| status = DataNodeFormatProcess(data_nodes, data_format, node_status); | |||
| // Set infer flag to false | |||
| SetInferOrigineFormatFlag(false); | |||
| status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); | |||
| (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); | |||
| return status; | |||
| } | |||
| bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { | |||
| bool is_graph_inferred = false; | |||
| return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); | |||
| } | |||
| } // namespace ge | |||
| @@ -30,10 +30,9 @@ namespace ge { | |||
| class FormatRefiner { | |||
| public: | |||
| static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); | |||
| static void SetInferOrigineFormatFlag(bool is_first = true); | |||
| private: | |||
| static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc); | |||
| static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||
| static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | |||
| std::vector<ge::NodePtr> &data_nodes, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| @@ -43,8 +42,9 @@ class FormatRefiner { | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static graphStatus DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||
| ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static bool IsGraphInferred(const ComputeGraphPtr &graph); | |||
| }; | |||
| } // namespace ge | |||
| #endif // COMMON_GRAPH_FORMAT_REFINER_H_ | |||
| @@ -158,6 +158,10 @@ const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; | |||
| const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims"; | |||
| const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size"; | |||
| // Identify node connecting to input and output | |||
| const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data"; | |||
| const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput"; | |||
| // To be deleted | |||
| const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; | |||
| const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; | |||
| @@ -725,6 +729,10 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | |||
| const std::string ATTR_MODEL_CORE_TYPE = "core_type"; | |||
| const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; | |||
| const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; | |||
| // Public attribute | |||
| const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | |||
| @@ -901,6 +909,7 @@ const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_l | |||
| const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; | |||
| const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; | |||
| const std::string ATTR_NAME_BATCH_LABEL = "_batch_label"; | |||
| const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch"; | |||
| // Control flow | |||
| const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; | |||
| @@ -910,6 +919,7 @@ const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | |||
| const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | |||
| const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | |||
| const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; | |||
| const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims"; | |||
| const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | |||
| const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | |||
| @@ -934,7 +944,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; | |||
| const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||
| // l1 fusion and other fusion in future | |||
| // lx fusion | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||
| const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||
| @@ -948,9 +958,17 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 | |||
| const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||
| const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||
| const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||
| const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; | |||
| const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; | |||
| const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; | |||
| const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; | |||
| const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; | |||
| const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; | |||
| const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; | |||
| // Op debug attrs | |||
| const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; | |||
| const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; | |||
| // Atomic addr clean attrs | |||
| const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | |||
| @@ -971,6 +989,8 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; | |||
| const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; | |||
| const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; | |||
| // For inserted op | |||
| const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; | |||
| @@ -1009,10 +1029,38 @@ const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_ | |||
| const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | |||
| const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list"; | |||
| // for unregistered op | |||
| const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath"; | |||
| const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist"; | |||
| // used for Horovod | |||
| const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; | |||
| const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; | |||
| // used for allreduce tailing optimization | |||
| const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; | |||
| const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; | |||
| // dynamic shape attr | |||
| const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; | |||
| const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; | |||
| // atc user def dtype&format | |||
| const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; | |||
| const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; | |||
| // for fusion op plugin | |||
| const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; | |||
| // graph partition for aicpu | |||
| const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name"; | |||
| const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name"; | |||
| // input and output memory type | |||
| const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; | |||
| const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; | |||
| const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; | |||
| // input_output_offset | |||
| const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; | |||
| const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; | |||
| } // namespace ge | |||
| @@ -33,7 +33,8 @@ using std::vector; | |||
| namespace ge { | |||
| NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | |||
| NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {} | |||
| NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) | |||
| : named_attrs_(owner, proto_msg) {} // lint !e1744 | |||
| void NamedAttrs::SetName(const std::string &name) { | |||
| auto proto_msg = named_attrs_.GetProtoMsg(); | |||
| @@ -238,7 +239,7 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>) | |||
| @@ -252,9 +253,11 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) | |||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>) | |||
| /*lint -e665*/ | |||
| ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>) | |||
| ATTR_VALUE_SET_GET_IMP(vector<DataType>) | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) | |||
| /*lint +e665*/ | |||
| ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665 | |||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 | |||
| #undef ATTR_VALUE_SET_GET_IMP | |||
| @@ -782,14 +785,14 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| if (graph_def == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | |||
| graph_def = nullptr; | |||
| return false; | |||
| return false; // lint !e665 | |||
| } else { | |||
| ModelSerializeImp imp; | |||
| imp.SetProtobufOwner(graph_def); | |||
| if (!imp.UnserializeGraph(graph, *graph_def)) { | |||
| GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); | |||
| return false; | |||
| } | |||
| } // lint !e514 | |||
| value = graph; | |||
| } | |||
| return true; | |||
| @@ -809,7 +812,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| if (graph_def == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | |||
| graph_def = nullptr; | |||
| return false; | |||
| return false; // lint !e665 | |||
| } else { | |||
| ComputeGraphPtr graph = nullptr; | |||
| ModelSerializeImp imp; | |||
| @@ -817,7 +820,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||
| if (!imp.UnserializeGraph(graph, *graph_def)) { | |||
| GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); | |||
| return false; | |||
| } | |||
| } // lint !e514 | |||
| value.push_back(graph); | |||
| } | |||
| } | |||
| @@ -969,7 +972,9 @@ ATTR_UTILS_SET_IMP(Tensor, GeTensor) | |||
| ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) | |||
| ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | |||
| ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | |||
| /*lint -e665*/ | |||
| ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>) | |||
| /*lint +e665*/ | |||
| ATTR_UTILS_SET_GET_IMP(ListInt, vector<int64_t>) | |||
| ATTR_UTILS_SET_IMP(ListInt, vector<int32_t>) | |||
| @@ -984,8 +989,8 @@ ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>) | |||
| ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>) | |||
| ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | |||
| ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | |||
| ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) | |||
| ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) | |||
| ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665 | |||
| ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 | |||
| bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, | |||
| std::initializer_list<ConstGeTensorPtr> &&value) { | |||
| @@ -1154,7 +1159,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(Con | |||
| } | |||
| for (const auto &item : bytes_vals) { | |||
| ModelSerialize serialize; | |||
| auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); | |||
| auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 | |||
| value.push_back(op_desc); | |||
| } | |||
| return true; | |||
| @@ -1206,7 +1211,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||
| op_def = ComGraphMakeShared<proto::OpDef>(); | |||
| if (op_def == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); | |||
| return nullptr; | |||
| return nullptr; // lint !e665 | |||
| } | |||
| ModelSerializeImp imp; | |||
| (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); | |||
| @@ -1216,27 +1221,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||
| GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); | |||
| op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
| if (op_desc->HasAttr("_input_name_idx_key")) { | |||
| if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed."); | |||
| } | |||
| } | |||
| if (op_desc->HasAttr("_input_name_idx_value")) { | |||
| if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed."); | |||
| } | |||
| // This function may be called by some passes of fusion engine, in this condition, do not need these attribute | |||
| if (!op_desc->input_name_idx_.empty()) { | |||
| op_desc->input_name_idx_.clear(); | |||
| } | |||
| if (op_desc->HasAttr("_opt_input")) { | |||
| if (op_desc->DelAttr("_opt_input") != SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed."); | |||
| } | |||
| } | |||
| if (!op_desc->output_name_idx_.empty()) { | |||
| op_desc->output_name_idx_.clear(); | |||
| } | |||
| if (!op_desc->optional_input_names_.empty()) { | |||
| op_desc->optional_input_names_.clear(); | |||
| } | |||
| return op_desc; | |||
| } | |||
| @@ -1260,6 +1254,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c | |||
| op_desc->extAttrs_ = org_op_desc->extAttrs_; | |||
| op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); | |||
| op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), | |||
| org_op_desc->optional_input_names_.end()); | |||
| op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); | |||
| op_desc->infer_func_ = org_op_desc->infer_func_; | |||
| @@ -220,6 +220,7 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; | |||
| const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; | |||
| const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; | |||
| const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; | |||
| const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; | |||
| GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} | |||
| @@ -567,6 +568,16 @@ DataType GeTensorDesc::GetOriginDataType() const { | |||
| return TypeUtils::SerialStringToDataType(origin_data_type_str); | |||
| } | |||
| std::vector<uint32_t> GeTensorDesc::GetRefPortIndex() const { | |||
| vector<uint32_t> ref_port_index; | |||
| (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); | |||
| return ref_port_index; | |||
| } | |||
| void GeTensorDesc::SetRefPortByIndex(const std::vector<uint32_t> &index) { | |||
| (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); | |||
| } | |||
| graphStatus GeTensorDesc::IsValid() const { | |||
| auto dtype = this->GetDataType(); | |||
| auto format = this->GetFormat(); | |||
| @@ -210,7 +210,7 @@ class GraphImpl { | |||
| graphStatus FindOpByName(const string &name, ge::Operator &op) const { | |||
| auto it = op_list_.find(name); | |||
| GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str()); | |||
| GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); | |||
| op = it->second; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -77,6 +77,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libprotobuf \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -94,10 +95,36 @@ LOCAL_CPPFLAGS += -fexceptions | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/atc/lib64/stub/graph.cc \ | |||
| ../../out/atc/lib64/stub/operator.cc \ | |||
| ../../out/atc/lib64/stub/tensor.cc \ | |||
| ../../out/atc/lib64/stub/operator_factory.cc \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| LOCAL_MULTILIB := 64 | |||
| LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| #compiler for host | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := fwk_stub/libgraph | |||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||
| LOCAL_CPPFLAGS += -fexceptions | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| ../../out/graph/lib64/stub/inference_context.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| @@ -122,6 +149,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libprotobuf \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -142,10 +170,39 @@ LOCAL_CFLAGS += -O2 | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/atc/lib64/stub/graph.cc \ | |||
| ../../out/atc/lib64/stub/operator.cc \ | |||
| ../../out/atc/lib64/stub/tensor.cc \ | |||
| ../../out/atc/lib64/stub/operator_factory.cc \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| ifeq ($(device_os),android) | |||
| LOCAL_LDFLAGS := -ldl | |||
| endif | |||
| LOCAL_MULTILIB := 64 | |||
| LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_SHARED_LIBRARY) | |||
| #compiler for device | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := fwk_stub/libgraph | |||
| LOCAL_CFLAGS += -O2 | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| ../../out/graph/lib64/stub/inference_context.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| @@ -174,6 +231,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libprotobuf \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -199,6 +257,7 @@ LOCAL_STATIC_LIBRARIES := \ | |||
| LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -222,6 +281,7 @@ LOCAL_STATIC_LIBRARIES := \ | |||
| LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -88,10 +88,8 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ | |||
| } | |||
| bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||
| if (op_desc == nullptr || op_def_proto == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input Para Invalid"); | |||
| return false; | |||
| } | |||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); | |||
| GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); | |||
| if (op_desc->op_def_.GetProtoMsg() != nullptr) { | |||
| *op_def_proto = *op_desc->op_def_.GetProtoMsg(); | |||
| // Delete unnecessary attr | |||
| @@ -130,18 +128,40 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||
| for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||
| op_def_proto->add_subgraph_name(name); | |||
| } | |||
| OpDescToAttrDef(op_desc, op_def_proto); | |||
| } | |||
| return true; | |||
| } | |||
| proto::AttrDef key; | |||
| proto::AttrDef value; | |||
| void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { | |||
| proto::AttrDef key_in; | |||
| proto::AttrDef value_in; | |||
| auto op_desc_attr = op_def_proto->mutable_attr(); | |||
| if (!op_desc->input_name_idx_.empty()) { | |||
| for (auto &item : op_desc->input_name_idx_) { | |||
| key_in.mutable_list()->add_s(item.first); | |||
| value_in.mutable_list()->add_i(item.second); | |||
| } | |||
| op_desc_attr->insert({"_input_name_key", key_in}); | |||
| op_desc_attr->insert({"_input_name_value", value_in}); | |||
| } | |||
| proto::AttrDef key_out; | |||
| proto::AttrDef value_out; | |||
| if (!op_desc->output_name_idx_.empty()) { | |||
| for (auto &item : op_desc->output_name_idx_) { | |||
| key.mutable_list()->add_s(item.first); | |||
| value.mutable_list()->add_i(item.second); | |||
| key_out.mutable_list()->add_s(item.first); | |||
| value_out.mutable_list()->add_i(item.second); | |||
| } | |||
| auto op_desc_attr = op_def_proto->mutable_attr(); | |||
| op_desc_attr->insert({"_output_name_key", key}); | |||
| op_desc_attr->insert({"_output_name_value", value}); | |||
| op_desc_attr->insert({"_output_name_key", key_out}); | |||
| op_desc_attr->insert({"_output_name_value", value_out}); | |||
| } | |||
| proto::AttrDef opt_input; | |||
| if (!op_desc->optional_input_names_.empty()) { | |||
| for (auto &item : op_desc->optional_input_names_) { | |||
| opt_input.mutable_list()->add_s(item); | |||
| } | |||
| op_desc_attr->insert({"_opt_input", opt_input}); | |||
| } | |||
| return true; | |||
| } | |||
| bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { | |||
| @@ -237,13 +257,70 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Unseriali | |||
| } | |||
| } | |||
| void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector<string> &key_in, std::vector<string> &key_out, | |||
| std::vector<uint32_t> &value_in, std::vector<uint32_t> &value_out, | |||
| std::vector<string> &opt_input) { | |||
| if (!key_in.empty()) { | |||
| if (key_in.size() != value_in.size()) { | |||
| GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), | |||
| value_in.size()); | |||
| } else { | |||
| for (uint32_t i = 0; i < key_in.size(); ++i) { | |||
| op_desc->input_name_idx_.insert(std::pair<string, uint32_t>(key_in.at(i), value_in.at(i))); | |||
| } | |||
| } | |||
| } | |||
| if (!key_out.empty()) { | |||
| if (key_out.size() != value_out.size()) { | |||
| GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), | |||
| value_out.size()); | |||
| } else { | |||
| for (uint32_t i = 0; i < key_out.size(); ++i) { | |||
| op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key_out.at(i), value_out.at(i))); | |||
| } | |||
| } | |||
| } | |||
| if (!opt_input.empty()) { | |||
| for (const auto &i : opt_input) { | |||
| op_desc->optional_input_names_.insert(i); | |||
| } | |||
| } | |||
| } | |||
| bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) { | |||
| std::vector<string> key; | |||
| std::vector<uint32_t> value; | |||
| std::vector<string> opt_input; | |||
| std::vector<string> key_in; | |||
| std::vector<uint32_t> value_in; | |||
| if (op_def_proto.attr().count("_opt_input") > 0) { | |||
| auto &name_list = op_def_proto.attr().at("_opt_input").list(); | |||
| for (const auto &item_s : name_list.s()) { | |||
| opt_input.push_back(item_s); | |||
| } | |||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||
| op_desc_attr->erase("_opt_input"); | |||
| } | |||
| if (op_def_proto.attr().count("_input_name_key") > 0) { | |||
| auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); | |||
| for (const auto &item_s : output_name_key_list.s()) { | |||
| key_in.push_back(item_s); | |||
| } | |||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||
| op_desc_attr->erase("_input_name_key"); | |||
| } | |||
| if (op_def_proto.attr().count("_input_name_value") > 0) { | |||
| auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); | |||
| for (const auto &item_i : input_name_value_list.i()) { | |||
| value_in.push_back(static_cast<uint32_t>(item_i)); | |||
| } | |||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||
| op_desc_attr->erase("_input_name_value"); | |||
| } | |||
| std::vector<string> key_out; | |||
| std::vector<uint32_t> value_out; | |||
| if (op_def_proto.attr().count("_output_name_key") > 0) { | |||
| auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); | |||
| for (const auto &item_s : output_name_key_list.s()) { | |||
| key.push_back(item_s); | |||
| key_out.push_back(item_s); | |||
| } | |||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||
| op_desc_attr->erase("_output_name_key"); | |||
| @@ -251,7 +328,7 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d | |||
| if (op_def_proto.attr().count("_output_name_value") > 0) { | |||
| auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); | |||
| for (const auto &item_i : output_name_value_list.i()) { | |||
| value.push_back(static_cast<uint32_t>(item_i)); | |||
| value_out.push_back(static_cast<uint32_t>(item_i)); | |||
| } | |||
| auto op_desc_attr = op_def_proto.mutable_attr(); | |||
| op_desc_attr->erase("_output_name_value"); | |||
| @@ -282,15 +359,8 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d | |||
| op_desc->SetSubgraphInstanceName(graph_index++, name); | |||
| } | |||
| if (key.size() != 0) { | |||
| if (key.size() != value.size()) { | |||
| GELOGE(GRAPH_FAILED, "twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); | |||
| } else { | |||
| for (uint32_t i = 0; i < key.size(); ++i) { | |||
| op_desc->output_name_idx_.insert(std::pair<string, uint32_t>(key.at(i), value.at(i))); | |||
| } | |||
| } | |||
| } | |||
| // insert name index by key and value | |||
| AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input); | |||
| return true; | |||
| } | |||
| @@ -338,13 +408,13 @@ bool ModelSerializeImp::HandleNodeNameRef() { | |||
| item.dst_node_name.c_str(), item.dst_in_index); | |||
| return false; | |||
| } | |||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); | |||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||
| } else { | |||
| // Control edge | |||
| auto src_anchor = src_node_it->second->GetOutControlAnchor(); | |||
| auto dst_anchor = item.dst_node->GetInControlAnchor(); | |||
| if (src_anchor != nullptr && dst_anchor != nullptr) { | |||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); | |||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||
| } | |||
| } | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "utils/ge_ir_utils.h" | |||
| #include "utils/node_utils.h" | |||
| #include "utils/op_desc_utils.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -154,7 +155,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons | |||
| const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||
| const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||
| if (peer_node == nullptr || r_peer_node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", | |||
| GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", | |||
| this->GetName().c_str(), i, j); | |||
| return false; | |||
| } | |||
| @@ -434,8 +435,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { | |||
| if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||
| GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(), | |||
| GetName().c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19019", {"opname", "index", "anchorname", "optype"}, | |||
| {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); | |||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, | |||
| GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| return in_data_anchors_[idx]; | |||
| @@ -445,7 +449,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { | |||
| // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ | |||
| if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||
| GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str()); | |||
| GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| // Return control anchor | |||
| @@ -461,8 +465,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { | |||
| // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ | |||
| if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||
| GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(), | |||
| GetName().c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, | |||
| { | |||
| GetName().c_str(), | |||
| std::to_string(idx), | |||
| "out_anchor", | |||
| GetType().c_str(), | |||
| }); | |||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, | |||
| GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| // Return control anchor | |||
| @@ -477,8 +488,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { | |||
| if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||
| GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(), | |||
| GetName().c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19019", {"opname", "index", "anchorname", "optype"}, | |||
| {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); | |||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, | |||
| GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| return out_data_anchors_[idx]; | |||
| @@ -726,22 +740,27 @@ graphStatus Node::Verify() const { | |||
| const string aipp_data_type = "AippData"; | |||
| const string const_type = "Const"; | |||
| const string variable_type = "Variable"; | |||
| bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||
| GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); | |||
| for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { | |||
| if (in_anchor_ptr == nullptr) { | |||
| GELOGW("in anchor ptr is null"); | |||
| continue; | |||
| if (!is_unknown_graph) { | |||
| for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { | |||
| GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue); | |||
| bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || | |||
| op_->GetType() == const_type || op_->GetType() == variable_type || | |||
| op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; | |||
| if (!valid_anchor) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, | |||
| {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); | |||
| GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS( | |||
| op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || | |||
| op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || | |||
| in_anchor_ptr->GetPeerAnchors().size() > 0, | |||
| GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||
| } | |||
| string frameworkop_type = "FrameworkOp"; | |||
| if (op_->GetType() != frameworkop_type) { | |||
| bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph; | |||
| if (need_update_name) { | |||
| auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType()); | |||
| if (node_op.IsEmpty()) { | |||
| GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str()); | |||
| @@ -761,7 +780,7 @@ graphStatus Node::Verify() const { | |||
| } | |||
| node_op.BreakConnect(); | |||
| } | |||
| GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;); | |||
| if (op_->CommonVerify() == GRAPH_SUCCESS) { | |||
| Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); | |||
| auto verify_func = op_->GetVerifyFunc(); | |||
| @@ -19,6 +19,7 @@ | |||
| #include "debug/ge_util.h" | |||
| #include "external/graph/operator.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "graph/ge_attr_value.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/operator_factory_impl.h" | |||
| @@ -32,6 +33,7 @@ using std::shared_ptr; | |||
| using std::string; | |||
| using std::vector; | |||
| /*lint -save -e521 -e681 -e732 -e737*/ | |||
| namespace ge { | |||
| const std::string ATTR_NAME_ID = "id"; | |||
| @@ -63,12 +65,6 @@ const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; | |||
| const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; | |||
| const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; | |||
| const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; | |||
| const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value"; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { | |||
| op_def_.InitDefault(); | |||
| if (op_def_.GetProtoMsg() != nullptr) { | |||
| @@ -210,8 +206,7 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d | |||
| } | |||
| graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| if (input_name_idx.find(name) != input_name_idx.end()) { | |||
| if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||
| GELOGI("input %s is exist, update it", name.c_str()); | |||
| graphStatus ret = UpdateInputDesc(name, input_desc); | |||
| return ret; | |||
| @@ -223,17 +218,15 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||
| return GRAPH_FAILED; | |||
| } | |||
| inputs_desc_.push_back(in_desc); | |||
| (void)input_name_idx.insert(make_pair(name, index)); | |||
| SetAllInputName(input_name_idx); | |||
| (void)input_name_idx_.insert(make_pair(name, index)); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| } | |||
| graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| for (unsigned int i = 0; i < num; i++) { | |||
| string input_name = name + std::to_string(i); | |||
| GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||
| GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, | |||
| "Add input tensor_desc is existed. name[%s]", input_name.c_str()); | |||
| std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||
| @@ -250,24 +243,22 @@ graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int nu | |||
| (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); | |||
| // Update index in input_name_idx | |||
| for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||
| for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { | |||
| if (it->second >= (index + i)) { | |||
| it->second += 1; | |||
| } | |||
| } | |||
| (void)input_name_idx.insert(make_pair(input_name, i + index)); | |||
| (void)input_name_idx_.insert(make_pair(input_name, i + index)); | |||
| } | |||
| SetAllInputName(input_name_idx); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| for (unsigned int i = 0; i < num; i++) { | |||
| string input_name = name + std::to_string(i); | |||
| GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, | |||
| GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, | |||
| "Add input tensor_desc is existed. name[%s]", input_name.c_str()); | |||
| std::shared_ptr<GeTensorDesc> in_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||
| @@ -278,13 +269,12 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n | |||
| (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); | |||
| // Update index in input_name_idx | |||
| for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { | |||
| for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { | |||
| it->second += 1; | |||
| } | |||
| (void)input_name_idx.insert(make_pair(input_name, 0)); | |||
| (void)input_name_idx_.insert(make_pair(input_name, 0)); | |||
| } | |||
| SetAllInputName(input_name_idx); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -315,19 +305,10 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int | |||
| graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { | |||
| if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; | |||
| vector<string> optional_input_names; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| optional_input_names.push_back(name); | |||
| (void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| (void)optional_input_names_.insert(name); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::vector<string> OpDesc::GetAllOptionalInputName() const { | |||
| vector<string> optional_input_names; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| return optional_input_names; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||
| GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); | |||
| @@ -342,12 +323,11 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { | |||
| return ( | |||
| IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") && | |||
| IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||
| IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") && | |||
| IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||
| IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||
| return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && | |||
| IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && | |||
| IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && | |||
| IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && | |||
| IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { | |||
| @@ -421,9 +401,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD | |||
| } | |||
| graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| if (it == input_name_idx.end()) { | |||
| auto it = input_name_idx_.find(name); | |||
| if (it == input_name_idx_.end()) { | |||
| GELOGW("Cann't find the input desc. name[%s]", name.c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -443,9 +422,8 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc & | |||
| } | |||
| bool OpDesc::InputIsSet(const string &name) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| if (it != input_name_idx.end()) { | |||
| auto it = input_name_idx_.find(name); | |||
| if (it != input_name_idx_.end()) { | |||
| GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); | |||
| auto tensor_desc = inputs_desc_[it->second]; | |||
| GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); | |||
| @@ -463,20 +441,40 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc | |||
| } | |||
| GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc()); | |||
| auto it = input_name_idx_.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); | |||
| return *(inputs_desc_[it->second].get()); | |||
| } | |||
| GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { | |||
| GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); | |||
| if (inputs_desc_[index] == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { | |||
| GELOGW("input desc is invalid"); | |||
| return nullptr; | |||
| } | |||
| return inputs_desc_[index]; | |||
| } | |||
| GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| if (it == input_name_idx.end()) { | |||
| GELOGW("Failed to get [%s] input desc", name.c_str()); | |||
| return nullptr; | |||
| } | |||
| return MutableInputDesc(it->second); | |||
| } | |||
| GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | |||
| vector<string> names; | |||
| if (input_name_idx.empty()) { | |||
| if (input_name_idx_.empty()) { | |||
| return OpDesc::Vistor<string>(shared_from_this(), names); | |||
| } | |||
| for (std::pair<string, uint32_t> input : input_name_idx) { | |||
| for (std::pair<string, uint32_t> input : input_name_idx_) { | |||
| names.push_back(input.first); | |||
| } | |||
| return OpDesc::Vistor<string>(shared_from_this(), names); | |||
| @@ -496,15 +494,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { | |||
| GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); | |||
| if (inputs_desc_[index] == nullptr) { | |||
| return nullptr; | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); | |||
| return inputs_desc_[index]; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const { | |||
| vector<GeTensorDesc> temp{}; | |||
| for (const auto &it : inputs_desc_) { | |||
| @@ -609,6 +598,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu | |||
| return outputs_desc_[index]; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { | |||
| auto it = output_name_idx_.find(name); | |||
| if (it == output_name_idx_.end()) { | |||
| GELOGW("Failed to get [%s] output desc", name.c_str()); | |||
| return nullptr; | |||
| } | |||
| return MutableOutputDesc(it->second); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { | |||
| return static_cast<uint32_t>(outputs_desc_.size()); | |||
| } | |||
| @@ -652,9 +650,8 @@ OpDesc::GetInputDescPtrDfault(uint32_t index) const { | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr<const GeTensorDesc>()); | |||
| auto it = input_name_idx_.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr<const GeTensorDesc>()); | |||
| return inputs_desc_[it->second]; | |||
| } | |||
| @@ -687,47 +684,26 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| bool OpDesc::IsOptionalInput(const string &name) const { | |||
| vector<string> optional_input_names; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); | |||
| for (auto &item : optional_input_names) { | |||
| if (item == name) { | |||
| return true; | |||
| } | |||
| void OpDesc::RemoveInputDesc(uint32_t index) { | |||
| while (inputs_desc_.size() > index) { | |||
| inputs_desc_.pop_back(); | |||
| } | |||
| return false; | |||
| } | |||
| bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | |||
| std::map<string, uint32_t> OpDesc::GetAllInputName() const { | |||
| std::map<string, uint32_t> input_name_idx; | |||
| std::vector<string> key; | |||
| std::vector<uint32_t> value; | |||
| (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||
| (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||
| if (key.size() != value.size()) { | |||
| GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); | |||
| } else { | |||
| for (uint32_t i = 0; i < key.size(); ++i) { | |||
| input_name_idx.insert(std::pair<string, uint32_t>(key.at(i), value.at(i))); | |||
| } | |||
| void OpDesc::RemoveOutputDesc(uint32_t index) { | |||
| while (outputs_desc_.size() > index) { | |||
| outputs_desc_.pop_back(); | |||
| } | |||
| return input_name_idx; | |||
| } | |||
| void OpDesc::SetAllInputName(const std::map<string, uint32_t> &input_name_idx) { | |||
| std::vector<string> key; | |||
| std::vector<uint32_t> value; | |||
| for (auto &item : input_name_idx) { | |||
| key.emplace_back(item.first); | |||
| value.emplace_back(item.second); | |||
| } | |||
| (void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); | |||
| (void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); | |||
| bool OpDesc::IsOptionalInput(const string &name) const { | |||
| return optional_input_names_.find(name) != optional_input_names_.end(); | |||
| } | |||
| bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } | |||
| std::map<string, uint32_t> OpDesc::GetAllInputName() const { return input_name_idx_; } | |||
| std::map<string, uint32_t> OpDesc::GetAllOutputName() { return output_name_idx_; } | |||
| bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||
| @@ -737,7 +713,6 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||
| auto factory_map_size = input_name_idx.size(); | |||
| // It indicates that some inputs have no optionalname. | |||
| // The redundant optionalname of factory needs to be deleted and then assigned | |||
| auto all_input_name_idx = GetAllInputName(); | |||
| if (input_map_size < factory_map_size) { | |||
| GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, | |||
| factory_map_size); | |||
| @@ -750,18 +725,17 @@ bool OpDesc::UpdateInputName(std::map<string, uint32_t> input_name_idx) { | |||
| } | |||
| if (input_name_idx.size() == input_map_size) { | |||
| GELOGI("UpdateInputName"); | |||
| all_input_name_idx = input_name_idx; | |||
| input_name_idx_ = input_name_idx; | |||
| } else { | |||
| ret = false; | |||
| GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); | |||
| } | |||
| } else if (input_map_size == factory_map_size) { | |||
| all_input_name_idx = input_name_idx; | |||
| input_name_idx_ = input_name_idx; | |||
| } else { | |||
| ret = false; | |||
| GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); | |||
| } | |||
| SetAllInputName(all_input_name_idx); | |||
| return ret; | |||
| } | |||
| @@ -882,36 +856,41 @@ graphStatus OpDesc::CommonVerify() const { | |||
| // Checking shape of all inputs | |||
| vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | |||
| for (int64_t dim : ishape) { | |||
| GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | |||
| iname.c_str()); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19014", {"opname", "value", "reason"}, | |||
| {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); | |||
| return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), | |||
| iname.c_str()); | |||
| } | |||
| } | |||
| // Check all attributes defined | |||
| const auto &all_attributes = GetAllAttrs(); | |||
| for (const auto &name : GetAllAttrNames()) { | |||
| GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | |||
| "operator attribute %s is empty.", name.c_str()); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| all_attributes.find(name) == all_attributes.end(), | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, | |||
| {GetName(), "attribute " + name, "is empty"}); | |||
| return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.begin(); | |||
| for (; it != input_name_idx.end(); ++it) { | |||
| auto it = input_name_idx_.begin(); | |||
| for (; it != input_name_idx_.end(); ++it) { | |||
| if (it->second == index) { | |||
| break; | |||
| } | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), ""); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); | |||
| return it->first; | |||
| } | |||
| int OpDesc::GetInputIndexByName(const string &name) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it_find = input_name_idx.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1); | |||
| auto it_find = input_name_idx_.find(name); | |||
| GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); | |||
| return static_cast<int>(it_find->second); | |||
| } | |||
| @@ -1204,12 +1183,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<bool> OpDesc::GetIsInputCo | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, | |||
| const int &index) { | |||
| auto input_name_idx = GetAllInputName(); | |||
| if (input_name_idx.find(name) != input_name_idx.end()) { | |||
| if (input_name_idx_.find(name) != input_name_idx_.end()) { | |||
| GELOGI("Restore input name index is existed. name[%s]", name.c_str()); | |||
| } | |||
| (void)input_name_idx.insert(make_pair(name, index)); | |||
| SetAllInputName(input_name_idx); | |||
| (void)input_name_idx_.insert(make_pair(name, index)); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -21,7 +21,7 @@ | |||
| #include <mutex> | |||
| #include <queue> | |||
| #include <set> | |||
| #include "array_ops.h" | |||
| #include "./array_ops.h" | |||
| #include "debug/ge_log.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "debug/ge_util.h" | |||
| @@ -36,6 +36,8 @@ | |||
| #include "graph/op_desc.h" | |||
| #include "graph/runtime_inference_context.h" | |||
| #include "graph/usr_types.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/op_desc_utils.h" | |||
| #include "utils/tensor_adapter.h" | |||
| @@ -54,11 +56,13 @@ using std::string; | |||
| using std::to_string; | |||
| using std::vector; | |||
| /*lint -save -e529 -e728*/ | |||
| /*lint -e446 -e732*/ | |||
| /*lint -e665*/ | |||
| namespace ge { | |||
| class OpIO { | |||
| public: | |||
| explicit OpIO(const string &name, int index, const OperatorImplPtr &owner) | |||
| : name_(name), index_(index), owner_(owner) {} | |||
| OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} | |||
| ~OpIO() = default; | |||
| @@ -546,56 +550,46 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) { | |||
| } | |||
| graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { | |||
| if (operator_impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ge::ConstNodePtr node_ptr = operator_impl_->GetNode(); | |||
| if (node_ptr) { | |||
| GE_CHECK_NOTNULL(operator_impl_); | |||
| auto node_ptr = operator_impl_->GetNode(); | |||
| if (node_ptr != nullptr) { | |||
| // For inner compute graph | |||
| auto op_desc = node_ptr->GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "op_desc is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| auto index = op_desc->GetInputIndexByName(dst_name); | |||
| auto in_data_anchor = node_ptr->GetInDataAnchor(index); | |||
| if (in_data_anchor == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(in_data_anchor); | |||
| auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (out_data_anchor == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::shared_ptr<Node> peer_node_ptr = out_data_anchor->GetOwnerNode(); | |||
| if (peer_node_ptr == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ge::OperatorImplPtr operator_impl_ptr = nullptr; | |||
| operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(peer_node_ptr); | |||
| if (operator_impl_ptr == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| Operator const_op(std::move(operator_impl_ptr)); | |||
| if (peer_node_ptr->GetOpDesc() != nullptr) { | |||
| const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | |||
| if (op_descType == CONSTANTOP) { | |||
| return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||
| } else if (op_descType == CONSTANT) { | |||
| return const_op.GetAttr(op::Const::name_attr_value(), data); | |||
| GE_CHECK_NOTNULL(out_data_anchor); | |||
| auto peer_node = out_data_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(peer_node); | |||
| auto peer_op_desc = peer_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(peer_op_desc); | |||
| auto peer_op_type = peer_op_desc->GetType(); | |||
| if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { | |||
| auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node); | |||
| GE_CHECK_NOTNULL(const_op_impl); | |||
| Operator const_op(std::move(const_op_impl)); | |||
| return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); | |||
| } else if (peer_op_type == DATA) { | |||
| auto parent_node = NodeUtils::GetParentInput(peer_node); | |||
| while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { | |||
| parent_node = NodeUtils::GetParentInput(parent_node); | |||
| } | |||
| if ((parent_node != nullptr) && | |||
| ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { | |||
| auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node); | |||
| GE_CHECK_NOTNULL(const_op_impl); | |||
| Operator const_op(std::move(const_op_impl)); | |||
| return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); | |||
| } | |||
| } | |||
| // Try get from runtime inference context | |||
| auto session_id = std::to_string(GetContext().SessionId()); | |||
| RuntimeInferenceContext *runtime_infer_ctx = nullptr; | |||
| if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { | |||
| GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); | |||
| auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); | |||
| auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); | |||
| if (ret == GRAPH_SUCCESS) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -604,6 +598,8 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||
| // For outer graph | |||
| return GetInputConstDataOut(dst_name, data); | |||
| } | |||
| auto op_name = operator_impl_->GetName(); | |||
| GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { | |||
| @@ -914,7 +910,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } | |||
| GELOGW("set attr name %s failed.", name.c_str()); \ | |||
| } \ | |||
| return *this; \ | |||
| } | |||
| } // lint !e665 | |||
| #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ | |||
| graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ | |||
| @@ -927,7 +923,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } | |||
| return GRAPH_FAILED; \ | |||
| } \ | |||
| return GRAPH_SUCCESS; \ | |||
| } | |||
| } // lint !e665 | |||
| void Operator::BreakConnect() const { | |||
| if (operator_impl_ == nullptr) { | |||
| @@ -948,7 +944,7 @@ void Operator::BreakConnect() const { | |||
| if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ | |||
| GELOGW("reg attr name %s failed.", name.c_str()); \ | |||
| } \ | |||
| } | |||
| } // lint !e665 | |||
| OP_ATTR_SET_IMP(int64_t, Int) | |||
| OP_ATTR_SET_IMP(int32_t, Int) | |||
| @@ -969,22 +965,22 @@ OP_ATTR_SET_IMP(const vector<vector<int64_t>> &, ListListInt) | |||
| OP_ATTR_SET_IMP(float, Float) | |||
| OP_ATTR_GET_IMP(float &, Float) | |||
| OP_ATTR_SET_IMP(const vector<float> &, ListFloat) | |||
| OP_ATTR_GET_IMP(vector<float> &, ListFloat) | |||
| OP_ATTR_GET_IMP(vector<float> &, ListFloat) // lint !e665 | |||
| OP_ATTR_SET_IMP(bool, Bool) | |||
| OP_ATTR_GET_IMP(bool &, Bool) | |||
| OP_ATTR_SET_IMP(const vector<bool> &, ListBool) | |||
| OP_ATTR_GET_IMP(vector<bool> &, ListBool) | |||
| OP_ATTR_GET_IMP(vector<bool> &, ListBool) // lint !e665 | |||
| OP_ATTR_SET_IMP(const string &, Str) | |||
| OP_ATTR_GET_IMP(string &, Str) | |||
| OP_ATTR_SET_IMP(const vector<string> &, ListStr) | |||
| OP_ATTR_GET_IMP(vector<string> &, ListStr) | |||
| OP_ATTR_GET_IMP(vector<string> &, ListStr) // lint !e665 | |||
| OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
| OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) | |||
| OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
| OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||
| OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) // lint !e665 | |||
| OP_ATTR_REG_IMP(int64_t, Int) | |||
| OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt) | |||
| @@ -1547,3 +1543,5 @@ void GraphUtils::BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_node | |||
| } | |||
| } | |||
| } // namespace ge | |||
| /*lint +e446 +e732*/ | |||
| /*lint +e665*/ | |||
| @@ -31,7 +31,9 @@ OpsProtoManager *OpsProtoManager::Instance() { | |||
| } | |||
| bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) { | |||
| /*lint -e1561*/ | |||
| auto proto_iter = options.find("ge.opsProtoLibPath"); | |||
| /*lint +e1561*/ | |||
| if (proto_iter == options.end()) { | |||
| GELOGW("ge.opsProtoLibPath option not set, return."); | |||
| return false; | |||
| @@ -85,6 +85,8 @@ uint32_t GEContext::DeviceId() { return device_id_; } | |||
| uint64_t GEContext::TraceId() { return trace_id_; } | |||
| void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } | |||
| void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | |||
| } // namespace ge | |||
| @@ -37,7 +37,7 @@ const string kWhile = "While"; | |||
| const string kIf = "If"; | |||
| const string kCase = "Case"; | |||
| const int kMaxElementNum = 100; | |||
| const uint16_t kMaxElementNum = 100; | |||
| std::unordered_set<string> function_op = {kWhile, kIf, kCase}; | |||
| } // namespace | |||
| @@ -170,6 +170,7 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile( | |||
| // data_nodes has been sorted | |||
| // for while, input num must be same as output num | |||
| auto input_num = root_node->GetAllInDataAnchorsSize(); | |||
| NodePtr netoutput = nullptr; | |||
| size_t ref_i = 0; | |||
| while (ref_i < input_num) { | |||
| @@ -212,10 +213,44 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile( | |||
| cell_netoutput_in.in_out = NODE_IN; | |||
| cell_netoutput_in.in_out_idx = ele.second; | |||
| ref_i_all_refs.emplace_back(cell_netoutput_in); | |||
| netoutput = ele.first; | |||
| } | |||
| node_refs.emplace_back(ref_i_all_refs); | |||
| ref_i++; | |||
| } | |||
| /* There exist scene like the follows, it means data0 data1 netoutput 0'th | |||
| * and 1'th tensor should be the same addr. | |||
| * Data0 Data1 | |||
| * \/ | |||
| * /\ | |||
| * netoutput | |||
| */ | |||
| if (netoutput == nullptr) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { | |||
| auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_data_anchor == nullptr) { | |||
| continue; | |||
| } | |||
| auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); | |||
| if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { | |||
| GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str()); | |||
| continue; | |||
| } | |||
| if (peer_out_data_node->GetType() != DATA) { | |||
| continue; | |||
| } | |||
| auto in_data_anchor_idx = in_anchor->GetIdx(); | |||
| auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast<uint32_t>(in_data_anchor_idx)); | |||
| int ref_d; | |||
| int ref_n; | |||
| (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); | |||
| (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); | |||
| node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end()); | |||
| node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end()); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -242,6 +277,10 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r | |||
| int sub_graph_idx = 0; | |||
| for (const auto &name : sub_graph_names) { | |||
| auto sub_graph = root_graph.GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); | |||
| continue; | |||
| } | |||
| for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { | |||
| auto sub_graph_node_type = sub_graph_node->GetType(); | |||
| @@ -296,6 +335,9 @@ graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_n | |||
| data_nodes.pop_back(); | |||
| int ref_idx = 0; | |||
| (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); | |||
| if (ref_idx >= static_cast<int>(classed_data_nodes.size())) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| classed_data_nodes[ref_idx].emplace_back(data); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| @@ -317,7 +359,7 @@ graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | |||
| } | |||
| int ref_o; | |||
| if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { | |||
| if (ref_o >= kMaxElementNum) { | |||
| if (ref_o >= static_cast<int>(classed_netoutput_nodes.size())) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| classed_netoutput_nodes[ref_o].emplace_back( | |||
| @@ -349,8 +391,9 @@ graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | |||
| vector<NodePtr> netoutput_nodes; | |||
| // Get data and netoutput of sub_graph | |||
| GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); | |||
| vector<vector<NodePtr>> classed_data_nodes(kMaxElementNum); // according to ref_idx | |||
| vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(kMaxElementNum); // according to ref_idx | |||
| size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum; | |||
| vector<vector<NodePtr>> classed_data_nodes(max_elem_num); // according to ref_idx | |||
| vector<vector<std::pair<NodePtr, size_t>>> classed_netoutput_nodes(max_elem_num); // according to ref_idx | |||
| status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); | |||
| @@ -30,6 +30,7 @@ graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::lock_guard<std::mutex> lk(ctx_mu_); | |||
| auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); | |||
| if (!emplace_ret.second) { | |||
| GELOGE(GRAPH_FAILED, "Old context not destroyed"); | |||
| @@ -37,6 +37,162 @@ | |||
| namespace ge { | |||
| namespace { | |||
| const uint32_t kWhileBodySubGraphIdx = 1; | |||
| graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { | |||
| GELOGD("Enter reverse brush while body subgraph process!"); | |||
| auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); | |||
| if (sub_graph_body == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Get while body graph failed!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (const auto &node_sub : sub_graph_body->GetAllNodes()) { | |||
| for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { | |||
| auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); | |||
| (void)input_desc->SetUnknownDimNumShape(); | |||
| } | |||
| for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { | |||
| auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); | |||
| (void)output_desc->SetUnknownDimNumShape(); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||
| // check sub_graph shape. Get max for update. | |||
| for (size_t i = 0; i < ref_out_tensors.size(); ++i) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| int64_t max_size = 0; | |||
| size_t max_shape_index = 0; | |||
| auto &ref_out_tensor = ref_out_tensors[i].at(0); | |||
| const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||
| for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { | |||
| auto &tensor = ref_out_tensors[i].at(j); | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto shape = tensor.MutableShape(); | |||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||
| GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", | |||
| node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| int64_t size = 1; | |||
| for (auto dim : shape.GetDims()) { | |||
| if (INT64_MAX / dim < size) { | |||
| GELOGE(PARAM_INVALID, "The shape size overflow"); | |||
| return PARAM_INVALID; | |||
| } | |||
| size *= dim; | |||
| } | |||
| if (size > max_size) { | |||
| max_size = size; | |||
| max_shape_index = j; | |||
| } | |||
| } | |||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||
| GELOGD("Enter update parent node shape for class branch op process"); | |||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { | |||
| return UpdataOutputForMultiBatcch(node, ref_out_tensors); | |||
| } | |||
| // check sub_graph shape.If not same ,do unknown shape process | |||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
| ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||
| for (auto &tensor : ref_out_tensors[i]) { | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto shape = tensor.MutableShape(); | |||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||
| GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||
| shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
| ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||
| break; | |||
| } | |||
| for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||
| if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||
| continue; | |||
| } | |||
| GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||
| j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
| (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||
| } | |||
| } | |||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||
| GELOGD("Enter update parent node shape for class while op process"); | |||
| if (ref_data_tensors.size() != ref_out_tensors.size()) { | |||
| GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), | |||
| ref_data_tensors.size(), ref_out_tensors.size()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (size_t i = 0; i < ref_data_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].size() != 1) { | |||
| GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| bool is_need_reverse_brush = false; | |||
| // check input and output | |||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
| auto tmp_shape = ref_out_tensor.MutableShape(); | |||
| // ref_i's data and output tensor shape should be same | |||
| for (auto &tensor : ref_data_tensors[i]) { | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto shape = tensor.MutableShape(); | |||
| if (shape.GetDims() != tmp_shape.GetDims()) { | |||
| ref_out_tensor.SetUnknownDimNumShape(); | |||
| is_need_reverse_brush = true; | |||
| break; | |||
| } | |||
| } | |||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
| } | |||
| // reverse refresh while body shape | |||
| if (is_need_reverse_brush) { | |||
| return ReverseBrushWhileBodySubGraph(node); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| @@ -66,11 +222,14 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), | |||
| node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { | |||
| continue; | |||
| } | |||
| auto input_desc = op_desc->MutableInputDesc(ref_i); | |||
| if (input_desc == nullptr) { | |||
| GE_LOGE( | |||
| @@ -98,6 +257,37 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput, | |||
| const ConstNodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) { | |||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
| auto sub_node = sub_nodes.at(i - 1); | |||
| if (sub_node->GetType() == NETOUTPUT) { | |||
| netoutput = sub_node; | |||
| } | |||
| if (sub_node->GetType() == DATA) { | |||
| if (sub_node->GetOpDesc() == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| int ref_i; | |||
| if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) { | |||
| GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), | |||
| ref_i, node->GetAllInDataAnchorsSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| @@ -105,7 +295,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize()); | |||
| std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize()); | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| if (name.empty()) { | |||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
| @@ -117,13 +310,9 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| NodePtr netoutput = nullptr; | |||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
| auto sub_node = sub_nodes.at(i - 1); | |||
| if (sub_node->GetType() == NETOUTPUT) { | |||
| netoutput = sub_node; | |||
| break; | |||
| } | |||
| auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| return ret; | |||
| } | |||
| if (netoutput == nullptr) { | |||
| GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||
| @@ -150,22 +339,23 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| continue; | |||
| } | |||
| GELOGI("Parent node index of edge desc is %d", ref_i); | |||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||
| if (output_desc == nullptr) { | |||
| GE_LOGE( | |||
| "The ref index(%d) on the input %d of netoutput %s on the sub graph %s " | |||
| "parent node %s are incompatible, outputs num %u", | |||
| ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
| node->GetAllOutDataAnchorsSize()); | |||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); | |||
| ref_out_tensors[ref_i].emplace_back(*edge_desc); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| if (node->GetType() == WHILE) { | |||
| return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); | |||
| } | |||
| return UpdateParentNodeForBranch(node, ref_out_tensors); | |||
| } | |||
| } // namespace | |||
| void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | |||
| if (!IsLogEnable(GE, DLOG_DEBUG)) { | |||
| return; | |||
| } | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "node is null"); | |||
| return; | |||
| @@ -185,6 +375,18 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
| TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " "; | |||
| } | |||
| str += input_desc_str; | |||
| input_desc_str = "input origin shape: "; | |||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||
| input_desc_str += "["; | |||
| for (int64_t dim : input_desc->GetOriginShape().GetDims()) { | |||
| input_desc_str += std::to_string(dim) + " "; | |||
| } | |||
| input_desc_str += "]"; | |||
| input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" + | |||
| TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " "; | |||
| } | |||
| str += input_desc_str; | |||
| } | |||
| if (op_desc->GetAllOutputsDescSize() != 0) { | |||
| @@ -202,6 +404,21 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
| TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " "; | |||
| } | |||
| str += output_desc_str; | |||
| output_desc_str = "output origin shape: "; | |||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||
| if (output_desc == nullptr) { | |||
| continue; | |||
| } | |||
| output_desc_str += "["; | |||
| for (int64_t dim : output_desc->GetOriginShape().GetDims()) { | |||
| output_desc_str += std::to_string(dim) + " "; | |||
| } | |||
| output_desc_str += "]"; | |||
| output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" + | |||
| TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " "; | |||
| } | |||
| str += output_desc_str; | |||
| } | |||
| GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str()); | |||
| } | |||
| @@ -222,7 +439,6 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & | |||
| return ret; | |||
| } | |||
| } | |||
| // Get infer func and execute | |||
| ret = op_desc->CallInferFunc(op); | |||
| if (ret == GRAPH_PARAM_INVALID) { | |||
| @@ -329,6 +545,9 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf | |||
| namespace { | |||
| std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { | |||
| return InferShapeAndType(node, true); | |||
| } | |||
| @@ -339,19 +558,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||
| GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| PrintInOutTensorShape(node, "before_infershape"); | |||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||
| auto inference_context = CreateInferenceContext(context_map, node); | |||
| if (inference_context == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "inference context is null"); | |||
| return GRAPH_FAILED; | |||
| bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||
| if (!is_unknown_graph) { | |||
| auto inference_context = CreateInferenceContext(context_map, node); | |||
| if (inference_context == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "inference context is null"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | |||
| op.SetInferenceContext(inference_context); | |||
| } | |||
| GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); | |||
| PrintInOutTensorShape(node, "before_infershape"); | |||
| Operator op = OpDescUtils::CreateOperatorFromNode(node); | |||
| op.SetInferenceContext(inference_context); | |||
| graphStatus status = InferShapeAndType(node, op, before_subgraph); | |||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | |||
| (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); | |||
| @@ -359,16 +579,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||
| GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto ctx_after_infer = op.GetInferenceContext(); | |||
| if (ctx_after_infer != nullptr) { | |||
| GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); | |||
| if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { | |||
| GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); | |||
| (void)context_map.emplace(node, ctx_after_infer); | |||
| if (!is_unknown_graph) { | |||
| auto ctx_after_infer = op.GetInferenceContext(); | |||
| if (ctx_after_infer != nullptr) { | |||
| GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); | |||
| if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { | |||
| GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), | |||
| ctx_after_infer->GetMarks().size()); | |||
| (void)context_map.emplace(node, ctx_after_infer); | |||
| } | |||
| } | |||
| } | |||
| PrintInOutTensorShape(node, "after_infershape"); | |||
| return GRAPH_SUCCESS; | |||
| @@ -1,6 +0,0 @@ | |||
| inc_path := $(shell pwd)/inc/external/ | |||
| out_path := $(shell pwd)/out/atc/lib64/stub/ | |||
| stub_path := $(shell pwd)/common/graph/stub/ | |||
| mkdir_stub := $(shell mkdir -p $(out_path)) | |||
| graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) | |||
| @@ -1,573 +0,0 @@ | |||
| import os | |||
| import re | |||
| import sys | |||
| import logging | |||
| logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', | |||
| level=logging.INFO) | |||
| """ | |||
| this attr is used for symbol table visible | |||
| """ | |||
| GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' | |||
| """ | |||
| generate stub func body by return type | |||
| """ | |||
| RETURN_STATEMENTS = { | |||
| 'graphStatus': ' return GRAPH_SUCCESS;', | |||
| 'Status': ' return SUCCESS;', | |||
| 'Graph': ' return Graph();', | |||
| 'Graph&': ' return *this;', | |||
| 'Format': ' return Format();', | |||
| 'Format&': ' return *this;', | |||
| 'Shape': ' return Shape();', | |||
| 'Shape&': ' return *this;', | |||
| 'TensorDesc': ' return TensorDesc();', | |||
| 'TensorDesc&': ' return *this;', | |||
| 'Tensor': ' return Tensor();', | |||
| 'Tensor&': ' return *this;', | |||
| 'Operator': ' return Operator();', | |||
| 'Operator&': ' return *this;', | |||
| 'Ptr': ' return nullptr;', | |||
| 'std::string': ' return "";', | |||
| 'std::string&': ' return "";', | |||
| 'string': ' return "";', | |||
| 'int': ' return 0;', | |||
| 'DataType': ' return DT_FLOAT;', | |||
| 'InferenceContextPtr': ' return nullptr;', | |||
| 'SubgraphBuilder': ' return nullptr;', | |||
| 'OperatorImplPtr': ' return nullptr;', | |||
| 'OutHandler': ' return nullptr;', | |||
| 'std::vector<std::string>': ' return {};', | |||
| 'std::vector<int64_t>': ' return {};', | |||
| 'std::map': ' return {};', | |||
| 'uint32_t': ' return 0;', | |||
| 'int64_t': ' return 0;', | |||
| 'uint64_t': ' return 0;', | |||
| 'size_t': ' return 0;', | |||
| 'float': ' return 0.0f;', | |||
| 'bool': ' return false;', | |||
| } | |||
| """ | |||
| max code len per line in hua_wei software programming specifications | |||
| """ | |||
| max_code_len_per_line = 100 | |||
| """ | |||
| white_list_for_debug, include_dir_key_words is to | |||
| determines which header files to generate cc files from | |||
| when DEBUG on | |||
| """ | |||
| white_list_for_debug = ["operator.h", "tensor.h", | |||
| "graph.h", "operator_factory.h", | |||
| "ge_ir_build.h"] | |||
| include_dir_key_words = ["ge", "graph"] | |||
| DEBUG = True | |||
| def need_generate_func(func_line): | |||
| """ | |||
| :param func_line: | |||
| :return: | |||
| """ | |||
| if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ | |||
| or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): | |||
| return False | |||
| return True | |||
| def file_endswith_white_list_suffix(file): | |||
| """ | |||
| :param file: | |||
| :return: | |||
| """ | |||
| if DEBUG: | |||
| for suffix in white_list_for_debug: | |||
| if file.endswith(suffix): | |||
| return True | |||
| return False | |||
| else: | |||
| return True | |||
| """ | |||
| belows are patterns used for analyse .h file | |||
| """ | |||
| # pattern function | |||
| pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after | |||
| ([a-zA-Z~_] # void int likely | |||
| .* | |||
| [)] #we find ) | |||
| (?!.*{) # we do not want the case int abc() const { return 1;} | |||
| .*) | |||
| (;.*) #we want to find ; and after for we will replace these later | |||
| \n$ | |||
| """, re.VERBOSE | re.MULTILINE | re.DOTALL) | |||
| # pattern comment | |||
| pattern_comment = re.compile(r'^\s*//') | |||
| pattern_comment_2_start = re.compile(r'^\s*/[*]') | |||
| pattern_comment_2_end = re.compile(r'[*]/\s*$') | |||
| # pattern define | |||
| pattern_define = re.compile(r'^\s*#define') | |||
| pattern_define_return = re.compile(r'\\\s*$') | |||
| # blank line | |||
| pattern_blank_line = re.compile(r'^\s*$') | |||
| # virtual,explicit,friend,static | |||
| pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') | |||
| # lead space | |||
| pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') | |||
| # functions will have patterns such as func ( or func( | |||
| # but operator is an exception; the class name is preceded by an operator, and the above mode does not exist | |||
| # format like :"operator = ()" | |||
| pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') | |||
| # template | |||
| pattern_template = re.compile(r'^\s*template') | |||
| pattern_template_end = re.compile(r'>\s*$') | |||
| # namespace | |||
| pattern_namespace = re.compile(r'namespace.*{') | |||
| # class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with | |||
| pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR) | |||
| # {} | |||
| pattern_start = re.compile('{') | |||
| pattern_end = re.compile('}') | |||
| line_index = 0 | |||
| class H2CC(object): | |||
| def __init__(self, input_file, output_file, shared_includes_content): | |||
| """ | |||
| :param input_file: | |||
| :param output_file: | |||
| :param shared_includes_content: | |||
| """ | |||
| self.input_file = input_file | |||
| self.output_file = output_file | |||
| self.shared_includes_content = shared_includes_content | |||
| self.line_index = 0 | |||
| self.input_fd = open(self.input_file, 'r') | |||
| self.input_content = self.input_fd.readlines() | |||
| self.output_fd = open(self.output_file, 'w') | |||
| # The state may be normal_now(in the middle of {}),class_now,namespace_now | |||
| self.stack = [] | |||
| self.stack_class = [] | |||
| self.stack_template = [] | |||
| # record funcs generated by h2cc func | |||
| self.func_list_exist = [] | |||
| def __del__(self): | |||
| self.input_fd.close() | |||
| self.output_fd.close() | |||
| del self.stack | |||
| del self.stack_class | |||
| del self.stack_template | |||
| del self.func_list_exist | |||
| def just_skip(self): | |||
| # skip blank line or comment | |||
| if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | |||
| self.input_content[self.line_index]): # /n or comment using // | |||
| self.line_index += 1 | |||
| if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /* | |||
| while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */ | |||
| self.line_index += 1 | |||
| self.line_index += 1 | |||
| # skip define | |||
| if pattern_define.search(self.input_content[self.line_index]): | |||
| while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search( | |||
| self.input_content[self.line_index]): | |||
| self.line_index += 1 | |||
| self.line_index += 1 | |||
| def write_inc_content(self): | |||
| for shared_include_content in self.shared_includes_content: | |||
| self.output_fd.write(shared_include_content) | |||
| def h2cc(self): | |||
| """ | |||
| :return: | |||
| """ | |||
| logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file) | |||
| global pattern_comment | |||
| global pattern_comment_2_start | |||
| global pattern_comment_2_end | |||
| global pattern_blank_line | |||
| global pattern_func | |||
| global pattern_keyword | |||
| global pattern_leading_space | |||
| global pattern_func_name | |||
| global pattern_template | |||
| global pattern_template_end | |||
| global pattern_namespace | |||
| global pattern_class | |||
| global pattern_start | |||
| global pattern_end | |||
| global line_index | |||
| # write inc content | |||
| self.write_inc_content() | |||
| # core processing cycle, process the input .h file by line | |||
| while self.line_index < len(self.input_content): | |||
| # handle comment and blank line | |||
| self.just_skip() | |||
| # match namespace | |||
| self.handle_namespace() | |||
| # match template | |||
| template_string = self.handle_template() | |||
| # match class | |||
| line = self.input_content[self.line_index] | |||
| match_class = pattern_class.search(line) | |||
| match_start = pattern_start.search(line) | |||
| handle_class_result = self.handle_class(template_string, line, match_start, match_class) | |||
| if handle_class_result == "continue": | |||
| continue | |||
| # match "}" | |||
| handle_stack_result = self.handle_stack(match_start) | |||
| if handle_stack_result == "continue": | |||
| continue | |||
| # handle func | |||
| handle_func1_result, line, start_i = self.handle_func1(line) | |||
| if handle_func1_result == "continue": | |||
| continue | |||
| # here means func is found | |||
| # delete key word | |||
| line = pattern_keyword.sub('', line) | |||
| logging.info("line[%s]", line) | |||
| # Class member function | |||
| # if friend we will not add class name | |||
| friend_match = re.search('friend ', line) | |||
| if len(self.stack_class) > 0 and not friend_match: | |||
| line, func_name = self.handle_class_member_func(line, template_string) | |||
| # Normal functions | |||
| else: | |||
| line, func_name = self.handle_normal_func(line, template_string) | |||
| need_generate = need_generate_func(line) | |||
| # func body | |||
| line += self.implement_function(line) | |||
| # comment | |||
| line = self.gen_comment(start_i) + line | |||
| # write to out file | |||
| self.write_func_content(line, func_name, need_generate) | |||
| # next loop | |||
| self.line_index += 1 | |||
| logging.info('Added %s functions', len(self.func_list_exist)) | |||
| logging.info('Successfully converted,please see ' + self.output_file) | |||
| def handle_func1(self, line): | |||
| """ | |||
| :param line: | |||
| :return: | |||
| """ | |||
| find1 = re.search('[(]', line) | |||
| if not find1: | |||
| self.line_index += 1 | |||
| return "continue", line, None | |||
| find2 = re.search('[)]', line) | |||
| start_i = self.line_index | |||
| space_match = pattern_leading_space.search(line) | |||
| # deal with | |||
| # int abc(int a, | |||
| # int b) | |||
| if find1 and (not find2): | |||
| self.line_index += 1 | |||
| line2 = self.input_content[self.line_index] | |||
| if space_match: | |||
| line2 = re.sub('^' + space_match.group(1), '', line2) | |||
| line += line2 | |||
| while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): | |||
| self.line_index += 1 | |||
| line2 = self.input_content[self.line_index] | |||
| line2 = re.sub('^' + space_match.group(1), '', line2) | |||
| line += line2 | |||
| match_start = pattern_start.search(self.input_content[self.line_index]) | |||
| match_end = pattern_end.search(self.input_content[self.line_index]) | |||
| if match_start: # like ) { or ) {} int the last line | |||
| if not match_end: | |||
| self.stack.append('normal_now') | |||
| ii = start_i | |||
| while ii <= self.line_index: | |||
| ii += 1 | |||
| self.line_index += 1 | |||
| return "continue", line, start_i | |||
| logging.info("line[%s]", line) | |||
| # ' int abc();'->'int abc()' | |||
| (line, match) = pattern_func.subn(r'\2\n', line) | |||
| logging.info("line[%s]", line) | |||
| # deal with case: | |||
| # 'int \n abc(int a, int b)' | |||
| if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): | |||
| line = self.input_content[start_i - 1] + line | |||
| line = line.lstrip() | |||
| if not match: | |||
| self.line_index += 1 | |||
| return "continue", line, start_i | |||
| return "pass", line, start_i | |||
| def handle_stack(self, match_start): | |||
| """ | |||
| :param match_start: | |||
| :return: | |||
| """ | |||
| line = self.input_content[self.line_index] | |||
| match_end = pattern_end.search(line) | |||
| if match_start: | |||
| self.stack.append('normal_now') | |||
| if match_end: | |||
| top_status = self.stack.pop() | |||
| if top_status == 'namespace_now': | |||
| self.output_fd.write(line + '\n') | |||
| elif top_status == 'class_now': | |||
| self.stack_class.pop() | |||
| self.stack_template.pop() | |||
| if match_start or match_end: | |||
| self.line_index += 1 | |||
| return "continue" | |||
| if len(self.stack) > 0 and self.stack[-1] == 'normal_now': | |||
| self.line_index += 1 | |||
| return "continue" | |||
| return "pass" | |||
| def handle_class(self, template_string, line, match_start, match_class): | |||
| """ | |||
| :param template_string: | |||
| :param line: | |||
| :param match_start: | |||
| :param match_class: | |||
| :return: | |||
| """ | |||
| if match_class: # we face a class | |||
| self.stack_template.append(template_string) | |||
| self.stack.append('class_now') | |||
| class_name = match_class.group(3) | |||
| # class template specializations: class A<u,Node<u> > | |||
| if '<' in class_name: | |||
| k = line.index('<') | |||
| fit = 1 | |||
| for ii in range(k + 1, len(line)): | |||
| if line[ii] == '<': | |||
| fit += 1 | |||
| if line[ii] == '>': | |||
| fit -= 1 | |||
| if fit == 0: | |||
| break | |||
| class_name += line[k + 1:ii + 1] | |||
| logging.info('class_name[%s]', class_name) | |||
| self.stack_class.append(class_name) | |||
| while not match_start: | |||
| self.line_index += 1 | |||
| line = self.input_content[self.line_index] | |||
| match_start = pattern_start.search(line) | |||
| self.line_index += 1 | |||
| return "continue" | |||
| return "pass" | |||
| def handle_template(self): | |||
| line = self.input_content[self.line_index] | |||
| match_template = pattern_template.search(line) | |||
| template_string = '' | |||
| if match_template: | |||
| match_template_end = pattern_template_end.search(line) | |||
| template_string = line | |||
| while not match_template_end: | |||
| self.line_index += 1 | |||
| line = self.input_content[self.line_index] | |||
| template_string += line | |||
| match_template_end = pattern_template_end.search(line) | |||
| self.line_index += 1 | |||
| return template_string | |||
| def handle_namespace(self): | |||
| line = self.input_content[self.line_index] | |||
| match_namespace = pattern_namespace.search(line) | |||
| if match_namespace: # we face namespace | |||
| self.output_fd.write(line + '\n') | |||
| self.stack.append('namespace_now') | |||
| self.line_index += 1 | |||
| def handle_normal_func(self, line, template_string): | |||
| template_line = '' | |||
| self.stack_template.append(template_string) | |||
| if self.stack_template[-1] != '': | |||
| template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) | |||
| # change '< class T = a, class U = A(3)>' to '<class T, class U>' | |||
| template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||
| template_line = re.sub(r'\s*=.*,', ',', template_line) | |||
| template_line = re.sub(r'\s*=.*', '', template_line) | |||
| line = re.sub(r'\s*=.*,', ',', line) | |||
| line = re.sub(r'\s*=.*\)', ')', line) | |||
| line = template_line + line | |||
| self.stack_template.pop() | |||
| func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||
| logging.info("line[%s]", line) | |||
| logging.info("func_name[%s]", func_name) | |||
| return line, func_name | |||
| def handle_class_member_func(self, line, template_string): | |||
| template_line = '' | |||
| x = '' | |||
| if template_string != '': | |||
| template_string = re.sub(r'\s*template', 'template', template_string) | |||
| template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) | |||
| template_string = re.sub(r'\s*=.*,', ',', template_string) | |||
| template_string = re.sub(r'\s*=.*', '', template_string) | |||
| if self.stack_template[-1] != '': | |||
| if not (re.search(r'<\s*>', stack_template[-1])): | |||
| template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) | |||
| if not (re.search(r'<.*>', self.stack_class[-1])): | |||
| # for x we get like template<class T, typename U> -> <T,U> | |||
| x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U> | |||
| x = re.sub(r'\n', '', x) | |||
| x = re.sub(r'\s*=.*,', ',', x) | |||
| x = re.sub(r'\s*=.*\>', '>', x) | |||
| x = x.rstrip() # remove \n | |||
| x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '', | |||
| x) # remove class,typename -> <T, U> | |||
| x = re.sub(r'<\s+', '<', x) | |||
| x = re.sub(r'\s+>', '>', x) | |||
| x = re.sub(r'\s+,', ',', x) | |||
| x = re.sub(r',\s+', ', ', x) | |||
| line = re.sub(r'\s*=\s+0', '', line) | |||
| line = re.sub(r'\s*=\s+.*,', ',', line) | |||
| line = re.sub(r'\s*=\s+.*\)', ')', line) | |||
| logging.info("x[%s]\nline[%s]", x, line) | |||
| # if the function is long, void ABC::foo() | |||
| # breaks into two lines void ABC::\n foo() | |||
| temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) | |||
| if len(temp_line) > max_code_len_per_line: | |||
| line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) | |||
| else: | |||
| line = temp_line | |||
| logging.info("line[%s]", line) | |||
| # add template as the above if there is one | |||
| template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||
| template_line = re.sub(r'\s*=.*,', ',', template_line) | |||
| template_line = re.sub(r'\s*=.*', '', template_line) | |||
| line = template_line + template_string + line | |||
| func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||
| logging.info("line[%s]", line) | |||
| logging.info("func_name[%s]", func_name) | |||
| return line, func_name | |||
| def write_func_content(self, content, func_name, need_generate): | |||
| if not (func_name in self.func_list_exist) and need_generate: | |||
| self.output_fd.write(content) | |||
| self.func_list_exist.append(func_name) | |||
| logging.info('add func:[%s]', func_name) | |||
| def gen_comment(self, start_i): | |||
| comment_line = '' | |||
| # Function comments are on top of function declarations, copy them over | |||
| k = start_i - 1 # one line before this func start | |||
| if pattern_template.search(self.input_content[k]): | |||
| k -= 1 | |||
| if pattern_comment_2_end.search(self.input_content[k]): | |||
| comment_line = self.input_content[k].lstrip() | |||
| while not pattern_comment_2_start.search(self.input_content[k]): | |||
| k -= 1 | |||
| comment_line = self.input_content[k].lstrip() + comment_line | |||
| else: | |||
| for j in range(k, 0, -1): | |||
| c_line = self.input_content[j] | |||
| if pattern_comment.search(c_line): | |||
| c_line = re.sub(r'\s*//', '//', c_line) | |||
| comment_line = c_line + comment_line | |||
| else: | |||
| break | |||
| return comment_line | |||
| @staticmethod | |||
| def implement_function(func): | |||
| function_def = '' | |||
| function_def += '{\n' | |||
| all_items = func.split() | |||
| start = 0 | |||
| return_type = all_items[start] | |||
| if return_type == "const": | |||
| start += 1 | |||
| return_type = all_items[start] | |||
| if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||
| return_type = "std::map" | |||
| if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||
| return_type = "Ptr" | |||
| if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||
| return_type += "&" | |||
| if RETURN_STATEMENTS.__contains__(return_type): | |||
| function_def += RETURN_STATEMENTS[return_type] | |||
| else: | |||
| logging.warning("Unhandled return type[%s]", return_type) | |||
| function_def += '\n' | |||
| function_def += '}\n' | |||
| function_def += '\n' | |||
| return function_def | |||
| def collect_header_files(path): | |||
| """ | |||
| :param path: | |||
| :return: | |||
| """ | |||
| header_files = [] | |||
| shared_includes_content = [] | |||
| for root, dirs, files in os.walk(path): | |||
| files.sort() | |||
| for file in files: | |||
| if file.find("git") >= 0: | |||
| continue | |||
| if not file.endswith('.h'): | |||
| continue | |||
| file_path = os.path.join(root, file) | |||
| file_path = file_path.replace('\\', '/') | |||
| header_files.append(file_path) | |||
| include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) | |||
| shared_includes_content.append(include_str) | |||
| return header_files, shared_includes_content | |||
| def generate_stub_file(inc_dir, out_cc_dir): | |||
| """ | |||
| :param inc_dir: | |||
| :param out_cc_dir: | |||
| :return: | |||
| """ | |||
| target_header_files, shared_includes_content = collect_header_files(inc_dir) | |||
| for header_file in target_header_files: | |||
| if not file_endswith_white_list_suffix(header_file): | |||
| continue | |||
| cc_file = re.sub('.h*$', '.cc', header_file) | |||
| h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) | |||
| h_2_cc.h2cc() | |||
| def gen_code(inc_dir, out_cc_dir): | |||
| """ | |||
| :param inc_dir: | |||
| :param out_cc_dir: | |||
| :return: | |||
| """ | |||
| if not inc_dir.endswith('/'): | |||
| inc_dir += '/' | |||
| if not out_cc_dir.endswith('/'): | |||
| out_cc_dir += '/' | |||
| for include_dir_key_word in include_dir_key_words: | |||
| generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) | |||
| if __name__ == '__main__': | |||
| inc_dir = sys.argv[1] | |||
| out_cc_dir = sys.argv[2] | |||
| gen_code(inc_dir, out_cc_dir) | |||
| @@ -178,16 +178,18 @@ int64_t Shape::GetShapeSize() const { | |||
| return 0; | |||
| } | |||
| TensorDesc::TensorDesc() { impl = ComGraphMakeShared<TensorDescImpl>(); } | |||
| TensorDesc::TensorDesc() { | |||
| impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665 | |||
| } | |||
| TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { | |||
| impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); | |||
| impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); // lint !e665 | |||
| SetRealDimCnt(shape.GetDimNum()); | |||
| } | |||
| TensorDesc::TensorDesc(const TensorDesc &desc) { | |||
| // Copy | |||
| impl = ComGraphMakeShared<TensorDescImpl>(); | |||
| impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665 | |||
| if (desc.impl != nullptr && impl != nullptr) { | |||
| *impl = *desc.impl; | |||
| } | |||
| @@ -358,7 +360,9 @@ void TensorDesc::SetName(const std::string &name) { | |||
| Tensor::Tensor() { impl = ComGraphMakeShared<TensorImpl>(); } | |||
| Tensor::Tensor(const TensorDesc &tensor_desc) { impl = ComGraphMakeShared<TensorImpl>(tensor_desc); } | |||
| Tensor::Tensor(const TensorDesc &tensor_desc) { | |||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc); // lint !e665 | |||
| } | |||
| Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) { | |||
| uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); | |||
| @@ -380,7 +384,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) | |||
| } | |||
| } | |||
| } | |||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); | |||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); // lint !e665 | |||
| } | |||
| Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { | |||
| @@ -402,7 +406,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) | |||
| } | |||
| } | |||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | |||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); // lint !e665 | |||
| } | |||
| Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) { | |||
| @@ -425,7 +429,7 @@ Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) { | |||
| } | |||
| } | |||
| } | |||
| impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); | |||
| impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); // lint !e665 | |||
| } | |||
| TensorDesc Tensor::GetTensorDesc() const { | |||
| @@ -639,7 +643,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||
| GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { | |||
| GeTensorPtr ge_tensor; | |||
| if (tensor.impl != nullptr) { | |||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); | |||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); // lint !e665 | |||
| } | |||
| return ge_tensor; | |||
| } | |||
| @@ -655,7 +659,7 @@ Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { | |||
| ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { | |||
| GeTensorPtr ge_tensor; | |||
| if (tensor.impl != nullptr) { | |||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); | |||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665 | |||
| } | |||
| return ge_tensor; | |||
| } | |||
| @@ -663,7 +667,7 @@ ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { | |||
| GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { | |||
| GeTensorPtr ge_tensor; | |||
| if (tensor.impl != nullptr) { | |||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); | |||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665 | |||
| } | |||
| return ge_tensor; | |||
| } | |||
| @@ -38,6 +38,7 @@ | |||
| #include "utils/ge_ir_utils.h" | |||
| #include "utils/node_utils.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "external/ge/ge_api_types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| @@ -410,8 +411,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra | |||
| /// @return graphStatus | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { | |||
| GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { | |||
| GE_CHECK_NOTNULL(src); | |||
| GE_CHECK_NOTNULL(insert_node); | |||
| @@ -570,7 +571,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons | |||
| static int max_dumpfile_num = 0; | |||
| if (max_dumpfile_num == 0) { | |||
| string opt = "0"; | |||
| (void)GetContext().GetOption("ge.maxDumpFileNum", opt); | |||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||
| max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||
| } | |||
| if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) { | |||
| @@ -670,7 +671,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText | |||
| if (maxDumpFileSize == 0) { | |||
| string opt = "0"; | |||
| // Can not check return value | |||
| (void)GetContext().GetOption("ge.maxDumpFileSize", opt); | |||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); | |||
| maxDumpFileSize = atol(opt.c_str()); | |||
| } | |||
| if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) { | |||
| @@ -740,7 +741,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||
| static int max_dumpfile_num = 0; | |||
| if (max_dumpfile_num == 0) { | |||
| string opt = "0"; | |||
| (void)GetContext().GetOption("ge.maxDumpFileNum", opt); | |||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||
| max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||
| } | |||
| if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) { | |||
| @@ -920,7 +921,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector<int> &io_map, In | |||
| InNodesToOut GetFullConnectIONodes(const NodePtr &node) { | |||
| InNodesToOut in_nodes_to_out; | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "Node is nullptr"); | |||
| return in_nodes_to_out; | |||
| } | |||
| auto in_nodes_list = node->GetInNodes(); | |||
| @@ -1308,6 +1309,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Copy all in-data edges from `src_node` to `dst_node`. | |||
| /// @param src_node | |||
| /// @param dst_node | |||
| /// @return | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, | |||
| NodePtr &dst_node) { | |||
| if ((src_node == nullptr) || (dst_node == nullptr)) { | |||
| GELOGE(GRAPH_FAILED, "Parameter is nullptr"); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto src_data_in_nodes = src_node->GetInDataNodes(); | |||
| if (src_data_in_nodes.empty()) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { | |||
| auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); | |||
| auto ret = | |||
| GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", | |||
| in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), | |||
| src_node->GetName().c_str(), dst_node->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, | |||
| const NodePtr &node) { | |||
| if (graph->AddInputNode(node) == nullptr) { | |||
| @@ -1328,6 +1359,153 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR | |||
| return result; | |||
| } | |||
| /// | |||
| /// Make a copy of ComputeGraph. | |||
| /// @param graph: original graph. | |||
| /// @param prefix: node name prefix of new graph. | |||
| /// @return ComputeGraphPtr | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr | |||
| GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, std::vector<NodePtr> &input_nodes, | |||
| std::vector<NodePtr> &output_nodes) { | |||
| GE_CHK_BOOL_EXEC(graph != nullptr, return nullptr, "Original graph is null"); | |||
| ComputeGraphPtr new_graph = ComGraphMakeShared<ComputeGraph>(graph->GetName()); | |||
| GE_CHK_BOOL_EXEC(new_graph != nullptr, return nullptr, "Create new graph failed"); | |||
| std::unordered_map<std::string, NodePtr> all_new_nodes; | |||
| for (const auto &n : graph->GetDirectNode()) { | |||
| OpDescPtr op_desc = AttrUtils::CopyOpDesc(n->GetOpDesc()); | |||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return nullptr, "Create new node failed"); | |||
| if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) { | |||
| return nullptr; | |||
| } | |||
| op_desc->SetName(prefix + n->GetName()); | |||
| NodePtr node = new_graph->AddNode(op_desc); | |||
| GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); | |||
| all_new_nodes[node->GetName()] = node; | |||
| if (node->GetType() == DATA) { | |||
| input_nodes.emplace_back(node); | |||
| } else if (node->GetType() == NETOUTPUT) { | |||
| output_nodes.emplace_back(node); | |||
| } | |||
| } | |||
| for (const auto &n : graph->GetDirectNode()) { | |||
| if (RelinkGraphEdges(n, prefix, all_new_nodes) != GRAPH_SUCCESS) { | |||
| return nullptr; | |||
| } | |||
| } | |||
| return new_graph; | |||
| } | |||
| /// | |||
| /// Copy tensor attribute to new node. | |||
| /// @param [in] dst_node: cloned node. | |||
| /// @param [in] src_node: original node. | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { | |||
| if (dst_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input param dst node not valid"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (src_node == nullptr || src_node->GetOpDesc() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input param src node not valid"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| const auto &src_desc = src_node->GetOpDesc(); | |||
| dst_desc->CopyAttrsFrom(*src_desc); | |||
| for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { | |||
| auto input_desc = dst_desc->MutableInputDesc(i); | |||
| if (input_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Param dst node not valid"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i)); | |||
| } | |||
| for (uint32_t i = 0; i < src_node->GetAllOutDataAnchorsSize(); ++i) { | |||
| auto output_desc = dst_desc->MutableOutputDesc(i); | |||
| if (output_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Param dst node not valid"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i)); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Relink all edges for cloned ComputeGraph. | |||
| /// @param [in] node: original node. | |||
| /// @param [in] prefix: node name prefix of new node. | |||
| /// @param [in] all_nodes: all nodes in new graph. | |||
| /// @return success: GRAPH_SUCESS | |||
| /// | |||
| graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix, | |||
| const std::unordered_map<string, NodePtr> &all_nodes) { | |||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input node not valid"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto it = all_nodes.find(prefix + node->GetName()); | |||
| if (it == all_nodes.end()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| const auto &new_node = it->second; | |||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | |||
| GE_CHK_BOOL_EXEC(in_anchor != nullptr, return GRAPH_FAILED, "In data anchor is null"); | |||
| const auto &out_anchor = in_anchor->GetPeerOutAnchor(); | |||
| if (out_anchor == nullptr) { | |||
| GELOGW("Peer out anchor is null: %s", node->GetName().c_str()); | |||
| continue; | |||
| } | |||
| GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); | |||
| it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); | |||
| if (it == all_nodes.end()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| const auto &new_out_node = it->second; | |||
| auto rslt = | |||
| GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInAnchor(in_anchor->GetIdx())); | |||
| GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", | |||
| new_out_node->GetName().c_str(), new_node->GetName().c_str()); | |||
| } | |||
| if (node->GetInControlAnchor() != nullptr) { | |||
| for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { | |||
| GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); | |||
| GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); | |||
| it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); | |||
| if (it == all_nodes.end()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| const auto &new_out_node = it->second; | |||
| auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInControlAnchor()); | |||
| GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", | |||
| new_out_node->GetName().c_str(), new_node->GetName().c_str()); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Get reference-mapping of all data_anchors in graph | |||
| /// @param [in] graph | |||
| @@ -1339,7 +1517,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, | |||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| for (auto &node : graph->GetAllNodes()) { | |||
| for (const auto &node : graph->GetAllNodes()) { | |||
| // in_data_anchor | |||
| if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); | |||
| @@ -1396,16 +1574,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, | |||
| return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); | |||
| } | |||
| std::string type = node->GetType(); | |||
| const std::string &type = node->GetType(); | |||
| if ((type == MERGE) || (type == STREAMMERGE)) { | |||
| return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); | |||
| } | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { | |||
| std::string symbol = cur_node_info.ToString(); | |||
| const std::string &symbol = cur_node_info.ToString(); | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| symbol_to_anchors[symbol] = {cur_node_info}; | |||
| anchor_to_symbol[symbol] = symbol; | |||
| @@ -1432,7 +1610,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(node); | |||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
| for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
| NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); | |||
| if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { | |||
| continue; | |||
| @@ -1446,7 +1624,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
| return GRAPH_FAILED; | |||
| } | |||
| } else { | |||
| std::string symbol = cur_node_info.ToString(); | |||
| const std::string &symbol = cur_node_info.ToString(); | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| symbol_to_anchors.emplace(std::make_pair(symbol, std::list<NodeIndexIO>{cur_node_info})); | |||
| anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); | |||
| @@ -1506,7 +1684,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| GE_CHECK_NOTNULL(node); | |||
| std::vector<NodeIndexIO> exist_node_infos; | |||
| std::vector<NodeIndexIO> cur_node_infos; | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { | |||
| std::string next_name; | |||
| @@ -1529,10 +1707,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| size_t anchor_nums = 0; | |||
| NodeIndexIO max_node_index_io(nullptr, 0, kOut); | |||
| for (auto &temp_node_info : exist_node_infos) { | |||
| for (const auto &temp_node_info : exist_node_infos) { | |||
| auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); | |||
| if (iter1 != anchor_to_symbol.end()) { | |||
| std::string temp_symbol = iter1->second; | |||
| const std::string &temp_symbol = iter1->second; | |||
| auto iter2 = symbol_to_anchors.find(temp_symbol); | |||
| if (iter2 != symbol_to_anchors.end()) { | |||
| if (iter2->second.size() > anchor_nums) { | |||
| @@ -1544,7 +1722,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| } | |||
| std::string symbol; | |||
| for (auto &temp_node_info : exist_node_infos) { | |||
| for (const auto &temp_node_info : exist_node_infos) { | |||
| if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||
| GRAPH_SUCCESS) || | |||
| symbol.empty()) { | |||
| @@ -1556,7 +1734,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| auto iter = symbol_to_anchors.find(symbol); | |||
| if (iter != symbol_to_anchors.end()) { | |||
| for (auto &temp_node_info : cur_node_infos) { | |||
| for (const auto &temp_node_info : cur_node_infos) { | |||
| GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); | |||
| iter->second.emplace_back(temp_node_info); | |||
| anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); | |||
| @@ -1584,7 +1762,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
| OpDescPtr op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| @@ -1627,8 +1805,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
| graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) { | |||
| std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
| std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
| const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
| const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
| if (symbol1 == symbol2) { | |||
| symbol = symbol1; | |||
| GELOGI("no need to union."); | |||
| @@ -1684,7 +1862,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::string symbol = iter1->second; | |||
| const std::string &symbol = iter1->second; | |||
| auto iter2 = symbol_to_anchors.find(symbol); | |||
| if (iter2 == symbol_to_anchors.end()) { | |||
| GE_LOGE("symbol %s not found.", symbol.c_str()); | |||
| @@ -1712,7 +1890,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||
| // pass-through op | |||
| NodePtr node = out_data_anchor->GetOwnerNode(); | |||
| std::string type = node->GetType(); | |||
| const std::string &type = node->GetType(); | |||
| const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; | |||
| if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { | |||
| reuse_in_index = output_index; | |||
| @@ -1755,7 +1933,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||
| uint32_t reuse_input_index = 0; | |||
| if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { | |||
| reuse_in_index = static_cast<int32_t>(reuse_input_index); | |||
| GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
| GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
| reuse_in_index); | |||
| return true; | |||
| } | |||
| @@ -2297,7 +2475,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & | |||
| return; | |||
| } | |||
| std::string name = node->GetName() + "_RetVal"; | |||
| std::string name = node->GetName() + "_RetVal_" + std::to_string(index); | |||
| OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); | |||
| if (ret_val_desc == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| @@ -295,16 +295,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
| if (op_desc == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||
| if (is_unknown_graph) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { | |||
| GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx()); | |||
| ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
| output_tensor.SetOriginShape(output_tensor.GetShape()); | |||
| output_tensor.SetOriginDataType(output_tensor.GetDataType()); | |||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||
| ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||
| output_tensor->SetOriginShape(output_tensor->GetShape()); | |||
| output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||
| GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | |||
| node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(), | |||
| TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(), | |||
| TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str()); | |||
| (void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); | |||
| node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), | |||
| TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||
| TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||
| for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
| if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | |||
| @@ -316,17 +321,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
| continue; | |||
| } | |||
| GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), | |||
| output_tensor.GetDataType(), output_tensor.GetOriginDataType()); | |||
| peer_input_desc->SetShape(output_tensor.GetShape()); | |||
| peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | |||
| peer_input_desc->SetDataType(output_tensor.GetDataType()); | |||
| peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), | |||
| output_tensor->GetDataType(), output_tensor->GetOriginDataType()); | |||
| peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); | |||
| peer_input_desc->SetShape(output_tensor->GetShape()); | |||
| peer_input_desc->SetDataType(output_tensor->GetDataType()); | |||
| peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); | |||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
| (void)output_tensor.GetShapeRange(shape_range); | |||
| (void)output_tensor->GetShapeRange(shape_range); | |||
| peer_input_desc->SetShapeRange(shape_range); | |||
| ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | |||
| static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
| static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||
| GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), | |||
| peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); | |||
| @@ -334,6 +339,50 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, | |||
| uint32_t index) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); | |||
| OpDescPtr op_desc = node->op_; | |||
| for (size_t i = op_desc->GetInputsSize(); i < index; ++i) { | |||
| if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Add input desc failed"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto anchor = ComGraphMakeShared<InDataAnchor>(node, i); | |||
| if (anchor == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| node->in_data_anchors_.push_back(anchor); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, | |||
| uint32_t index) { | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| OpDescPtr op_desc = node->op_; | |||
| op_desc->RemoveInputDesc(index); | |||
| while (node->in_data_anchors_.size() > index) { | |||
| node->in_data_anchors_.pop_back(); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| bool NodeUtils::IsInNodesEmpty(const Node &node) { | |||
| for (const auto &in_anchor : node.in_data_anchors_) { | |||
| if (in_anchor != nullptr) { | |||
| @@ -401,10 +450,13 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||
| graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { | |||
| auto desc = node.GetOpDesc(); | |||
| GE_CHECK_NOTNULL(desc); | |||
| // check self | |||
| is_unknow = OpShapeIsUnknown(desc); | |||
| if (is_unknow) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| auto sub_graph_names = desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| is_unknow = OpShapeIsUnknown(desc); | |||
| return GRAPH_SUCCESS; | |||
| } else { | |||
| auto owner_graph = node.GetOwnerComputeGraph(); | |||
| @@ -440,6 +492,7 @@ std::string NodeUtils::GetNodeType(const Node &node) { | |||
| (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); | |||
| return type; | |||
| } | |||
| ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | |||
| auto op_desc = node.GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| @@ -492,6 +545,14 @@ bool NodeUtils::IsSubgraphInput(const NodePtr &node) { | |||
| return false; | |||
| } | |||
| if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { | |||
| bool is_unknown_shape = false; | |||
| (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); | |||
| if (is_unknown_shape) return false; | |||
| } | |||
| if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && | |||
| kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && | |||
| kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { | |||
| return false; | |||
| } | |||
| @@ -513,7 +574,16 @@ bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { | |||
| if (parent_op_desc == nullptr) { | |||
| return false; | |||
| } | |||
| if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { | |||
| bool is_unknown_shape = false; | |||
| (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); | |||
| if (is_unknown_shape) return false; | |||
| } | |||
| if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && | |||
| kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && | |||
| kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { | |||
| return false; | |||
| } | |||
| @@ -555,6 +625,53 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) { | |||
| return peer_out_anchor->GetOwnerNode(); | |||
| } | |||
| /// | |||
| /// @brief Check is varying_input for while node | |||
| /// @param [in] node: Data node for subgraph | |||
| /// @return bool | |||
| /// | |||
| bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { | |||
| if (node == nullptr) { | |||
| return false; | |||
| } | |||
| if (node->GetType() != DATA) { | |||
| return false; // not input_node for subgraph | |||
| } | |||
| const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||
| if (parent_node == nullptr) { | |||
| return false; // root graph | |||
| } | |||
| if (kWhileOpTypes.count(parent_node->GetType()) == 0) { | |||
| return false; // not input_node for while subgraph | |||
| } | |||
| uint32_t index_i = 0; | |||
| if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { | |||
| GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); | |||
| return false; | |||
| } | |||
| bool varying_flag = true; | |||
| for (const auto &item : node->GetOutDataNodesAndAnchors()) { | |||
| if (item.first->GetType() != NETOUTPUT) { | |||
| continue; | |||
| } | |||
| OpDescPtr op_desc = item.first->GetOpDesc(); | |||
| uint32_t index_o = 0; | |||
| if ((op_desc == nullptr) || | |||
| !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { | |||
| continue; // input for while-cond subgraph | |||
| } | |||
| if (index_i != index_o) { | |||
| continue; // varying input for while-body subgraph | |||
| } | |||
| varying_flag = false; | |||
| break; | |||
| } | |||
| return varying_flag; | |||
| } | |||
| /// | |||
| /// @brief Get subgraph input is constant. | |||
| /// @param [in] node | |||
| @@ -637,4 +754,86 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { | |||
| vector<NodePtr> in_data_node_vec; | |||
| auto op_desc = node.GetOpDesc(); | |||
| GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); | |||
| auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (subgraph_names.empty()) { | |||
| GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); | |||
| return in_data_node_vec; | |||
| } | |||
| auto compute_graph = node.GetOwnerComputeGraph(); | |||
| for (const std::string &instance_name : subgraph_names) { | |||
| auto subgraph = compute_graph->GetSubgraph(instance_name); | |||
| for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||
| int parent_index = -1; | |||
| if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { | |||
| (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); | |||
| if (parent_index == index) { | |||
| in_data_node_vec.emplace_back(node_in_subgraph); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return in_data_node_vec; | |||
| } | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) { | |||
| vector<NodePtr> out_data_node_vec; | |||
| auto op_desc = node.GetOpDesc(); | |||
| GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); | |||
| auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (subgraph_names.empty()) { | |||
| GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); | |||
| return out_data_node_vec; | |||
| } | |||
| auto compute_graph = node.GetOwnerComputeGraph(); | |||
| for (const std::string &instance_name : subgraph_names) { | |||
| auto subgraph = compute_graph->GetSubgraph(instance_name); | |||
| for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||
| if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { | |||
| out_data_node_vec.emplace_back(node_in_subgraph); | |||
| } | |||
| } | |||
| } | |||
| return out_data_node_vec; | |||
| } | |||
| NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { | |||
| if (node.GetInDataAnchor(index) == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); | |||
| } | |||
| vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { | |||
| vector<NodePtr> out_data_nodes; | |||
| auto out_data_anchor = node.GetOutDataAnchor(index); | |||
| if (out_data_anchor == nullptr) { | |||
| return out_data_nodes; | |||
| } | |||
| for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
| if (peer_in_anchor == nullptr) { | |||
| continue; | |||
| } | |||
| if (peer_in_anchor->GetOwnerNode() == nullptr) { | |||
| continue; | |||
| } | |||
| out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); | |||
| } | |||
| return out_data_nodes; | |||
| } | |||
| } // namespace ge | |||
| @@ -28,6 +28,7 @@ | |||
| using std::vector; | |||
| /*lint -e512 -e737 -e752*/ | |||
| namespace ge { | |||
| const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; | |||
| static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; | |||
| @@ -132,11 +133,11 @@ graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, Quantize | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { | |||
| GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); | |||
| return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); | |||
| return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||
| } | |||
| graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { | |||
| return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); | |||
| return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||
| } | |||
| GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { | |||
| @@ -197,24 +198,33 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
| continue; | |||
| } | |||
| auto in_node = out_anchor->GetOwnerNode(); | |||
| if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
| ret.push_back(in_node); | |||
| } else if (in_node->GetType() == DATA) { | |||
| const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); | |||
| GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); | |||
| const NodePtr &parent_node = graph->GetParentNode(); | |||
| if (parent_node == nullptr) { | |||
| continue; // Root graph. | |||
| } | |||
| if (kWhileOpTypes.count(parent_node->GetType()) > 0) { | |||
| continue; // Subgraph of While cond or body. | |||
| while (true) { | |||
| if (in_node == nullptr) { | |||
| break; | |||
| } | |||
| NodePtr input_node = NodeUtils::GetParentInput(in_node); | |||
| if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { | |||
| ret.push_back(input_node); | |||
| if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
| ret.push_back(in_node); | |||
| break; | |||
| } else if (in_node->GetType() == DATA) { | |||
| if (NodeUtils::IsWhileVaryingInput(in_node)) { | |||
| break; | |||
| } | |||
| in_node = NodeUtils::GetParentInput(in_node); | |||
| } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { | |||
| bool is_constant = false; | |||
| (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); | |||
| if (!is_constant) { | |||
| break; | |||
| } | |||
| // Enter node has and only has one input | |||
| if (in_node->GetInDataNodes().size() != 1) { | |||
| GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), | |||
| in_node->GetInDataNodes().size()); | |||
| break; | |||
| } | |||
| in_node = in_node->GetInDataNodes().at(0); | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| @@ -245,7 +255,7 @@ size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { | |||
| continue; | |||
| } | |||
| } | |||
| return input_num; | |||
| return input_num; // lint !e712 | |||
| } else { | |||
| GE_IF_BOOL_EXEC( | |||
| node.GetInDataNodes().size() < GetConstInputs(node).size(), | |||
| @@ -350,7 +360,7 @@ bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { | |||
| bool ret = false; | |||
| if (index < node.GetAllInDataAnchors().size()) { | |||
| if (NodeUtils::IsAnchorStatusSet(node)) { | |||
| ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); | |||
| ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712 | |||
| } else { | |||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | |||
| if (anchor->GetIdx() != static_cast<int>(index)) { | |||
| @@ -435,10 +445,27 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) { | |||
| vector<GeTensorPtr> ret; | |||
| GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!"); | |||
| auto op_desc = node.GetOpDesc(); | |||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); | |||
| // Place holder operator, try to get the weight from parent node | |||
| // when parent node is const operator | |||
| if (node.GetType() == PLACEHOLDER) { | |||
| std::string parent_op; | |||
| (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); | |||
| // This if judgment is necessary because the current subgraph optimization is multithreaded | |||
| // and the parent node of the PLD operation should be a stable type, such as const | |||
| if (parent_op == CONSTANT || parent_op == CONSTANTOP) { | |||
| NodePtr parent_node = nullptr; | |||
| parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); | |||
| if (parent_node != nullptr) { | |||
| op_desc = parent_node->GetOpDesc(); | |||
| GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); | |||
| } | |||
| } | |||
| } | |||
| // Const operator, take the weight directly | |||
| if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) { | |||
| auto weight = MutableWeights(node.GetOpDesc()); | |||
| if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { | |||
| auto weight = MutableWeights(op_desc); | |||
| if (weight == nullptr) { | |||
| GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); | |||
| return ret; | |||
| @@ -733,3 +760,4 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgr | |||
| return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); | |||
| } | |||
| } // namespace ge | |||
| /*lint +e512 +e737 +e752*/ | |||
| @@ -19,6 +19,7 @@ | |||
| #include "debug/ge_log.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/types.h" | |||
| #include "graph/utils/type_utils.h" | |||
| @@ -105,7 +106,10 @@ static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_ | |||
| element_cnt = 1; | |||
| for (int64_t dim : dims) { | |||
| if (CheckMultiplyOverflowInt64(element_cnt, dim)) { | |||
| GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim); | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19013", {"function", "var1", "var2"}, | |||
| {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); | |||
| GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); | |||
| return GRAPH_FAILED; | |||
| } | |||
| element_cnt *= dim; | |||
| @@ -273,7 +277,6 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
| case FORMAT_FRACTAL_Z: | |||
| graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); | |||
| break; | |||
| case FORMAT_NC1HWC0_C04: | |||
| case FORMAT_FRACTAL_NZ: | |||
| case FORMAT_FRACTAL_ZZ: | |||
| case FORMAT_NDHWC: | |||
| @@ -285,6 +288,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
| case FORMAT_NDC1HWC0: | |||
| case FORMAT_FRACTAL_Z_C04: | |||
| case FORMAT_FRACTAL_ZN_LSTM: | |||
| case FORMAT_NC1HWC0_C04: | |||
| graph_status = CalcElementCntByDims(dims, element_cnt); | |||
| break; | |||
| default: | |||
| @@ -147,7 +147,8 @@ static const std::map<std::string, Format> kStringToFormatMap = { | |||
| {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, | |||
| {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, | |||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
| {"ALL", FORMAT_ALL}}; | |||
| {"ALL", FORMAT_ALL}, | |||
| {"NULL", FORMAT_NULL}}; | |||
| static const std::map<DataType, std::string> kDataTypeToStringMap = { | |||
| {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||
| @@ -60,6 +60,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/formats/formats.cc" | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| "common/fp16_t.cc" | |||
| "common/ge/op_tiling_manager.cc" | |||
| "common/ge/plugin_manager.cc" | |||
| "common/helper/model_cache_helper.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| @@ -94,14 +95,25 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/*.cc" | |||
| "graph/manager/graph_context.cc" | |||
| "graph/manager/graph_manager.cc" | |||
| "graph/manager/graph_manager_utils.cc" | |||
| "graph/manager/graph_mem_allocator.cc" | |||
| "graph/manager/graph_caching_allocator.cc" | |||
| "graph/manager/graph_var_manager.cc" | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/trans_var_data_utils.cc" | |||
| "graph/manager/util/debug.cc" | |||
| "graph/manager/util/hcom_util.cc" | |||
| "graph/manager/util/rt_context_util.cc" | |||
| "graph/manager/util/variable_accelerate_ctrl.cc" | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/util/debug.cc" | |||
| "graph/manager/util/hcom_util.cc" | |||
| "graph/manager/util/rt_context_util.cc" | |||
| "graph/manager/util/variable_accelerate_ctrl.cc" | |||
| "graph/optimize/graph_optimize.cc" | |||
| "graph/optimize/mem_rw_conflict_optimize.cc" | |||
| "graph/optimize/optimizer/allreduce_fusion_pass.cc" | |||
| "graph/optimize/summary_optimize.cc" | |||
| "graph/partition/dynamic_shape_partition.cc" | |||
| @@ -159,8 +171,11 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "hybrid/node_executor/aicpu/aicpu_ext_info.cc" | |||
| "hybrid/node_executor/aicpu/aicpu_node_executor.cc" | |||
| "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" | |||
| "hybrid/node_executor/controlop/control_op_executor.cc" | |||
| "hybrid/node_executor/hccl/hccl_node_executor.cc" | |||
| "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" | |||
| "hybrid/node_executor/node_executor.cc" | |||
| "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | |||
| "hybrid/node_executor/task_context.cc" | |||
| "init/gelib.cc" | |||
| "model/ge_model.cc" | |||
| @@ -204,6 +219,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/formats/formats.cc" | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| "common/fp16_t.cc" | |||
| "common/ge/op_tiling_manager.cc" | |||
| "common/ge/plugin_manager.cc" | |||
| "common/helper/model_cache_helper.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| @@ -236,13 +252,19 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/*.cc" | |||
| "graph/manager/graph_caching_allocator.cc" | |||
| "graph/manager/graph_context.cc" | |||
| "graph/manager/graph_manager.cc" | |||
| "graph/manager/graph_manager_utils.cc" | |||
| "graph/manager/graph_mem_allocator.cc" | |||
| "graph/manager/trans_var_data_utils.cc" | |||
| "graph/manager/graph_var_manager.cc" | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/util/debug.cc" | |||
| "graph/manager/util/rt_context_util.cc" | |||
| "graph/manager/util/variable_accelerate_ctrl.cc" | |||
| "graph/optimize/graph_optimize.cc" | |||
| "graph/optimize/mem_rw_conflict_optimize.cc" | |||
| "graph/optimize/summary_optimize.cc" | |||
| "graph/partition/dynamic_shape_partition.cc" | |||
| "graph/partition/engine_place.cc" | |||
| @@ -28,6 +28,7 @@ | |||
| #include "graph/opsproto_manager.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "graph/manager/util/rt_context_util.h" | |||
| #include "graph/common/ge_call_wrapper.h" | |||
| #include "register/op_registry.h" | |||
| #include "common/ge/tbe_plugin_manager.h" | |||
| @@ -41,8 +42,8 @@ namespace { | |||
| const int32_t kMaxStrLen = 128; | |||
| } | |||
| static bool kGeInitialized = false; | |||
| static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use | |||
| static bool g_ge_initialized = false; | |||
| static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use | |||
| namespace ge { | |||
| void GetOpsProtoPath(std::string &opsproto_path) { | |||
| @@ -61,31 +62,6 @@ void GetOpsProtoPath(std::string &opsproto_path) { | |||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||
| } | |||
| Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | |||
| const int kDecimal = 10; | |||
| auto dump_op_env = std::getenv("DUMP_OP"); | |||
| int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | |||
| auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); | |||
| if (disableReuseMemoryIter != options.end()) { | |||
| if (disableReuseMemoryIter->second == "0") { | |||
| GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | |||
| if (dump_op_flag) { | |||
| GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | |||
| } | |||
| } else if (disableReuseMemoryIter->second == "1") { | |||
| GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | |||
| } else { | |||
| GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| if (dump_op_flag) { | |||
| GELOGW("Will dump incorrect op data with default reuse memory"); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CheckOptionsValid(const std::map<string, string> &options) { | |||
| // check job_id is valid | |||
| auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); | |||
| @@ -96,11 +72,6 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||
| } | |||
| } | |||
| // Check ge.exec.disableReuseMemory and env DUMP_OP | |||
| if (CheckDumpAndReuseMemory(options) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -108,7 +79,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||
| Status GEInitialize(const std::map<string, string> &options) { | |||
| GELOGT(TRACE_INIT, "GEInitialize start"); | |||
| // 0.check init status | |||
| if (kGeInitialized) { | |||
| if (g_ge_initialized) { | |||
| GELOGW("GEInitialize is called more than once"); | |||
| return SUCCESS; | |||
| } | |||
| @@ -147,9 +118,9 @@ Status GEInitialize(const std::map<string, string> &options) { | |||
| } | |||
| // 7.check return status, return | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| // Initialize success, first time calling initialize | |||
| kGeInitialized = true; | |||
| g_ge_initialized = true; | |||
| } | |||
| GELOGT(TRACE_STOP, "GEInitialize finished"); | |||
| @@ -160,12 +131,12 @@ Status GEInitialize(const std::map<string, string> &options) { | |||
| Status GEFinalize() { | |||
| GELOGT(TRACE_INIT, "GEFinalize start"); | |||
| // check init status | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| GELOGW("GEFinalize is called before GEInitialize"); | |||
| return SUCCESS; | |||
| } | |||
| std::lock_guard<std::mutex> lock(kGeReleaseMutex); | |||
| std::lock_guard<std::mutex> lock(g_ge_release_mutex); | |||
| // call Finalize | |||
| Status ret = SUCCESS; | |||
| Status middle_ret; | |||
| @@ -187,10 +158,10 @@ Status GEFinalize() { | |||
| ret = middle_ret; | |||
| } | |||
| if (kGeInitialized && ret == SUCCESS) { | |||
| if (g_ge_initialized && ret == SUCCESS) { | |||
| // Unified destruct rt_context | |||
| RtContextUtil::GetInstance().DestroyrtContexts(); | |||
| kGeInitialized = false; | |||
| RtContextUtil::GetInstance().DestroyAllRtContexts(); | |||
| g_ge_initialized = false; | |||
| } | |||
| GELOGT(TRACE_STOP, "GEFinalize finished"); | |||
| @@ -202,7 +173,7 @@ Session::Session(const std::map<string, string> &options) { | |||
| GELOGT(TRACE_INIT, "Session Constructor start"); | |||
| // check init status | |||
| sessionId_ = 0; | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED); | |||
| return; | |||
| } | |||
| @@ -232,13 +203,13 @@ Session::Session(const std::map<string, string> &options) { | |||
| Session::~Session() { | |||
| GELOGT(TRACE_INIT, "Session Destructor start"); | |||
| // 0.check init status | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| GELOGW("GE is not yet initialized or is finalized."); | |||
| return; | |||
| } | |||
| Status ret = FAILED; | |||
| std::lock_guard<std::mutex> lock(kGeReleaseMutex); | |||
| std::lock_guard<std::mutex> lock(g_ge_release_mutex); | |||
| try { | |||
| uint64_t session_id = sessionId_; | |||
| // call DestroySession | |||
| @@ -72,9 +72,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons | |||
| void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
| const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||
| bool enum2str) { | |||
| if (field == nullptr || reflection == nullptr) { | |||
| return; | |||
| } | |||
| switch (field->type()) { | |||
| case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||
| const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | |||
| @@ -118,8 +115,12 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr | |||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | |||
| char str[kSignificantDigits]; | |||
| sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)); | |||
| json[field->name()] = str; | |||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { | |||
| json[field->name()] = str; | |||
| } else { | |||
| json[field->name()] = reflection->GetFloat(message, field); | |||
| } | |||
| break; | |||
| case ProtobufFieldDescriptor::TYPE_STRING: | |||
| @@ -29,7 +29,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| namespace { | |||
| enum DataTypeTransMode { | |||
| kTransferWithDatatypeFloatToFloat16, | |||
| @@ -27,7 +27,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| struct CastArgs { | |||
| const uint8_t *data; | |||
| size_t src_data_size; | |||
| @@ -179,6 +179,5 @@ Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::v | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -180,6 +180,5 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, con | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -56,7 +56,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||
| dst_shape.clear(); | |||
| hw_shape.clear(); | |||
| auto w0 = GetCubeSizeByDataType(data_type); | |||
| auto h0 = GetCubeSizeByDataType(data_type); | |||
| int64_t h0 = kCubeSize; | |||
| switch (src_shape.size()) { | |||
| case 1: | |||
| dst_shape.push_back(Ceil(src_shape[0], w0)); | |||
| @@ -19,6 +19,7 @@ | |||
| #include <securec.h> | |||
| #include <memory> | |||
| #include "common/debug/log.h" | |||
| #include "common/formats/utils/formats_definitions.h" | |||
| #include "common/formats/utils/formats_trans_utils.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| @@ -107,8 +108,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
| int64_t hw = h * w; | |||
| int64_t chw = c * hw; | |||
| int64_t hwc0 = hw * c0; | |||
| int64_t nchw = n * chw; | |||
| int64_t hwc0 = hw * c0; | |||
| // horizontal fractal matrix count (N) | |||
| int64_t hf_cnt = Ceil(n, static_cast<int64_t>(kNiSize)); | |||
| @@ -119,18 +120,15 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
| int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t dst_size = total_ele_cnt * size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dst == nullptr, | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| } | |||
| return OUT_OF_MEMORY;); | |||
| for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { | |||
| // vertical fractal matrix base index | |||
| @@ -156,12 +154,20 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
| auto protected_size = dst_size - offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||
| ? dst_size - offset | |||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||
| errno_t ret; | |||
| errno_t ret = EOK; | |||
| if (need_pad_zero) { | |||
| ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||
| } else { | |||
| ret = memcpy_s(dst.get() + offset, static_cast<size_t>(protected_size), args.data + src_offset * size, | |||
| static_cast<size_t>(size)); | |||
| if (protected_size < size) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| protected_size, size); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | |||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size); | |||
| for (int64_t index = 0; index < size; index++) { | |||
| *dst_data++ = *src_data++; | |||
| } | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, | |||
| @@ -199,18 +205,15 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dst == nullptr, | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| } | |||
| return OUT_OF_MEMORY;); | |||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||
| for (int64_t hi = 0; hi < h; hi++) { | |||
| @@ -223,14 +226,22 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||
| ? dst_size - dst_offset | |||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||
| errno_t ret; | |||
| errno_t ret = EOK; | |||
| if (pad_zero) { | |||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||
| static_cast<size_t>(data_size)); | |||
| } else { | |||
| if (protected_size < data_size) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| protected_size, data_size); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | |||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), | |||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | |||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size); | |||
| for (int64_t index = 0; index < data_size; index++) { | |||
| *dst_data++ = *src_data++; | |||
| } | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| @@ -269,18 +280,15 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dst == nullptr, | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| } | |||
| return OUT_OF_MEMORY;); | |||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||
| for (int64_t hi = 0; hi < h; hi++) { | |||
| @@ -293,14 +301,22 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||
| ? dst_size - dst_offset | |||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||
| errno_t ret; | |||
| errno_t ret = EOK; | |||
| if (pad_zero) { | |||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||
| static_cast<size_t>(data_size)); | |||
| } else { | |||
| if (protected_size < data_size) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| protected_size, data_size); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); | |||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), | |||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | |||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size); | |||
| for (int64_t index = 0; index < data_size; index++) { | |||
| *dst_data++ = *src_data++; | |||
| } | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| @@ -337,16 +353,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||
| return PARAM_INVALID; | |||
| } | |||
| if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatFromNchwToFz(args, result); | |||
| if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatNhwcToFz(args, result); | |||
| } | |||
| if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatHwcnToFz(args, result); | |||
| } | |||
| if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatNhwcToFz(args, result); | |||
| if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatFromNchwToFz(args, result); | |||
| } | |||
| return UNSUPPORTED; | |||
| @@ -358,14 +374,14 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i | |||
| return UNSUPPORTED; | |||
| } | |||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNchwToFz(src_shape, data_type, dst_shape); | |||
| if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNhwcToFz(src_shape, data_type, dst_shape); | |||
| } | |||
| if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeHwcnToFz(src_shape, data_type, dst_shape); | |||
| } | |||
| if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNhwcToFz(src_shape, data_type, dst_shape); | |||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNchwToFz(src_shape, data_type, dst_shape); | |||
| } | |||
| return UNSUPPORTED; | |||
| @@ -374,6 +390,5 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z) | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_HWCN, FORMAT_FRACTAL_Z) | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NHWC, FORMAT_FRACTAL_Z) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -39,7 +39,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| namespace { | |||
| constexpr int64_t kMaxDimsNumC = 4; | |||
| Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | |||
| @@ -109,7 +108,7 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| return NOT_CHANGED; | |||
| } | |||
| /* prepare for padding in chw*/ | |||
| // prepare for padding in chw | |||
| int64_t tmp = h * w * c; | |||
| int64_t n_o = Ceil(n, static_cast<int64_t>(c0)); | |||
| int64_t c_o = c0; | |||
| @@ -309,6 +308,5 @@ Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vecto | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -19,7 +19,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| static const int kCubeSize = 16; | |||
| static const int kNiSize = 16; | |||
| static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; | |||
| @@ -47,7 +46,6 @@ enum FracZDimIndex { kFracZHWC1, kFracZN0, kFracZNi, kFracZC0, kFracZDimsNum }; | |||
| enum DhwcnDimIndex { kDhwcnD, kDhwcnH, kDhwcnW, kDhwcnC, kDhwcnN, kDhwcnDimsNum }; | |||
| enum DhwncDimIndex { kDhwncD, kDhwncH, kDhwncW, kDhwncN, kDhwncC, kDhwncDimsNum }; | |||
| } // namespace formats | |||
| } // namespace ge | |||
| #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ | |||
| @@ -21,7 +21,6 @@ | |||
| #include <sstream> | |||
| #include <string> | |||
| #include <vector> | |||
| #include "external/graph/types.h" | |||
| #include "graph/ge_tensor.h" | |||
| @@ -69,7 +68,6 @@ T Ceil(T n1, T n2) { | |||
| } | |||
| return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| #endif // GE_COMMON_FORMATS_UTILS_FORMATS_TRANS_UTILS_H_ | |||
| @@ -600,5 +600,5 @@ int16_t GetManBitLength(T man) { | |||
| } | |||
| return len; | |||
| } | |||
| }; // namespace ge | |||
| } // namespace ge | |||
| #endif // GE_COMMON_FP16_T_H_ | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/ge/op_tiling_manager.h" | |||
| #include "framework/common/debug/log.h" | |||
| #include <string> | |||
| namespace { | |||
| const char *const kEnvName = "ASCEND_OPP_PATH"; | |||
| const std::string kDefaultPath = "/usr/local/Ascend/opp"; | |||
| const std::string kDefaultBuiltInTilingPath = "/op_impl/built-in/liboptiling.so"; | |||
| const std::string kDefaultCustomTilingPath = "/op_impl/custom/liboptiling.so"; | |||
| const uint8_t kPrefixIndex = 9; | |||
| } // namespace | |||
| namespace ge { | |||
| void OpTilingManager::ClearHandles() noexcept { | |||
| for (const auto &handle : handles_) { | |||
| if (dlclose(handle.second) != 0) { | |||
| GELOGE(FAILED, "Failed to close handle of %s: %s", handle.first.c_str(), dlerror()); | |||
| } | |||
| } | |||
| handles_.clear(); | |||
| } | |||
| OpTilingManager::~OpTilingManager() { ClearHandles(); } | |||
| std::string OpTilingManager::GetPath() { | |||
| const char *opp_path_env = std::getenv(kEnvName); | |||
| std::string opp_path = kDefaultPath; | |||
| if (opp_path_env != nullptr) { | |||
| char resolved_path[PATH_MAX]; | |||
| if (realpath(opp_path_env, resolved_path) == NULL) { | |||
| GELOGE(PARAM_INVALID, "Failed load tiling lib as env 'ASCEND_OPP_PATH'(%s) is invalid path.", opp_path_env); | |||
| return std::string(); | |||
| } | |||
| opp_path = resolved_path; | |||
| } | |||
| return opp_path; | |||
| } | |||
| void OpTilingManager::LoadSo() { | |||
| std::string opp_path = GetPath(); | |||
| if (opp_path.empty()) { | |||
| GELOGW("Skip load tiling lib."); | |||
| return; | |||
| } | |||
| std::string built_in_tiling_lib = opp_path + kDefaultBuiltInTilingPath; | |||
| std::string custom_tiling_lib = opp_path + kDefaultCustomTilingPath; | |||
| std::string built_in_name = kDefaultBuiltInTilingPath.substr(kPrefixIndex); | |||
| std::string custom_name = kDefaultCustomTilingPath.substr(kPrefixIndex); | |||
| void *handle_bi = dlopen(built_in_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
| if (handle_bi == nullptr) { | |||
| GELOGW("Failed to dlopen %s!", dlerror()); | |||
| } else { | |||
| handles_[built_in_name] = handle_bi; | |||
| } | |||
| void *handle_ct = dlopen(custom_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
| if (handle_ct == nullptr) { | |||
| GELOGW("Failed to dlopen %s!", dlerror()); | |||
| } else { | |||
| handles_[custom_name] = handle_ct; | |||
| } | |||
| } | |||
| } // namespace ge | |||
| @@ -0,0 +1,38 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_COMMON_GE_OP_TILING_MANAGER_H_ | |||
| #define GE_COMMON_GE_OP_TILING_MANAGER_H_ | |||
| #include <map> | |||
| namespace ge { | |||
| using SoToHandleMap = std::map<std::string, void *>; | |||
| class OpTilingManager { | |||
| public: | |||
| OpTilingManager() = default; | |||
| ~OpTilingManager(); | |||
| void LoadSo(); | |||
| private: | |||
| static std::string GetPath(); | |||
| void ClearHandles() noexcept; | |||
| SoToHandleMap handles_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_COMMON_GE_OP_TILING_MANAGER_H_ | |||
| @@ -182,7 +182,7 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) { | |||
| } | |||
| void TBEPluginManager::LoadCustomOpLib() { | |||
| LoadPluginSo(); | |||
| LoadPluginSo(options_); | |||
| std::vector<OpRegistrationData> registration_datas = domi::OpRegistry::Instance()->registrationDatas; | |||
| GELOGI("The size of registration_datas is: %zu", registration_datas.size()); | |||
| @@ -193,10 +193,13 @@ void TBEPluginManager::LoadCustomOpLib() { | |||
| } | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo() { | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo( | |||
| const std::map<string, string> &options) { | |||
| vector<string> file_list; | |||
| string caffe_parser_path; | |||
| std::string plugin_path; | |||
| options_ = options; | |||
| GetCustomOpPath(plugin_path); | |||
| // Whether there are files in the plugin so path | |||
| @@ -48,7 +48,7 @@ class TBEPluginManager { | |||
| static void InitPreparation(const std::map<string, string> &options); | |||
| void LoadPluginSo(); | |||
| void LoadPluginSo(const std::map<string, string> &options); | |||
| private: | |||
| TBEPluginManager() = default; | |||
| @@ -36,6 +36,7 @@ GE_COMMON_LOCAL_SRC_FILES := \ | |||
| properties_manager.cc \ | |||
| types.cc\ | |||
| model_parser/base.cc \ | |||
| model_parser/graph_parser_util.cc \ | |||
| tbe_kernel_store.cc \ | |||
| op/attr_value_util.cc \ | |||
| op/ge_op_utils.cc \ | |||
| @@ -17,6 +17,7 @@ | |||
| #include "framework/common/helper/model_helper.h" | |||
| #include "common/ge/ge_util.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "framework/common/debug/log.h" | |||
| #include "framework/common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| @@ -89,10 +90,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||
| } | |||
| } | |||
| auto ge_model_weight = ge_model->GetWeight(); | |||
| GELOGI("WEIGHTS_DATA size is %zu , %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | |||
| if (SaveModelPartition(om_file_save_helper, ModelPartitionType::WEIGHTS_DATA, ge_model_weight.GetData(), | |||
| ge_model_weight.GetSize()) != SUCCESS) { | |||
| GELOGW("Add weight partition failed"); // weight is not necessary | |||
| GELOGI("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); | |||
| // weight is not necessary | |||
| if (ge_model_weight.GetSize() > 0) { | |||
| GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, ModelPartitionType::WEIGHTS_DATA, | |||
| ge_model_weight.GetData(), ge_model_weight.GetSize()), | |||
| "Add weight partition failed"); | |||
| } | |||
| TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | |||
| @@ -238,44 +241,48 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(const ge::ModelData &model_data) { | |||
| if (model_data.model_data == nullptr || model_data.model_len == 0) { | |||
| GELOGE(FAILED, "Model_data is nullptr, or model_data_size is 0"); | |||
| return FAILED; | |||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0"); | |||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||
| } | |||
| if (is_assign_model_) { | |||
| GELOGE(FAILED, "Model helper has already loaded!"); | |||
| return FAILED; | |||
| GELOGE(GE_EXEC_LOAD_MODEL_REPEATED, "Model helper has already loaded!"); | |||
| return GE_EXEC_LOAD_MODEL_REPEATED; | |||
| } | |||
| if (ReleaseLocalModelData() != SUCCESS) { | |||
| GELOGE(FAILED, "ReleaseLocalModelData failed."); | |||
| return FAILED; | |||
| GELOGE(INTERNAL_ERROR, "ReleaseLocalModelData failed."); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| Status status = ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_); | |||
| if (ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_) != SUCCESS) { | |||
| GELOGE(FAILED, "Parse model content failed!"); | |||
| return FAILED; | |||
| GELOGE(status, "Parse model content failed!"); | |||
| return status; | |||
| } | |||
| file_header_ = reinterpret_cast<ModelFileHeader *>(model_data.model_data); | |||
| OmFileLoadHelper om_load_helper; | |||
| if (om_load_helper.Init(model_addr_tmp_, model_len_tmp_) != SUCCESS) { | |||
| GELOGE(FAILED, "Om_load_helper init failed"); | |||
| status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); | |||
| if (status != SUCCESS) { | |||
| GELOGE(status, "Om_load_helper init failed"); | |||
| model_addr_tmp_ = nullptr; | |||
| return FAILED; | |||
| return status; | |||
| } | |||
| auto partition_table = reinterpret_cast<ModelPartitionTable *>(model_addr_tmp_); | |||
| if (partition_table->num == kOriginalOmPartitionNum) { | |||
| GELOGE(FAILED, "om model is error,please use executable om model"); | |||
| return FAILED; | |||
| model_addr_tmp_ = nullptr; | |||
| GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "om model is error,please use executable om model"); | |||
| return GE_EXEC_MODEL_PARTITION_NUM_INVALID; | |||
| } | |||
| // Encrypt model need to del temp model/no encrypt model don't need to del model | |||
| model_addr_tmp_ = nullptr; | |||
| if (GenerateGeModel(om_load_helper) != SUCCESS) { | |||
| GELOGE(FAILED, "GenerateGeModel failed"); | |||
| return FAILED; | |||
| status = GenerateGeModel(om_load_helper); | |||
| if (status != SUCCESS) { | |||
| GELOGE(status, "GenerateGeModel failed"); | |||
| return status; | |||
| } | |||
| is_assign_model_ = true; | |||
| @@ -287,19 +294,19 @@ Status ModelHelper::GenerateGeModel(OmFileLoadHelper &om_load_helper) { | |||
| GE_CHECK_NOTNULL(model_); | |||
| Status ret = LoadModelData(om_load_helper); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| return GE_EXEC_LOAD_MODEL_PARTITION_FAILED; | |||
| } | |||
| ret = LoadWeights(om_load_helper); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| return GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED; | |||
| } | |||
| ret = LoadTask(om_load_helper); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| return GE_EXEC_LOAD_TASK_PARTITION_FAILED; | |||
| } | |||
| ret = LoadTBEKernelStore(om_load_helper); | |||
| if (ret != SUCCESS) { | |||
| return ret; | |||
| return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -390,107 +397,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo | |||
| return out_model; | |||
| } | |||
| // Transit func for model to ge_model. It will be removed when load and build support ge_model in future | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransModelToGeModel(const ModelPtr &model, | |||
| GeModelPtr &ge_model) { | |||
| if (model == nullptr) { | |||
| GELOGE(FAILED, "Model is null"); | |||
| return FAILED; | |||
| } | |||
| ge_model = ge::MakeShared<ge::GeModel>(); | |||
| GE_CHECK_NOTNULL(ge_model); | |||
| ge_model->SetGraph(model->GetGraph()); | |||
| ge_model->SetName(model->GetName()); | |||
| ge_model->SetVersion(model->GetVersion()); | |||
| ge_model->SetPlatformVersion(model->GetPlatformVersion()); | |||
| ge_model->SetAttr(model->MutableAttrMap()); | |||
| // Copy weight info | |||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); | |||
| // ge::Buffer weight; | |||
| ge::Buffer weight; | |||
| (void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight); | |||
| ge_model->SetWeight(weight); | |||
| // Copy task info | |||
| if (model->HasAttr(MODEL_ATTR_TASKS)) { | |||
| ge::Buffer task_buffer; | |||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED, | |||
| "Get bytes failed."); | |||
| std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>(); | |||
| GE_CHECK_NOTNULL(task); | |||
| GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED); | |||
| GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED); | |||
| GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast<int>(task_buffer.GetSize()), task.get()), | |||
| return INTERNAL_ERROR, "ReadProtoFromArray failed."); | |||
| ge_model->SetModelTaskDef(task); | |||
| } | |||
| // Copy tbe kernel info | |||
| // TBEKernelStore kernel_store; | |||
| TBEKernelStore kernel_store; | |||
| if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) { | |||
| for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | |||
| auto node_op_desc = n->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | |||
| TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | |||
| GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | |||
| kernel_store.AddTBEKernel(tbe_kernel); | |||
| GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); | |||
| } | |||
| } | |||
| if (!kernel_store.Build()) { | |||
| GELOGE(FAILED, "TBE Kernels store build failed!"); | |||
| return FAILED; | |||
| } | |||
| ge_model->SetTBEKernelStore(kernel_store); | |||
| return SUCCESS; | |||
| } | |||
| // trasit func for ge_model to Model. will be removed when load and build support ge_model in future | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransGeModelToModel(const GeModelPtr &ge_model, | |||
| ModelPtr &model) { | |||
| if (ge_model == nullptr) { | |||
| GELOGE(FAILED, "Ge_model is null"); | |||
| return FAILED; | |||
| } | |||
| model = ge::MakeShared<ge::Model>(); | |||
| GE_CHECK_NOTNULL(model); | |||
| model->SetGraph(ge_model->GetGraph()); | |||
| model->SetName(ge_model->GetName()); | |||
| model->SetVersion(ge_model->GetVersion()); | |||
| model->SetPlatformVersion(ge_model->GetPlatformVersion()); | |||
| model->SetAttr(ge_model->MutableAttrMap()); | |||
| // Copy weight info | |||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); | |||
| bool ret = ge::AttrUtils::SetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, ge_model->GetWeight()); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "Copy weight buffer failed!"); | |||
| return FAILED; | |||
| } | |||
| // Copy task info | |||
| std::shared_ptr<ModelTaskDef> model_task = ge_model->GetModelTaskDefPtr(); | |||
| if (model_task != nullptr) { | |||
| int size = model_task->ByteSize(); | |||
| ge::Buffer buffer(static_cast<size_t>(size)); | |||
| if (buffer.GetSize() == 0) { | |||
| GELOGE(MEMALLOC_FAILED, "alloc model attr task buffer failed!"); | |||
| return MEMALLOC_FAILED; | |||
| } | |||
| // no need to check value | |||
| (void)model_task->SerializePartialToArray(buffer.GetData(), size); | |||
| ret = ge::AttrUtils::SetZeroCopyBytes(model, MODEL_ATTR_TASKS, std::move(buffer)); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "Copy task buffer failed!"); | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ModelHelper::ReleaseLocalModelData() noexcept { | |||
| Status result = SUCCESS; | |||
| if (model_addr_tmp_ != nullptr) { | |||
| @@ -41,8 +41,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(c | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, | |||
| const uint32_t model_data_size) { | |||
| if (LoadModelPartitionTable(model_data, model_data_size) != SUCCESS) { | |||
| return FAILED; | |||
| Status status = LoadModelPartitionTable(model_data, model_data_size); | |||
| if (status != SUCCESS) { | |||
| return status; | |||
| } | |||
| is_inited_ = true; | |||
| return SUCCESS; | |||
| @@ -66,7 +67,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod | |||
| } | |||
| if (!found) { | |||
| if (type != ModelPartitionType::TBE_KERNELS) { | |||
| if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA) { | |||
| GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast<int>(type)); | |||
| return FAILED; | |||
| } | |||
| @@ -83,7 +84,9 @@ Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { | |||
| // Model length too small | |||
| if (model.model_len < (sizeof(ModelFileHeader) + sizeof(ModelPartitionTable))) { | |||
| GELOGE(PARAM_INVALID, "Invalid model. length < sizeof(ModelFileHeader) + sizeof(ModelPartitionTable)."); | |||
| GELOGE(PARAM_INVALID, | |||
| "Invalid model. length[%u] < sizeof(ModelFileHeader)[%zu] + sizeof(ModelPartitionTable)[%zu].", | |||
| model.model_len, sizeof(ModelFileHeader), sizeof(ModelPartitionTable)); | |||
| return PARAM_INVALID; | |||
| } | |||
| @@ -93,9 +96,9 @@ Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { | |||
| if ((model_header->length != model.model_len - sizeof(ModelFileHeader)) || | |||
| (MODEL_FILE_MAGIC_NUM != model_header->magic)) { | |||
| GELOGE(PARAM_INVALID, | |||
| "Invalid model. file_header->length(%u) + sizeof(ModelFileHeader)(%zu) != model->model_len(%u) || " | |||
| "MODEL_FILE_MAGIC_NUM != file_header->magic", | |||
| model_header->length, sizeof(ModelFileHeader), model.model_len); | |||
| "Invalid model. file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != model->model_len[%u] || " | |||
| "MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]", | |||
| model_header->length, sizeof(ModelFileHeader), model.model_len, MODEL_FILE_MAGIC_NUM, model_header->magic); | |||
| return PARAM_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| @@ -112,16 +115,16 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint | |||
| // Original model partition include graph-info | |||
| if ((partition_table->num != PARTITION_SIZE) && (partition_table->num != (PARTITION_SIZE - 1)) && | |||
| (partition_table->num != 1)) { | |||
| GELOGE(PARAM_INVALID, "Invalid partition_table->num:%u", partition_table->num); | |||
| return PARAM_INVALID; | |||
| GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "Invalid partition_table->num:%u", partition_table->num); | |||
| return GE_EXEC_MODEL_PARTITION_NUM_INVALID; | |||
| } | |||
| size_t mem_offset = SIZE_OF_MODEL_PARTITION_TABLE(*partition_table); | |||
| GELOGI("ModelPartitionTable num :%u, ModelFileHeader length :%zu, ModelPartitionTable length :%zu", | |||
| partition_table->num, sizeof(ModelFileHeader), mem_offset); | |||
| if (model_data_size <= mem_offset) { | |||
| GELOGE(PARAM_INVALID, "invalid model data, partition_table->num:%u, model data size %u", partition_table->num, | |||
| model_data_size); | |||
| return PARAM_INVALID; | |||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u", | |||
| partition_table->num, model_data_size); | |||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||
| } | |||
| for (uint32_t i = 0; i < partition_table->num; i++) { | |||
| ModelPartition partition; | |||
| @@ -131,9 +134,9 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint | |||
| context_.partition_datas_.push_back(partition); | |||
| if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { | |||
| GELOGE(PARAM_INVALID, "The partition size %zu is greater than the model data size %u.", | |||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.", | |||
| partition.size + mem_offset, model_data_size); | |||
| return PARAM_INVALID; | |||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||
| } | |||
| mem_offset += partition.size; | |||
| GELOGI("Partition, type:%d, size:%u", static_cast<int>(partition.type), partition.size); | |||
| @@ -92,5 +92,5 @@ fp16_t max(fp16_t fp1, fp16_t fp2); | |||
| /// @brief Calculate the minimum fp16_t of fp1 and fp2 | |||
| /// @return Returns minimum fp16_t of fp1 and fp2 | |||
| fp16_t min(fp16_t fp1, fp16_t fp2); | |||
| }; // namespace ge | |||
| } // namespace ge | |||
| #endif // GE_COMMON_MATH_FP16_MATH_H_ | |||
| @@ -27,7 +27,6 @@ | |||
| #include "mmpa/mmpa_api.h" | |||
| namespace ge { | |||
| /** | |||
| * @ingroup domi_calibration | |||
| * @brief Initializes an input array to a specified value | |||
| @@ -67,7 +66,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // end namespace ge | |||
| #endif // GE_COMMON_MATH_UTIL_H_ | |||
| @@ -35,15 +35,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro | |||
| ge::ModelData &model_data) { | |||
| std::string real_path = RealPath(model_path); | |||
| if (real_path.empty()) { | |||
| GELOGE(PARAM_INVALID, "Model file path '%s' is invalid", model_path); | |||
| return PARAM_INVALID; | |||
| GELOGE(GE_EXEC_MODEL_PATH_INVALID, "Model file path '%s' is invalid", model_path); | |||
| return GE_EXEC_MODEL_PATH_INVALID; | |||
| } | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(model_path) == -1, return FAILED, "File size not valid."); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(model_path) == -1, return GE_EXEC_READ_MODEL_FILE_FAILED, | |||
| "File size not valid."); | |||
| std::ifstream fs(real_path.c_str(), std::ifstream::binary); | |||
| GE_CHK_BOOL_RET_STATUS(fs.is_open(), FAILED, "Open file failed! path:%s", model_path); | |||
| GE_CHK_BOOL_RET_STATUS(fs.is_open(), GE_EXEC_READ_MODEL_FILE_FAILED, "Open file failed! path:%s", model_path); | |||
| // get length of file: | |||
| (void)fs.seekg(0, std::ifstream::end); | |||
| @@ -55,7 +56,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro | |||
| char *data = new (std::nothrow) char[len]; | |||
| if (data == nullptr) { | |||
| GELOGE(MEMALLOC_FAILED, "Load model From file failed, bad memory allocation occur. (need:%ld)", len); | |||
| GELOGE(MEMALLOC_FAILED, "Load model From file failed, bad memory allocation occur. (need:%u)", len); | |||
| return MEMALLOC_FAILED; | |||
| } | |||
| @@ -79,31 +80,33 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::ParseMo | |||
| GE_CHECK_NOTNULL(model.model_data); | |||
| // Model length too small | |||
| GE_CHK_BOOL_RET_STATUS(model.model_len >= sizeof(ModelFileHeader), PARAM_INVALID, | |||
| "Invalid model. length < sizeof(ModelFileHeader)."); | |||
| GE_CHK_BOOL_RET_STATUS(model.model_len >= sizeof(ModelFileHeader), GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||
| "Invalid model. Model data size %u must be greater than or equal to %zu.", model.model_len, | |||
| sizeof(ModelFileHeader)); | |||
| // Get file header | |||
| auto file_header = reinterpret_cast<ModelFileHeader *>(model.model_data); | |||
| // Determine whether the file length and magic number match | |||
| GE_CHK_BOOL_RET_STATUS( | |||
| file_header->length == model.model_len - sizeof(ModelFileHeader) && file_header->magic == MODEL_FILE_MAGIC_NUM, | |||
| PARAM_INVALID, | |||
| "Invalid model. file_header->length + sizeof(ModelFileHeader) != model->model_len || MODEL_FILE_MAGIC_NUM != " | |||
| "file_header->magic"); | |||
| GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||
| "Invalid model. file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != model->model_len[%u] || " | |||
| "MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]", | |||
| file_header->length, sizeof(ModelFileHeader), model.model_len, MODEL_FILE_MAGIC_NUM, file_header->magic); | |||
| Status res = SUCCESS; | |||
| // Get data address | |||
| uint8_t *data = reinterpret_cast<uint8_t *>(model.model_data) + sizeof(ModelFileHeader); | |||
| if (file_header->is_encrypt == ModelEncryptType::UNENCRYPTED) { // Unencrypted model | |||
| GE_CHK_BOOL_RET_STATUS(model.key.empty(), PARAM_INVALID, | |||
| GE_CHK_BOOL_RET_STATUS(model.key.empty(), GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, | |||
| "Invalid param. model is unencrypted, but key is not empty."); | |||
| model_data = data; | |||
| model_len = file_header->length; | |||
| GELOGI("Model_len is %u, model_file_head_len is %zu.", model_len, sizeof(ModelFileHeader)); | |||
| } else { | |||
| GELOGE(PARAM_INVALID, "Invalid model. ModelEncryptType not supported."); | |||
| res = PARAM_INVALID; | |||
| GELOGE(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, "Invalid model. ModelEncryptType not supported."); | |||
| res = GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION; | |||
| } | |||
| return res; | |||
| @@ -0,0 +1,501 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph_parser_util.h" | |||
| #include <memory> | |||
| #include "common/auth/file_saver.h" | |||
| #include "common/convert/pb2json.h" | |||
| #include "common/debug/log.h" | |||
| #include "common/debug/memory_dumper.h" | |||
| #include "common/model_parser/base.h" | |||
| #include "common/model_saver.h" | |||
| #include "common/properties_manager.h" | |||
| #include "common/string_util.h" | |||
| #include "common/types.h" | |||
| #include "common/util.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "external/register/register_types.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "framework/omg/parser/parser_inner_ctx.h" | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/optimize/common/params.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "omg/omg_inner_types.h" | |||
| #include "omg/parser/model_parser.h" | |||
| #include "omg/parser/parser_factory.h" | |||
| #include "omg/parser/weights_parser.h" | |||
| #include "parser/common/pre_checker.h" | |||
| #include "proto/ge_ir.pb.h" | |||
| #include "register/op_registry.h" | |||
| namespace ge { | |||
| namespace { | |||
| // The function is incomplete. Currently, only l2_optimize, off_optimize is supported. | |||
| const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; | |||
| const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; | |||
| const char *const kSplitError1 = "size not equal to 2 split by \":\""; | |||
| const char *const kEmptyError = "can not be empty"; | |||
| const char *const kFloatNumError = "exist float number"; | |||
| const char *const kDigitError = "is not digit"; | |||
| const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; | |||
| const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; | |||
| const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; | |||
| vector<string> SplitInputShape(const std::string &input_shape) { | |||
| vector<string> shape_pair_vec; | |||
| size_t pos = input_shape.rfind(":"); | |||
| if (pos != std::string::npos) { | |||
| shape_pair_vec.emplace_back(input_shape.substr(0, pos)); | |||
| shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); | |||
| } | |||
| return shape_pair_vec; | |||
| } | |||
| static std::map<std::string, ge::DataType> output_type_str_to_datatype = { | |||
| {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; | |||
| static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) { | |||
| if ((s == "true") || (s == "false")) { | |||
| return true; | |||
| } else { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s}); | |||
| GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| bool CheckDigitStr(std::string &str) { | |||
| for (char c : str) { | |||
| if (!isdigit(c)) { | |||
| GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| return true; | |||
| } | |||
| Status StringToInt(std::string &str, int32_t &value) { | |||
| try { | |||
| if (!CheckDigitStr(str)) { | |||
| GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"--output_type", str, "is not positive integer"}); | |||
| return PARAM_INVALID; | |||
| } | |||
| value = stoi(str); | |||
| } catch (std::invalid_argument &) { | |||
| GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str}); | |||
| return PARAM_INVALID; | |||
| } catch (std::out_of_range &) { | |||
| GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str}); | |||
| return PARAM_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status VerifyOutputTypeAndOutNodes(std::vector<std::string> &out_type_vec) { | |||
| std::vector<std::pair<std::string, int32_t>> user_out_nodes = domi::GetContext().user_out_nodes; | |||
| std::set<std::string> out_nodes_info; | |||
| for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { | |||
| // out_nodes set should include output_type and output_format | |||
| std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); | |||
| out_nodes_info.emplace(tmp); | |||
| } | |||
| for (uint32_t i = 0; i < out_type_vec.size(); ++i) { | |||
| if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"--output_type", out_type_vec[i], kOutputTypeError}); | |||
| GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); | |||
| return domi::FAILED; | |||
| } | |||
| } | |||
| return domi::SUCCESS; | |||
| } | |||
| Status ParseOutputType(const std::string &output_type, std::map<std::string, vector<uint32_t>> &out_type_index_map, | |||
| std::map<std::string, vector<ge::DataType>> &out_type_dt_map) { | |||
| if (output_type.find(':') == std::string::npos) { | |||
| GELOGI("output_type is not multiple nodes, means all out nodes"); | |||
| auto it = output_type_str_to_datatype.find(output_type); | |||
| if (it == output_type_str_to_datatype.end()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"--output_type", output_type, kOutputTypeSupport}); | |||
| GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); | |||
| return domi::FAILED; | |||
| } | |||
| return domi::SUCCESS; | |||
| } | |||
| std::vector<std::string> out_type_vec; | |||
| vector<string> nodes_v = StringUtils::Split(output_type, ';'); | |||
| for (const string &node : nodes_v) { | |||
| vector<string> node_index_type_v = StringUtils::Split(node, ':'); | |||
| if (node_index_type_v.size() != 3) { // The size must be 3. | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"--output_type", node, kOutputTypeSample}); | |||
| GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample); | |||
| return domi::FAILED; | |||
| } | |||
| ge::DataType tmp_dt; | |||
| std::string node_name = StringUtils::Trim(node_index_type_v[0]); | |||
| std::string index_str = StringUtils::Trim(node_index_type_v[1]); | |||
| int32_t index; | |||
| if (StringToInt(index_str, index) != SUCCESS) { | |||
| GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); | |||
| return domi::FAILED; | |||
| } | |||
| std::string dt_value = StringUtils::Trim(node_index_type_v[2]); | |||
| auto it = output_type_str_to_datatype.find(dt_value); | |||
| if (it == output_type_str_to_datatype.end()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"--output_type", dt_value, kOutputTypeSupport}); | |||
| GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport); | |||
| return domi::FAILED; | |||
| } else { | |||
| tmp_dt = it->second; | |||
| } | |||
| out_type_vec.push_back(node_name + ":" + index_str); | |||
| auto it_index = out_type_index_map.find(node_name); | |||
| if (it_index == out_type_index_map.end()) { | |||
| vector<uint32_t> tmp_vec; | |||
| tmp_vec.push_back(index); | |||
| out_type_index_map.emplace(node_name, tmp_vec); | |||
| } else { | |||
| it_index->second.push_back(index); | |||
| } | |||
| auto it_dt = out_type_dt_map.find(node_name); | |||
| if (it_dt == out_type_dt_map.end()) { | |||
| vector<ge::DataType> tmp_vec; | |||
| tmp_vec.push_back(tmp_dt); | |||
| out_type_dt_map.emplace(node_name, tmp_vec); | |||
| } else { | |||
| it_dt->second.push_back(tmp_dt); | |||
| } | |||
| } | |||
| return VerifyOutputTypeAndOutNodes(out_type_vec); | |||
| } | |||
| Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { | |||
| int32_t out_size = op_desc->GetOutputsSize(); | |||
| if (index < 0 || index >= out_size) { | |||
| GELOGE(domi::FAILED, | |||
| "out_node [%s] output index:%d must be smaller " | |||
| "than node output size:%d and can not be negative!", | |||
| op_desc->GetName().c_str(), index, out_size); | |||
| std::string fail_reason = "output index:" + to_string(index) + | |||
| " must be smaller than output size:" + to_string(out_size) + " and can not be negative!"; | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, | |||
| {"out_nodes", op_desc->GetName(), fail_reason}); | |||
| return domi::FAILED; | |||
| } | |||
| return domi::SUCCESS; | |||
| } | |||
| Status GetOutputLeaf(NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info) { | |||
| ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); | |||
| if (tmpDescPtr == nullptr) { | |||
| GELOGE(domi::FAILED, "Get outnode op desc fail."); | |||
| return domi::FAILED; | |||
| } | |||
| size_t size = tmpDescPtr->GetOutputsSize(); | |||
| if (node->GetType() != NETOUTPUT) { | |||
| for (size_t index = 0; index < size; ++index) { | |||
| output_nodes_info.push_back(std::make_pair(node, index)); | |||
| } | |||
| } else { | |||
| const auto in_anchors = node->GetAllInDataAnchors(); | |||
| for (auto in_anchor : in_anchors) { | |||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||
| if (out_anchor == nullptr) { | |||
| GELOGE(domi::FAILED, "Get leaf node op desc fail."); | |||
| return domi::FAILED; | |||
| } | |||
| auto out_node = out_anchor->GetOwnerNode(); | |||
| output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name) { | |||
| output_nodes_name.clear(); | |||
| if (domi::GetContext().out_top_names.empty()) { | |||
| // tf process, no top name. | |||
| for (const auto output_node_info : output_nodes_info) { | |||
| std::string node_name = output_node_info.first->GetName(); | |||
| int32_t index = output_node_info.second; | |||
| output_nodes_name.push_back(node_name + ":" + std::to_string(index)); | |||
| } | |||
| return; | |||
| } | |||
| // caffe process, need add top name after node_name:index | |||
| for (size_t i = 0; i < output_nodes_info.size(); ++i) { | |||
| std::string node_name = output_nodes_info[i].first->GetName(); | |||
| int32_t index = output_nodes_info[i].second; | |||
| if (i < domi::GetContext().out_top_names.size()) { | |||
| output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + domi::GetContext().out_top_names[i]); | |||
| } else { | |||
| GELOGW("Get top name of node [%s] fail.", node_name.c_str()); | |||
| output_nodes_name.push_back(node_name + ":" + std::to_string(index)); | |||
| } | |||
| } | |||
| } | |||
| } // namespace | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { | |||
| if (is_output_fp16.empty()) { | |||
| return SUCCESS; | |||
| } | |||
| vector<domiTensorFormat_t> &output_formats = domi::GetContext().output_formats; | |||
| output_formats.clear(); | |||
| vector<string> node_format_vec = StringUtils::Split(is_output_fp16, ','); | |||
| for (auto &is_fp16 : node_format_vec) { | |||
| StringUtils::Trim(is_fp16); | |||
| if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) { | |||
| GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]", | |||
| is_output_fp16.c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| if (is_fp16 == "false") { | |||
| output_formats.push_back(DOMI_TENSOR_ND); | |||
| } else if (is_fp16 == "true") { | |||
| output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, | |||
| const std::string &output_type, | |||
| const std::string &output) { | |||
| ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| std::vector<std::pair<std::string, int32_t>> user_out_nodes = domi::GetContext().user_out_nodes; | |||
| std::vector<domiTensorFormat_t> output_formats = domi::GetContext().output_formats; | |||
| std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_info; | |||
| std::vector<std::string> output_nodes_name; | |||
| std::map<std::string, vector<uint32_t>> out_type_index_map; | |||
| std::map<std::string, vector<ge::DataType>> out_type_dt_map; | |||
| if (!output_type.empty()) { | |||
| if (ParseOutputType(output_type, out_type_index_map, out_type_dt_map) != SUCCESS) { | |||
| GELOGE(domi::FAILED, "Parse output_type failed."); | |||
| return domi::FAILED; | |||
| } | |||
| } | |||
| // User declared outputs | |||
| for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { | |||
| ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first); | |||
| if (out_node == nullptr) { | |||
| GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str()); | |||
| return domi::FAILED; | |||
| } | |||
| auto op_desc = out_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { | |||
| GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); | |||
| return domi::FAILED; | |||
| } | |||
| if (i < output_formats.size()) { | |||
| if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { | |||
| GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); | |||
| if (!ge::AttrUtils::SetBool(op_desc, "output_set_fp16_nc1hwc0", true)) { | |||
| GELOGW("The output node [%s] set NC1HWC0 failed", user_out_nodes[i].first.c_str()); | |||
| } | |||
| } | |||
| } | |||
| auto it_index = out_type_index_map.find(user_out_nodes[i].first); | |||
| auto it_dt = out_type_dt_map.find(user_out_nodes[i].first); | |||
| if ((it_index != out_type_index_map.end()) && (it_dt != out_type_dt_map.end())) { | |||
| GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); | |||
| (void)ge::AttrUtils::SetListDataType(op_desc, "_output_dt_list", it_dt->second); | |||
| (void)ge::AttrUtils::SetListInt(op_desc, "_output_dt_index", it_index->second); | |||
| } | |||
| output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); | |||
| } | |||
| // default output node (leaf) | |||
| if (user_out_nodes.empty()) { | |||
| for (ge::NodePtr node : compute_graph->GetDirectNode()) { | |||
| if (!node->GetInDataNodes().empty() && node->GetOutDataNodes().empty()) { | |||
| Status ret = GetOutputLeaf(node, output_nodes_info); | |||
| GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail."); | |||
| } | |||
| } | |||
| } | |||
| GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); | |||
| compute_graph->SetGraphOutNodesInfo(output_nodes_info); | |||
| domi::GetContext().net_out_nodes = output_nodes_name; | |||
| return domi::SUCCESS; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ParseInputShape( | |||
| const string &input_shape, unordered_map<string, vector<int64_t>> &shape_map, | |||
| vector<pair<string, vector<int64_t>>> &user_shape_map, bool is_dynamic_input) { | |||
| vector<string> shape_vec = StringUtils::Split(input_shape, ';'); | |||
| const int DEFAULT_SHAPE_PAIR_SIZE = 2; | |||
| for (const auto &shape : shape_vec) { | |||
| vector<string> shape_pair_vec = SplitInputShape(shape); | |||
| if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, | |||
| {shape, kSplitError1, kInputShapeSample1}); | |||
| GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", | |||
| shape.c_str(), kSplitError1, kInputShapeSample1); | |||
| return false; | |||
| } | |||
| if (shape_pair_vec[1].empty()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, | |||
| {shape, kEmptyError, kInputShapeSample1}); | |||
| GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", | |||
| shape.c_str(), kEmptyError, kInputShapeSample1); | |||
| return false; | |||
| } | |||
| vector<string> shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); | |||
| vector<int64_t> shape_values; | |||
| for (auto &shape_value_str : shape_value_strs) { | |||
| // stoul: The method may throw an exception: invalid_argument/out_of_range | |||
| if (std::string::npos != shape_value_str.find('.')) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, | |||
| {shape, kFloatNumError, kInputShapeSample2}); | |||
| GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", | |||
| shape.c_str(), kFloatNumError, kInputShapeSample2); | |||
| return false; | |||
| } | |||
| long left_result = 0; | |||
| try { | |||
| left_result = stol(StringUtils::Trim(shape_value_str)); | |||
| if (!shape_value_str.empty() && (shape_value_str.front() == '-')) { | |||
| // The value maybe dynamic shape [-1], need substr it and verify isdigit. | |||
| shape_value_str = shape_value_str.substr(1); | |||
| } | |||
| for (char c : shape_value_str) { | |||
| if (!isdigit(c)) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, | |||
| {shape, kDigitError, kInputShapeSample2}); | |||
| GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| } catch (const std::out_of_range &) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, | |||
| {"input_shape", shape_value_str}); | |||
| GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str()); | |||
| return false; | |||
| } catch (const std::invalid_argument &) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, | |||
| {"input_shape", shape_value_str}); | |||
| GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str()); | |||
| return false; | |||
| } catch (...) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, | |||
| {"input_shape", shape_value_str}); | |||
| GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str()); | |||
| return false; | |||
| } | |||
| int64_t result = left_result; | |||
| // - 1 is not currently supported | |||
| if (!is_dynamic_input && result <= 0) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)}); | |||
| GELOGW( | |||
| "Input parameter[--input_shape]’s shape value[%s] is invalid, " | |||
| "expect positive integer, but value is %ld.", | |||
| shape.c_str(), result); | |||
| return false; | |||
| } | |||
| shape_values.push_back(result); | |||
| } | |||
| shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); | |||
| user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); | |||
| } | |||
| return true; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputNodes(const string &out_nodes) { | |||
| try { | |||
| // parse output node | |||
| if (!out_nodes.empty()) { | |||
| domi::GetContext().out_nodes_map.clear(); | |||
| domi::GetContext().user_out_nodes.clear(); | |||
| vector<string> nodes_v = StringUtils::Split(out_nodes, ';'); | |||
| for (const string &node : nodes_v) { | |||
| vector<string> key_value_v = StringUtils::Split(node, ':'); | |||
| if (key_value_v.size() != 2) { // The size must be 2. | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E10001", {"parameter", "value", "reason"}, | |||
| {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); | |||
| GELOGE(PARAM_INVALID, | |||
| "The input format of --out_nodes is invalid, the correct format is " | |||
| "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", | |||
| node.c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); | |||
| // stoi: The method may throw an exception: invalid_argument/out_of_range | |||
| if (!CheckDigitStr(key_value_v[1])) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {"--out_nodes", out_nodes, "is not positive integer"}); | |||
| GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); | |||
| return PARAM_INVALID; | |||
| } | |||
| int32_t index = stoi(StringUtils::Trim(key_value_v[1])); | |||
| if (iter != domi::GetContext().out_nodes_map.end()) { | |||
| iter->second.emplace_back(index); | |||
| } else { | |||
| std::vector<int32_t> index_v; | |||
| index_v.emplace_back(index); | |||
| domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v); | |||
| } | |||
| domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); | |||
| } | |||
| } | |||
| } catch (std::invalid_argument &) { | |||
| GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes}); | |||
| return PARAM_INVALID; | |||
| } catch (std::out_of_range &) { | |||
| GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes}); | |||
| return PARAM_INVALID; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOpConf(const char *op_conf) { | |||
| if (op_conf != nullptr && *op_conf != '\0') { | |||
| // divided by ":" | |||
| PropertiesManager::Instance().SetPropertyDelimiter(OP_CONF_DELIMITER); | |||
| // Parsing the op_conf configuration item file | |||
| if (!PropertiesManager::Instance().Init(op_conf)) { | |||
| GELOGE(FAILED, "op_name_map init failed!"); | |||
| return FAILED; | |||
| } | |||
| // Return map and put it into ATC global variable | |||
| domi::GetContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap(); | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||