Merge pull request !45 from yanghaoran/mastertags/v0.6.0-beta
| @@ -52,5 +52,16 @@ struct GETaskInfo { | |||
| std::vector<GETaskKernelHcclInfo> kernelHcclInfo; | |||
| }; | |||
| struct HcomOpertion { | |||
| std::string hcclType; | |||
| void *inputPtr; | |||
| void *outputPtr; | |||
| uint64_t count; | |||
| int32_t dataType; | |||
| int32_t opType; | |||
| int32_t root; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ | |||
| @@ -28,6 +28,7 @@ struct CompressConfig { | |||
| size_t channel; // channels of L2 or DDR. For load balance | |||
| size_t fractalSize; // size of compressing block | |||
| bool isTight; // whether compose compressed data tightly | |||
| size_t init_offset; | |||
| }; | |||
| CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef COMPRESS_WEIGHT_H | |||
| #define COMPRESS_WEIGHT_H | |||
| #include "compress.h" | |||
| const int SHAPE_SIZE_WEIGHT = 4; | |||
| struct CompressOpConfig { | |||
| int64_t wShape[SHAPE_SIZE_WEIGHT]; | |||
| size_t compressTilingK; | |||
| size_t compressTilingN; | |||
| struct CompressConfig compressConfig; | |||
| }; | |||
| extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer, | |||
| CompressOpConfig *const param); | |||
| #endif // COMPRESS_WEIGHT_H | |||
| @@ -27,7 +27,6 @@ using std::string; | |||
| using std::vector; | |||
| namespace fe { | |||
| class PlatformInfoManager { | |||
| public: | |||
| PlatformInfoManager(const PlatformInfoManager &) = delete; | |||
| @@ -39,6 +38,8 @@ class PlatformInfoManager { | |||
| uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||
| uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||
| void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); | |||
| private: | |||
| @@ -94,6 +95,5 @@ class PlatformInfoManager { | |||
| map<string, PlatformInfo> platformInfoMap_; | |||
| OptionalInfo optiCompilationInfo_; | |||
| }; | |||
| } // namespace fe | |||
| #endif | |||
| @@ -44,8 +44,12 @@ const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | |||
| const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | |||
| const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | |||
| const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; | |||
| const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; | |||
| const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; | |||
| const char *const OPTION_EXEC_OP_DEBUG_LEVEL = "ge.exec.opDebugLevel"; | |||
| const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||
| const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||
| const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; | |||
| // profiling flag | |||
| const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; | |||
| const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | |||
| @@ -219,6 +223,10 @@ const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | |||
| // Configure input fp16 nodes | |||
| const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||
| // Configure debug level, its value should be 0(default), 1 or 2. | |||
| // 0: close debug; 1: open TBE compiler; 2: open ccec compiler | |||
| const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; | |||
| // Graph run mode | |||
| enum GraphRunMode { PREDICTION = 0, TRAIN }; | |||
| @@ -145,7 +145,8 @@ enum Format { | |||
| FORMAT_FRACTAL_ZN_LSTM, | |||
| FORMAT_FRACTAL_Z_G, | |||
| FORMAT_RESERVED, | |||
| FORMAT_ALL | |||
| FORMAT_ALL, | |||
| FORMAT_NULL | |||
| }; | |||
| // for unknown shape op type | |||
| @@ -98,6 +98,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||
| OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); | |||
| OpRegistrationData &InputReorderVector(const vector<int> &input_order); | |||
| domi::ImplyType GetImplyType() const; | |||
| std::string GetOmOptype() const; | |||
| std::set<std::string> GetOriginOpTypeSet() const; | |||
| @@ -51,30 +51,6 @@ inline pid_t GetTid() { | |||
| return tid; | |||
| } | |||
| #define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() | |||
| #define GE_TIMESTAMP_END(stage, stage_name) \ | |||
| do { \ | |||
| uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ | |||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ | |||
| (endUsec_##stage - startUsec_##stage)); \ | |||
| } while (0); | |||
| #define GE_TIMESTAMP_CALLNUM_START(stage) \ | |||
| uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ | |||
| uint64_t call_num_of##stage = 0; \ | |||
| uint64_t time_of##stage = 0 | |||
| #define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) | |||
| #define GE_TIMESTAMP_ADD(stage) \ | |||
| time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ | |||
| call_num_of##stage++ | |||
| #define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ | |||
| GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ | |||
| call_num_of##stage) | |||
| #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ | |||
| dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ | |||
| ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) | |||
| @@ -19,15 +19,12 @@ | |||
| #include <string> | |||
| #include "cce/cce_def.hpp" | |||
| #include "runtime/rt.h" | |||
| #include "common/string_util.h" | |||
| #include "common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "ge/ge_api_error_codes.h" | |||
| using cce::CC_STATUS_SUCCESS; | |||
| using cce::ccStatus_t; | |||
| #if !defined(__ANDROID__) && !defined(ANDROID) | |||
| #define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) | |||
| #else | |||
| @@ -102,17 +99,13 @@ using cce::ccStatus_t; | |||
| } while (0); | |||
| // If expr is not true, print the log and return the specified status | |||
| #define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| std::string msg; \ | |||
| (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ | |||
| (void)msg.append( \ | |||
| ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ | |||
| DOMI_LOGE("%s", msg.c_str()); \ | |||
| return _status; \ | |||
| } \ | |||
| #define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ | |||
| do { \ | |||
| bool b = (expr); \ | |||
| if (!b) { \ | |||
| GELOGE(_status, __VA_ARGS__); \ | |||
| return _status; \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is not true, print the log and return the specified status | |||
| @@ -132,7 +125,7 @@ using cce::ccStatus_t; | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, print the log and execute a custom statement | |||
| #define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \ | |||
| @@ -142,7 +135,7 @@ using cce::ccStatus_t; | |||
| GELOGW(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, print the log and execute a custom statement | |||
| #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ | |||
| { \ | |||
| @@ -151,7 +144,7 @@ using cce::ccStatus_t; | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, print the log and execute a custom statement | |||
| #define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \ | |||
| @@ -161,7 +154,7 @@ using cce::ccStatus_t; | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is true, print logs and execute custom statements | |||
| #define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ | |||
| @@ -171,7 +164,7 @@ using cce::ccStatus_t; | |||
| DOMI_LOGE(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is true, print the Information log and execute a custom statement | |||
| #define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \ | |||
| { \ | |||
| @@ -180,7 +173,7 @@ using cce::ccStatus_t; | |||
| GELOGI(__VA_ARGS__); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not SUCCESS, print the log and execute the expression + return | |||
| #define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \ | |||
| @@ -191,7 +184,7 @@ using cce::ccStatus_t; | |||
| exec_expr; \ | |||
| return; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not SUCCESS, print the log and execute the expression + return _status | |||
| #define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ | |||
| @@ -202,7 +195,7 @@ using cce::ccStatus_t; | |||
| exec_expr; \ | |||
| return _status; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not true, execute a custom statement | |||
| #define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ | |||
| @@ -211,7 +204,7 @@ using cce::ccStatus_t; | |||
| if (!b) { \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // -----------------runtime related macro definitions------------------------------- | |||
| // If expr is not RT_ERROR_NONE, print the log | |||
| @@ -231,7 +224,7 @@ using cce::ccStatus_t; | |||
| DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If expr is not RT_ERROR_NONE, print the log and return | |||
| #define GE_CHK_RT_RET(expr) \ | |||
| @@ -243,23 +236,13 @@ using cce::ccStatus_t; | |||
| } \ | |||
| } while (0); | |||
| // ------------------------cce related macro definitions---------------------------- | |||
| // If expr is not CC_STATUS_SUCCESS, print the log | |||
| #define GE_CHK_CCE(expr) \ | |||
| do { \ | |||
| ccStatus_t _cc_ret = (expr); \ | |||
| if (_cc_ret != CC_STATUS_SUCCESS) { \ | |||
| DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ | |||
| } \ | |||
| } while (0); | |||
| // If expr is true, execute exec_expr without printing logs | |||
| #define GE_IF_BOOL_EXEC(expr, exec_expr) \ | |||
| { \ | |||
| if (expr) { \ | |||
| exec_expr; \ | |||
| } \ | |||
| }; | |||
| } | |||
| // If make_shared is abnormal, print the log and execute the statement | |||
| #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ | |||
| @@ -54,9 +54,9 @@ const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; | |||
| struct DataBuffer { | |||
| public: | |||
| void *data; // Data address | |||
| uint32_t length; // Data length | |||
| uint64_t length; // Data length | |||
| bool isDataSupportMemShare = false; | |||
| DataBuffer(void *dataIn, uint32_t len, bool isSupportMemShare) | |||
| DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare) | |||
| : data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {} | |||
| DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {} | |||
| @@ -106,7 +106,7 @@ struct ShapeDescription { | |||
| // Definition of input and output description information | |||
| struct InputOutputDescInfo { | |||
| std::string name; | |||
| uint32_t size; | |||
| uint64_t size; | |||
| uint32_t data_type; | |||
| ShapeDescription shape_info; | |||
| }; | |||
| @@ -231,6 +231,7 @@ struct Options { | |||
| // Profiling info of task | |||
| struct TaskDescInfo { | |||
| std::string model_name; | |||
| std::string op_name; | |||
| uint32_t block_dim; | |||
| uint32_t task_id; | |||
| @@ -239,6 +240,7 @@ struct TaskDescInfo { | |||
| // Profiling info of graph | |||
| struct ComputeGraphDescInfo { | |||
| std::string model_name; | |||
| std::string op_name; | |||
| std::string op_type; | |||
| std::vector<Format> input_format; | |||
| @@ -44,8 +44,6 @@ class ModelHelper { | |||
| void SetSaveMode(bool val) { is_offline_ = val; } | |||
| bool GetSaveMode(void) const { return is_offline_; } | |||
| static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); | |||
| static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); | |||
| Status GetBaseNameFromFileName(const std::string& file_name, std::string& base_name); | |||
| Status GetModelNameFromMergedGraphName(const std::string& graph_name, std::string& model_name); | |||
| @@ -48,6 +48,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL; | |||
| // Supported public properties name | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time | |||
| @@ -335,6 +338,7 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); | |||
| REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); | |||
| REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); | |||
| REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") | |||
| REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); | |||
| // ANN dedicated operator | |||
| REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); | |||
| @@ -631,6 +635,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_N | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG; | |||
| // convolution node type | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION; | |||
| // adds a convolutional node name for the hard AIPP | |||
| @@ -21,12 +21,12 @@ | |||
| #include <string> | |||
| #include <vector> | |||
| #include "common/dynamic_aipp.h" | |||
| #include "common/ge_inner_error_codes.h" | |||
| #include "common/ge_types.h" | |||
| #include "common/types.h" | |||
| #include "graph/tensor.h" | |||
| #include "runtime/base.h" | |||
| #include "common/dynamic_aipp.h" | |||
| namespace ge { | |||
| class ModelListenerAdapter; | |||
| @@ -62,7 +62,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||
| // Get input and output descriptor | |||
| ge::Status GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | |||
| std::vector<ge::TensorDesc> &output_desc); | |||
| std::vector<ge::TensorDesc> &output_desc, bool new_model_desc = false); | |||
| /// | |||
| /// @ingroup ge | |||
| @@ -28,16 +28,21 @@ | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class RuntimeModel; | |||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||
| class ModelRunner { | |||
| public: | |||
| static ModelRunner &Instance(); | |||
| bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||
| std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); | |||
| bool LoadModelComplete(uint32_t model_id); | |||
| const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | |||
| const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const; | |||
| bool UnloadModel(uint32_t model_id); | |||
| bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | |||
| @@ -21,6 +21,7 @@ | |||
| #include <functional> | |||
| #include <memory> | |||
| #include <string> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include "cce/taskdown_api.h" | |||
| @@ -52,21 +53,27 @@ class TaskInfo { | |||
| virtual ~TaskInfo() {} | |||
| uint32_t stream_id() const { return stream_id_; } | |||
| TaskInfoType type() const { return type_; } | |||
| std::string op_name() const { return op_name_; } | |||
| bool dump_flag() const { return dump_flag_; } | |||
| protected: | |||
| TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} | |||
| TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) | |||
| : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} | |||
| private: | |||
| std::string op_name_; | |||
| uint32_t stream_id_; | |||
| TaskInfoType type_; | |||
| bool dump_flag_; | |||
| }; | |||
| class CceTaskInfo : public TaskInfo { | |||
| public: | |||
| CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, | |||
| const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, | |||
| const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||
| : TaskInfo(stream_id, TaskInfoType::CCE), | |||
| CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, | |||
| uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size, | |||
| const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table, | |||
| const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), | |||
| ctx_(ctx), | |||
| stub_func_(stub_func), | |||
| block_dim_(block_dim), | |||
| @@ -102,11 +109,11 @@ class CceTaskInfo : public TaskInfo { | |||
| class TbeTaskInfo : public TaskInfo { | |||
| public: | |||
| TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args, | |||
| uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size, | |||
| const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs) | |||
| : TaskInfo(stream_id, TaskInfoType::TBE), | |||
| TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, | |||
| const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, | |||
| uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), | |||
| stub_func_(stub_func), | |||
| block_dim_(block_dim), | |||
| args_(args), | |||
| @@ -153,9 +160,10 @@ class TbeTaskInfo : public TaskInfo { | |||
| class AicpuTaskInfo : public TaskInfo { | |||
| public: | |||
| AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, | |||
| const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs) | |||
| : TaskInfo(stream_id, TaskInfoType::AICPU), | |||
| AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, | |||
| const std::string &node_def, const std::vector<void *> &input_data_addrs, | |||
| const std::vector<void *> &output_data_addrs, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), | |||
| so_name_(so_name), | |||
| kernel_name_(kernel_name), | |||
| node_def_(node_def), | |||
| @@ -177,37 +185,45 @@ class AicpuTaskInfo : public TaskInfo { | |||
| std::vector<void *> output_data_addrs_; | |||
| }; | |||
| class LabelTaskInfo : public TaskInfo { | |||
| class LabelSetTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| protected: | |||
| LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||
| : TaskInfo(stream_id, type), label_id_(label_id) {} | |||
| virtual ~LabelTaskInfo() override {} | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSetTaskInfo : public LabelTaskInfo { | |||
| class LabelGotoTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||
| ~LabelSetTaskInfo() override {} | |||
| LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| uint32_t label_id() const { return label_id_; } | |||
| private: | |||
| uint32_t label_id_; | |||
| }; | |||
| class LabelSwitchTaskInfo : public LabelTaskInfo { | |||
| class LabelSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||
| LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, | |||
| const std::vector<uint32_t> &label_list, void *cond) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), | |||
| label_size_(label_size), | |||
| label_list_(label_list), | |||
| cond_(cond) {} | |||
| ~LabelSwitchTaskInfo() override {} | |||
| }; | |||
| uint32_t label_size() { return label_size_; }; | |||
| const std::vector<uint32_t> &label_list() { return label_list_; }; | |||
| void *cond() { return cond_; }; | |||
| class LabelGotoTaskInfo : public LabelTaskInfo { | |||
| public: | |||
| LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||
| : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||
| ~LabelGotoTaskInfo() override {} | |||
| private: | |||
| uint32_t label_size_; | |||
| std::vector<uint32_t> label_list_; | |||
| void *cond_; | |||
| }; | |||
| class EventTaskInfo : public TaskInfo { | |||
| @@ -215,8 +231,8 @@ class EventTaskInfo : public TaskInfo { | |||
| uint32_t event_id() const { return event_id_; } | |||
| protected: | |||
| EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||
| : TaskInfo(stream_id, type), event_id_(event_id) {} | |||
| EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||
| : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} | |||
| virtual ~EventTaskInfo() override {} | |||
| uint32_t event_id_; | |||
| @@ -224,39 +240,41 @@ class EventTaskInfo : public TaskInfo { | |||
| class EventRecordTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||
| EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||
| ~EventRecordTaskInfo() override {} | |||
| }; | |||
| class EventWaitTaskInfo : public EventTaskInfo { | |||
| public: | |||
| EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||
| EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) | |||
| : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||
| ~EventWaitTaskInfo() override {} | |||
| }; | |||
| class FusionStartTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} | |||
| explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} | |||
| ~FusionStartTaskInfo() override {} | |||
| }; | |||
| class FusionEndTaskInfo : public TaskInfo { | |||
| public: | |||
| explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} | |||
| explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} | |||
| ~FusionEndTaskInfo() override {} | |||
| }; | |||
| class HcclTaskInfo : public TaskInfo { | |||
| public: | |||
| HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, | |||
| void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||
| HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, | |||
| void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||
| const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | |||
| int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model, | |||
| std::function<bool(void *)> hcom_unbind_model, | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task) | |||
| : TaskInfo(stream_id, TaskInfoType::HCCL), | |||
| int64_t op_type, int64_t data_type, const std::string &group, | |||
| std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model, | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), | |||
| hccl_type_(hccl_type), | |||
| input_data_addr_(input_data_addr), | |||
| output_data_addr_(output_data_addr), | |||
| @@ -269,6 +287,7 @@ class HcclTaskInfo : public TaskInfo { | |||
| root_id_(root_id), | |||
| op_type_(op_type), | |||
| data_type_(data_type), | |||
| group_(group), | |||
| hcom_bind_model_(hcom_bind_model), | |||
| hcom_unbind_model_(hcom_unbind_model), | |||
| hcom_distribute_task_(hcom_distribute_task) {} | |||
| @@ -286,6 +305,7 @@ class HcclTaskInfo : public TaskInfo { | |||
| int64_t root_id() const { return root_id_; } | |||
| int64_t op_type() const { return op_type_; } | |||
| int64_t data_type() const { return data_type_; } | |||
| const std::string &group() const { return group_; } | |||
| std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | |||
| std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | |||
| @@ -305,6 +325,7 @@ class HcclTaskInfo : public TaskInfo { | |||
| int64_t root_id_; | |||
| int64_t op_type_; | |||
| int64_t data_type_; | |||
| std::string group_; | |||
| std::function<bool(void *, void *)> hcom_bind_model_; | |||
| std::function<bool(void *)> hcom_unbind_model_; | |||
| std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_; | |||
| @@ -312,8 +333,11 @@ class HcclTaskInfo : public TaskInfo { | |||
| class ProfilerTraceTaskInfo : public TaskInfo { | |||
| public: | |||
| ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||
| : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} | |||
| ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), | |||
| log_id_(log_id), | |||
| notify_(notify), | |||
| flat_(flat) {} | |||
| ~ProfilerTraceTaskInfo() override {} | |||
| uint64_t log_id() const { return log_id_; } | |||
| @@ -328,8 +352,9 @@ class ProfilerTraceTaskInfo : public TaskInfo { | |||
| class MemcpyAsyncTaskInfo : public TaskInfo { | |||
| public: | |||
| MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) | |||
| : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), | |||
| MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, | |||
| uint64_t count, uint32_t kind, bool dump_flag) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), | |||
| dst_(dst), | |||
| dst_max_(dst_max), | |||
| src_(src), | |||
| @@ -353,9 +378,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo { | |||
| class StreamSwitchTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, | |||
| int64_t data_type) | |||
| : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), | |||
| StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, | |||
| void *value_addr, int64_t cond, int64_t data_type) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), | |||
| true_stream_id_(true_stream_id), | |||
| input_addr_(input_addr), | |||
| value_addr_(value_addr), | |||
| @@ -379,8 +404,8 @@ class StreamSwitchTaskInfo : public TaskInfo { | |||
| class StreamActiveTaskInfo : public TaskInfo { | |||
| public: | |||
| StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) | |||
| : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} | |||
| StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) | |||
| : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} | |||
| ~StreamActiveTaskInfo() override {} | |||
| uint32_t active_stream_id() const { return active_stream_id_; } | |||
| @@ -27,6 +27,7 @@ | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/op_desc.h" | |||
| #include "graph/detail/attributes_holder.h" | |||
| namespace ge { | |||
| class GeGenerator { | |||
| @@ -98,13 +98,14 @@ Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); | |||
| Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); | |||
| Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name); | |||
| Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info); | |||
| void GetOutputNodesNameAndIndex(std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||
| std::vector<std::string> &output_nodes_name); | |||
| void UpdateOmgCtxWithParserCtx(); | |||
| void UpdateParserCtxWithOmgCtx(); | |||
| } // namespace ge | |||
| namespace domi { | |||
| @@ -94,6 +94,8 @@ struct OmgContext { | |||
| std::vector<std::pair<std::string, int32_t>> user_out_nodes; | |||
| // net out nodes (where user_out_nodes or leaf nodes) | |||
| std::vector<std::string> net_out_nodes; | |||
| // net out nodes top names(only caffe has top) | |||
| std::vector<std::string> out_top_names; | |||
| // path for the aicpu custom operator so_file | |||
| std::vector<std::string> aicpu_op_run_paths; | |||
| // ddk version | |||
| @@ -74,6 +74,9 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| size_t GetAllNodesSize() const; | |||
| Vistor<NodePtr> GetAllNodes() const; | |||
| // is_unknown_shape: false, same with GetAllNodes func | |||
| // is_unknown_shape: true, same with GetDirectNodes func | |||
| Vistor<NodePtr> GetNodes(bool is_unknown_shape) const; | |||
| size_t GetDirectNodesSize() const; | |||
| Vistor<NodePtr> GetDirectNode() const; | |||
| Vistor<NodePtr> GetInputNodes() const; | |||
| @@ -174,6 +177,10 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| void SetInputSize(uint32_t size) { input_size_ = size; } | |||
| uint32_t GetInputSize() const { return input_size_; } | |||
| // false: known shape true: unknow shape | |||
| bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } | |||
| void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } | |||
| /// | |||
| /// Set is need train iteration. | |||
| /// If set true, it means this graph need to be run iteration some | |||
| @@ -282,7 +289,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||
| std::map<uint32_t, std::string> op_name_map_; | |||
| uint64_t session_id_ = 0; | |||
| ge::Format data_format_ = ge::FORMAT_ND; | |||
| // unknown graph indicator, default is false, mean known shape | |||
| bool is_unknown_shape_graph_ = false; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_COMPUTE_GRAPH_H_ | |||
| @@ -139,6 +139,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_INPUTS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_AIPP_OUTPUTS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_DIMS; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; | |||
| @@ -776,6 +778,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; | |||
| @@ -994,7 +1000,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; | |||
| // used for l1 fusion and other fusion in future | |||
| // used for lX fusion | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; | |||
| @@ -1008,9 +1014,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; | |||
| // op overflow dump | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; | |||
| // functional ops attr | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; | |||
| @@ -1056,6 +1070,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOR | |||
| // for gradient group | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; | |||
| // dynamic shape attrs | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; | |||
| // for fusion op plugin | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | |||
| @@ -149,5 +149,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { | |||
| AnyMap extAttrs_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ | |||
| @@ -28,6 +28,7 @@ class GEContext { | |||
| uint32_t DeviceId(); | |||
| uint64_t TraceId(); | |||
| void Init(); | |||
| void SetSessionId(uint64_t session_id); | |||
| void SetCtxDeviceId(uint32_t device_id); | |||
| private: | |||
| @@ -25,6 +25,7 @@ | |||
| #include "graph/buffer.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/types.h" | |||
| namespace ge { | |||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { | |||
| public: | |||
| @@ -108,8 +109,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH | |||
| DataType GetDataType() const; | |||
| void SetDataType(DataType dt); | |||
| void SetOriginDataType(DataType originDataType); | |||
| DataType GetOriginDataType() const; | |||
| void SetOriginDataType(DataType originDataType); | |||
| std::vector<uint32_t> GetRefPortIndex() const; | |||
| void SetRefPortByIndex(const std::vector<uint32_t> &index); | |||
| GeTensorDesc Clone() const; | |||
| GeTensorDesc &operator=(const GeTensorDesc &desc); | |||
| @@ -186,5 +190,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { | |||
| GeTensorDesc &DescReference() const; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_GE_TENSOR_H_ | |||
| @@ -49,5 +49,4 @@ class ModelSerialize { | |||
| friend class GraphDebugImp; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_MODEL_SERIALIZE_H_ | |||
| @@ -105,6 +105,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| GeTensorDescPtr MutableInputDesc(uint32_t index) const; | |||
| GeTensorDescPtr MutableInputDesc(const string &name) const; | |||
| Vistor<GeTensorDesc> GetAllInputsDesc() const; | |||
| Vistor<GeTensorDescPtr> GetAllInputsDescPtr() const; | |||
| @@ -127,6 +129,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||
| GeTensorDescPtr MutableOutputDesc(uint32_t index) const; | |||
| GeTensorDescPtr MutableOutputDesc(const string &name) const; | |||
| uint32_t GetAllOutputsDescSize() const; | |||
| Vistor<GeTensorDesc> GetAllOutputsDesc() const; | |||
| @@ -130,7 +130,7 @@ struct NodeIndexIO { | |||
| IOType io_type_ = kOut; | |||
| std::string value_; | |||
| std::string ToString() const { return value_; } | |||
| const std::string &ToString() const { return value_; } | |||
| }; | |||
| class GraphUtils { | |||
| @@ -188,8 +188,8 @@ class GraphUtils { | |||
| /// @param [in] output_index | |||
| /// @return graphStatus | |||
| /// | |||
| static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||
| static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); | |||
| static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); | |||
| @@ -303,6 +303,14 @@ class GraphUtils { | |||
| /// | |||
| static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | |||
| /// | |||
| /// Copy all in-data edges from `src_node` to `dst_node` | |||
| /// @param src_node | |||
| /// @param dst_node | |||
| /// @return | |||
| /// | |||
| static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); | |||
| static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | |||
| static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||
| @@ -728,5 +736,4 @@ class PartialGraphBuilder : public ComputeGraphBuilder { | |||
| std::vector<NodePtr> exist_nodes_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ | |||
| @@ -99,6 +99,13 @@ class NodeUtils { | |||
| /// | |||
| static NodePtr GetParentInput(const NodePtr &node); | |||
| /// | |||
| /// @brief Check is varying_input for while node | |||
| /// @param [in] node: Data node for subgraph | |||
| /// @return bool | |||
| /// | |||
| static bool IsWhileVaryingInput(const ge::NodePtr &node); | |||
| /// | |||
| /// @brief Get subgraph input is constant. | |||
| /// @param [in] node | |||
| @@ -114,6 +121,24 @@ class NodeUtils { | |||
| /// | |||
| static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| static vector<NodePtr> GetSubgraphDataNodesByIndex(const Node &node, int index); | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| static vector<NodePtr> GetSubgraphOutputNodes(const Node &node); | |||
| static NodePtr GetInDataNodeByIndex(const Node &node, int index); | |||
| static vector<NodePtr> GetOutDataNodesByIndex(const Node &node, int index); | |||
| private: | |||
| static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | |||
| static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | |||
| @@ -20,6 +20,7 @@ | |||
| #include <memory> | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/tensor.h" | |||
| namespace ge { | |||
| using GeTensorPtr = std::shared_ptr<GeTensor>; | |||
| using ConstGeTensorPtr = std::shared_ptr<const GeTensor>; | |||
| @@ -21,6 +21,7 @@ | |||
| #include "graph/def_types.h" | |||
| #include "graph/ge_error_codes.h" | |||
| #include "graph/ge_tensor.h" | |||
| namespace ge { | |||
| class TensorUtils { | |||
| public: | |||
| @@ -71,5 +71,6 @@ target_link_libraries(graph PRIVATE | |||
| ${PROTOBUF_LIBRARY} | |||
| ${c_sec} | |||
| ${slog} | |||
| ${error_manager} | |||
| rt | |||
| dl) | |||
| @@ -106,6 +106,15 @@ ComputeGraph::Vistor<NodePtr> ComputeGraph::AllGraphNodes(std::vector<std::share | |||
| return Vistor<NodePtr>(shared_from_this(), all_nodes); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetNodes( | |||
| bool is_unknown_shape) const { | |||
| if (is_unknown_shape) { | |||
| return GetDirectNode(); | |||
| } else { | |||
| return GetAllNodes(); | |||
| } | |||
| } | |||
| size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetDirectNode() const { | |||
| @@ -497,6 +506,10 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptr<Compute | |||
| if (name != subgraph->GetName()) { | |||
| GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); | |||
| } | |||
| if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { | |||
| GE_LOGE("The subgraph %s existed", name.c_str()); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| sub_graph_.push_back(subgraph); | |||
| names_to_subgraph_[name] = subgraph; | |||
| return GRAPH_SUCCESS; | |||
| @@ -34,12 +34,16 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | |||
| GE_REGISTER_OPTYPE(SWITCH, "Switch"); | |||
| GE_REGISTER_OPTYPE(MERGE, "Merge"); | |||
| GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | |||
| GE_REGISTER_OPTYPE(ENTER, "Enter"); | |||
| GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); | |||
| GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | |||
| GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | |||
| GE_REGISTER_OPTYPE(CONSTANT, "Const"); | |||
| GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | |||
| GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | |||
| GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | |||
| GE_REGISTER_OPTYPE(INITDATA, "InitData"); | |||
| GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); | |||
| GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); | |||
| GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); | |||
| @@ -41,11 +41,9 @@ using namespace ge; | |||
| using namespace std; | |||
| namespace ge { | |||
| namespace { | |||
| static const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
| static bool net_format_is_nd = true; | |||
| static Format g_user_set_format = FORMAT_ND; | |||
| static bool is_first_infer = true; | |||
| static RefRelations reflection_builder; | |||
| const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | |||
| const string kIsGraphInferred = "_is_graph_inferred"; | |||
| RefRelations reflection_builder; | |||
| } // namespace | |||
| graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | |||
| @@ -72,9 +70,49 @@ graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &re | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { | |||
| graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { | |||
| // 5 meas dim num | |||
| if (node_ptr->GetType() != "BiasAdd") { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::unordered_map<string, Format> kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; | |||
| for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { | |||
| auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); | |||
| GE_CHECK_NOTNULL(in_desc); | |||
| if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||
| continue; | |||
| } | |||
| auto format = in_desc->GetOriginFormat(); | |||
| auto key = TypeUtils::FormatToSerialString(format); | |||
| auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||
| in_desc->SetOriginFormat(fixed_format); | |||
| in_desc->SetFormat(fixed_format); | |||
| GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, | |||
| node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||
| TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||
| } | |||
| for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { | |||
| auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); | |||
| GE_CHECK_NOTNULL(out_desc); | |||
| if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num | |||
| continue; | |||
| } | |||
| auto format = out_desc->GetOriginFormat(); | |||
| auto key = TypeUtils::FormatToSerialString(format); | |||
| auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; | |||
| out_desc->SetOriginFormat(fixed_format); | |||
| out_desc->SetFormat(fixed_format); | |||
| GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, | |||
| node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), | |||
| TypeUtils::FormatToSerialString(fixed_format).c_str()); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { | |||
| if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { | |||
| ConstGeTensorPtr tensor_value; | |||
| if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { | |||
| GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); | |||
| @@ -95,7 +133,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| } | |||
| anchor_points.clear(); | |||
| // Get all anchor point nodes and switch nodes | |||
| for (const auto &node_ptr : graph->GetAllNodes()) { | |||
| for (auto &node_ptr : graph->GetAllNodes()) { | |||
| if (node_ptr == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| @@ -103,7 +141,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| if (op_desc == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| graphStatus status = RefreshConstantOutProcess(op_desc); | |||
| graphStatus status = RefreshConstantOutProcess(graph, op_desc); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); | |||
| return GRAPH_FAILED; | |||
| @@ -135,6 +173,16 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std | |||
| if (!node_is_all_nd) { | |||
| continue; | |||
| } | |||
| // special process for biasAdd op | |||
| // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg | |||
| // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism | |||
| // so here do special process | |||
| status = BiasAddFormatFixProcess(node_ptr); | |||
| if (status != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); | |||
| anchor_points.push_back(node_ptr); | |||
| } | |||
| @@ -344,14 +392,11 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector<ge::NodePtr> &anchor | |||
| } | |||
| } | |||
| void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; } | |||
| graphStatus FormatRefiner::DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
| graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||
| ge::Format data_format, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status) { | |||
| bool is_internal_format = TypeUtils::IsInternalFormat(data_format); | |||
| bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); | |||
| if (!need_process) { | |||
| GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, | |||
| if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { | |||
| GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), | |||
| TypeUtils::FormatToSerialString(data_format).c_str()); | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -410,8 +455,6 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
| std::vector<ge::NodePtr> anchor_points; | |||
| std::vector<ge::NodePtr> data_nodes; | |||
| // global net format | |||
| net_format_is_nd = true; | |||
| g_user_set_format = FORMAT_ND; | |||
| if (graph == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "input graph is null"); | |||
| @@ -448,10 +491,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) | |||
| /// format for these data nodes. | |||
| /// Notice: ignore 5D formats | |||
| auto data_format = graph->GetDataFormat(); | |||
| status = DataNodeFormatProcess(data_nodes, data_format, node_status); | |||
| // Set infer flag to false | |||
| SetInferOrigineFormatFlag(false); | |||
| status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); | |||
| (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); | |||
| return status; | |||
| } | |||
| bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { | |||
| bool is_graph_inferred = false; | |||
| return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); | |||
| } | |||
| } // namespace ge | |||
| @@ -30,10 +30,9 @@ namespace ge { | |||
| class FormatRefiner { | |||
| public: | |||
| static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); | |||
| static void SetInferOrigineFormatFlag(bool is_first = true); | |||
| private: | |||
| static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc); | |||
| static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | |||
| static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector<ge::NodePtr> &anchor_points, | |||
| std::vector<ge::NodePtr> &data_nodes, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| @@ -43,8 +42,9 @@ class FormatRefiner { | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static graphStatus ForwardInferProcess(std::deque<ge::NodePtr> &nodes, ge::NodePtr &node, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static graphStatus DataNodeFormatProcess(std::vector<ge::NodePtr> &data_nodes, ge::Format data_format, | |||
| std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector<ge::NodePtr> &data_nodes, | |||
| ge::Format data_format, std::unordered_map<ge::NodePtr, bool> &node_status); | |||
| static bool IsGraphInferred(const ComputeGraphPtr &graph); | |||
| }; | |||
| } // namespace ge | |||
| #endif // COMMON_GRAPH_FORMAT_REFINER_H_ | |||
| @@ -121,6 +121,8 @@ const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | |||
| const std::string ATTR_NAME_AIPP_INPUTS = "_aipp_inputs"; | |||
| const std::string ATTR_NAME_AIPP_OUTPUTS = "_aipp_outputs"; | |||
| const std::string ATTR_NAME_INPUT_DIMS = "input_dims"; | |||
| const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | |||
| const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; | |||
| @@ -723,6 +725,10 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; | |||
| const std::string ATTR_MODEL_CORE_TYPE = "core_type"; | |||
| const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; | |||
| const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; | |||
| // Public attribute | |||
| const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; | |||
| @@ -932,7 +938,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; | |||
| const std::string MODEL_ATTR_SESSION_ID = "session_id"; | |||
| // l1 fusion and other fusion in future | |||
| // lx fusion | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; | |||
| const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; | |||
| const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | |||
| @@ -946,9 +952,17 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 | |||
| const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; | |||
| const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; | |||
| const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; | |||
| const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; | |||
| const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; | |||
| const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; | |||
| const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; | |||
| const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; | |||
| const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; | |||
| const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; | |||
| // Op debug attrs | |||
| const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; | |||
| const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; | |||
| // Atomic addr clean attrs | |||
| const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; | |||
| @@ -1013,4 +1027,11 @@ const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; | |||
| // used for allreduce tailing optimization | |||
| const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; | |||
| const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; | |||
| // dynamic shape attr | |||
| const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; | |||
| const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; | |||
| // for fusion op plugin | |||
| const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; | |||
| } // namespace ge | |||
| @@ -220,6 +220,7 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; | |||
| const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; | |||
| const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; | |||
| const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; | |||
| const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; | |||
| GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} | |||
| @@ -567,6 +568,16 @@ DataType GeTensorDesc::GetOriginDataType() const { | |||
| return TypeUtils::SerialStringToDataType(origin_data_type_str); | |||
| } | |||
| std::vector<uint32_t> GeTensorDesc::GetRefPortIndex() const { | |||
| vector<uint32_t> ref_port_index; | |||
| (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); | |||
| return ref_port_index; | |||
| } | |||
| void GeTensorDesc::SetRefPortByIndex(const std::vector<uint32_t> &index) { | |||
| (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); | |||
| } | |||
| graphStatus GeTensorDesc::IsValid() const { | |||
| auto dtype = this->GetDataType(); | |||
| auto format = this->GetFormat(); | |||
| @@ -210,7 +210,7 @@ class GraphImpl { | |||
| graphStatus FindOpByName(const string &name, ge::Operator &op) const { | |||
| auto it = op_list_.find(name); | |||
| GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str()); | |||
| GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); | |||
| op = it->second; | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -1,5 +1,5 @@ | |||
| LOCAL_PATH := $(call my-dir) | |||
| include $(LOCAL_PATH)/stub/Makefile | |||
| COMMON_LOCAL_SRC_FILES := \ | |||
| ./proto/om.proto \ | |||
| ./proto/ge_ir.proto \ | |||
| @@ -77,6 +77,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libprotobuf \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -85,6 +86,54 @@ LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| #compiler for host | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := stub/libgraph | |||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||
| LOCAL_CPPFLAGS += -fexceptions | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| LOCAL_MULTILIB := 64 | |||
| LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| #compiler for host | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := fwk_stub/libgraph | |||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 | |||
| LOCAL_CPPFLAGS += -fexceptions | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| LOCAL_MULTILIB := 64 | |||
| LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| #compiler for device | |||
| include $(CLEAR_VARS) | |||
| @@ -99,6 +148,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libprotobuf \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -111,6 +161,60 @@ LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_SHARED_LIBRARY) | |||
| #compiler for device | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := stub/libgraph | |||
| LOCAL_CFLAGS += -O2 | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| ifeq ($(device_os),android) | |||
| LOCAL_LDFLAGS := -ldl | |||
| endif | |||
| LOCAL_MULTILIB := 64 | |||
| LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_SHARED_LIBRARY) | |||
| #compiler for device | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := fwk_stub/libgraph | |||
| LOCAL_CFLAGS += -O2 | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := \ | |||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||
| ../../out/graph/lib64/stub/graph.cc \ | |||
| ../../out/graph/lib64/stub/operator.cc \ | |||
| ../../out/graph/lib64/stub/operator_factory.cc \ | |||
| ../../out/graph/lib64/stub/tensor.cc \ | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| ifeq ($(device_os),android) | |||
| LOCAL_LDFLAGS := -ldl | |||
| endif | |||
| LOCAL_MULTILIB := 64 | |||
| LOCAL_PROPRIETARY_MODULE := true | |||
| include $(BUILD_SHARED_LIBRARY) | |||
| # compile for ut/st | |||
| include $(CLEAR_VARS) | |||
| @@ -125,6 +229,7 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libprotobuf \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -150,6 +255,7 @@ LOCAL_STATIC_LIBRARIES := \ | |||
| LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -173,6 +279,7 @@ LOCAL_STATIC_LIBRARIES := \ | |||
| LOCAL_SHARED_LIBRARIES := \ | |||
| libc_sec \ | |||
| libslog \ | |||
| liberror_manager \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| @@ -88,10 +88,8 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ | |||
| } | |||
| bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { | |||
| if (op_desc == nullptr || op_def_proto == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Input Para Invalid"); | |||
| return false; | |||
| } | |||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); | |||
| GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); | |||
| if (op_desc->op_def_.GetProtoMsg() != nullptr) { | |||
| *op_def_proto = *op_desc->op_def_.GetProtoMsg(); | |||
| // Delete unnecessary attr | |||
| @@ -130,16 +128,17 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op | |||
| for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { | |||
| op_def_proto->add_subgraph_name(name); | |||
| } | |||
| proto::AttrDef key; | |||
| proto::AttrDef value; | |||
| for (auto &item : op_desc->output_name_idx_) { | |||
| key.mutable_list()->add_s(item.first); | |||
| value.mutable_list()->add_i(item.second); | |||
| if (!op_desc->output_name_idx_.empty()) { | |||
| proto::AttrDef key; | |||
| proto::AttrDef value; | |||
| for (auto &item : op_desc->output_name_idx_) { | |||
| key.mutable_list()->add_s(item.first); | |||
| value.mutable_list()->add_i(item.second); | |||
| } | |||
| auto op_desc_attr = op_def_proto->mutable_attr(); | |||
| op_desc_attr->insert({"_output_name_key", key}); | |||
| op_desc_attr->insert({"_output_name_value", value}); | |||
| } | |||
| auto op_desc_attr = op_def_proto->mutable_attr(); | |||
| op_desc_attr->insert({"_output_name_key", key}); | |||
| op_desc_attr->insert({"_output_name_value", value}); | |||
| } | |||
| return true; | |||
| } | |||
| @@ -26,6 +26,7 @@ | |||
| #include "utils/ge_ir_utils.h" | |||
| #include "utils/node_utils.h" | |||
| #include "utils/op_desc_utils.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| using std::string; | |||
| using std::vector; | |||
| @@ -154,7 +155,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons | |||
| const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||
| const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); | |||
| if (peer_node == nullptr || r_peer_node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", | |||
| GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", | |||
| this->GetName().c_str(), i, j); | |||
| return false; | |||
| } | |||
| @@ -434,8 +435,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor<AnchorPtr> Node::Get | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { | |||
| if (idx < 0 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||
| GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(), | |||
| GetName().c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19019", {"opname", "index", "anchorname", "optype"}, | |||
| {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); | |||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, | |||
| GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| return in_data_anchors_[idx]; | |||
| @@ -445,7 +449,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { | |||
| // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ | |||
| if (idx < -1 || idx >= static_cast<int>(in_data_anchors_.size())) { | |||
| GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19019", {"opname", "index", "anchorname", "optype"}, | |||
| {GetName().c_str(), std::to_string(idx), "in_anchor", GetType().c_str()}); | |||
| GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| // Return control anchor | |||
| @@ -461,8 +468,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { | |||
| // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ | |||
| if (idx < -1 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||
| GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(), | |||
| GetName().c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, | |||
| { | |||
| GetName().c_str(), | |||
| std::to_string(idx), | |||
| "out_anchor", | |||
| GetType().c_str(), | |||
| }); | |||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, | |||
| GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| // Return control anchor | |||
| @@ -477,8 +491,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { | |||
| if (idx < 0 || idx >= static_cast<int>(out_data_anchors_.size())) { | |||
| GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(), | |||
| GetName().c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19019", {"opname", "index", "anchorname", "optype"}, | |||
| {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); | |||
| GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, | |||
| GetType().c_str()); | |||
| return nullptr; | |||
| } else { | |||
| return out_data_anchors_[idx]; | |||
| @@ -733,11 +750,15 @@ graphStatus Node::Verify() const { | |||
| GELOGW("in anchor ptr is null"); | |||
| continue; | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS( | |||
| op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || | |||
| op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || | |||
| in_anchor_ptr->GetPeerAnchors().size() > 0, | |||
| GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||
| bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || | |||
| op_->GetType() == const_type || op_->GetType() == variable_type || | |||
| op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; | |||
| if (!valid_anchor) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"name", "index"}, | |||
| {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); | |||
| GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| string frameworkop_type = "FrameworkOp"; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "debug/ge_util.h" | |||
| #include "external/graph/operator.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "graph/ge_attr_value.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/operator_factory_impl.h" | |||
| @@ -470,6 +471,25 @@ GeTensorDesc OpDesc::GetInputDesc(const string &name) const { | |||
| return *(inputs_desc_[it->second].get()); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { | |||
| GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); | |||
| if (inputs_desc_[index] == nullptr) { | |||
| return nullptr; | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); | |||
| return inputs_desc_[index]; | |||
| } | |||
| GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| auto it = input_name_idx.find(name); | |||
| if (it == input_name_idx.end()) { | |||
| GELOGW("Failed to get [%s] input desc", name.c_str()); | |||
| return nullptr; | |||
| } | |||
| return MutableInputDesc(it->second); | |||
| } | |||
| GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<string> OpDesc::GetAllInputNames() const { | |||
| auto input_name_idx = GetAllInputName(); | |||
| vector<string> names; | |||
| @@ -496,15 +516,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { | |||
| GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); | |||
| if (inputs_desc_[index] == nullptr) { | |||
| return nullptr; | |||
| } | |||
| GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); | |||
| return inputs_desc_[index]; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor<GeTensorDesc> OpDesc::GetAllInputsDesc() const { | |||
| vector<GeTensorDesc> temp{}; | |||
| for (const auto &it : inputs_desc_) { | |||
| @@ -609,6 +620,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu | |||
| return outputs_desc_[index]; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { | |||
| auto it = output_name_idx_.find(name); | |||
| if (it == output_name_idx_.end()) { | |||
| GELOGW("Failed to get [%s] output desc", name.c_str()); | |||
| return nullptr; | |||
| } | |||
| return MutableOutputDesc(it->second); | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { | |||
| return static_cast<uint32_t>(outputs_desc_.size()); | |||
| } | |||
| @@ -882,15 +902,22 @@ graphStatus OpDesc::CommonVerify() const { | |||
| // Checking shape of all inputs | |||
| vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | |||
| for (int64_t dim : ishape) { | |||
| GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | |||
| iname.c_str()); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19014", {"opname", "value", "reason"}, | |||
| {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); | |||
| return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), | |||
| iname.c_str()); | |||
| } | |||
| } | |||
| // Check all attributes defined | |||
| const auto &all_attributes = GetAllAttrs(); | |||
| for (const auto &name : GetAllAttrNames()) { | |||
| GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, | |||
| "operator attribute %s is empty.", name.c_str()); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| all_attributes.find(name) == all_attributes.end(), | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, | |||
| {GetName(), "attribute " + name, "is empty"}); | |||
| return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| @@ -36,6 +36,8 @@ | |||
| #include "graph/op_desc.h" | |||
| #include "graph/runtime_inference_context.h" | |||
| #include "graph/usr_types.h" | |||
| #include "graph/utils/node_utils.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "utils/graph_utils.h" | |||
| #include "utils/op_desc_utils.h" | |||
| #include "utils/tensor_adapter.h" | |||
| @@ -57,8 +59,7 @@ using std::vector; | |||
| namespace ge { | |||
| class OpIO { | |||
| public: | |||
| explicit OpIO(const string &name, int index, const OperatorImplPtr &owner) | |||
| : name_(name), index_(index), owner_(owner) {} | |||
| OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} | |||
| ~OpIO() = default; | |||
| @@ -546,56 +547,46 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) { | |||
| } | |||
| graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { | |||
| if (operator_impl_ == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ge::ConstNodePtr node_ptr = operator_impl_->GetNode(); | |||
| if (node_ptr) { | |||
| GE_CHECK_NOTNULL(operator_impl_); | |||
| auto node_ptr = operator_impl_->GetNode(); | |||
| if (node_ptr != nullptr) { | |||
| // For inner compute graph | |||
| auto op_desc = node_ptr->GetOpDesc(); | |||
| if (op_desc == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "op_desc is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| auto index = op_desc->GetInputIndexByName(dst_name); | |||
| auto in_data_anchor = node_ptr->GetInDataAnchor(index); | |||
| if (in_data_anchor == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| GE_CHECK_NOTNULL(in_data_anchor); | |||
| auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (out_data_anchor == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::shared_ptr<Node> peer_node_ptr = out_data_anchor->GetOwnerNode(); | |||
| if (peer_node_ptr == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr."); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ge::OperatorImplPtr operator_impl_ptr = nullptr; | |||
| operator_impl_ptr = ComGraphMakeShared<OperatorImpl>(peer_node_ptr); | |||
| if (operator_impl_ptr == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| Operator const_op(std::move(operator_impl_ptr)); | |||
| if (peer_node_ptr->GetOpDesc() != nullptr) { | |||
| const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); | |||
| if (op_descType == CONSTANTOP) { | |||
| return const_op.GetAttr(op::Constant::name_attr_value(), data); | |||
| } else if (op_descType == CONSTANT) { | |||
| return const_op.GetAttr(op::Const::name_attr_value(), data); | |||
| GE_CHECK_NOTNULL(out_data_anchor); | |||
| auto peer_node = out_data_anchor->GetOwnerNode(); | |||
| GE_CHECK_NOTNULL(peer_node); | |||
| auto peer_op_desc = peer_node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(peer_op_desc); | |||
| auto peer_op_type = peer_op_desc->GetType(); | |||
| if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { | |||
| auto const_op_impl = ComGraphMakeShared<OperatorImpl>(peer_node); | |||
| GE_CHECK_NOTNULL(const_op_impl); | |||
| Operator const_op(std::move(const_op_impl)); | |||
| return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); | |||
| } else if (peer_op_type == DATA) { | |||
| auto parent_node = NodeUtils::GetParentInput(peer_node); | |||
| while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { | |||
| parent_node = NodeUtils::GetParentInput(parent_node); | |||
| } | |||
| if ((parent_node != nullptr) && | |||
| ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { | |||
| auto const_op_impl = ComGraphMakeShared<OperatorImpl>(parent_node); | |||
| GE_CHECK_NOTNULL(const_op_impl); | |||
| Operator const_op(std::move(const_op_impl)); | |||
| return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); | |||
| } | |||
| } | |||
| // Try get from runtime inference context | |||
| auto session_id = std::to_string(GetContext().SessionId()); | |||
| RuntimeInferenceContext *runtime_infer_ctx = nullptr; | |||
| if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { | |||
| GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); | |||
| auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); | |||
| auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); | |||
| if (ret == GRAPH_SUCCESS) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| @@ -604,6 +595,8 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||
| // For outer graph | |||
| return GetInputConstDataOut(dst_name, data); | |||
| } | |||
| auto op_name = operator_impl_->GetName(); | |||
| GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { | |||
| @@ -85,6 +85,8 @@ uint32_t GEContext::DeviceId() { return device_id_; } | |||
| uint64_t GEContext::TraceId() { return trace_id_; } | |||
| void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } | |||
| void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } | |||
| } // namespace ge | |||
| @@ -242,6 +242,10 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r | |||
| int sub_graph_idx = 0; | |||
| for (const auto &name : sub_graph_names) { | |||
| auto sub_graph = root_graph.GetSubgraph(name); | |||
| if (sub_graph == nullptr) { | |||
| GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); | |||
| continue; | |||
| } | |||
| for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { | |||
| auto sub_graph_node_type = sub_graph_node->GetType(); | |||
| @@ -37,6 +37,115 @@ | |||
| namespace ge { | |||
| namespace { | |||
| const uint32_t kWhileBodySubGraphIdx = 1; | |||
| graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { | |||
| GELOGD("Enter reverse brush while body subgraph process!"); | |||
| auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); | |||
| if (sub_graph_body == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Get while body graph failed!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (const auto &node_sub : sub_graph_body->GetAllNodes()) { | |||
| if (node_sub->GetInDataNodes().size() == 0) { | |||
| continue; | |||
| } | |||
| for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { | |||
| auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); | |||
| (void)input_desc->SetUnknownDimNumShape(); | |||
| } | |||
| for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { | |||
| auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); | |||
| (void)output_desc->SetUnknownDimNumShape(); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||
| GELOGD("Enter update parent node shape for class branch op process"); | |||
| // check sub_graph shape.If not same ,do unknown shape process | |||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
| ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); | |||
| for (auto &tensor : ref_out_tensors[i]) { | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto shape = tensor.MutableShape(); | |||
| if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { | |||
| GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||
| shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
| ref_out_tensor_shape = GeShape(UNKNOWN_RANK); | |||
| break; | |||
| } | |||
| for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { | |||
| if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { | |||
| continue; | |||
| } | |||
| GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, | |||
| j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); | |||
| (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); | |||
| } | |||
| } | |||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector<std::vector<GeTensorDesc>> &ref_data_tensors, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_out_tensors) { | |||
| GELOGD("Enter update parent node shape for class while op process"); | |||
| if (ref_data_tensors.size() != ref_out_tensors.size()) { | |||
| GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), | |||
| ref_data_tensors.size(), ref_out_tensors.size()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (size_t i = 0; i < ref_data_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].size() != 1) { | |||
| GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); | |||
| return GRAPH_FAILED; | |||
| } | |||
| } | |||
| bool is_need_reverse_brush = false; | |||
| // check input and output | |||
| for (size_t i = 0; i < ref_out_tensors.size(); i++) { | |||
| if (ref_out_tensors[i].empty()) { | |||
| continue; | |||
| } | |||
| auto ref_out_tensor = ref_out_tensors[i].at(0); | |||
| auto tmp_shape = ref_out_tensor.MutableShape(); | |||
| // ref_i's data and output tensor shape should be same | |||
| for (auto &tensor : ref_data_tensors[i]) { | |||
| if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { | |||
| GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| auto shape = tensor.MutableShape(); | |||
| if (shape.GetDims() != tmp_shape.GetDims()) { | |||
| ref_out_tensor.SetUnknownDimNumShape(); | |||
| is_need_reverse_brush = true; | |||
| break; | |||
| } | |||
| } | |||
| (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); | |||
| } | |||
| // reverse refresh while body shape | |||
| if (is_need_reverse_brush) { | |||
| return ReverseBrushWhileBodySubGraph(node); | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| @@ -98,6 +207,37 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr<ComputeGraph> &sub_graph, NodePtr &netoutput, | |||
| const ConstNodePtr &node, | |||
| std::vector<std::vector<GeTensorDesc>> &ref_data_tensors) { | |||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
| auto sub_node = sub_nodes.at(i - 1); | |||
| if (sub_node->GetType() == NETOUTPUT) { | |||
| netoutput = sub_node; | |||
| } | |||
| if (sub_node->GetType() == DATA) { | |||
| if (sub_node->GetOpDesc() == nullptr) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| int ref_i; | |||
| if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { | |||
| GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllInDataAnchorsSize()) { | |||
| GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), | |||
| ref_i, node->GetAllInDataAnchorsSize()); | |||
| return GRAPH_FAILED; | |||
| } | |||
| ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | |||
| @@ -105,7 +245,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| std::vector<std::vector<GeTensorDesc>> ref_data_tensors(node->GetAllInDataAnchorsSize()); | |||
| std::vector<std::vector<GeTensorDesc>> ref_out_tensors(node->GetAllOutDataAnchorsSize()); | |||
| auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); | |||
| for (const auto &name : sub_graph_names) { | |||
| if (name.empty()) { | |||
| GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); | |||
| @@ -117,13 +260,9 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| NodePtr netoutput = nullptr; | |||
| auto sub_nodes = sub_graph->GetDirectNode(); | |||
| for (size_t i = sub_nodes.size(); i > 0; --i) { | |||
| auto sub_node = sub_nodes.at(i - 1); | |||
| if (sub_node->GetType() == NETOUTPUT) { | |||
| netoutput = sub_node; | |||
| break; | |||
| } | |||
| auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| return ret; | |||
| } | |||
| if (netoutput == nullptr) { | |||
| GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); | |||
| @@ -150,19 +289,17 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { | |||
| continue; | |||
| } | |||
| GELOGI("Parent node index of edge desc is %d", ref_i); | |||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(ref_i)); | |||
| if (output_desc == nullptr) { | |||
| GE_LOGE( | |||
| "The ref index(%d) on the input %d of netoutput %s on the sub graph %s " | |||
| "parent node %s are incompatible, outputs num %u", | |||
| ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), | |||
| node->GetAllOutDataAnchorsSize()); | |||
| if (ref_i < 0 || static_cast<uint32_t>(ref_i) >= node->GetAllOutDataAnchorsSize()) { | |||
| return GRAPH_FAILED; | |||
| } | |||
| op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); | |||
| ref_out_tensors[ref_i].emplace_back(*edge_desc); | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| if (node->GetType() == WHILE) { | |||
| return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); | |||
| } | |||
| return UpdateParentNodeForBranch(node, ref_out_tensors); | |||
| } | |||
| } // namespace | |||
| void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { | |||
| @@ -170,6 +307,9 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||
| GELOGE(GRAPH_FAILED, "node is null"); | |||
| return; | |||
| } | |||
| if (!IsLogEnable(GE, DLOG_DEBUG)) { | |||
| return; | |||
| } | |||
| ge::OpDescPtr op_desc = node->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return ); | |||
| std::string str; | |||
| @@ -1,18 +1,18 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| */ | |||
| #ifndef COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ | |||
| #define COMMON_GRAPH_UTILS_GE_IR_UTILS_H_ | |||
| @@ -38,6 +38,7 @@ | |||
| #include "utils/ge_ir_utils.h" | |||
| #include "utils/node_utils.h" | |||
| #include "debug/ge_op_types.h" | |||
| #include "external/ge/ge_api_types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "graph/utils/tensor_utils.h" | |||
| @@ -410,8 +411,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra | |||
| /// @return graphStatus | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||
| GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { | |||
| GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts, | |||
| const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { | |||
| GE_CHECK_NOTNULL(src); | |||
| GE_CHECK_NOTNULL(insert_node); | |||
| @@ -570,7 +571,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons | |||
| static int max_dumpfile_num = 0; | |||
| if (max_dumpfile_num == 0) { | |||
| string opt = "0"; | |||
| (void)GetContext().GetOption("ge.maxDumpFileNum", opt); | |||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||
| max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||
| } | |||
| if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) { | |||
| @@ -670,7 +671,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText | |||
| if (maxDumpFileSize == 0) { | |||
| string opt = "0"; | |||
| // Can not check return value | |||
| (void)GetContext().GetOption("ge.maxDumpFileSize", opt); | |||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); | |||
| maxDumpFileSize = atol(opt.c_str()); | |||
| } | |||
| if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) { | |||
| @@ -740,7 +741,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||
| static int max_dumpfile_num = 0; | |||
| if (max_dumpfile_num == 0) { | |||
| string opt = "0"; | |||
| (void)GetContext().GetOption("ge.maxDumpFileNum", opt); | |||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||
| max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||
| } | |||
| if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) { | |||
| @@ -920,7 +921,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector<int> &io_map, In | |||
| InNodesToOut GetFullConnectIONodes(const NodePtr &node) { | |||
| InNodesToOut in_nodes_to_out; | |||
| if (node == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str()); | |||
| GELOGE(GRAPH_FAILED, "Node is nullptr"); | |||
| return in_nodes_to_out; | |||
| } | |||
| auto in_nodes_list = node->GetInNodes(); | |||
| @@ -1308,6 +1309,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// Copy all in-data edges from `src_node` to `dst_node`. | |||
| /// @param src_node | |||
| /// @param dst_node | |||
| /// @return | |||
| /// | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, | |||
| NodePtr &dst_node) { | |||
| if ((src_node == nullptr) || (dst_node == nullptr)) { | |||
| GELOGE(GRAPH_FAILED, "Parameter is nullptr"); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto src_data_in_nodes = src_node->GetInDataNodes(); | |||
| if (src_data_in_nodes.empty()) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { | |||
| auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); | |||
| auto ret = | |||
| GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", | |||
| in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), | |||
| src_node->GetName().c_str(), dst_node->GetName().c_str()); | |||
| return ret; | |||
| } | |||
| } | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, | |||
| const NodePtr &node) { | |||
| if (graph->AddInputNode(node) == nullptr) { | |||
| @@ -1339,7 +1370,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, | |||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(graph); | |||
| for (auto &node : graph->GetAllNodes()) { | |||
| for (const auto &node : graph->GetAllNodes()) { | |||
| // in_data_anchor | |||
| if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { | |||
| GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); | |||
| @@ -1396,16 +1427,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, | |||
| return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); | |||
| } | |||
| std::string type = node->GetType(); | |||
| const std::string &type = node->GetType(); | |||
| if ((type == MERGE) || (type == STREAMMERGE)) { | |||
| return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); | |||
| } | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { | |||
| std::string symbol = cur_node_info.ToString(); | |||
| const std::string &symbol = cur_node_info.ToString(); | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| symbol_to_anchors[symbol] = {cur_node_info}; | |||
| anchor_to_symbol[symbol] = symbol; | |||
| @@ -1432,7 +1463,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol) { | |||
| GE_CHECK_NOTNULL(node); | |||
| for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
| for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||
| NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); | |||
| if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { | |||
| continue; | |||
| @@ -1446,7 +1477,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, | |||
| return GRAPH_FAILED; | |||
| } | |||
| } else { | |||
| std::string symbol = cur_node_info.ToString(); | |||
| const std::string &symbol = cur_node_info.ToString(); | |||
| GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); | |||
| symbol_to_anchors.emplace(std::make_pair(symbol, std::list<NodeIndexIO>{cur_node_info})); | |||
| anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); | |||
| @@ -1506,7 +1537,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| GE_CHECK_NOTNULL(node); | |||
| std::vector<NodeIndexIO> exist_node_infos; | |||
| std::vector<NodeIndexIO> cur_node_infos; | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| if (peer_out_anchor == nullptr) { | |||
| std::string next_name; | |||
| @@ -1529,10 +1560,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| size_t anchor_nums = 0; | |||
| NodeIndexIO max_node_index_io(nullptr, 0, kOut); | |||
| for (auto &temp_node_info : exist_node_infos) { | |||
| for (const auto &temp_node_info : exist_node_infos) { | |||
| auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); | |||
| if (iter1 != anchor_to_symbol.end()) { | |||
| std::string temp_symbol = iter1->second; | |||
| const std::string &temp_symbol = iter1->second; | |||
| auto iter2 = symbol_to_anchors.find(temp_symbol); | |||
| if (iter2 != symbol_to_anchors.end()) { | |||
| if (iter2->second.size() > anchor_nums) { | |||
| @@ -1544,7 +1575,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| } | |||
| std::string symbol; | |||
| for (auto &temp_node_info : exist_node_infos) { | |||
| for (const auto &temp_node_info : exist_node_infos) { | |||
| if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != | |||
| GRAPH_SUCCESS) || | |||
| symbol.empty()) { | |||
| @@ -1556,7 +1587,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||
| auto iter = symbol_to_anchors.find(symbol); | |||
| if (iter != symbol_to_anchors.end()) { | |||
| for (auto &temp_node_info : cur_node_infos) { | |||
| for (const auto &temp_node_info : cur_node_infos) { | |||
| GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); | |||
| iter->second.emplace_back(temp_node_info); | |||
| anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); | |||
| @@ -1584,7 +1615,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
| OpDescPtr op_desc = node->GetOpDesc(); | |||
| GE_CHECK_NOTNULL(op_desc); | |||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||
| @@ -1627,8 +1658,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, | |||
| graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, | |||
| std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors, | |||
| std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol) { | |||
| std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
| std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
| const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; | |||
| const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; | |||
| if (symbol1 == symbol2) { | |||
| symbol = symbol1; | |||
| GELOGI("no need to union."); | |||
| @@ -1684,7 +1715,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const | |||
| return GRAPH_FAILED; | |||
| } | |||
| std::string symbol = iter1->second; | |||
| const std::string &symbol = iter1->second; | |||
| auto iter2 = symbol_to_anchors.find(symbol); | |||
| if (iter2 == symbol_to_anchors.end()) { | |||
| GE_LOGE("symbol %s not found.", symbol.c_str()); | |||
| @@ -1712,7 +1743,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||
| // pass-through op | |||
| NodePtr node = out_data_anchor->GetOwnerNode(); | |||
| std::string type = node->GetType(); | |||
| const std::string &type = node->GetType(); | |||
| const std::set<std::string> pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; | |||
| if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { | |||
| reuse_in_index = output_index; | |||
| @@ -1755,7 +1786,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||
| uint32_t reuse_input_index = 0; | |||
| if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { | |||
| reuse_in_index = static_cast<int32_t>(reuse_input_index); | |||
| GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
| GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, | |||
| reuse_in_index); | |||
| return true; | |||
| } | |||
| @@ -2297,7 +2328,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & | |||
| return; | |||
| } | |||
| std::string name = node->GetName() + "_RetVal"; | |||
| std::string name = node->GetName() + "_RetVal_" + std::to_string(index); | |||
| OpDescPtr ret_val_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); | |||
| if (ret_val_desc == nullptr) { | |||
| error_code = GRAPH_FAILED; | |||
| @@ -296,15 +296,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
| return GRAPH_FAILED; | |||
| } | |||
| for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { | |||
| GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx()); | |||
| ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
| output_tensor.SetOriginShape(output_tensor.GetShape()); | |||
| output_tensor.SetOriginDataType(output_tensor.GetDataType()); | |||
| auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); | |||
| ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||
| bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); | |||
| if (!is_unknown_graph) { | |||
| output_tensor->SetOriginShape(output_tensor->GetShape()); | |||
| output_tensor->SetOriginDataType(output_tensor->GetDataType()); | |||
| } | |||
| GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", | |||
| node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(), | |||
| TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(), | |||
| TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str()); | |||
| (void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); | |||
| node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), | |||
| TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), | |||
| TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); | |||
| for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { | |||
| if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { | |||
| GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); | |||
| @@ -316,17 +319,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||
| continue; | |||
| } | |||
| GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), | |||
| output_tensor.GetDataType(), output_tensor.GetOriginDataType()); | |||
| peer_input_desc->SetShape(output_tensor.GetShape()); | |||
| peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | |||
| peer_input_desc->SetDataType(output_tensor.GetDataType()); | |||
| peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), | |||
| output_tensor->GetDataType(), output_tensor->GetOriginDataType()); | |||
| peer_input_desc->SetShape(output_tensor->GetShape()); | |||
| peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); | |||
| peer_input_desc->SetDataType(output_tensor->GetDataType()); | |||
| peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); | |||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||
| (void)output_tensor.GetShapeRange(shape_range); | |||
| (void)output_tensor->GetShapeRange(shape_range); | |||
| peer_input_desc->SetShapeRange(shape_range); | |||
| ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | |||
| static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | |||
| static_cast<uint32_t>(output_tensor->GetShape().GetDims().size())); | |||
| GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | |||
| peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), | |||
| peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); | |||
| @@ -401,10 +404,13 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const | |||
| graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { | |||
| auto desc = node.GetOpDesc(); | |||
| GE_CHECK_NOTNULL(desc); | |||
| // check self | |||
| is_unknow = OpShapeIsUnknown(desc); | |||
| if (is_unknow) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| auto sub_graph_names = desc->GetSubgraphInstanceNames(); | |||
| if (sub_graph_names.empty()) { | |||
| is_unknow = OpShapeIsUnknown(desc); | |||
| return GRAPH_SUCCESS; | |||
| } else { | |||
| auto owner_graph = node.GetOwnerComputeGraph(); | |||
| @@ -555,6 +561,53 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) { | |||
| return peer_out_anchor->GetOwnerNode(); | |||
| } | |||
| /// | |||
| /// @brief Check is varying_input for while node | |||
| /// @param [in] node: Data node for subgraph | |||
| /// @return bool | |||
| /// | |||
| bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { | |||
| if (node == nullptr) { | |||
| return false; | |||
| } | |||
| if (node->GetType() != DATA) { | |||
| return false; // not input_node for subgraph | |||
| } | |||
| const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||
| if (parent_node == nullptr) { | |||
| return false; // root graph | |||
| } | |||
| if (kWhileOpTypes.count(parent_node->GetType()) == 0) { | |||
| return false; // not input_node for while subgraph | |||
| } | |||
| uint32_t index_i = 0; | |||
| if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { | |||
| GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); | |||
| return false; | |||
| } | |||
| bool varying_flag = true; | |||
| for (const auto &item : node->GetOutDataNodesAndAnchors()) { | |||
| if (item.first->GetType() != NETOUTPUT) { | |||
| continue; | |||
| } | |||
| OpDescPtr op_desc = item.first->GetOpDesc(); | |||
| uint32_t index_o = 0; | |||
| if ((op_desc == nullptr) || | |||
| !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { | |||
| continue; // input for while-cond subgraph | |||
| } | |||
| if (index_i != index_o) { | |||
| continue; // varying input for while-body subgraph | |||
| } | |||
| varying_flag = false; | |||
| break; | |||
| } | |||
| return varying_flag; | |||
| } | |||
| /// | |||
| /// @brief Get subgraph input is constant. | |||
| /// @param [in] node | |||
| @@ -637,4 +690,86 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| vector<NodePtr> NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { | |||
| vector<NodePtr> in_data_node_vec; | |||
| auto op_desc = node.GetOpDesc(); | |||
| GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); | |||
| auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (subgraph_names.empty()) { | |||
| GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); | |||
| return in_data_node_vec; | |||
| } | |||
| auto compute_graph = node.GetOwnerComputeGraph(); | |||
| for (const std::string &instance_name : subgraph_names) { | |||
| auto subgraph = compute_graph->GetSubgraph(instance_name); | |||
| for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||
| int parent_index = 0; | |||
| if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { | |||
| (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); | |||
| if (parent_index == index) { | |||
| in_data_node_vec.emplace_back(node_in_subgraph); | |||
| } | |||
| } | |||
| } | |||
| } | |||
| return in_data_node_vec; | |||
| } | |||
| /// | |||
| /// @brief Get subgraph input data node by index. | |||
| /// @param [in] node | |||
| /// @return Node | |||
| /// | |||
| vector<NodePtr> NodeUtils::GetSubgraphOutputNodes(const Node &node) { | |||
| vector<NodePtr> out_data_node_vec; | |||
| auto op_desc = node.GetOpDesc(); | |||
| GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); | |||
| auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||
| if (subgraph_names.empty()) { | |||
| GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); | |||
| return out_data_node_vec; | |||
| } | |||
| auto compute_graph = node.GetOwnerComputeGraph(); | |||
| for (const std::string &instance_name : subgraph_names) { | |||
| auto subgraph = compute_graph->GetSubgraph(instance_name); | |||
| for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { | |||
| if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { | |||
| out_data_node_vec.emplace_back(node_in_subgraph); | |||
| } | |||
| } | |||
| } | |||
| return out_data_node_vec; | |||
| } | |||
| NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { | |||
| if (node.GetInDataAnchor(index) == nullptr) { | |||
| return nullptr; | |||
| } | |||
| if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { | |||
| return nullptr; | |||
| } | |||
| return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); | |||
| } | |||
| vector<NodePtr> NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { | |||
| vector<NodePtr> out_data_nodes; | |||
| auto out_data_anchor = node.GetOutDataAnchor(index); | |||
| if (out_data_anchor == nullptr) { | |||
| return out_data_nodes; | |||
| } | |||
| for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { | |||
| if (peer_in_anchor == nullptr) { | |||
| continue; | |||
| } | |||
| if (peer_in_anchor->GetOwnerNode() == nullptr) { | |||
| continue; | |||
| } | |||
| out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); | |||
| } | |||
| return out_data_nodes; | |||
| } | |||
| } // namespace ge | |||
| @@ -197,24 +197,33 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
| continue; | |||
| } | |||
| auto in_node = out_anchor->GetOwnerNode(); | |||
| if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
| ret.push_back(in_node); | |||
| } else if (in_node->GetType() == DATA) { | |||
| const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); | |||
| GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); | |||
| const NodePtr &parent_node = graph->GetParentNode(); | |||
| if (parent_node == nullptr) { | |||
| continue; // Root graph. | |||
| } | |||
| if (kWhileOpTypes.count(parent_node->GetType()) > 0) { | |||
| continue; // Subgraph of While cond or body. | |||
| while (true) { | |||
| if (in_node == nullptr) { | |||
| break; | |||
| } | |||
| NodePtr input_node = NodeUtils::GetParentInput(in_node); | |||
| if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { | |||
| ret.push_back(input_node); | |||
| if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | |||
| ret.push_back(in_node); | |||
| break; | |||
| } else if (in_node->GetType() == DATA) { | |||
| if (NodeUtils::IsWhileVaryingInput(in_node)) { | |||
| break; | |||
| } | |||
| in_node = NodeUtils::GetParentInput(in_node); | |||
| } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { | |||
| bool is_constant = false; | |||
| (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); | |||
| if (!is_constant) { | |||
| break; | |||
| } | |||
| // Enter node has and only has one input | |||
| if (in_node->GetInDataNodes().size() != 1) { | |||
| GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), | |||
| in_node->GetInDataNodes().size()); | |||
| break; | |||
| } | |||
| in_node = in_node->GetInDataNodes().at(0); | |||
| } else { | |||
| break; | |||
| } | |||
| } | |||
| } | |||
| @@ -435,10 +444,27 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils::MutableWeights(const ge::Node &node) { | |||
| vector<GeTensorPtr> ret; | |||
| GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!"); | |||
| auto op_desc = node.GetOpDesc(); | |||
| GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); | |||
| // Place holder operator, try to get the weight from parent node | |||
| // when parent node is const operator | |||
| if (node.GetType() == PLACEHOLDER) { | |||
| std::string parent_op; | |||
| (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); | |||
| // This if judgment is necessary because the current subgraph optimization is multithreaded | |||
| // and the parent node of the PLD operation should be a stable type, such as const | |||
| if (parent_op == CONSTANT || parent_op == CONSTANTOP) { | |||
| NodePtr parent_node = nullptr; | |||
| parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); | |||
| if (parent_node != nullptr) { | |||
| op_desc = parent_node->GetOpDesc(); | |||
| GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); | |||
| } | |||
| } | |||
| } | |||
| // Const operator, take the weight directly | |||
| if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) { | |||
| auto weight = MutableWeights(node.GetOpDesc()); | |||
| if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { | |||
| auto weight = MutableWeights(op_desc); | |||
| if (weight == nullptr) { | |||
| GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); | |||
| return ret; | |||
| @@ -19,6 +19,7 @@ | |||
| #include "debug/ge_log.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "graph/ge_tensor.h" | |||
| #include "graph/types.h" | |||
| #include "graph/utils/type_utils.h" | |||
| @@ -105,7 +106,10 @@ static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_ | |||
| element_cnt = 1; | |||
| for (int64_t dim : dims) { | |||
| if (CheckMultiplyOverflowInt64(element_cnt, dim)) { | |||
| GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim); | |||
| ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19013", {"function", "var1", "var2"}, | |||
| {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); | |||
| GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); | |||
| return GRAPH_FAILED; | |||
| } | |||
| element_cnt *= dim; | |||
| @@ -273,7 +277,6 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
| case FORMAT_FRACTAL_Z: | |||
| graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); | |||
| break; | |||
| case FORMAT_NC1HWC0_C04: | |||
| case FORMAT_FRACTAL_NZ: | |||
| case FORMAT_FRACTAL_ZZ: | |||
| case FORMAT_NDHWC: | |||
| @@ -285,6 +288,7 @@ static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format | |||
| case FORMAT_NDC1HWC0: | |||
| case FORMAT_FRACTAL_Z_C04: | |||
| case FORMAT_FRACTAL_ZN_LSTM: | |||
| case FORMAT_NC1HWC0_C04: | |||
| graph_status = CalcElementCntByDims(dims, element_cnt); | |||
| break; | |||
| default: | |||
| @@ -147,7 +147,8 @@ static const std::map<std::string, Format> kStringToFormatMap = { | |||
| {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, | |||
| {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, | |||
| {"FORMAT_RESERVED", FORMAT_RESERVED}, | |||
| {"ALL", FORMAT_ALL}}; | |||
| {"ALL", FORMAT_ALL}, | |||
| {"NULL", FORMAT_NULL}}; | |||
| static const std::map<DataType, std::string> kDataTypeToStringMap = { | |||
| {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. | |||
| @@ -60,6 +60,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/formats/formats.cc" | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| "common/fp16_t.cc" | |||
| "common/ge/op_tiling_manager.cc" | |||
| "common/ge/plugin_manager.cc" | |||
| "common/helper/model_cache_helper.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| @@ -94,7 +95,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/*.cc" | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/util/debug.cc" | |||
| @@ -159,8 +159,11 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "hybrid/node_executor/aicpu/aicpu_ext_info.cc" | |||
| "hybrid/node_executor/aicpu/aicpu_node_executor.cc" | |||
| "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" | |||
| "hybrid/node_executor/controlop/control_op_executor.cc" | |||
| "hybrid/node_executor/hccl/hccl_node_executor.cc" | |||
| "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" | |||
| "hybrid/node_executor/node_executor.cc" | |||
| "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" | |||
| "hybrid/node_executor/task_context.cc" | |||
| "init/gelib.cc" | |||
| "model/ge_model.cc" | |||
| @@ -204,6 +207,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "common/formats/formats.cc" | |||
| "common/formats/utils/formats_trans_utils.cc" | |||
| "common/fp16_t.cc" | |||
| "common/ge/op_tiling_manager.cc" | |||
| "common/ge/plugin_manager.cc" | |||
| "common/helper/model_cache_helper.cc" | |||
| "common/profiling/profiling_manager.cc" | |||
| @@ -236,7 +240,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" | |||
| "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" | |||
| "graph/load/new_model_manager/task_info/task_info.cc" | |||
| "graph/load/output/output.cc" | |||
| "graph/manager/*.cc" | |||
| "graph/manager/model_manager/event_manager.cc" | |||
| "graph/manager/util/debug.cc" | |||
| @@ -28,6 +28,7 @@ | |||
| #include "graph/opsproto_manager.h" | |||
| #include "graph/utils/type_utils.h" | |||
| #include "graph/manager/util/rt_context_util.h" | |||
| #include "graph/common/ge_call_wrapper.h" | |||
| #include "register/op_registry.h" | |||
| #include "common/ge/tbe_plugin_manager.h" | |||
| @@ -41,8 +42,8 @@ namespace { | |||
| const int32_t kMaxStrLen = 128; | |||
| } | |||
| static bool kGeInitialized = false; | |||
| static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use | |||
| static bool g_ge_initialized = false; | |||
| static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use | |||
| namespace ge { | |||
| void GetOpsProtoPath(std::string &opsproto_path) { | |||
| @@ -61,31 +62,6 @@ void GetOpsProtoPath(std::string &opsproto_path) { | |||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||
| } | |||
| Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | |||
| const int kDecimal = 10; | |||
| auto dump_op_env = std::getenv("DUMP_OP"); | |||
| int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; | |||
| auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); | |||
| if (disableReuseMemoryIter != options.end()) { | |||
| if (disableReuseMemoryIter->second == "0") { | |||
| GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); | |||
| if (dump_op_flag) { | |||
| GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); | |||
| } | |||
| } else if (disableReuseMemoryIter->second == "1") { | |||
| GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); | |||
| } else { | |||
| GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); | |||
| return FAILED; | |||
| } | |||
| } else { | |||
| if (dump_op_flag) { | |||
| GELOGW("Will dump incorrect op data with default reuse memory"); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status CheckOptionsValid(const std::map<string, string> &options) { | |||
| // check job_id is valid | |||
| auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); | |||
| @@ -96,11 +72,6 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||
| } | |||
| } | |||
| // Check ge.exec.disableReuseMemory and env DUMP_OP | |||
| if (CheckDumpAndReuseMemory(options) != SUCCESS) { | |||
| return FAILED; | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| @@ -108,7 +79,7 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||
| Status GEInitialize(const std::map<string, string> &options) { | |||
| GELOGT(TRACE_INIT, "GEInitialize start"); | |||
| // 0.check init status | |||
| if (kGeInitialized) { | |||
| if (g_ge_initialized) { | |||
| GELOGW("GEInitialize is called more than once"); | |||
| return SUCCESS; | |||
| } | |||
| @@ -147,9 +118,9 @@ Status GEInitialize(const std::map<string, string> &options) { | |||
| } | |||
| // 7.check return status, return | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| // Initialize success, first time calling initialize | |||
| kGeInitialized = true; | |||
| g_ge_initialized = true; | |||
| } | |||
| GELOGT(TRACE_STOP, "GEInitialize finished"); | |||
| @@ -160,12 +131,12 @@ Status GEInitialize(const std::map<string, string> &options) { | |||
| Status GEFinalize() { | |||
| GELOGT(TRACE_INIT, "GEFinalize start"); | |||
| // check init status | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| GELOGW("GEFinalize is called before GEInitialize"); | |||
| return SUCCESS; | |||
| } | |||
| std::lock_guard<std::mutex> lock(kGeReleaseMutex); | |||
| std::lock_guard<std::mutex> lock(g_ge_release_mutex); | |||
| // call Finalize | |||
| Status ret = SUCCESS; | |||
| Status middle_ret; | |||
| @@ -187,10 +158,10 @@ Status GEFinalize() { | |||
| ret = middle_ret; | |||
| } | |||
| if (kGeInitialized && ret == SUCCESS) { | |||
| if (g_ge_initialized && ret == SUCCESS) { | |||
| // Unified destruct rt_context | |||
| RtContextUtil::GetInstance().DestroyrtContexts(); | |||
| kGeInitialized = false; | |||
| RtContextUtil::GetInstance().DestroyAllRtContexts(); | |||
| g_ge_initialized = false; | |||
| } | |||
| GELOGT(TRACE_STOP, "GEFinalize finished"); | |||
| @@ -202,7 +173,7 @@ Session::Session(const std::map<string, string> &options) { | |||
| GELOGT(TRACE_INIT, "Session Constructor start"); | |||
| // check init status | |||
| sessionId_ = 0; | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED); | |||
| return; | |||
| } | |||
| @@ -232,13 +203,13 @@ Session::Session(const std::map<string, string> &options) { | |||
| Session::~Session() { | |||
| GELOGT(TRACE_INIT, "Session Destructor start"); | |||
| // 0.check init status | |||
| if (!kGeInitialized) { | |||
| if (!g_ge_initialized) { | |||
| GELOGW("GE is not yet initialized or is finalized."); | |||
| return; | |||
| } | |||
| Status ret = FAILED; | |||
| std::lock_guard<std::mutex> lock(kGeReleaseMutex); | |||
| std::lock_guard<std::mutex> lock(g_ge_release_mutex); | |||
| try { | |||
| uint64_t session_id = sessionId_; | |||
| // call DestroySession | |||
| @@ -72,9 +72,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons | |||
| void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||
| const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||
| bool enum2str) { | |||
| if (field == nullptr || reflection == nullptr) { | |||
| return; | |||
| } | |||
| switch (field->type()) { | |||
| case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||
| const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | |||
| @@ -118,8 +115,12 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr | |||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | |||
| char str[kSignificantDigits]; | |||
| sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)); | |||
| json[field->name()] = str; | |||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { | |||
| json[field->name()] = str; | |||
| } else { | |||
| json[field->name()] = reflection->GetFloat(message, field); | |||
| } | |||
| break; | |||
| case ProtobufFieldDescriptor::TYPE_STRING: | |||
| @@ -29,7 +29,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| namespace { | |||
| enum DataTypeTransMode { | |||
| kTransferWithDatatypeFloatToFloat16, | |||
| @@ -27,7 +27,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| struct CastArgs { | |||
| const uint8_t *data; | |||
| size_t src_data_size; | |||
| @@ -179,6 +179,5 @@ Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::v | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -180,6 +180,5 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, con | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -56,7 +56,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap | |||
| dst_shape.clear(); | |||
| hw_shape.clear(); | |||
| auto w0 = GetCubeSizeByDataType(data_type); | |||
| auto h0 = GetCubeSizeByDataType(data_type); | |||
| int64_t h0 = kCubeSize; | |||
| switch (src_shape.size()) { | |||
| case 1: | |||
| dst_shape.push_back(Ceil(src_shape[0], w0)); | |||
| @@ -19,6 +19,7 @@ | |||
| #include <securec.h> | |||
| #include <memory> | |||
| #include "common/debug/log.h" | |||
| #include "common/formats/utils/formats_definitions.h" | |||
| #include "common/formats/utils/formats_trans_utils.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| @@ -107,8 +108,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
| int64_t hw = h * w; | |||
| int64_t chw = c * hw; | |||
| int64_t hwc0 = hw * c0; | |||
| int64_t nchw = n * chw; | |||
| int64_t hwc0 = hw * c0; | |||
| // horizontal fractal matrix count (N) | |||
| int64_t hf_cnt = Ceil(n, static_cast<int64_t>(kNiSize)); | |||
| @@ -119,18 +120,15 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
| int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | |||
| int size = GetSizeByDataType(args.src_data_type); | |||
| int64_t dst_size = total_ele_cnt * size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dst == nullptr, | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| } | |||
| return OUT_OF_MEMORY;); | |||
| for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { | |||
| // vertical fractal matrix base index | |||
| @@ -156,12 +154,20 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||
| auto protected_size = dst_size - offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | |||
| ? dst_size - offset | |||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||
| errno_t ret; | |||
| errno_t ret = EOK; | |||
| if (need_pad_zero) { | |||
| ret = memset_s(dst.get() + offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||
| } else { | |||
| ret = memcpy_s(dst.get() + offset, static_cast<size_t>(protected_size), args.data + src_offset * size, | |||
| static_cast<size_t>(size)); | |||
| if (protected_size < size) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| protected_size, size); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| char *dst_data = reinterpret_cast<char *>(dst.get() + offset); | |||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_offset * size); | |||
| for (int64_t index = 0; index < size; index++) { | |||
| *dst_data++ = *src_data++; | |||
| } | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, | |||
| @@ -199,18 +205,15 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dst == nullptr, | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| } | |||
| return OUT_OF_MEMORY;); | |||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||
| for (int64_t hi = 0; hi < h; hi++) { | |||
| @@ -223,14 +226,22 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||
| ? dst_size - dst_offset | |||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||
| errno_t ret; | |||
| errno_t ret = EOK; | |||
| if (pad_zero) { | |||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||
| static_cast<size_t>(data_size)); | |||
| } else { | |||
| if (protected_size < data_size) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| protected_size, data_size); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; | |||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), | |||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | |||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size); | |||
| for (int64_t index = 0; index < data_size; index++) { | |||
| *dst_data++ = *src_data++; | |||
| } | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| @@ -269,18 +280,15 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||
| dst_size *= dim; | |||
| } | |||
| dst_size *= data_size; | |||
| if (dst_size == 0) { | |||
| result.length = static_cast<size_t>(dst_size); | |||
| return SUCCESS; | |||
| } | |||
| GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast<size_t>(dst_size); return SUCCESS;); | |||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||
| if (dst == nullptr) { | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| dst == nullptr, | |||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | |||
| TypeUtils::FormatToSerialString(args.src_format).c_str(), | |||
| TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); | |||
| return OUT_OF_MEMORY; | |||
| } | |||
| return OUT_OF_MEMORY;); | |||
| for (int64_t c1i = 0; c1i < c1; c1i++) { | |||
| for (int64_t hi = 0; hi < h; hi++) { | |||
| @@ -293,14 +301,22 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||
| ? dst_size - dst_offset | |||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||
| auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); | |||
| errno_t ret; | |||
| errno_t ret = EOK; | |||
| if (pad_zero) { | |||
| ret = memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, | |||
| static_cast<size_t>(data_size)); | |||
| } else { | |||
| if (protected_size < data_size) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", | |||
| protected_size, data_size); | |||
| return INTERNAL_ERROR; | |||
| } | |||
| int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); | |||
| ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), | |||
| args.data + src_idx * data_size, static_cast<size_t>(data_size)); | |||
| char *dst_data = reinterpret_cast<char *>(dst.get() + dst_offset); | |||
| const char *src_data = reinterpret_cast<const char *>(args.data + src_idx * data_size); | |||
| for (int64_t index = 0; index < data_size; index++) { | |||
| *dst_data++ = *src_data++; | |||
| } | |||
| } | |||
| if (ret != EOK) { | |||
| GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", | |||
| @@ -337,16 +353,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r | |||
| return PARAM_INVALID; | |||
| } | |||
| if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatFromNchwToFz(args, result); | |||
| if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatNhwcToFz(args, result); | |||
| } | |||
| if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatHwcnToFz(args, result); | |||
| } | |||
| if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatNhwcToFz(args, result); | |||
| if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransFormatFromNchwToFz(args, result); | |||
| } | |||
| return UNSUPPORTED; | |||
| @@ -358,14 +374,14 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i | |||
| return UNSUPPORTED; | |||
| } | |||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNchwToFz(src_shape, data_type, dst_shape); | |||
| if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNhwcToFz(src_shape, data_type, dst_shape); | |||
| } | |||
| if (src_format == FORMAT_HWCN && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeHwcnToFz(src_shape, data_type, dst_shape); | |||
| } | |||
| if (src_format == FORMAT_NHWC && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNhwcToFz(src_shape, data_type, dst_shape); | |||
| if (src_format == FORMAT_NCHW && dst_format == FORMAT_FRACTAL_Z) { | |||
| return TransShapeNchwToFz(src_shape, data_type, dst_shape); | |||
| } | |||
| return UNSUPPORTED; | |||
| @@ -374,6 +390,5 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector<i | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NCHW, FORMAT_FRACTAL_Z) | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_HWCN, FORMAT_FRACTAL_Z) | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferFractalZ, FORMAT_NHWC, FORMAT_FRACTAL_Z) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -39,7 +39,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| namespace { | |||
| constexpr int64_t kMaxDimsNumC = 4; | |||
| Status CheckDataTypeSupport(DataType data_type) { return GetSizeByDataType(data_type) > 0 ? SUCCESS : UNSUPPORTED; } | |||
| @@ -109,7 +108,7 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||
| return NOT_CHANGED; | |||
| } | |||
| /* prepare for padding in chw*/ | |||
| // prepare for padding in chw | |||
| int64_t tmp = h * w * c; | |||
| int64_t n_o = Ceil(n, static_cast<int64_t>(c0)); | |||
| int64_t c_o = c0; | |||
| @@ -309,6 +308,5 @@ Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vecto | |||
| } | |||
| REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) | |||
| } // namespace formats | |||
| } // namespace ge | |||
| @@ -19,7 +19,6 @@ | |||
| namespace ge { | |||
| namespace formats { | |||
| static const int kCubeSize = 16; | |||
| static const int kNiSize = 16; | |||
| static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; | |||
| @@ -47,7 +46,6 @@ enum FracZDimIndex { kFracZHWC1, kFracZN0, kFracZNi, kFracZC0, kFracZDimsNum }; | |||
| enum DhwcnDimIndex { kDhwcnD, kDhwcnH, kDhwcnW, kDhwcnC, kDhwcnN, kDhwcnDimsNum }; | |||
| enum DhwncDimIndex { kDhwncD, kDhwncH, kDhwncW, kDhwncN, kDhwncC, kDhwncDimsNum }; | |||
| } // namespace formats | |||
| } // namespace ge | |||
| #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ | |||
| @@ -69,7 +69,6 @@ T Ceil(T n1, T n2) { | |||
| } | |||
| return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | |||
| } | |||
| } // namespace formats | |||
| } // namespace ge | |||
| #endif // GE_COMMON_FORMATS_UTILS_FORMATS_TRANS_UTILS_H_ | |||
| @@ -600,5 +600,5 @@ int16_t GetManBitLength(T man) { | |||
| } | |||
| return len; | |||
| } | |||
| }; // namespace ge | |||
| } // namespace ge | |||
| #endif // GE_COMMON_FP16_T_H_ | |||
| @@ -0,0 +1,81 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "common/ge/op_tiling_manager.h" | |||
| #include "framework/common/debug/log.h" | |||
| #include <string> | |||
| namespace { | |||
| const char *const kEnvName = "ASCEND_OPP_PATH"; | |||
| const std::string kDefaultPath = "/usr/local/Ascend/opp"; | |||
| const std::string kDefaultBuiltInTilingPath = "/op_impl/built-in/liboptiling.so"; | |||
| const std::string kDefaultCustomTilingPath = "/op_impl/custom/liboptiling.so"; | |||
| const uint8_t kPrefixIndex = 9; | |||
| } // namespace | |||
| namespace ge { | |||
| void OpTilingManager::ClearHandles() noexcept { | |||
| for (const auto &handle : handles_) { | |||
| if (dlclose(handle.second) != 0) { | |||
| GELOGE(FAILED, "Failed to close handle of %s: %s", handle.first.c_str(), dlerror()); | |||
| } | |||
| } | |||
| handles_.clear(); | |||
| } | |||
| OpTilingManager::~OpTilingManager() { ClearHandles(); } | |||
| std::string OpTilingManager::GetPath() { | |||
| const char *opp_path_env = std::getenv(kEnvName); | |||
| std::string opp_path = kDefaultPath; | |||
| if (opp_path_env != nullptr) { | |||
| char resolved_path[PATH_MAX]; | |||
| if (realpath(opp_path_env, resolved_path) == NULL) { | |||
| GELOGE(PARAM_INVALID, "Failed load tiling lib as env 'ASCEND_OPP_PATH'(%s) is invalid path.", opp_path_env); | |||
| return std::string(); | |||
| } | |||
| opp_path = resolved_path; | |||
| } | |||
| return opp_path; | |||
| } | |||
| void OpTilingManager::LoadSo() { | |||
| std::string opp_path = GetPath(); | |||
| if (opp_path.empty()) { | |||
| GELOGW("Skip load tiling lib."); | |||
| return; | |||
| } | |||
| std::string built_in_tiling_lib = opp_path + kDefaultBuiltInTilingPath; | |||
| std::string custom_tiling_lib = opp_path + kDefaultCustomTilingPath; | |||
| std::string built_in_name = kDefaultBuiltInTilingPath.substr(kPrefixIndex); | |||
| std::string custom_name = kDefaultCustomTilingPath.substr(kPrefixIndex); | |||
| void *handle_bi = dlopen(built_in_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
| if (handle_bi == nullptr) { | |||
| GELOGW("Failed to dlopen %s!", dlerror()); | |||
| } else { | |||
| handles_[built_in_name] = handle_bi; | |||
| } | |||
| void *handle_ct = dlopen(custom_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); | |||
| if (handle_ct == nullptr) { | |||
| GELOGW("Failed to dlopen %s!", dlerror()); | |||
| } else { | |||
| handles_[custom_name] = handle_ct; | |||
| } | |||
| } | |||
| } // namespace ge | |||
| @@ -14,16 +14,25 @@ | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ | |||
| #define GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ | |||
| #ifndef GE_COMMON_GE_OP_TILING_MANAGER_H_ | |||
| #define GE_COMMON_GE_OP_TILING_MANAGER_H_ | |||
| #include "graph/passes/base_pass.h" | |||
| #include <map> | |||
| namespace ge { | |||
| class IdentifyReferencePass : public BaseNodePass { | |||
| using SoToHandleMap = std::map<std::string, void *>; | |||
| class OpTilingManager { | |||
| public: | |||
| Status Run(NodePtr &node) override; | |||
| OpTilingManager() = default; | |||
| ~OpTilingManager(); | |||
| void LoadSo(); | |||
| private: | |||
| static std::string GetPath(); | |||
| void ClearHandles() noexcept; | |||
| SoToHandleMap handles_; | |||
| }; | |||
| } // namespace ge | |||
| #endif // GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ | |||
| #endif // GE_COMMON_GE_OP_TILING_MANAGER_H_ | |||
| @@ -17,6 +17,7 @@ | |||
| #include "framework/common/helper/model_helper.h" | |||
| #include "common/ge/ge_util.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "framework/common/debug/log.h" | |||
| #include "framework/common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| @@ -267,6 +268,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c | |||
| } | |||
| auto partition_table = reinterpret_cast<ModelPartitionTable *>(model_addr_tmp_); | |||
| if (partition_table->num == kOriginalOmPartitionNum) { | |||
| model_addr_tmp_ = nullptr; | |||
| GELOGE(FAILED, "om model is error,please use executable om model"); | |||
| return FAILED; | |||
| } | |||
| @@ -390,107 +392,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo | |||
| return out_model; | |||
| } | |||
| // Transit func for model to ge_model. It will be removed when load and build support ge_model in future | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransModelToGeModel(const ModelPtr &model, | |||
| GeModelPtr &ge_model) { | |||
| if (model == nullptr) { | |||
| GELOGE(FAILED, "Model is null"); | |||
| return FAILED; | |||
| } | |||
| ge_model = ge::MakeShared<ge::GeModel>(); | |||
| GE_CHECK_NOTNULL(ge_model); | |||
| ge_model->SetGraph(model->GetGraph()); | |||
| ge_model->SetName(model->GetName()); | |||
| ge_model->SetVersion(model->GetVersion()); | |||
| ge_model->SetPlatformVersion(model->GetPlatformVersion()); | |||
| ge_model->SetAttr(model->MutableAttrMap()); | |||
| // Copy weight info | |||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); | |||
| // ge::Buffer weight; | |||
| ge::Buffer weight; | |||
| (void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight); | |||
| ge_model->SetWeight(weight); | |||
| // Copy task info | |||
| if (model->HasAttr(MODEL_ATTR_TASKS)) { | |||
| ge::Buffer task_buffer; | |||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED, | |||
| "Get bytes failed."); | |||
| std::shared_ptr<ModelTaskDef> task = ge::MakeShared<ModelTaskDef>(); | |||
| GE_CHECK_NOTNULL(task); | |||
| GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED); | |||
| GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED); | |||
| GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast<int>(task_buffer.GetSize()), task.get()), | |||
| return INTERNAL_ERROR, "ReadProtoFromArray failed."); | |||
| ge_model->SetModelTaskDef(task); | |||
| } | |||
| // Copy tbe kernel info | |||
| // TBEKernelStore kernel_store; | |||
| TBEKernelStore kernel_store; | |||
| if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) { | |||
| for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { | |||
| auto node_op_desc = n->GetOpDesc(); | |||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | |||
| TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | |||
| GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | |||
| kernel_store.AddTBEKernel(tbe_kernel); | |||
| GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); | |||
| } | |||
| } | |||
| if (!kernel_store.Build()) { | |||
| GELOGE(FAILED, "TBE Kernels store build failed!"); | |||
| return FAILED; | |||
| } | |||
| ge_model->SetTBEKernelStore(kernel_store); | |||
| return SUCCESS; | |||
| } | |||
| // trasit func for ge_model to Model. will be removed when load and build support ge_model in future | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransGeModelToModel(const GeModelPtr &ge_model, | |||
| ModelPtr &model) { | |||
| if (ge_model == nullptr) { | |||
| GELOGE(FAILED, "Ge_model is null"); | |||
| return FAILED; | |||
| } | |||
| model = ge::MakeShared<ge::Model>(); | |||
| GE_CHECK_NOTNULL(model); | |||
| model->SetGraph(ge_model->GetGraph()); | |||
| model->SetName(ge_model->GetName()); | |||
| model->SetVersion(ge_model->GetVersion()); | |||
| model->SetPlatformVersion(ge_model->GetPlatformVersion()); | |||
| model->SetAttr(ge_model->MutableAttrMap()); | |||
| // Copy weight info | |||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); | |||
| bool ret = ge::AttrUtils::SetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, ge_model->GetWeight()); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "Copy weight buffer failed!"); | |||
| return FAILED; | |||
| } | |||
| // Copy task info | |||
| std::shared_ptr<ModelTaskDef> model_task = ge_model->GetModelTaskDefPtr(); | |||
| if (model_task != nullptr) { | |||
| int size = model_task->ByteSize(); | |||
| ge::Buffer buffer(static_cast<size_t>(size)); | |||
| if (buffer.GetSize() == 0) { | |||
| GELOGE(MEMALLOC_FAILED, "alloc model attr task buffer failed!"); | |||
| return MEMALLOC_FAILED; | |||
| } | |||
| // no need to check value | |||
| (void)model_task->SerializePartialToArray(buffer.GetData(), size); | |||
| ret = ge::AttrUtils::SetZeroCopyBytes(model, MODEL_ATTR_TASKS, std::move(buffer)); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "Copy task buffer failed!"); | |||
| return FAILED; | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status ModelHelper::ReleaseLocalModelData() noexcept { | |||
| Status result = SUCCESS; | |||
| if (model_addr_tmp_ != nullptr) { | |||
| @@ -92,5 +92,5 @@ fp16_t max(fp16_t fp1, fp16_t fp2); | |||
| /// @brief Calculate the minimum fp16_t of fp1 and fp2 | |||
| /// @return Returns minimum fp16_t of fp1 and fp2 | |||
| fp16_t min(fp16_t fp1, fp16_t fp2); | |||
| }; // namespace ge | |||
| } // namespace ge | |||
| #endif // GE_COMMON_MATH_FP16_MATH_H_ | |||
| @@ -27,7 +27,6 @@ | |||
| #include "mmpa/mmpa_api.h" | |||
| namespace ge { | |||
| /** | |||
| * @ingroup domi_calibration | |||
| * @brief Initializes an input array to a specified value | |||
| @@ -67,7 +66,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| } // end namespace ge | |||
| #endif // GE_COMMON_MATH_UTIL_H_ | |||
| @@ -60,8 +60,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi | |||
| mode_t mode = S_IRUSR | S_IWUSR; | |||
| int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode); | |||
| if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"filepath", "errMsg"}, {file_path, strerror(errno)}); | |||
| GELOGE(FAILED, "Open file failed. file path : %s, %s", file_path, strerror(errno)); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)}); | |||
| GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno)); | |||
| return FAILED; | |||
| } | |||
| const char *model_char = model_str.c_str(); | |||
| @@ -69,8 +69,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi | |||
| // Write data to file | |||
| mmSsize_t mmpa_ret = mmWrite(fd, const_cast<void *>((const void *)model_char), len); | |||
| if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"mmpa_ret", "errMsg"}, | |||
| {std::to_string(mmpa_ret), strerror(errno)}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); | |||
| // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose | |||
| GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); | |||
| ret = FAILED; | |||
| @@ -336,16 +336,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||
| std::string data; | |||
| for (const auto &task : task_desc_info) { | |||
| std::string model_name = task.model_name; | |||
| std::string op_name = task.op_name; | |||
| uint32_t block_dim = task.block_dim; | |||
| uint32_t task_id = task.task_id; | |||
| uint32_t stream_id = task.stream_id; | |||
| data = op_name.append(" ").append(std::to_string(block_dim) | |||
| .append(" ") | |||
| .append(std::to_string(task_id)) | |||
| .append(" ") | |||
| .append(std::to_string(stream_id)) | |||
| .append("\n")); | |||
| data = model_name.append(" ").append(op_name).append(" ").append(std::to_string(block_dim) | |||
| .append(" ") | |||
| .append(std::to_string(task_id)) | |||
| .append(" ") | |||
| .append(std::to_string(stream_id)) | |||
| .append("\n")); | |||
| Msprof::Engine::ReporterData reporter_data{}; | |||
| reporter_data.deviceId = device_id; | |||
| @@ -376,7 +377,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin | |||
| std::string data; | |||
| for (const auto &graph : compute_graph_desc_info) { | |||
| data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); | |||
| data.append("model_name:") | |||
| .append(graph.model_name) | |||
| .append(" op_name:") | |||
| .append(graph.op_name) | |||
| .append(" op_type:") | |||
| .append(graph.op_type); | |||
| for (size_t i = 0; i < graph.input_format.size(); ++i) { | |||
| data.append(" input_id:") | |||
| .append(std::to_string(i)) | |||
| @@ -20,15 +20,204 @@ | |||
| #include <cstdio> | |||
| #include <fstream> | |||
| #include "common/ge/ge_util.h" | |||
| #include "common/util.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "framework/common/debug/log.h" | |||
| #include "framework/common/ge_types.h" | |||
| #include "framework/common/types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/ge_context.h" | |||
| #include "graph/utils/attr_utils.h" | |||
| namespace ge { | |||
| namespace { | |||
| const string kEnableFlag = "1"; | |||
| const uint32_t kAicoreOverflow = (0x1 << 0); | |||
| const uint32_t kAtomicOverflow = (0x1 << 1); | |||
| const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); | |||
| } // namespace | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { | |||
| CopyFrom(other); | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( | |||
| const DumpProperties &other) { | |||
| CopyFrom(other); | |||
| return *this; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { | |||
| enable_dump_.clear(); | |||
| enable_dump_debug_.clear(); | |||
| dump_path_.clear(); | |||
| dump_step_.clear(); | |||
| dump_mode_.clear(); | |||
| is_op_debug_ = false; | |||
| op_debug_mode_ = 0; | |||
| string enable_dump; | |||
| (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); | |||
| enable_dump_ = enable_dump; | |||
| string enable_dump_debug; | |||
| (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); | |||
| enable_dump_debug_ = enable_dump_debug; | |||
| if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { | |||
| string dump_path; | |||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { | |||
| if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { | |||
| dump_path = dump_path + "/"; | |||
| } | |||
| dump_path = dump_path + CurrentTimeInStr() + "/"; | |||
| GELOGI("Get dump path %s successfully", dump_path.c_str()); | |||
| SetDumpPath(dump_path); | |||
| } else { | |||
| GELOGW("DUMP_PATH is not set"); | |||
| } | |||
| } | |||
| if (enable_dump_ == kEnableFlag) { | |||
| string dump_step; | |||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { | |||
| GELOGD("Get dump step %s successfully", dump_step.c_str()); | |||
| SetDumpStep(dump_step); | |||
| } | |||
| string dump_mode; | |||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { | |||
| GELOGD("Get dump mode %s successfully", dump_mode.c_str()); | |||
| SetDumpMode(dump_mode); | |||
| } | |||
| AddPropertyValue(DUMP_ALL_MODEL, {}); | |||
| } | |||
| SetDumpDebugOptions(); | |||
| } | |||
| // The following is the new dump scenario of the fusion operator | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( | |||
| const std::string &model, const std::set<std::string> &layers) { | |||
| for (const std::string &layer : layers) { | |||
| GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); | |||
| } | |||
| model_dump_properties_map_[model] = layers; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { | |||
| auto iter = model_dump_properties_map_.find(model); | |||
| if (iter != model_dump_properties_map_.end()) { | |||
| model_dump_properties_map_.erase(iter); | |||
| } | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> DumpProperties::GetAllDumpModel() const { | |||
| std::set<std::string> model_list; | |||
| for (auto &iter : model_dump_properties_map_) { | |||
| model_list.insert(iter.first); | |||
| } | |||
| return model_list; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> DumpProperties::GetPropertyValue( | |||
| const std::string &model) const { | |||
| auto iter = model_dump_properties_map_.find(model); | |||
| if (iter != model_dump_properties_map_.end()) { | |||
| return iter->second; | |||
| } | |||
| return {}; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( | |||
| const std::string &model, const std::string &om_name, const std::string &op_name) const { | |||
| // if dump all | |||
| if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { | |||
| return true; | |||
| } | |||
| // if this model need dump | |||
| auto om_name_iter = model_dump_properties_map_.find(om_name); | |||
| auto model_name_iter = model_dump_properties_map_.find(model); | |||
| if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { | |||
| // if no dump layer info, dump all layer in this model | |||
| auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; | |||
| if (model_iter->second.empty()) { | |||
| return true; | |||
| } | |||
| return model_iter->second.find(op_name) != model_iter->second.end(); | |||
| } | |||
| GELOGD("Model %s is not seated to be dump.", model.c_str()); | |||
| return false; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { | |||
| dump_path_ = path; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpPath() const { return dump_path_; } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { | |||
| dump_step_ = step; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpStep() const { return dump_step_; } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { | |||
| dump_mode_ = mode; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpMode() const { return dump_mode_; } | |||
| void DumpProperties::CopyFrom(const DumpProperties &other) { | |||
| if (&other != this) { | |||
| enable_dump_ = other.enable_dump_; | |||
| enable_dump_debug_ = other.enable_dump_debug_; | |||
| dump_path_ = other.dump_path_; | |||
| dump_step_ = other.dump_step_; | |||
| dump_mode_ = other.dump_mode_; | |||
| model_dump_properties_map_ = other.model_dump_properties_map_; | |||
| is_op_debug_ = other.is_op_debug_; | |||
| op_debug_mode_ = other.op_debug_mode_; | |||
| } | |||
| } | |||
| void DumpProperties::SetDumpDebugOptions() { | |||
| if (enable_dump_debug_ == kEnableFlag) { | |||
| string dump_debug_mode; | |||
| if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { | |||
| GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); | |||
| } else { | |||
| GELOGW("Dump debug mode is not set."); | |||
| return; | |||
| } | |||
| if (dump_debug_mode == OP_DEBUG_AICORE) { | |||
| GELOGD("ge.exec.dumpDebugMode=aicore_overflow, op debug is open."); | |||
| is_op_debug_ = true; | |||
| op_debug_mode_ = kAicoreOverflow; | |||
| } else if (dump_debug_mode == OP_DEBUG_ATOMIC) { | |||
| GELOGD("ge.exec.dumpDebugMode=atomic_overflow, op debug is open."); | |||
| is_op_debug_ = true; | |||
| op_debug_mode_ = kAtomicOverflow; | |||
| } else if (dump_debug_mode == OP_DEBUG_ALL) { | |||
| GELOGD("ge.exec.dumpDebugMode=all, op debug is open."); | |||
| is_op_debug_ = true; | |||
| op_debug_mode_ = kAllOverflow; | |||
| } else { | |||
| GELOGW("ge.exec.dumpDebugMode is invalid."); | |||
| } | |||
| } else { | |||
| GELOGI("ge.exec.enableDumpDebug is false or is not set."); | |||
| } | |||
| } | |||
| PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {} | |||
| PropertiesManager::~PropertiesManager() {} | |||
| @@ -159,131 +348,22 @@ PropertiesManager::GetPropertyMap() { | |||
| // Set separator | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyDelimiter(const std::string &de) { | |||
| std::lock_guard<std::mutex> lock(mutex_); | |||
| delimiter = de; | |||
| } | |||
| // The following is the new dump scenario of the fusion operator | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::AddDumpPropertyValue( | |||
| const std::string &model, const std::set<std::string> &layers) { | |||
| for (const std::string &layer : layers) { | |||
| GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); | |||
| } | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| model_dump_properties_map_[model] = layers; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::DeleteDumpPropertyValue( | |||
| const std::string &model) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| auto iter = model_dump_properties_map_.find(model); | |||
| if (iter != model_dump_properties_map_.end()) { | |||
| model_dump_properties_map_.erase(iter); | |||
| } | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::ClearDumpPropertyValue() { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| model_dump_properties_map_.clear(); | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> PropertiesManager::GetAllDumpModel() { | |||
| std::set<std::string> model_list; | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| for (auto &iter : model_dump_properties_map_) { | |||
| model_list.insert(iter.first); | |||
| } | |||
| return model_list; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set<std::string> PropertiesManager::GetDumpPropertyValue( | |||
| const std::string &model) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| auto iter = model_dump_properties_map_.find(model); | |||
| if (iter != model_dump_properties_map_.end()) { | |||
| return iter->second; | |||
| } | |||
| return {}; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::IsLayerNeedDump(const std::string &model, | |||
| const std::string &om_name, | |||
| const std::string &op_name) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| // if dump all | |||
| if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { | |||
| return true; | |||
| } | |||
| // if this model need dump | |||
| auto om_name_iter = model_dump_properties_map_.find(om_name); | |||
| auto model_name_iter = model_dump_properties_map_.find(model); | |||
| if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { | |||
| // if no dump layer info, dump all layer in this model | |||
| auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; | |||
| if (model_iter->second.empty()) { | |||
| return true; | |||
| } | |||
| return model_iter->second.find(op_name) != model_iter->second.end(); | |||
| } | |||
| GELOGD("Model %s is not seated to be dump.", model.c_str()); | |||
| return false; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &PropertiesManager::GetDumpProperties( | |||
| uint64_t session_id) { | |||
| std::lock_guard<std::mutex> lock(mutex_); | |||
| // If session_id is not found in dump_properties_map_, operator[] will insert one. | |||
| return dump_properties_map_[session_id]; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::QueryModelDumpStatus( | |||
| const std::string &model) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| auto iter = model_dump_properties_map_.find(model); | |||
| if (iter != model_dump_properties_map_.end()) { | |||
| return true; | |||
| } else if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { | |||
| return true; | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::RemoveDumpProperties(uint64_t session_id) { | |||
| std::lock_guard<std::mutex> lock(mutex_); | |||
| auto iter = dump_properties_map_.find(session_id); | |||
| if (iter != dump_properties_map_.end()) { | |||
| dump_properties_map_.erase(iter); | |||
| } | |||
| return false; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputModel( | |||
| const std::string &output_mode) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| this->output_mode_ = output_mode; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputModel() { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| return this->output_mode_; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputPath( | |||
| const std::string &output_path) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| this->output_path_ = output_path; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputPath() { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| return this->output_path_; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpStep(const std::string &dump_step) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| this->dump_step_ = dump_step; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpStep() { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| return this->dump_step_; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpMode(const std::string &dump_mode) { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| this->dump_mode_ = dump_mode; | |||
| } | |||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpMode() { | |||
| std::lock_guard<std::mutex> lock(dump_mutex_); | |||
| return this->dump_mode_; | |||
| } | |||
| } // namespace ge | |||
| @@ -32,6 +32,50 @@ static const char *USE_FUSION __attribute__((unused)) = "FMK_USE_FUSION"; | |||
| static const char *TIMESTAT_ENABLE __attribute__((unused)) = "DAVINCI_TIMESTAT_ENABLE"; | |||
| static const char *ANNDROID_DEBUG __attribute__((unused)) = "ANNDROID_DEBUG"; | |||
| class DumpProperties { | |||
| public: | |||
| DumpProperties() = default; | |||
| ~DumpProperties() = default; | |||
| DumpProperties(const DumpProperties &dump); | |||
| DumpProperties &operator=(const DumpProperties &dump); | |||
| void InitByOptions(); | |||
| void AddPropertyValue(const std::string &model, const std::set<std::string> &layers); | |||
| void DeletePropertyValue(const std::string &model); | |||
| std::set<std::string> GetAllDumpModel() const; | |||
| std::set<std::string> GetPropertyValue(const std::string &model) const; | |||
| bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name) const; | |||
| void SetDumpPath(const std::string &path); | |||
| std::string GetDumpPath() const; | |||
| void SetDumpStep(const std::string &step); | |||
| std::string GetDumpStep() const; | |||
| void SetDumpMode(const std::string &mode); | |||
| std::string GetDumpMode() const; | |||
| bool IsOpDebugOpen() const { return is_op_debug_; } | |||
| uint32_t GetOpDebugMode() const { return op_debug_mode_; } | |||
| private: | |||
| void CopyFrom(const DumpProperties &other); | |||
| void SetDumpDebugOptions(); | |||
| string enable_dump_; | |||
| string enable_dump_debug_; | |||
| std::string dump_path_; | |||
| std::string dump_step_; | |||
| std::string dump_mode_; | |||
| std::map<std::string, std::set<std::string>> model_dump_properties_map_; | |||
| bool is_op_debug_ = false; | |||
| uint32_t op_debug_mode_ = 0; | |||
| }; | |||
| class PropertiesManager { | |||
| public: | |||
| // Singleton | |||
| @@ -81,21 +125,8 @@ class PropertiesManager { | |||
| */ | |||
| void SetPropertyDelimiter(const std::string &de); | |||
| void AddDumpPropertyValue(const std::string &model, const std::set<std::string> &layers); | |||
| std::set<std::string> GetAllDumpModel(); | |||
| std::set<std::string> GetDumpPropertyValue(const std::string &model); | |||
| bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name); | |||
| void DeleteDumpPropertyValue(const std::string &model); | |||
| void ClearDumpPropertyValue(); | |||
| bool QueryModelDumpStatus(const std::string &model); | |||
| void SetDumpOutputModel(const std::string &output_model); | |||
| std::string GetDumpOutputModel(); | |||
| void SetDumpOutputPath(const std::string &output_path); | |||
| std::string GetDumpOutputPath(); | |||
| void SetDumpStep(const std::string &dump_step); | |||
| std::string GetDumpStep(); | |||
| void SetDumpMode(const std::string &dump_mode); | |||
| std::string GetDumpMode(); | |||
| DumpProperties &GetDumpProperties(uint64_t session_id); | |||
| void RemoveDumpProperties(uint64_t session_id); | |||
| private: | |||
| // Private construct, destructor | |||
| @@ -119,12 +150,7 @@ class PropertiesManager { | |||
| std::map<std::string, std::string> properties_map_; | |||
| std::mutex mutex_; | |||
| std::string output_mode_; | |||
| std::string output_path_; | |||
| std::string dump_step_; | |||
| std::string dump_mode_; | |||
| std::map<std::string, std::set<std::string>> model_dump_properties_map_; // model_dump_layers_map_ | |||
| std::mutex dump_mutex_; | |||
| std::map<uint64_t, DumpProperties> dump_properties_map_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -28,7 +28,6 @@ | |||
| #include "graph/op_kernel_bin.h" | |||
| namespace ge { | |||
| using TBEKernel = ge::OpKernelBin; | |||
| using TBEKernelPtr = std::shared_ptr<ge::OpKernelBin>; | |||
| @@ -26,6 +26,11 @@ const std::string DUMP_LAYER = "layer"; | |||
| const std::string DUMP_FILE_PATH = "path"; | |||
| const std::string DUMP_MODE = "dump_mode"; | |||
| // op debug mode | |||
| const std::string OP_DEBUG_AICORE = "aicore_overflow"; | |||
| const std::string OP_DEBUG_ATOMIC = "atomic_overflow"; | |||
| const std::string OP_DEBUG_ALL = "all"; | |||
| const int DEFAULT_FORMAT = static_cast<const int>(ge::FORMAT_NCHW); | |||
| // Supported public property names | |||
| const std::string PROP_OME_START_TIME = "ome_start_time"; // start time | |||
| @@ -277,8 +282,8 @@ REGISTER_OPTYPE_DEFINE(GETSPAN, "GetSpan"); | |||
| REGISTER_OPTYPE_DEFINE(STOPGRADIENT, "StopGradient"); | |||
| REGISTER_OPTYPE_DEFINE(PREVENTGRADIENT, "PreventGradient"); | |||
| REGISTER_OPTYPE_DEFINE(GUARANTEECONST, "GuaranteeConst"); | |||
| REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs") | |||
| REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs") | |||
| REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); | |||
| REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs"); | |||
| REGISTER_OPTYPE_DEFINE(CONFUSIONMATRIX, "ConfusionMatrix"); | |||
| REGISTER_OPTYPE_DEFINE(RANK, "Rank"); | |||
| REGISTER_OPTYPE_DEFINE(PLACEHOLDER, "PlaceHolder"); | |||
| @@ -286,6 +291,7 @@ REGISTER_OPTYPE_DEFINE(END, "End"); | |||
| REGISTER_OPTYPE_DEFINE(BASICLSTMCELL, "BasicLSTMCell"); | |||
| REGISTER_OPTYPE_DEFINE(GETNEXT, "GetNext"); | |||
| REGISTER_OPTYPE_DEFINE(INITDATA, "InitData"); | |||
| REGISTER_OPTYPE_DEFINE(REFIDENTITY, "RefIdentity"); | |||
| /***************Ann special operator*************************/ | |||
| REGISTER_OPTYPE_DEFINE(ANN_MEAN, "AnnMean"); | |||
| @@ -479,72 +485,72 @@ const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M. | |||
| #endif | |||
| /// | |||
| ///@brief Magic number of model file | |||
| /// @brief Magic number of model file | |||
| /// | |||
| const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number | |||
| /// | |||
| ///@brief Model head length | |||
| /// @brief Model head length | |||
| /// | |||
| const uint32_t MODEL_FILE_HEAD_LEN = 256; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Input node type | |||
| /// @ingroup domi_omg | |||
| /// @brief Input node type | |||
| /// | |||
| const std::string INPUT_TYPE = "Input"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief AIPP label, label AIPP conv operator | |||
| /// @ingroup domi_omg | |||
| /// @brief AIPP label, label AIPP conv operator | |||
| /// | |||
| const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief AIPP label, label aipp data operator | |||
| /// @ingroup domi_omg | |||
| /// @brief AIPP label, label aipp data operator | |||
| /// | |||
| const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Record the w dimension of model input corresponding to dynamic AIPP | |||
| /// @ingroup domi_omg | |||
| /// @brief Record the w dimension of model input corresponding to dynamic AIPP | |||
| /// | |||
| const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Record the H dimension of model input corresponding to dynamic AIPP | |||
| /// @ingroup domi_omg | |||
| /// @brief Record the H dimension of model input corresponding to dynamic AIPP | |||
| /// | |||
| const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief The tag of the data operator. Mark this input to the dynamic AIPP operator | |||
| /// @ingroup domi_omg | |||
| /// @brief The tag of the data operator. Mark this input to the dynamic AIPP operator | |||
| /// | |||
| const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief DATA node type | |||
| /// @ingroup domi_omg | |||
| /// @brief DATA node type | |||
| /// | |||
| const std::string DATA_TYPE = "Data"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief DATA node type | |||
| /// @ingroup domi_omg | |||
| /// @brief DATA node type | |||
| /// | |||
| const std::string AIPP_DATA_TYPE = "AippData"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Frame operator type | |||
| /// @ingroup domi_omg | |||
| /// @brief Frame operator type | |||
| /// | |||
| const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Data node type | |||
| /// @ingroup domi_omg | |||
| /// @brief Data node type | |||
| /// | |||
| const std::string ANN_DATA_TYPE = "AnnData"; | |||
| const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput"; | |||
| @@ -552,136 +558,139 @@ const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv"; | |||
| const std::string ANN_CONV_TYPE = "AnnConvolution"; | |||
| const std::string ANN_FC_TYPE = "AnnFullConnection"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Convolution node type | |||
| /// @ingroup domi_omg | |||
| /// @brief Convolution node type | |||
| /// | |||
| const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; | |||
| const std::string NODE_NAME_END_GRAPH = "Node_EndGraph"; | |||
| const std::string NODE_NAME_OP_DEBUG = "Node_OpDebug"; | |||
| const std::string OP_TYPE_OP_DEBUG = "Opdebug"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Convolution node type | |||
| /// @ingroup domi_omg | |||
| /// @brief Convolution node type | |||
| /// | |||
| const std::string OP_TYPE_CONVOLUTION = "Convolution"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Add convolution node name to AIPP | |||
| /// @ingroup domi_omg | |||
| /// @brief Add convolution node name to AIPP | |||
| /// | |||
| const std::string AIPP_CONV_OP_NAME = "aipp_conv_op"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Operator configuration item separator | |||
| /// @ingroup domi_omg | |||
| /// @brief Operator configuration item separator | |||
| /// | |||
| const std::string OP_CONF_DELIMITER = ":"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief attr value name | |||
| /// @ingroup domi_omg | |||
| /// @brief attr value name | |||
| /// | |||
| const std::string ATTR_NAME_VALUE1 = "value1"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief attr value name, 6d_2_4d C | |||
| /// @ingroup domi_omg | |||
| /// @brief attr value name, 6d_2_4d C | |||
| /// | |||
| const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue"; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief alpha default value | |||
| /// @ingroup domi_omg | |||
| /// @brief alpha default value | |||
| /// | |||
| const float ALPHA_DEFAULT_VALUE = 1.0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief beta default value | |||
| /// @ingroup domi_omg | |||
| /// @brief beta default value | |||
| /// | |||
| const float BETA_DEFAULT_VALUE = 0.0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief coef default value | |||
| /// @ingroup domi_omg | |||
| /// @brief coef default value | |||
| /// | |||
| const float COEF_DEFAULT_VALUE = 0.0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Relu6 coef value | |||
| /// @ingroup domi_omg | |||
| /// @brief Relu6 coef value | |||
| /// | |||
| const float RELU6_COEF = 6.0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief stride default value | |||
| /// @ingroup domi_omg | |||
| /// @brief stride default value | |||
| /// | |||
| const uint32_t STRIDE_DEFAULT_VALUE = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief pad default value | |||
| /// @ingroup domi_omg | |||
| /// @brief pad default value | |||
| /// | |||
| const uint32_t PAD_DEFAULT_VALUE = 0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief dilation default value | |||
| /// @ingroup domi_omg | |||
| /// @brief dilation default value | |||
| /// | |||
| const int DILATION_DEFAULT_VALUE = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief kernel default value | |||
| /// @ingroup domi_omg | |||
| /// @brief kernel default value | |||
| /// | |||
| const uint32_t KERNEL_DEFAULT_VALUE = 0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief defaule convolution group size | |||
| /// @ingroup domi_omg | |||
| /// @brief defaule convolution group size | |||
| /// | |||
| const uint32_t DEFAULT_CONV_GROUP = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Default deconvolution adj | |||
| /// @ingroup domi_omg | |||
| /// @brief Default deconvolution adj | |||
| /// | |||
| const uint32_t DEFAULT_DECONV_ADJ = 0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Represents value 1 | |||
| /// @ingroup domi_omg | |||
| /// @brief Represents value 1 | |||
| /// | |||
| const uint32_t NUM_ONE = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief spatial dim size default value | |||
| /// @ingroup domi_omg | |||
| /// @brief spatial dim size default value | |||
| /// | |||
| const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief dim extended default value | |||
| /// @ingroup domi_omg | |||
| /// @brief dim extended default value | |||
| /// | |||
| const int32_t DIM_DEFAULT_VALUE = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief The first weight list in opdef is filter | |||
| /// @ingroup domi_omg | |||
| /// @brief The first weight list in opdef is filter | |||
| /// | |||
| const int32_t WEIGHT_FILTER_INDEX = 0; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief The second weight list in opdef is bias | |||
| /// @ingroup domi_omg | |||
| /// @brief The second weight list in opdef is bias | |||
| /// | |||
| const int32_t WEIGHT_BIAS_INDEX = 1; | |||
| const int32_t TENSOR_ND_SUPPORT_SIZE = 8; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief NCHW index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief NCHW index default value | |||
| /// | |||
| const uint32_t NCHW_DIM_N = 0; | |||
| const uint32_t NCHW_DIM_C = 1; | |||
| @@ -689,8 +698,8 @@ const uint32_t NCHW_DIM_H = 2; | |||
| const uint32_t NCHW_DIM_W = 3; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief KCHW index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief KCHW index default value | |||
| /// | |||
| const uint32_t KCHW_DIM_K = 0; | |||
| const uint32_t KCHW_DIM_C = 1; | |||
| @@ -698,8 +707,8 @@ const uint32_t KCHW_DIM_H = 2; | |||
| const uint32_t KCHW_DIM_W = 3; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief HWCK index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief HWCK index default value | |||
| /// | |||
| const uint32_t HWCK_DIM_H = 0; | |||
| const uint32_t HWCK_DIM_W = 1; | |||
| @@ -707,8 +716,8 @@ const uint32_t HWCK_DIM_C = 2; | |||
| const uint32_t HWCK_DIM_K = 3; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief NHWC index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief NHWC index default value | |||
| /// | |||
| const uint32_t NHWC_DIM_N = 0; | |||
| const uint32_t NHWC_DIM_H = 1; | |||
| @@ -716,8 +725,8 @@ const uint32_t NHWC_DIM_W = 2; | |||
| const uint32_t NHWC_DIM_C = 3; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief CHWN index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief CHWN index default value | |||
| /// | |||
| const uint32_t CHWN_DIM_N = 3; | |||
| const uint32_t CHWN_DIM_C = 0; | |||
| @@ -725,23 +734,23 @@ const uint32_t CHWN_DIM_H = 1; | |||
| const uint32_t CHWN_DIM_W = 2; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief CHW index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief CHW index default value | |||
| /// | |||
| const uint32_t CHW_DIM_C = 0; | |||
| const uint32_t CHW_DIM_H = 1; | |||
| const uint32_t CHW_DIM_W = 2; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief HWC index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief HWC index default value | |||
| /// | |||
| const uint32_t HWC_DIM_H = 0; | |||
| const uint32_t HWC_DIM_W = 1; | |||
| const uint32_t HWC_DIM_C = 2; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief Pad index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief Pad index default value | |||
| /// | |||
| const uint32_t PAD_H_HEAD = 0; | |||
| const uint32_t PAD_H_TAIL = 1; | |||
| @@ -749,35 +758,35 @@ const uint32_t PAD_W_HEAD = 2; | |||
| const uint32_t PAD_W_TAIL = 3; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief window index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief window index default value | |||
| /// | |||
| const uint32_t WINDOW_H = 0; | |||
| const uint32_t WINDOW_W = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief stride index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief stride index default value | |||
| /// | |||
| const uint32_t STRIDE_H = 0; | |||
| const uint32_t STRIDE_W = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief dilation index default value | |||
| /// @ingroup domi_omg | |||
| /// @brief dilation index default value | |||
| /// | |||
| const uint32_t DILATION_H = 0; | |||
| const uint32_t DILATION_W = 1; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief the num of XRBG channel | |||
| /// @ingroup domi_omg | |||
| /// @brief the num of XRBG channel | |||
| /// | |||
| const uint32_t XRGB_CHN_NUM = 4; | |||
| /// | |||
| ///@ingroup domi_omg | |||
| ///@brief global pooling default value | |||
| /// @ingroup domi_omg | |||
| /// @brief global pooling default value | |||
| /// | |||
| const bool DEFAULT_GLOBAL_POOLING = false; | |||
| @@ -801,4 +810,4 @@ const uint32_t STREAM_SWITCH_INPUT_NUM = 2; | |||
| const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step"; | |||
| const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd"; | |||
| }; // namespace ge | |||
| } // namespace ge | |||
| @@ -56,6 +56,7 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | |||
| /// The maximum length of the file. | |||
| /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 | |||
| const int kMaxFileSizeLimit = INT_MAX; | |||
| const char *const kPathValidReason = "The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; | |||
| } // namespace | |||
| namespace ge { | |||
| @@ -77,7 +78,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||
| std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | |||
| if (!fs.is_open()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"realpath"}, {file}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"}); | |||
| GELOGE(ge::FAILED, "Open real path[%s] failed.", file); | |||
| return false; | |||
| } | |||
| @@ -90,7 +91,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||
| fs.close(); | |||
| if (!ret) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"filepath"}, {file}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file}); | |||
| GELOGE(ge::FAILED, "Parse file[%s] failed.", file); | |||
| return ret; | |||
| } | |||
| @@ -114,17 +115,18 @@ long GetFileLength(const std::string &input_file) { | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); | |||
| unsigned long long file_length = 0; | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"filepath"}, {input_file}); | |||
| return -1, "Open file[%s] failed", input_file.c_str()); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)}); | |||
| return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"filepath"}, {input_file}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); | |||
| return -1, "File[%s] size is 0, not valid.", input_file.c_str()); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E10039", {"filepath", "filesize", "maxlen"}, | |||
| "E19016", {"filepath", "filesize", "maxlen"}, | |||
| {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); | |||
| return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, kMaxFileSizeLimit); | |||
| return static_cast<long>(file_length); | |||
| @@ -219,7 +221,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||
| if (ret != 0) { | |||
| if (errno != EEXIST) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||
| GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||
| return ret; | |||
| } | |||
| } | |||
| @@ -230,7 +232,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||
| if (ret != 0) { | |||
| if (errno != EEXIST) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||
| GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||
| GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||
| return ret; | |||
| } | |||
| } | |||
| @@ -258,16 +260,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch | |||
| "incorrect parameter. nullptr == file || nullptr == message"); | |||
| std::string real_path = RealPath(file); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"filepath"}, {file}); | |||
| return false, "Get path[%s]'s real path failed", file); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( | |||
| "E19000", {"path", "errmsg"}, {file, strerror(errno)}); | |||
| return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | |||
| std::ifstream fs(real_path.c_str(), std::ifstream::in); | |||
| if (!fs.is_open()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"realpth", "protofile"}, {real_path, file}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file}); | |||
| GELOGE(ge::FAILED, "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), | |||
| file); | |||
| return false; | |||
| @@ -275,7 +277,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch | |||
| google::protobuf::io::IstreamInputStream input(&fs); | |||
| bool ret = google::protobuf::TextFormat::Parse(&input, message); | |||
| GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"protofile"}, {file}); | |||
| GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file}); | |||
| GELOGE(ret, | |||
| "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " | |||
| "please check whether the file is a valid protobuf format file.", | |||
| @@ -360,14 +362,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||
| // The specified path is empty | |||
| std::map<std::string, std::string> args_map; | |||
| if (file_path.empty()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); | |||
| GELOGW("Input parameter's value is empty."); | |||
| return false; | |||
| } | |||
| std::string real_path = RealPath(file_path.c_str()); | |||
| // Unable to get absolute path (does not exist or does not have permission to access) | |||
| if (real_path.empty()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file_path, strerror(errno)}); | |||
| GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); | |||
| return false; | |||
| } | |||
| @@ -380,16 +382,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| !ValidateStr(real_path, mode), | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); | |||
| return false, | |||
| "Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " | |||
| "and chinese character.", | |||
| atc_param.c_str(), real_path.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {atc_param, real_path, kPathValidReason}); | |||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); | |||
| // The absolute path points to a file that is not readable | |||
| if (access(real_path.c_str(), R_OK) != 0) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); | |||
| GELOGW("Read path[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"file", "errmsg"}, {file_path.c_str(), strerror(errno)}); | |||
| GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); | |||
| return false; | |||
| } | |||
| @@ -400,7 +400,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||
| const std::string &atc_param) { | |||
| // The specified path is empty | |||
| if (file_path.empty()) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); | |||
| GELOGW("Input parameter's value is empty."); | |||
| return false; | |||
| } | |||
| @@ -416,18 +416,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||
| !ValidateStr(real_path, mode), | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); | |||
| return false, | |||
| "Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " | |||
| "and chinese character.", | |||
| atc_param.c_str(), real_path.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, | |||
| {atc_param, real_path, kPathValidReason}); | |||
| return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); | |||
| // File is not readable or writable | |||
| if (access(real_path.c_str(), W_OK | F_OK) != 0) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"realpath", "path", "errmsg"}, | |||
| {real_path, file_path, strerror(errno)}); | |||
| GELOGW("Write file[%s] failed, input path is %s, errmsg[%s]", real_path.c_str(), file_path.c_str(), | |||
| strerror(errno)); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {real_path, strerror(errno)}); | |||
| GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); | |||
| return false; | |||
| } | |||
| } else { | |||
| @@ -445,8 +441,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||
| std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | |||
| // Determine whether the specified path is valid by creating the path | |||
| if (CreateDirectory(prefix_path) != 0) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"path"}, {file_path}); | |||
| GELOGW("Can not create prefix path for path[%s].", file_path.c_str()); | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {file_path}); | |||
| GELOGW("Can not create directory[%s].", file_path.c_str()); | |||
| return false; | |||
| } | |||
| } | |||
| @@ -24,6 +24,7 @@ | |||
| #include "common/debug/log.h" | |||
| #include "common/ge/ge_util.h" | |||
| #include "common/util/error_manager/error_manager.h" | |||
| #include "framework/common/debug/ge_log.h" | |||
| #include "graph/ge_context.h" | |||
| #include "init/gelib.h" | |||
| @@ -161,6 +162,10 @@ bool DNNEngineManager::IsEngineRegistered(const std::string &name) { | |||
| return false; | |||
| } | |||
| void DNNEngineManager::InitPerformanceStaistic() { checksupport_cost_.clear(); } | |||
| const map<string, uint64_t> &DNNEngineManager::GetCheckSupportCost() const { return checksupport_cost_; } | |||
| std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr"); | |||
| return ""); | |||
| @@ -194,15 +199,20 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||
| if (kernel_info_store != kernel_map.end()) { | |||
| std::string unsupported_reason; | |||
| // It will be replaced by engine' checksupport | |||
| uint64_t start_time = GetCurrentTimestap(); | |||
| if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||
| checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; | |||
| op_desc->SetOpEngineName(it.engine); | |||
| op_desc->SetOpKernelLibName(kernel_name); | |||
| GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), | |||
| it.engine.c_str(), op_desc->GetName().c_str()); | |||
| return it.engine; | |||
| } else { | |||
| checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; | |||
| bool is_custom_op = false; | |||
| if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"}, | |||
| {kernel_name, op_desc->GetType(), op_desc->GetName()}); | |||
| GELOGE(FAILED, | |||
| "The custom operator registered by the user does not support the logic function delivered by this " | |||
| "network. Check support failed, kernel_name is %s, op type is %s, op name is %s", | |||
| @@ -221,9 +231,13 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||
| } | |||
| } | |||
| for (const auto &it : unsupported_reasons) { | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E13002", {"optype", "opskernel", "reason"}, | |||
| {op_desc->GetType(), it.first, it.second}); | |||
| GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "GetDNNEngineName:Op type %s of ops kernel %s is unsupported, reason:%s", | |||
| op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str()); | |||
| } | |||
| ErrorManager::GetInstance().ATCReportErrMessage("E13003", {"opname", "optype"}, | |||
| {op_desc->GetName(), op_desc->GetType()}); | |||
| GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "Can't find any supported ops kernel and engine of %s, type is %s", | |||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||
| return ""; | |||
| @@ -384,7 +398,13 @@ Status DNNEngineManager::ReadJsonFile(const std::string &file_path, JsonHandle h | |||
| return FAILED; | |||
| } | |||
| ifs >> *json_file; | |||
| try { | |||
| ifs >> *json_file; | |||
| } catch (const json::exception &e) { | |||
| GELOGE(FAILED, "Read json file failed"); | |||
| ifs.close(); | |||
| return FAILED; | |||
| } | |||
| ifs.close(); | |||
| GELOGI("Read json file success"); | |||
| return SUCCESS; | |||
| @@ -63,6 +63,8 @@ class DNNEngineManager { | |||
| // If can't find appropriate engine name, return "", report error | |||
| string GetDNNEngineName(const OpDescPtr &op_desc); | |||
| const map<string, SchedulerConf> &GetSchedulers() const; | |||
| const map<string, uint64_t> &GetCheckSupportCost() const; | |||
| void InitPerformanceStaistic(); | |||
| private: | |||
| DNNEngineManager(); | |||
| @@ -78,6 +80,7 @@ class DNNEngineManager { | |||
| std::map<std::string, DNNEnginePtr> engines_map_; | |||
| std::map<std::string, ge::DNNEngineAttribute> engines_attrs_map_; | |||
| std::map<string, SchedulerConf> schedulers_; | |||
| std::map<string, uint64_t> checksupport_cost_; | |||
| bool init_flag_; | |||
| }; | |||
| } // namespace ge | |||
| @@ -26,6 +26,7 @@ file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "ge_executor.cc" | |||
| "../common/ge/op_tiling_manager.cc" | |||
| "../common/ge/plugin_manager.cc" | |||
| "../common/profiling/profiling_manager.cc" | |||
| "../graph/execute/graph_execute.cc" | |||
| @@ -59,7 +60,6 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||
| "../graph/load/new_model_manager/task_info/task_info.cc" | |||
| "../graph/load/new_model_manager/tbe_handle_store.cc" | |||
| "../graph/load/new_model_manager/zero_copy_task.cc" | |||
| "../graph/load/output/output.cc" | |||
| "../graph/manager/graph_caching_allocator.cc" | |||
| "../graph/manager/graph_manager_utils.cc" | |||
| "../graph/manager/graph_mem_allocator.cc" | |||
| @@ -452,7 +452,7 @@ Status GeExecutor::RunModel(const ge::RunModelData &input_data, ge::RunModelData | |||
| // Get input and output descriptor | |||
| Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDesc> &input_desc, | |||
| std::vector<ge::TensorDesc> &output_desc) { | |||
| std::vector<ge::TensorDesc> &output_desc, bool new_model_desc) { | |||
| GELOGI("get model desc info begin."); | |||
| if (!isInit_) { | |||
| GELOGE(GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); | |||
| @@ -464,8 +464,8 @@ Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector<ge::TensorDes | |||
| std::vector<uint32_t> input_formats; | |||
| std::vector<uint32_t> output_formats; | |||
| Status ret = | |||
| GraphExecutor::GetInputOutputDescInfo(model_id, input_desc_infos, output_desc_infos, input_formats, output_formats); | |||
| Status ret = GraphExecutor::GetInputOutputDescInfo(model_id, input_desc_infos, output_desc_infos, input_formats, | |||
| output_formats, new_model_desc); | |||
| if (ret != domi::SUCCESS) { | |||
| GELOGE(ret, "GetInputOutputDescInfo failed. ret = %u", ret); | |||
| return TransferDomiErrorCode(ret); | |||
| @@ -854,5 +854,4 @@ Status GeExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, | |||
| GELOGI("GetAllAippInputOutputDims succ."); | |||
| return SUCCESS; | |||
| } | |||
| } // namespace ge | |||
| @@ -4,6 +4,7 @@ local_ge_executor_src_files := \ | |||
| ge_executor.cc \ | |||
| ../common/profiling/profiling_manager.cc \ | |||
| ../common/ge/plugin_manager.cc \ | |||
| ../common/ge/op_tiling_manager.cc \ | |||
| ../graph/load/graph_loader.cc \ | |||
| ../graph/execute/graph_execute.cc \ | |||
| ../omm/csa_interact.cc \ | |||
| @@ -44,7 +45,6 @@ local_ge_executor_src_files := \ | |||
| ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | |||
| ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | |||
| ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | |||
| ../graph/load/output/output.cc \ | |||
| ../single_op/single_op_manager.cc \ | |||
| ../single_op/single_op_model.cc \ | |||
| ../single_op/single_op.cc \ | |||
| @@ -53,6 +53,7 @@ local_ge_executor_src_files := \ | |||
| ../single_op/task/build_task_utils.cc \ | |||
| ../single_op/task/tbe_task_builder.cc \ | |||
| ../single_op/task/aicpu_task_builder.cc \ | |||
| ../single_op/task/aicpu_kernel_task_builder.cc \ | |||
| ../hybrid/hybrid_davinci_model_stub.cc\ | |||
| local_ge_executor_c_include := \ | |||
| @@ -1,5 +1,5 @@ | |||
| LOCAL_PATH := $(call my-dir) | |||
| include $(LOCAL_PATH)/stub/Makefile | |||
| COMMON_LOCAL_SRC_FILES := \ | |||
| proto/fusion_model.proto \ | |||
| proto/optimizer_priority.proto \ | |||
| @@ -32,6 +32,7 @@ COMMON_LOCAL_SRC_FILES := \ | |||
| GRAPH_MANAGER_LOCAL_SRC_FILES := \ | |||
| common/ge/plugin_manager.cc\ | |||
| common/ge/op_tiling_manager.cc\ | |||
| init/gelib.cc \ | |||
| session/inner_session.cc \ | |||
| session/session_manager.cc \ | |||
| @@ -91,6 +92,7 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/no_use_reshape_remove_pass.cc \ | |||
| graph/passes/iterator_op_pass.cc \ | |||
| graph/passes/atomic_addr_clean_pass.cc \ | |||
| graph/passes/mark_same_addr_pass.cc \ | |||
| graph/common/omg_util.cc \ | |||
| graph/common/bcast.cc \ | |||
| graph/passes/dimension_compute_pass.cc \ | |||
| @@ -145,6 +147,7 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/stop_gradient_pass.cc \ | |||
| graph/passes/prevent_gradient_pass.cc \ | |||
| graph/passes/identity_pass.cc \ | |||
| graph/passes/ref_identity_delete_op_pass.cc \ | |||
| graph/passes/placeholder_with_default_pass.cc \ | |||
| graph/passes/snapshot_pass.cc \ | |||
| graph/passes/guarantee_const_pass.cc \ | |||
| @@ -153,7 +156,9 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/folding_pass.cc \ | |||
| graph/passes/cast_translate_pass.cc \ | |||
| graph/passes/prune_pass.cc \ | |||
| graph/passes/switch_op_pass.cc \ | |||
| graph/passes/merge_to_stream_merge_pass.cc \ | |||
| graph/passes/switch_to_stream_switch_pass.cc \ | |||
| graph/passes/attach_stream_label_pass.cc \ | |||
| graph/passes/multi_batch_pass.cc \ | |||
| graph/passes/next_iteration_pass.cc \ | |||
| graph/passes/control_trigger_pass.cc \ | |||
| @@ -173,7 +178,6 @@ OMG_HOST_SRC_FILES := \ | |||
| graph/passes/variable_op_pass.cc \ | |||
| graph/passes/cast_remove_pass.cc \ | |||
| graph/passes/transpose_transdata_pass.cc \ | |||
| graph/passes/identify_reference_pass.cc \ | |||
| graph/passes/hccl_memcpy_pass.cc \ | |||
| graph/passes/flow_ctrl_pass.cc \ | |||
| graph/passes/link_gen_mask_nodes_pass.cc \ | |||
| @@ -199,7 +203,6 @@ OME_HOST_SRC_FILES := \ | |||
| graph/load/new_model_manager/tbe_handle_store.cc \ | |||
| graph/load/new_model_manager/cpu_queue_schedule.cc \ | |||
| graph/load/new_model_manager/zero_copy_task.cc \ | |||
| graph/load/output/output.cc \ | |||
| graph/load/new_model_manager/data_dumper.cc \ | |||
| graph/load/new_model_manager/task_info/task_info.cc \ | |||
| graph/load/new_model_manager/task_info/event_record_task_info.cc \ | |||
| @@ -224,6 +227,7 @@ OME_HOST_SRC_FILES := \ | |||
| single_op/task/build_task_utils.cc \ | |||
| single_op/task/tbe_task_builder.cc \ | |||
| single_op/task/aicpu_task_builder.cc \ | |||
| single_op/task/aicpu_kernel_task_builder.cc \ | |||
| single_op/single_op.cc \ | |||
| single_op/single_op_model.cc \ | |||
| single_op/stream_resource.cc \ | |||
| @@ -353,6 +357,28 @@ LOCAL_SHARED_LIBRARIES := \ | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| #compiler for host infer | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := stub/libge_compiler | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 | |||
| LOCAL_CFLAGS += -DFMK_HOST_INFER -DFMK_SUPPORT_DUMP | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| endif | |||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_ir_build.cc | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| #compiler for device | |||
| @@ -23,6 +23,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| common/formats/utils/formats_trans_utils.cc \ | |||
| common/fp16_t.cc \ | |||
| common/ge/plugin_manager.cc\ | |||
| common/ge/op_tiling_manager.cc\ | |||
| common/helper/model_cache_helper.cc \ | |||
| common/profiling/profiling_manager.cc \ | |||
| engine_manager/dnnengine_manager.cc \ | |||
| @@ -77,7 +78,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/load/new_model_manager/task_info/task_info.cc \ | |||
| graph/load/new_model_manager/tbe_handle_store.cc \ | |||
| graph/load/new_model_manager/zero_copy_task.cc \ | |||
| graph/load/output/output.cc \ | |||
| graph/manager/graph_context.cc \ | |||
| graph/manager/graph_manager.cc \ | |||
| graph/manager/graph_manager_utils.cc \ | |||
| @@ -99,6 +99,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/passes/aicpu_constant_folding_pass.cc \ | |||
| graph/passes/assert_pass.cc \ | |||
| graph/passes/atomic_addr_clean_pass.cc \ | |||
| graph/passes/mark_same_addr_pass.cc \ | |||
| graph/partition/dynamic_shape_partition.cc \ | |||
| graph/passes/base_pass.cc \ | |||
| graph/passes/cast_remove_pass.cc \ | |||
| @@ -158,8 +159,8 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/passes/get_original_format_pass.cc \ | |||
| graph/passes/guarantee_const_pass.cc \ | |||
| graph/passes/hccl_memcpy_pass.cc \ | |||
| graph/passes/identify_reference_pass.cc \ | |||
| graph/passes/identity_pass.cc \ | |||
| graph/passes/ref_identity_delete_op_pass.cc \ | |||
| graph/passes/infershape_pass.cc \ | |||
| graph/passes/isolated_op_remove_pass.cc \ | |||
| graph/passes/iterator_op_pass.cc \ | |||
| @@ -191,7 +192,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| graph/passes/data_pass.cc \ | |||
| graph/passes/switch_data_edges_bypass.cc \ | |||
| graph/passes/switch_logic_remove_pass.cc \ | |||
| graph/passes/switch_op_pass.cc \ | |||
| graph/passes/merge_to_stream_merge_pass.cc \ | |||
| graph/passes/switch_to_stream_switch_pass.cc \ | |||
| graph/passes/attach_stream_label_pass.cc \ | |||
| graph/passes/switch_dead_branch_elimination.cc \ | |||
| graph/passes/replace_transshape_pass.cc \ | |||
| graph/passes/transop_breadth_fusion_pass.cc \ | |||
| @@ -230,6 +233,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| single_op/task/op_task.cc \ | |||
| single_op/task/tbe_task_builder.cc \ | |||
| single_op/task/aicpu_task_builder.cc \ | |||
| single_op/task/aicpu_kernel_task_builder.cc \ | |||
| hybrid/common/tensor_value.cc \ | |||
| hybrid/common/npu_memory_allocator.cc \ | |||
| hybrid/executor/rt_callback_manager.cc \ | |||
| @@ -239,12 +243,15 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| hybrid/executor/hybrid_model_executor.cc \ | |||
| hybrid/executor/hybrid_model_async_executor.cc \ | |||
| hybrid/executor/hybrid_execution_context.cc \ | |||
| hybrid/executor/subgraph_context.cc \ | |||
| hybrid/executor/subgraph_executor.cc \ | |||
| hybrid/executor/worker/task_compile_engine.cc \ | |||
| hybrid/executor/worker/shape_inference_engine.cc \ | |||
| hybrid/executor/worker/execution_engine.cc \ | |||
| hybrid/model/hybrid_model.cc \ | |||
| hybrid/model/hybrid_model_builder.cc \ | |||
| hybrid/model/node_item.cc \ | |||
| hybrid/model/graph_item.cc \ | |||
| hybrid/node_executor/aicore/aicore_node_executor.cc \ | |||
| hybrid/node_executor/aicore/aicore_op_task.cc \ | |||
| hybrid/node_executor/aicore/aicore_task_builder.cc \ | |||
| @@ -253,6 +260,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||
| hybrid/node_executor/aicpu/aicpu_node_executor.cc \ | |||
| hybrid/node_executor/compiledsubgraph/known_node_executor.cc \ | |||
| hybrid/node_executor/hostcpu/ge_local_node_executor.cc \ | |||
| hybrid/node_executor/controlop/control_op_executor.cc \ | |||
| hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ | |||
| hybrid/node_executor/hccl/hccl_node_executor.cc \ | |||
| hybrid/node_executor/node_executor.cc \ | |||
| hybrid/node_executor/task_context.cc \ | |||
| hybrid/hybrid_davinci_model.cc \ | |||
| @@ -338,6 +348,28 @@ LOCAL_SHARED_LIBRARIES += \ | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| #compiler for GeRunner | |||
| include $(CLEAR_VARS) | |||
| LOCAL_MODULE := stub/libge_runner | |||
| LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 | |||
| LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD | |||
| ifeq ($(DEBUG), 1) | |||
| LOCAL_CFLAGS += -g -O0 | |||
| endif | |||
| LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc | |||
| LOCAL_SHARED_LIBRARIES := | |||
| LOCAL_LDFLAGS := -lrt -ldl | |||
| include $(BUILD_HOST_SHARED_LIBRARY) | |||
| # add engine_conf.json to host | |||
| include $(CLEAR_VARS) | |||
| @@ -407,6 +439,7 @@ LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD | |||
| LOCAL_CFLAGS += -g -O0 | |||
| LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) | |||
| LOCAL_SRC_FILES := $(LIBGE_LOCAL_SRC_FILES) | |||
| LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES) | |||
| @@ -49,6 +49,15 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint | |||
| return true; | |||
| } | |||
| bool ModelRunner::LoadModelComplete(uint32_t model_id) { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| return false; | |||
| } | |||
| return model_iter->second->LoadComplete(); | |||
| } | |||
| const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| @@ -60,6 +69,28 @@ const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const | |||
| return model_iter->second->GetTaskIdList(); | |||
| } | |||
| const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); | |||
| static const std::vector<uint32_t> empty_ret; | |||
| return empty_ret; | |||
| } | |||
| return model_iter->second->GetStreamIdList(); | |||
| } | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { | |||
| auto model_iter = runtime_models_.find(model_id); | |||
| if (model_iter == runtime_models_.end()) { | |||
| GELOGW("Model id %u not found.", model_id); | |||
| static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret; | |||
| return empty_ret; | |||
| } | |||
| return model_iter->second->GetRuntimeInfoMap(); | |||
| } | |||
| bool ModelRunner::UnloadModel(uint32_t model_id) { | |||
| auto iter = runtime_models_.find(model_id); | |||
| if (iter != runtime_models_.end()) { | |||
| @@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde | |||
| DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | |||
| bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | |||
| if (!ret) { | |||
| GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||
| GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||
| return ret; | |||
| } | |||
| data_index = data_begin + data_count; | |||
| @@ -28,7 +28,6 @@ | |||
| namespace ge { | |||
| namespace model_runner { | |||
| RuntimeModel::~RuntimeModel() { | |||
| GELOGI("RuntimeModel destructor start"); | |||
| @@ -116,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { | |||
| return true; | |||
| } | |||
| bool RuntimeModel::InitLabel(uint32_t batch_num) { | |||
| GELOGI("batch number:%u.", batch_num); | |||
| for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { | |||
| rtLabel_t rt_lLabel = nullptr; | |||
| rtError_t rt_ret = rtLabelCreate(&rt_lLabel); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); | |||
| return false; | |||
| bool RuntimeModel::InitLabel(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| GELOGI("batch number:%u.", davinci_model->GetBatchNum()); | |||
| label_list_.resize(davinci_model->GetBatchNum()); | |||
| for (auto &task_info : davinci_model->GetTaskInfoList()) { | |||
| if (task_info == nullptr) { | |||
| GELOGE(PARAM_INVALID, "task_info is null."); | |||
| continue; | |||
| } | |||
| if (task_info->type() != TaskInfoType::LABEL_SET) { | |||
| continue; | |||
| } | |||
| auto label_set_task_info = std::static_pointer_cast<LabelSetTaskInfo>(task_info); | |||
| if (rt_lLabel == nullptr) { | |||
| GELOGE(RT_FAILED, "rtLabel is nullptr!"); | |||
| if (label_set_task_info->stream_id() >= stream_list_.size()) { | |||
| GELOGE(PARAM_INVALID, "Invalid stream id."); | |||
| return false; | |||
| } | |||
| label_list_.emplace_back(rt_lLabel); | |||
| rtLabel_t rt_label = nullptr; | |||
| rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| label_list_[label_set_task_info->label_id()] = rt_label; | |||
| } | |||
| return true; | |||
| } | |||
| @@ -164,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr<DavinciModel> &davinci_model) { | |||
| return false; | |||
| } | |||
| if (!InitLabel(davinci_model->GetBatchNum())) { | |||
| if (!InitLabel(davinci_model)) { | |||
| return false; | |||
| } | |||
| @@ -209,20 +219,41 @@ bool RuntimeModel::LoadTask() { | |||
| return false; | |||
| } | |||
| task_id_list_.push_back(task_id); | |||
| stream_id_list_.push_back(stream_id); | |||
| if (task->Args() != nullptr) { | |||
| std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr; | |||
| GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false); | |||
| auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); | |||
| if (!emplace_ret.second) { | |||
| GELOGW("Task name exist:%s", task->task_name().c_str()); | |||
| } | |||
| } | |||
| } | |||
| if (task_list_.empty()) { | |||
| GELOGE(FAILED, "Task list is empty"); | |||
| return false; | |||
| } | |||
| GELOGI("Distribute task succ."); | |||
| auto rt_ret = rtModelLoadComplete(rt_model_handle_); | |||
| GELOGI("LoadTask succ."); | |||
| return true; | |||
| } | |||
| bool RuntimeModel::LoadComplete() { | |||
| uint32_t task_id = 0; | |||
| uint32_t stream_id = 0; | |||
| auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret); | |||
| return RT_FAILED; | |||
| } | |||
| task_id_list_.push_back(task_id); | |||
| stream_id_list_.push_back(stream_id); | |||
| rt_ret = rtModelLoadComplete(rt_model_handle_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("LoadTask succ."); | |||
| return true; | |||
| } | |||
| @@ -270,10 +301,14 @@ bool RuntimeModel::Run() { | |||
| return false; | |||
| } | |||
| GELOGI("Run rtModelExecute success"); | |||
| GELOGI("Run rtModelExecute success, ret = 0x%X", ret); | |||
| ret = rtStreamSynchronize(rt_model_stream_); | |||
| if (ret != RT_ERROR_NONE) { | |||
| if (ret == RT_ERROR_END_OF_SEQUENCE) { | |||
| GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); | |||
| return true; | |||
| } | |||
| GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); | |||
| return false; | |||
| } | |||
| @@ -433,7 +468,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||
| } | |||
| if (constant->output_tensors[0].size < constant->weight_data.size()) { | |||
| GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size, | |||
| GELOGE(PARAM_INVALID, "Output size:%u less than weight data size:%zu", constant->output_tensors[0].size, | |||
| constant->weight_data.size()); | |||
| return false; | |||
| } | |||
| @@ -448,11 +483,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||
| /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | |||
| /// and that of unknown shape is zero too. | |||
| /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | |||
| int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||
| if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||
| elem_num = 1; | |||
| } | |||
| int64_t elem_num = | |||
| (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||
| if (constant->weight_data.size() < sizeof(uint64_t)) { | |||
| GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | |||
| return false; | |||
| @@ -495,5 +527,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp | |||
| const std::vector<uint32_t> &RuntimeModel::GetTaskIdList() const { return task_id_list_; } | |||
| const std::vector<uint32_t> &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -27,7 +27,7 @@ | |||
| namespace ge { | |||
| namespace model_runner { | |||
| using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; | |||
| class Task; | |||
| class RuntimeModel { | |||
| public: | |||
| @@ -35,7 +35,10 @@ class RuntimeModel { | |||
| ~RuntimeModel(); | |||
| bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool LoadComplete(); | |||
| const std::vector<uint32_t> &GetTaskIdList() const; | |||
| const std::vector<uint32_t> &GetStreamIdList() const; | |||
| const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } | |||
| bool Run(); | |||
| bool CopyInputData(const InputData &input_data); | |||
| bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||
| @@ -48,7 +51,7 @@ class RuntimeModel { | |||
| bool LoadTask(); | |||
| bool InitStream(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitEvent(uint32_t event_num); | |||
| bool InitLabel(uint32_t batch_num); | |||
| bool InitLabel(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitDataInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitOutputInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| bool InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model); | |||
| @@ -77,6 +80,8 @@ class RuntimeModel { | |||
| std::vector<std::shared_ptr<OpInfo>> constant_info_list_{}; | |||
| std::vector<uint32_t> task_id_list_{}; | |||
| std::vector<uint32_t> stream_id_list_{}; | |||
| std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; | |||
| }; | |||
| } // namespace model_runner | |||
| @@ -85,11 +85,15 @@ bool AicpuTask::Distribute() { | |||
| return false; | |||
| } | |||
| GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size, | |||
| io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data()); | |||
| rt_ret = rtCpuKernelLaunch(reinterpret_cast<const void *>(task_info_->so_name().data()), | |||
| reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, args_size, | |||
| nullptr, stream_); | |||
| input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset); | |||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||
| GELOGI( | |||
| "Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.", | |||
| args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag); | |||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()), | |||
| reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, | |||
| args_size, nullptr, stream_, dump_flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| @@ -18,6 +18,7 @@ | |||
| #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ | |||
| #include <memory> | |||
| #include <string> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| @@ -30,12 +31,17 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> { | |||
| bool Distribute() override; | |||
| void *Args() override { return input_output_addr_; } | |||
| std::string task_name() const override { return task_info_->op_name(); } | |||
| private: | |||
| static void ReleaseRtMem(void **ptr) noexcept; | |||
| std::shared_ptr<AicpuTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *args_; | |||
| void *input_output_addr_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -115,7 +115,6 @@ bool HcclTask::Distribute() { | |||
| rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| (void)rtStreamDestroy(stream); | |||
| return false; | |||
| } | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_goto_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info) | |||
| : TaskRepeater<LabelGotoTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelGotoTask::~LabelGotoTask() {} | |||
| bool LabelGotoTask::Distribute() { | |||
| GELOGI("LabelGotoTask Distribute start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelGotoEx(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelGotoTask : public TaskRepeater<LabelGotoTaskInfo> { | |||
| public: | |||
| LabelGotoTask(const ModelContext &model_context, const std::shared_ptr<LabelGotoTaskInfo> &task_info); | |||
| ~LabelGotoTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelGotoTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ | |||
| @@ -0,0 +1,70 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_set_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSetTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| label_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| auto stream_list = model_context.stream_list(); | |||
| auto label_list = model_context.label_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| uint32_t label_id = task_info->label_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); | |||
| if (stream_id >= stream_list.size() || label_id >= label_list.size()) { | |||
| GELOGW("Stream/Label id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| label_ = label_list[label_id]; | |||
| } | |||
| LabelSetTask::~LabelSetTask() {} | |||
| bool LabelSetTask::Distribute() { | |||
| GELOGI("LabelSetTask Distribute start."); | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (label_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "label is null!"); | |||
| return false; | |||
| } | |||
| rtError_t rt_ret = rtLabelSet(label_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSetTask : public TaskRepeater<LabelSetTaskInfo> { | |||
| public: | |||
| LabelSetTask(const ModelContext &model_context, const std::shared_ptr<LabelSetTaskInfo> &task_info); | |||
| ~LabelSetTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| std::shared_ptr<LabelSetTaskInfo> task_info_; | |||
| void *stream_; | |||
| void *label_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ | |||
| @@ -0,0 +1,131 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_runtime/task/label_switch_task.h" | |||
| #include "ge_runtime/task/task_factory.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, | |||
| const std::shared_ptr<LabelSwitchTaskInfo> &task_info) | |||
| : TaskRepeater<LabelSwitchTaskInfo>(model_context, task_info), | |||
| task_info_(task_info), | |||
| stream_(nullptr), | |||
| all_label_resource_(), | |||
| label_info_(nullptr) { | |||
| if (task_info_ == nullptr) { | |||
| GELOGW("task_info_ is null!"); | |||
| return; | |||
| } | |||
| all_label_resource_ = model_context.label_list(); | |||
| auto stream_list = model_context.stream_list(); | |||
| uint32_t stream_id = task_info->stream_id(); | |||
| GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); | |||
| if (stream_id >= stream_list.size()) { | |||
| GELOGW("Stream id invalid."); | |||
| return; | |||
| } | |||
| stream_ = stream_list[stream_id]; | |||
| } | |||
| LabelSwitchTask::~LabelSwitchTask() { | |||
| if (label_info_ != nullptr) { | |||
| rtError_t rt_ret = rtFree(label_info_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); | |||
| } | |||
| label_info_ = nullptr; | |||
| } | |||
| } | |||
| bool LabelSwitchTask::Distribute() { | |||
| GELOGI("LabelSwitchTask Distribute start."); | |||
| if (!CheckParamValid()) { | |||
| return false; | |||
| } | |||
| const std::vector<uint32_t> &label_index_list = task_info_->label_list(); | |||
| std::vector<void *> label_list(task_info_->label_size(), nullptr); | |||
| for (size_t i = 0; i < task_info_->label_size(); ++i) { | |||
| uint32_t label_index = label_index_list[i]; | |||
| if (label_index >= all_label_resource_.size()) { | |||
| GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, | |||
| all_label_resource_.size()); | |||
| return false; | |||
| } | |||
| label_list[i] = all_label_resource_[label_index]; | |||
| GELOGI("Case %zu: label id %zu.", i, label_index); | |||
| } | |||
| uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); | |||
| rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTask end."); | |||
| return true; | |||
| } | |||
| bool LabelSwitchTask::CheckParamValid() { | |||
| if (stream_ == nullptr) { | |||
| GELOGE(PARAM_INVALID, "stream is null!"); | |||
| return false; | |||
| } | |||
| if (task_info_->label_list().empty()) { | |||
| GELOGE(PARAM_INVALID, "label_list is empty."); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() != task_info_->label_list().size()) { | |||
| GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), | |||
| task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { | |||
| GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); | |||
| return false; | |||
| } | |||
| if (label_info_ != nullptr) { | |||
| GELOGE(PARAM_INVALID, "label_info_ has dirty data."); | |||
| return false; | |||
| } | |||
| return true; | |||
| } | |||
| REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| @@ -0,0 +1,44 @@ | |||
| /** | |||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| #include <memory> | |||
| #include "ge_runtime/task/task.h" | |||
| namespace ge { | |||
| namespace model_runner { | |||
| class LabelSwitchTask : public TaskRepeater<LabelSwitchTaskInfo> { | |||
| public: | |||
| LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr<LabelSwitchTaskInfo> &task_info); | |||
| ~LabelSwitchTask() override; | |||
| bool Distribute() override; | |||
| private: | |||
| bool CheckParamValid(); | |||
| std::shared_ptr<LabelSwitchTaskInfo> task_info_; | |||
| void *stream_; | |||
| std::vector<void *> all_label_resource_; | |||
| void *label_info_; | |||
| }; | |||
| } // namespace model_runner | |||
| } // namespace ge | |||
| #endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ | |||
| @@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() { | |||
| } | |||
| if (static_cast<uint64_t>(task_info_->true_stream_id()) >= stream_list_.size()) { | |||
| GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(), | |||
| GELOGE(PARAM_INVALID, "true_stream_id %ld must less than stream_list_ size %zu!", task_info_->true_stream_id(), | |||
| stream_list_.size()); | |||
| return false; | |||
| } | |||
| @@ -18,7 +18,9 @@ | |||
| #define GE_GE_RUNTIME_TASK_TASK_H_ | |||
| #include <memory> | |||
| #include <utility> | |||
| #include <vector> | |||
| #include <string> | |||
| #include "runtime/rt_model.h" | |||
| #include "ge_runtime/model_context.h" | |||
| #include "ge_runtime/task_info.h" | |||
| @@ -32,6 +34,10 @@ class Task { | |||
| virtual ~Task() {} | |||
| virtual bool Distribute() = 0; | |||
| virtual void *Args() { return nullptr; } | |||
| virtual std::string task_name() const { return ""; } | |||
| }; | |||
| template <class T> | |||
| @@ -95,15 +95,14 @@ bool TbeTask::Distribute() { | |||
| return false; | |||
| } | |||
| GELOGI("InitTbeTask end."); | |||
| GELOGI("DistributeTbeTask start."); | |||
| rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_); | |||
| auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; | |||
| rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); | |||
| if (rt_ret != RT_ERROR_NONE) { | |||
| GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); | |||
| return false; | |||
| } | |||
| GELOGI("DistributeTbeTask end."); | |||
| GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag); | |||
| return true; | |||
| } | |||
| @@ -30,6 +30,10 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> { | |||
| bool Distribute() override; | |||
| void *Args() override { return args_; } | |||
| std::string task_name() const override { return task_info_->op_name(); } | |||
| private: | |||
| std::shared_ptr<TbeTaskInfo> task_info_; | |||
| void *stream_; | |||